SSHTheater: fix known_hosts update issue
diff --git a/dockerimg/sshtheater.go b/dockerimg/sshtheater.go
index 31d1849..35433dd 100644
--- a/dockerimg/sshtheater.go
+++ b/dockerimg/sshtheater.go
@@ -5,7 +5,6 @@
"crypto/rand"
"crypto/rsa"
"crypto/x509"
- "encoding/base64"
"encoding/pem"
"fmt"
"os"
@@ -14,6 +13,7 @@
"github.com/kevinburke/ssh_config"
"golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/knownhosts"
)
const keyBitSize = 2048
@@ -58,6 +58,12 @@
// to manually accept changes to your known_hosts file etc.
func NewSSHTheather(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
base := filepath.Join(os.Getenv("HOME"), ".sketch")
+ if _, err := os.Stat(base); err != nil {
+ if err := os.Mkdir(base, 0o777); err != nil {
+ return nil, fmt.Errorf("couldn't create %s: %w", base, err)
+ }
+ }
+
cst := &SSHTheater{
cntrName: cntrName,
sshHost: sshHost,
@@ -67,8 +73,11 @@
serverIdentityPath: filepath.Join(base, "container_server_identity"),
sshConfigPath: filepath.Join(base, "ssh_config"),
}
- if err := cst.addContainerToSSHConfig(); err != nil {
- return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
+ if _, err := createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
+ return nil, fmt.Errorf("couldn't create server identity: %w", err)
+ }
+ if _, err := createKeyPairIfMissing(cst.userIdentityPath); err != nil {
+ return nil, fmt.Errorf("couldn't create user identity: %w", err)
}
serverIdentity, err := os.ReadFile(cst.serverIdentityPath)
@@ -90,6 +99,14 @@
}
cst.userIdentity = userIdentity
+ if err := cst.addContainerToSSHConfig(); err != nil {
+ return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
+ }
+
+ if err := cst.addContainerToKnownHosts(); err != nil {
+ return nil, fmt.Errorf("couldn't update known hosts: %w", err)
+ }
+
return cst, nil
}
@@ -194,21 +211,7 @@
hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
- if _, err := createKeyPairIfMissing(c.userIdentityPath); err != nil {
- return fmt.Errorf("couldn't create user identity: %w", err)
- }
hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
-
- serverPubKey, err := createKeyPairIfMissing(c.serverIdentityPath)
- if err != nil {
- return fmt.Errorf("couldn't create server identity: %w", err)
- }
- if serverPubKey != nil {
- c.serverPublicKey = serverPubKey
- if err := c.addContainerToKnownHosts(); err != nil {
- return fmt.Errorf("couldn't update known hosts: %w", err)
- }
- }
hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
@@ -217,13 +220,6 @@
}
func (c *SSHTheater) addContainerToSSHConfig() error {
- dotSketchPath := filepath.Join(os.Getenv("HOME"), ".sketch")
- if _, err := os.Stat(dotSketchPath); err != nil {
- if err := os.Mkdir(dotSketchPath, 0o777); err != nil {
- return fmt.Errorf("couldn't create %s: %w", dotSketchPath, err)
- }
- }
-
f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
if err != nil {
return fmt.Errorf("couldn't open ssh_config: %w", err)
@@ -279,15 +275,26 @@
return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
}
defer f.Close()
+ pkBytes := c.serverPublicKey.Marshal()
+ if len(pkBytes) == 0 {
+ return fmt.Errorf("empty serverPublicKey. This is a bug")
+ }
+ newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
- line := strings.Join(
- []string{
- fmt.Sprintf("[%s]:%s", c.sshHost, c.sshPort),
- c.serverPublicKey.Type(),
- base64.StdEncoding.EncodeToString(c.serverPublicKey.Marshal()),
- }, " ")
- if _, err := f.Write([]byte(line + "\n")); err != nil {
- return fmt.Errorf("couldn't write new known_host entry to to %s: %w", c.knownHostsPath, err)
+ outputLines := []string{}
+ scanner := bufio.NewScanner(f)
+ for scanner.Scan() {
+ outputLines = append(outputLines, scanner.Text())
+ }
+ outputLines = append(outputLines, newHostLine)
+ if err := f.Truncate(0); err != nil {
+ return fmt.Errorf("couldn't truncate known_hosts: %w", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
+ }
+ if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
+ return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
}
return nil
@@ -300,13 +307,7 @@
}
defer f.Close()
scanner := bufio.NewScanner(f)
- lineToRemove := strings.Join(
- []string{
- fmt.Sprintf("[%s]:%s", c.sshHost, c.sshPort),
- c.serverPublicKey.Type(),
- base64.StdEncoding.EncodeToString(c.serverPublicKey.Marshal()),
- }, " ")
-
+ lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
outputLines := []string{}
for scanner.Scan() {
if scanner.Text() == lineToRemove {