SSHTheater: fix known_hosts update issue
diff --git a/dockerimg/sshtheater.go b/dockerimg/sshtheater.go
index 31d1849..35433dd 100644
--- a/dockerimg/sshtheater.go
+++ b/dockerimg/sshtheater.go
@@ -5,7 +5,6 @@
 	"crypto/rand"
 	"crypto/rsa"
 	"crypto/x509"
-	"encoding/base64"
 	"encoding/pem"
 	"fmt"
 	"os"
@@ -14,6 +13,7 @@
 
 	"github.com/kevinburke/ssh_config"
 	"golang.org/x/crypto/ssh"
+	"golang.org/x/crypto/ssh/knownhosts"
 )
 
 const keyBitSize = 2048
@@ -58,6 +58,12 @@
 // to manually accept changes to your known_hosts file etc.
 func NewSSHTheather(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
 	base := filepath.Join(os.Getenv("HOME"), ".sketch")
+	if _, err := os.Stat(base); err != nil {
+		if err := os.Mkdir(base, 0o777); err != nil {
+			return nil, fmt.Errorf("couldn't create %s: %w", base, err)
+		}
+	}
+
 	cst := &SSHTheater{
 		cntrName:           cntrName,
 		sshHost:            sshHost,
@@ -67,8 +73,11 @@
 		serverIdentityPath: filepath.Join(base, "container_server_identity"),
 		sshConfigPath:      filepath.Join(base, "ssh_config"),
 	}
-	if err := cst.addContainerToSSHConfig(); err != nil {
-		return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
+	if _, err := createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
+		return nil, fmt.Errorf("couldn't create server identity: %w", err)
+	}
+	if _, err := createKeyPairIfMissing(cst.userIdentityPath); err != nil {
+		return nil, fmt.Errorf("couldn't create user identity: %w", err)
 	}
 
 	serverIdentity, err := os.ReadFile(cst.serverIdentityPath)
@@ -90,6 +99,14 @@
 	}
 	cst.userIdentity = userIdentity
 
+	if err := cst.addContainerToSSHConfig(); err != nil {
+		return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
+	}
+
+	if err := cst.addContainerToKnownHosts(); err != nil {
+		return nil, fmt.Errorf("couldn't update known hosts: %w", err)
+	}
+
 	return cst, nil
 }
 
@@ -194,21 +211,7 @@
 		hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
 		hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
 
-		if _, err := createKeyPairIfMissing(c.userIdentityPath); err != nil {
-			return fmt.Errorf("couldn't create user identity: %w", err)
-		}
 		hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
-
-		serverPubKey, err := createKeyPairIfMissing(c.serverIdentityPath)
-		if err != nil {
-			return fmt.Errorf("couldn't create server identity: %w", err)
-		}
-		if serverPubKey != nil {
-			c.serverPublicKey = serverPubKey
-			if err := c.addContainerToKnownHosts(); err != nil {
-				return fmt.Errorf("couldn't update known hosts: %w", err)
-			}
-		}
 		hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
 
 		cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
@@ -217,13 +220,6 @@
 }
 
 func (c *SSHTheater) addContainerToSSHConfig() error {
-	dotSketchPath := filepath.Join(os.Getenv("HOME"), ".sketch")
-	if _, err := os.Stat(dotSketchPath); err != nil {
-		if err := os.Mkdir(dotSketchPath, 0o777); err != nil {
-			return fmt.Errorf("couldn't create %s: %w", dotSketchPath, err)
-		}
-	}
-
 	f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
 	if err != nil {
 		return fmt.Errorf("couldn't open ssh_config: %w", err)
@@ -279,15 +275,26 @@
 		return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
 	}
 	defer f.Close()
+	pkBytes := c.serverPublicKey.Marshal()
+	if len(pkBytes) == 0 {
+		return fmt.Errorf("empty serverPublicKey. This is a bug")
+	}
+	newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
 
-	line := strings.Join(
-		[]string{
-			fmt.Sprintf("[%s]:%s", c.sshHost, c.sshPort),
-			c.serverPublicKey.Type(),
-			base64.StdEncoding.EncodeToString(c.serverPublicKey.Marshal()),
-		}, " ")
-	if _, err := f.Write([]byte(line + "\n")); err != nil {
-		return fmt.Errorf("couldn't write new known_host entry to to %s: %w", c.knownHostsPath, err)
+	outputLines := []string{}
+	scanner := bufio.NewScanner(f)
+	for scanner.Scan() {
+		outputLines = append(outputLines, scanner.Text())
+	}
+	outputLines = append(outputLines, newHostLine)
+	if err := f.Truncate(0); err != nil {
+		return fmt.Errorf("couldn't truncate known_hosts: %w", err)
+	}
+	if _, err := f.Seek(0, 0); err != nil {
+		return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
+	}
+	if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
+		return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
 	}
 
 	return nil
@@ -300,13 +307,7 @@
 	}
 	defer f.Close()
 	scanner := bufio.NewScanner(f)
-	lineToRemove := strings.Join(
-		[]string{
-			fmt.Sprintf("[%s]:%s", c.sshHost, c.sshPort),
-			c.serverPublicKey.Type(),
-			base64.StdEncoding.EncodeToString(c.serverPublicKey.Marshal()),
-		}, " ")
-
+	lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
 	outputLines := []string{}
 	for scanner.Scan() {
 		if scanner.Text() == lineToRemove {