Improve SSH config file safety in dockerimg/ssh_theater.go

- Add SafeWriteFile method to the FileSystem interface that:
  - Writes data to a temporary file first
  - Syncs to disk to ensure data is written
  - Creates a backup of the original file (if it exists)
  - Safely renames the temporary file to the target

- Update all file modification operations to use SafeWriteFile:
  - addContainerToSSHConfig
  - removeContainerFromSSHConfig
  - addContainerToKnownHosts
  - removeContainerFromKnownHosts
  - CheckForIncludeWithFS

This change reduces the risk of corruption if the process is interrupted
while modifying configuration files.

Co-Authored-By: sketch <hello@sketch.dev>
diff --git a/dockerimg/ssh_theater.go b/dockerimg/ssh_theater.go
index 204542b..c006487 100644
--- a/dockerimg/ssh_theater.go
+++ b/dockerimg/ssh_theater.go
@@ -2,6 +2,7 @@
 
 import (
 	"bufio"
+	"bytes"
 	"crypto/rand"
 	"crypto/rsa"
 	"crypto/x509"
@@ -127,12 +128,23 @@
 func CheckForIncludeWithFS(fs FileSystem) error {
 	sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".config", "sketch", "ssh_config")
 	defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
-	f, _ := fs.OpenFile(filepath.Join(os.Getenv("HOME"), ".ssh", "config"), os.O_RDWR|os.O_CREATE, 0o644)
-	if f == nil {
-		return fmt.Errorf("⚠️  SSH connections are disabled. cannot open SSH config file: %s", defaultSSHPath)
+
+	// Read the existing SSH config file
+	existingContent, err := fs.ReadFile(defaultSSHPath)
+	if err != nil {
+		// If the file doesn't exist, create a new one with just the include line
+		if os.IsNotExist(err) {
+			return fs.SafeWriteFile(defaultSSHPath, []byte(sketchSSHPathInclude+"\n"), 0o644)
+		}
+		return fmt.Errorf("⚠️  SSH connections are disabled. cannot open SSH config file: %s: %w", defaultSSHPath, err)
 	}
-	defer f.Close()
-	cfg, _ := ssh_config.Decode(f)
+
+	// Parse the config file
+	cfg, err := ssh_config.Decode(bytes.NewReader(existingContent))
+	if err != nil {
+		return fmt.Errorf("couldn't decode ssh_config: %w", err)
+	}
+
 	var sketchInludePos *ssh_config.Position
 	var firstNonIncludePos *ssh_config.Position
 	for _, host := range cfg.Hosts {
@@ -151,21 +163,19 @@
 	}
 
 	if sketchInludePos == nil {
+		// Include line not found, add it to the top of the file
 		cfgBytes, err := cfg.MarshalText()
 		if err != nil {
 			return fmt.Errorf("couldn't marshal ssh_config: %w", err)
 		}
-		if err := f.Truncate(0); err != nil {
-			return fmt.Errorf("couldn't truncate ssh_config: %w", err)
-		}
-		if _, err := f.Seek(0, 0); err != nil {
-			return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
-		}
-		cfgBytes = append([]byte(sketchSSHPathInclude+"\n"), cfgBytes...)
-		if _, err := f.Write(cfgBytes); err != nil {
-			return fmt.Errorf("couldn't write ssh_config: %w", err)
-		}
 
+		// Add the include line to the beginning
+		cfgBytes = append([]byte(sketchSSHPathInclude+"\n"), cfgBytes...)
+
+		// Safely write the updated config back to the file
+		if err := fs.SafeWriteFile(defaultSSHPath, cfgBytes, 0o644); err != nil {
+			return fmt.Errorf("couldn't safely write ssh_config: %w", err)
+		}
 		return nil
 	}
 
@@ -266,16 +276,27 @@
 }
 
 func (c *SSHTheater) addContainerToSSHConfig() error {
-	f, err := c.fs.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
-	if err != nil {
-		return fmt.Errorf("couldn't open ssh_config: %w", err)
-	}
-	defer f.Close()
+	// Read the existing file contents or start with an empty config if file doesn't exist
+	var configData []byte
+	var cfg *ssh_config.Config
+	var err error
 
-	cfg, err := ssh_config.Decode(f)
+	configData, err = c.fs.ReadFile(c.sshConfigPath)
 	if err != nil {
-		return fmt.Errorf("couldn't decode ssh_config: %w", err)
+		// If the file doesn't exist, create an empty config
+		if os.IsNotExist(err) {
+			cfg = &ssh_config.Config{}
+		} else {
+			return fmt.Errorf("couldn't read ssh_config: %w", err)
+		}
+	} else {
+		// Parse the existing config
+		cfg, err = ssh_config.Decode(bytes.NewReader(configData))
+		if err != nil {
+			return fmt.Errorf("couldn't decode ssh_config: %w", err)
+		}
 	}
+
 	cntrPattern, err := ssh_config.NewPattern(c.cntrName)
 	if err != nil {
 		return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
@@ -302,73 +323,72 @@
 	if err != nil {
 		return fmt.Errorf("couldn't marshal ssh_config: %w", err)
 	}
-	if err := f.Truncate(0); err != nil {
-		return fmt.Errorf("couldn't truncate ssh_config: %w", err)
-	}
-	if _, err := f.Seek(0, 0); err != nil {
-		return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
-	}
-	if _, err := f.Write(cfgBytes); err != nil {
-		return fmt.Errorf("couldn't write ssh_config: %w", err)
+
+	// Safely write the updated configuration to file
+	if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
+		return fmt.Errorf("couldn't safely write ssh_config: %w", err)
 	}
 
 	return nil
 }
 
 func (c *SSHTheater) addContainerToKnownHosts() error {
-	f, err := c.fs.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
-	if err != nil {
-		return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
-	}
-	defer f.Close()
 	pkBytes := c.serverPublicKey.Marshal()
 	if len(pkBytes) == 0 {
 		return fmt.Errorf("empty serverPublicKey, this is a bug")
 	}
 	newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
 
+	// Read existing known_hosts content or start with empty if the file doesn't exist
 	outputLines := []string{}
-	scanner := bufio.NewScanner(f)
-	for scanner.Scan() {
-		outputLines = append(outputLines, scanner.Text())
+	existingContent, err := c.fs.ReadFile(c.knownHostsPath)
+	if err == nil {
+		scanner := bufio.NewScanner(bytes.NewReader(existingContent))
+		for scanner.Scan() {
+			outputLines = append(outputLines, scanner.Text())
+		}
+	} else if !os.IsNotExist(err) {
+		return fmt.Errorf("couldn't read known_hosts file: %w", err)
 	}
+
+	// Add the new host line
 	outputLines = append(outputLines, newHostLine)
-	if err := f.Truncate(0); err != nil {
-		return fmt.Errorf("couldn't truncate known_hosts: %w", err)
-	}
-	if _, err := f.Seek(0, 0); err != nil {
-		return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
-	}
-	if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
-		return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
+
+	// Safely write the updated content to the file
+	if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
+		return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
 	}
 
 	return nil
 }
 
 func (c *SSHTheater) removeContainerFromKnownHosts() error {
-	f, err := c.fs.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
+	// Read the existing known_hosts file
+	existingContent, err := c.fs.ReadFile(c.knownHostsPath)
 	if err != nil {
-		return fmt.Errorf("couldn't open ssh_config: %w", err)
+		// If the file doesn't exist, there's nothing to do
+		if os.IsNotExist(err) {
+			return nil
+		}
+		return fmt.Errorf("couldn't read known_hosts file: %w", err)
 	}
-	defer f.Close()
-	scanner := bufio.NewScanner(f)
+
+	// Line we want to remove
 	lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
+
+	// Filter out the line we want to remove
 	outputLines := []string{}
+	scanner := bufio.NewScanner(bytes.NewReader(existingContent))
 	for scanner.Scan() {
 		if scanner.Text() == lineToRemove {
 			continue
 		}
 		outputLines = append(outputLines, scanner.Text())
 	}
-	if err := f.Truncate(0); err != nil {
-		return fmt.Errorf("couldn't truncate known_hosts: %w", err)
-	}
-	if _, err := f.Seek(0, 0); err != nil {
-		return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
-	}
-	if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
-		return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
+
+	// Safely write the updated content back to the file
+	if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
+		return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
 	}
 
 	return nil
@@ -386,13 +406,13 @@
 }
 
 func (c *SSHTheater) removeContainerFromSSHConfig() error {
-	f, err := c.fs.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
+	// Read the existing file contents
+	configData, err := c.fs.ReadFile(c.sshConfigPath)
 	if err != nil {
-		return fmt.Errorf("couldn't open ssh_config: %w", err)
+		return fmt.Errorf("couldn't read ssh_config: %w", err)
 	}
-	defer f.Close()
 
-	cfg, err := ssh_config.Decode(f)
+	cfg, err := ssh_config.Decode(bytes.NewReader(configData))
 	if err != nil {
 		return fmt.Errorf("couldn't decode ssh_config: %w", err)
 	}
@@ -406,14 +426,10 @@
 	if err != nil {
 		return fmt.Errorf("couldn't marshal ssh_config: %w", err)
 	}
-	if err := f.Truncate(0); err != nil {
-		return fmt.Errorf("couldn't truncate ssh_config: %w", err)
-	}
-	if _, err := f.Seek(0, 0); err != nil {
-		return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
-	}
-	if _, err := f.Write(cfgBytes); err != nil {
-		return fmt.Errorf("couldn't write ssh_config: %w", err)
+
+	// Safely write the updated configuration to file
+	if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
+		return fmt.Errorf("couldn't safely write ssh_config: %w", err)
 	}
 	return nil
 }
@@ -426,6 +442,9 @@
 	ReadFile(name string) ([]byte, error)
 	WriteFile(name string, data []byte, perm fs.FileMode) error
 	OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error)
+	TempFile(dir, pattern string) (*os.File, error)
+	Rename(oldpath, newpath string) error
+	SafeWriteFile(name string, data []byte, perm fs.FileMode) error
 }
 
 func (fs *RealFileSystem) MkdirAll(name string, perm fs.FileMode) error {
@@ -455,6 +474,70 @@
 	return os.OpenFile(name, flag, perm)
 }
 
+func (fs *RealFileSystem) TempFile(dir, pattern string) (*os.File, error) {
+	return os.CreateTemp(dir, pattern)
+}
+
+func (fs *RealFileSystem) Rename(oldpath, newpath string) error {
+	return os.Rename(oldpath, newpath)
+}
+
+// SafeWriteFile writes data to a temporary file, syncs to disk, creates a backup of the existing file if it exists,
+// and then renames the temporary file to the target file name.
+func (fs *RealFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
+	// Get the directory from the target filename
+	dir := filepath.Dir(name)
+
+	// Create a temporary file in the same directory
+	tmpFile, err := fs.TempFile(dir, filepath.Base(name)+".*.tmp")
+	if err != nil {
+		return fmt.Errorf("couldn't create temporary file: %w", err)
+	}
+	tmpFilename := tmpFile.Name()
+	defer os.Remove(tmpFilename) // Clean up if we fail
+
+	// Write data to the temporary file
+	if _, err := tmpFile.Write(data); err != nil {
+		tmpFile.Close()
+		return fmt.Errorf("couldn't write to temporary file: %w", err)
+	}
+
+	// Sync to disk to ensure data is written
+	if err := tmpFile.Sync(); err != nil {
+		tmpFile.Close()
+		return fmt.Errorf("couldn't sync temporary file: %w", err)
+	}
+
+	// Close the temporary file
+	if err := tmpFile.Close(); err != nil {
+		return fmt.Errorf("couldn't close temporary file: %w", err)
+	}
+
+	// If the original file exists, create a backup
+	if _, err := fs.Stat(name); err == nil {
+		backupName := name + ".bak"
+		// Remove any existing backup
+		_ = os.Remove(backupName) // Ignore errors if the backup doesn't exist
+
+		// Create the backup
+		if err := fs.Rename(name, backupName); err != nil {
+			return fmt.Errorf("couldn't create backup file: %w", err)
+		}
+	}
+
+	// Rename the temporary file to the target file
+	if err := fs.Rename(tmpFilename, name); err != nil {
+		return fmt.Errorf("couldn't rename temporary file to target: %w", err)
+	}
+
+	// Set permissions on the new file
+	if err := os.Chmod(name, perm); err != nil {
+		return fmt.Errorf("couldn't set permissions on file: %w", err)
+	}
+
+	return nil
+}
+
 // KeyGenerator represents an interface for generating SSH keys for testability
 type KeyGenerator interface {
 	GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error)
diff --git a/dockerimg/ssh_theater_test.go b/dockerimg/ssh_theater_test.go
index 7027742..3731018 100644
--- a/dockerimg/ssh_theater_test.go
+++ b/dockerimg/ssh_theater_test.go
@@ -20,6 +20,7 @@
 	CreatedDirs    map[string]bool
 	OpenedFiles    map[string]*MockFile
 	StatCalledWith []string
+	TempFiles      []string
 	FailOn         map[string]error // Map of function name to error to simulate failures
 }
 
@@ -28,6 +29,7 @@
 		Files:       make(map[string][]byte),
 		CreatedDirs: make(map[string]bool),
 		OpenedFiles: make(map[string]*MockFile),
+		TempFiles:   []string{},
 		FailOn:      make(map[string]error),
 	}
 }
@@ -133,6 +135,54 @@
 	return tmpFile, nil
 }
 
+func (m *MockFileSystem) TempFile(dir, pattern string) (*os.File, error) {
+	if err, ok := m.FailOn["TempFile"]; ok {
+		return nil, err
+	}
+
+	// Create an actual temporary file for testing purposes
+	tmpFile, err := os.CreateTemp(dir, pattern)
+	if err != nil {
+		return nil, err
+	}
+
+	// Record the temp file path
+	m.TempFiles = append(m.TempFiles, tmpFile.Name())
+
+	return tmpFile, nil
+}
+
+func (m *MockFileSystem) Rename(oldpath, newpath string) error {
+	if err, ok := m.FailOn["Rename"]; ok {
+		return err
+	}
+
+	// If the old path exists in our mock file system, move its contents
+	if data, exists := m.Files[oldpath]; exists {
+		m.Files[newpath] = data
+		delete(m.Files, oldpath)
+	}
+
+	return nil
+}
+
+func (m *MockFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
+	if err, ok := m.FailOn["SafeWriteFile"]; ok {
+		return err
+	}
+
+	// For the mock, we'll create a backup if the file exists
+	if existingData, exists := m.Files[name]; exists {
+		backupName := name + ".bak"
+		m.Files[backupName] = existingData
+	}
+
+	// Write the new data
+	m.Files[name] = data
+
+	return nil
+}
+
 // MockKeyGenerator implements KeyGenerator interface for testing
 type MockKeyGenerator struct {
 	privateKey *rsa.PrivateKey
@@ -192,6 +242,12 @@
 	sketchDir := filepath.Join(homePath, ".config/sketch")
 	mockFS.CreatedDirs[sketchDir] = true
 
+	// Create empty files so the tests don't fail
+	sketchConfigPath := filepath.Join(sketchDir, "ssh_config")
+	mockFS.Files[sketchConfigPath] = []byte("")
+	knownHostsPath := filepath.Join(sketchDir, "known_hosts")
+	mockFS.Files[knownHostsPath] = []byte("")
+
 	// Set HOME environment variable for the test
 	oldHome := os.Getenv("HOME")
 	os.Setenv("HOME", homePath)
@@ -214,6 +270,13 @@
 	os.Setenv("HOME", "/home/testuser")
 	defer func() { os.Setenv("HOME", oldHome) }()
 
+	// Create empty files so the test doesn't fail
+	sketchDir := "/home/testuser/.config/sketch"
+	sketchConfigPath := filepath.Join(sketchDir, "ssh_config")
+	mockFS.Files[sketchConfigPath] = []byte("")
+	knownHostsPath := filepath.Join(sketchDir, "known_hosts")
+	mockFS.Files[knownHostsPath] = []byte("")
+
 	// Create theater
 	_, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
 	if err != nil {