Improve SSH config file safety in dockerimg/ssh_theater.go
- Add SafeWriteFile method to the FileSystem interface that:
- Writes data to a temporary file first
- Syncs to disk to ensure data is written
- Creates a backup of the original file (if it exists)
- Safely renames the temporary file to the target
- Update all file modification operations to use SafeWriteFile:
- addContainerToSSHConfig
- removeContainerFromSSHConfig
- addContainerToKnownHosts
- removeContainerFromKnownHosts
- CheckForIncludeWithFS
This change reduces the risk of corruption if the process is interrupted
while modifying configuration files.
Co-Authored-By: sketch <hello@sketch.dev>
diff --git a/dockerimg/ssh_theater.go b/dockerimg/ssh_theater.go
index 204542b..c006487 100644
--- a/dockerimg/ssh_theater.go
+++ b/dockerimg/ssh_theater.go
@@ -2,6 +2,7 @@
import (
"bufio"
+ "bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
@@ -127,12 +128,23 @@
func CheckForIncludeWithFS(fs FileSystem) error {
sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".config", "sketch", "ssh_config")
defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
- f, _ := fs.OpenFile(filepath.Join(os.Getenv("HOME"), ".ssh", "config"), os.O_RDWR|os.O_CREATE, 0o644)
- if f == nil {
- return fmt.Errorf("⚠️ SSH connections are disabled. cannot open SSH config file: %s", defaultSSHPath)
+
+ // Read the existing SSH config file
+ existingContent, err := fs.ReadFile(defaultSSHPath)
+ if err != nil {
+ // If the file doesn't exist, create a new one with just the include line
+ if os.IsNotExist(err) {
+ return fs.SafeWriteFile(defaultSSHPath, []byte(sketchSSHPathInclude+"\n"), 0o644)
+ }
+ return fmt.Errorf("⚠️ SSH connections are disabled. cannot open SSH config file: %s: %w", defaultSSHPath, err)
}
- defer f.Close()
- cfg, _ := ssh_config.Decode(f)
+
+ // Parse the config file
+ cfg, err := ssh_config.Decode(bytes.NewReader(existingContent))
+ if err != nil {
+ return fmt.Errorf("couldn't decode ssh_config: %w", err)
+ }
+
var sketchInludePos *ssh_config.Position
var firstNonIncludePos *ssh_config.Position
for _, host := range cfg.Hosts {
@@ -151,21 +163,19 @@
}
if sketchInludePos == nil {
+ // Include line not found, add it to the top of the file
cfgBytes, err := cfg.MarshalText()
if err != nil {
return fmt.Errorf("couldn't marshal ssh_config: %w", err)
}
- if err := f.Truncate(0); err != nil {
- return fmt.Errorf("couldn't truncate ssh_config: %w", err)
- }
- if _, err := f.Seek(0, 0); err != nil {
- return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
- }
- cfgBytes = append([]byte(sketchSSHPathInclude+"\n"), cfgBytes...)
- if _, err := f.Write(cfgBytes); err != nil {
- return fmt.Errorf("couldn't write ssh_config: %w", err)
- }
+ // Add the include line to the beginning
+ cfgBytes = append([]byte(sketchSSHPathInclude+"\n"), cfgBytes...)
+
+ // Safely write the updated config back to the file
+ if err := fs.SafeWriteFile(defaultSSHPath, cfgBytes, 0o644); err != nil {
+ return fmt.Errorf("couldn't safely write ssh_config: %w", err)
+ }
return nil
}
@@ -266,16 +276,27 @@
}
func (c *SSHTheater) addContainerToSSHConfig() error {
- f, err := c.fs.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
- if err != nil {
- return fmt.Errorf("couldn't open ssh_config: %w", err)
- }
- defer f.Close()
+ // Read the existing file contents or start with an empty config if file doesn't exist
+ var configData []byte
+ var cfg *ssh_config.Config
+ var err error
- cfg, err := ssh_config.Decode(f)
+ configData, err = c.fs.ReadFile(c.sshConfigPath)
if err != nil {
- return fmt.Errorf("couldn't decode ssh_config: %w", err)
+ // If the file doesn't exist, create an empty config
+ if os.IsNotExist(err) {
+ cfg = &ssh_config.Config{}
+ } else {
+ return fmt.Errorf("couldn't read ssh_config: %w", err)
+ }
+ } else {
+ // Parse the existing config
+ cfg, err = ssh_config.Decode(bytes.NewReader(configData))
+ if err != nil {
+ return fmt.Errorf("couldn't decode ssh_config: %w", err)
+ }
}
+
cntrPattern, err := ssh_config.NewPattern(c.cntrName)
if err != nil {
return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
@@ -302,73 +323,72 @@
if err != nil {
return fmt.Errorf("couldn't marshal ssh_config: %w", err)
}
- if err := f.Truncate(0); err != nil {
- return fmt.Errorf("couldn't truncate ssh_config: %w", err)
- }
- if _, err := f.Seek(0, 0); err != nil {
- return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
- }
- if _, err := f.Write(cfgBytes); err != nil {
- return fmt.Errorf("couldn't write ssh_config: %w", err)
+
+ // Safely write the updated configuration to file
+ if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
+ return fmt.Errorf("couldn't safely write ssh_config: %w", err)
}
return nil
}
func (c *SSHTheater) addContainerToKnownHosts() error {
- f, err := c.fs.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
- if err != nil {
- return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
- }
- defer f.Close()
pkBytes := c.serverPublicKey.Marshal()
if len(pkBytes) == 0 {
return fmt.Errorf("empty serverPublicKey, this is a bug")
}
newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
+ // Read existing known_hosts content or start with empty if the file doesn't exist
outputLines := []string{}
- scanner := bufio.NewScanner(f)
- for scanner.Scan() {
- outputLines = append(outputLines, scanner.Text())
+ existingContent, err := c.fs.ReadFile(c.knownHostsPath)
+ if err == nil {
+ scanner := bufio.NewScanner(bytes.NewReader(existingContent))
+ for scanner.Scan() {
+ outputLines = append(outputLines, scanner.Text())
+ }
+ } else if !os.IsNotExist(err) {
+ return fmt.Errorf("couldn't read known_hosts file: %w", err)
}
+
+ // Add the new host line
outputLines = append(outputLines, newHostLine)
- if err := f.Truncate(0); err != nil {
- return fmt.Errorf("couldn't truncate known_hosts: %w", err)
- }
- if _, err := f.Seek(0, 0); err != nil {
- return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
- }
- if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
- return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
+
+ // Safely write the updated content to the file
+ if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
+ return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
}
return nil
}
func (c *SSHTheater) removeContainerFromKnownHosts() error {
- f, err := c.fs.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
+ // Read the existing known_hosts file
+ existingContent, err := c.fs.ReadFile(c.knownHostsPath)
if err != nil {
- return fmt.Errorf("couldn't open ssh_config: %w", err)
+ // If the file doesn't exist, there's nothing to do
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return fmt.Errorf("couldn't read known_hosts file: %w", err)
}
- defer f.Close()
- scanner := bufio.NewScanner(f)
+
+ // Line we want to remove
lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
+
+ // Filter out the line we want to remove
outputLines := []string{}
+ scanner := bufio.NewScanner(bytes.NewReader(existingContent))
for scanner.Scan() {
if scanner.Text() == lineToRemove {
continue
}
outputLines = append(outputLines, scanner.Text())
}
- if err := f.Truncate(0); err != nil {
- return fmt.Errorf("couldn't truncate known_hosts: %w", err)
- }
- if _, err := f.Seek(0, 0); err != nil {
- return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
- }
- if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
- return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
+
+ // Safely write the updated content back to the file
+ if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
+ return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
}
return nil
@@ -386,13 +406,13 @@
}
func (c *SSHTheater) removeContainerFromSSHConfig() error {
- f, err := c.fs.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
+ // Read the existing file contents
+ configData, err := c.fs.ReadFile(c.sshConfigPath)
if err != nil {
- return fmt.Errorf("couldn't open ssh_config: %w", err)
+ return fmt.Errorf("couldn't read ssh_config: %w", err)
}
- defer f.Close()
- cfg, err := ssh_config.Decode(f)
+ cfg, err := ssh_config.Decode(bytes.NewReader(configData))
if err != nil {
return fmt.Errorf("couldn't decode ssh_config: %w", err)
}
@@ -406,14 +426,10 @@
if err != nil {
return fmt.Errorf("couldn't marshal ssh_config: %w", err)
}
- if err := f.Truncate(0); err != nil {
- return fmt.Errorf("couldn't truncate ssh_config: %w", err)
- }
- if _, err := f.Seek(0, 0); err != nil {
- return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
- }
- if _, err := f.Write(cfgBytes); err != nil {
- return fmt.Errorf("couldn't write ssh_config: %w", err)
+
+ // Safely write the updated configuration to file
+ if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
+ return fmt.Errorf("couldn't safely write ssh_config: %w", err)
}
return nil
}
@@ -426,6 +442,9 @@
ReadFile(name string) ([]byte, error)
WriteFile(name string, data []byte, perm fs.FileMode) error
OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error)
+ TempFile(dir, pattern string) (*os.File, error)
+ Rename(oldpath, newpath string) error
+ SafeWriteFile(name string, data []byte, perm fs.FileMode) error
}
func (fs *RealFileSystem) MkdirAll(name string, perm fs.FileMode) error {
@@ -455,6 +474,70 @@
return os.OpenFile(name, flag, perm)
}
+func (fs *RealFileSystem) TempFile(dir, pattern string) (*os.File, error) {
+ return os.CreateTemp(dir, pattern)
+}
+
+func (fs *RealFileSystem) Rename(oldpath, newpath string) error {
+ return os.Rename(oldpath, newpath)
+}
+
+// SafeWriteFile writes data to a temporary file, syncs to disk, creates a backup of the existing file if it exists,
+// and then renames the temporary file to the target file name.
+func (fs *RealFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
+ // Get the directory from the target filename
+ dir := filepath.Dir(name)
+
+ // Create a temporary file in the same directory
+ tmpFile, err := fs.TempFile(dir, filepath.Base(name)+".*.tmp")
+ if err != nil {
+ return fmt.Errorf("couldn't create temporary file: %w", err)
+ }
+ tmpFilename := tmpFile.Name()
+ defer os.Remove(tmpFilename) // Clean up if we fail
+
+ // Write data to the temporary file
+ if _, err := tmpFile.Write(data); err != nil {
+ tmpFile.Close()
+ return fmt.Errorf("couldn't write to temporary file: %w", err)
+ }
+
+ // Sync to disk to ensure data is written
+ if err := tmpFile.Sync(); err != nil {
+ tmpFile.Close()
+ return fmt.Errorf("couldn't sync temporary file: %w", err)
+ }
+
+ // Close the temporary file
+ if err := tmpFile.Close(); err != nil {
+ return fmt.Errorf("couldn't close temporary file: %w", err)
+ }
+
+ // If the original file exists, create a backup
+ if _, err := fs.Stat(name); err == nil {
+ backupName := name + ".bak"
+ // Remove any existing backup
+ _ = os.Remove(backupName) // Ignore errors if the backup doesn't exist
+
+ // Create the backup
+ if err := fs.Rename(name, backupName); err != nil {
+ return fmt.Errorf("couldn't create backup file: %w", err)
+ }
+ }
+
+ // Rename the temporary file to the target file
+ if err := fs.Rename(tmpFilename, name); err != nil {
+ return fmt.Errorf("couldn't rename temporary file to target: %w", err)
+ }
+
+ // Set permissions on the new file
+ if err := os.Chmod(name, perm); err != nil {
+ return fmt.Errorf("couldn't set permissions on file: %w", err)
+ }
+
+ return nil
+}
+
// KeyGenerator represents an interface for generating SSH keys for testability
type KeyGenerator interface {
GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error)
diff --git a/dockerimg/ssh_theater_test.go b/dockerimg/ssh_theater_test.go
index 7027742..3731018 100644
--- a/dockerimg/ssh_theater_test.go
+++ b/dockerimg/ssh_theater_test.go
@@ -20,6 +20,7 @@
CreatedDirs map[string]bool
OpenedFiles map[string]*MockFile
StatCalledWith []string
+ TempFiles []string
FailOn map[string]error // Map of function name to error to simulate failures
}
@@ -28,6 +29,7 @@
Files: make(map[string][]byte),
CreatedDirs: make(map[string]bool),
OpenedFiles: make(map[string]*MockFile),
+ TempFiles: []string{},
FailOn: make(map[string]error),
}
}
@@ -133,6 +135,54 @@
return tmpFile, nil
}
+func (m *MockFileSystem) TempFile(dir, pattern string) (*os.File, error) {
+ if err, ok := m.FailOn["TempFile"]; ok {
+ return nil, err
+ }
+
+ // Create an actual temporary file for testing purposes
+ tmpFile, err := os.CreateTemp(dir, pattern)
+ if err != nil {
+ return nil, err
+ }
+
+ // Record the temp file path
+ m.TempFiles = append(m.TempFiles, tmpFile.Name())
+
+ return tmpFile, nil
+}
+
+func (m *MockFileSystem) Rename(oldpath, newpath string) error {
+ if err, ok := m.FailOn["Rename"]; ok {
+ return err
+ }
+
+ // If the old path exists in our mock file system, move its contents
+ if data, exists := m.Files[oldpath]; exists {
+ m.Files[newpath] = data
+ delete(m.Files, oldpath)
+ }
+
+ return nil
+}
+
+func (m *MockFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
+ if err, ok := m.FailOn["SafeWriteFile"]; ok {
+ return err
+ }
+
+ // For the mock, we'll create a backup if the file exists
+ if existingData, exists := m.Files[name]; exists {
+ backupName := name + ".bak"
+ m.Files[backupName] = existingData
+ }
+
+ // Write the new data
+ m.Files[name] = data
+
+ return nil
+}
+
// MockKeyGenerator implements KeyGenerator interface for testing
type MockKeyGenerator struct {
privateKey *rsa.PrivateKey
@@ -192,6 +242,12 @@
sketchDir := filepath.Join(homePath, ".config/sketch")
mockFS.CreatedDirs[sketchDir] = true
+ // Create empty files so the tests don't fail
+ sketchConfigPath := filepath.Join(sketchDir, "ssh_config")
+ mockFS.Files[sketchConfigPath] = []byte("")
+ knownHostsPath := filepath.Join(sketchDir, "known_hosts")
+ mockFS.Files[knownHostsPath] = []byte("")
+
// Set HOME environment variable for the test
oldHome := os.Getenv("HOME")
os.Setenv("HOME", homePath)
@@ -214,6 +270,13 @@
os.Setenv("HOME", "/home/testuser")
defer func() { os.Setenv("HOME", oldHome) }()
+ // Create empty files so the test doesn't fail
+ sketchDir := "/home/testuser/.config/sketch"
+ sketchConfigPath := filepath.Join(sketchDir, "ssh_config")
+ mockFS.Files[sketchConfigPath] = []byte("")
+ knownHostsPath := filepath.Join(sketchDir, "known_hosts")
+ mockFS.Files[knownHostsPath] = []byte("")
+
// Create theater
_, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
if err != nil {