Improve SSH config file safety in dockerimg/ssh_theater.go
- Add SafeWriteFile method to the FileSystem interface that:
- Writes data to a temporary file first
- Syncs to disk to ensure data is written
- Creates a backup of the original file (if it exists)
- Safely renames the temporary file to the target
- Update all file modification operations to use SafeWriteFile:
- addContainerToSSHConfig
- removeContainerFromSSHConfig
- addContainerToKnownHosts
- removeContainerFromKnownHosts
- CheckForIncludeWithFS
This change reduces the risk of corruption if the process is interrupted
while modifying configuration files.
Co-Authored-By: sketch <hello@sketch.dev>
diff --git a/dockerimg/ssh_theater_test.go b/dockerimg/ssh_theater_test.go
index 7027742..3731018 100644
--- a/dockerimg/ssh_theater_test.go
+++ b/dockerimg/ssh_theater_test.go
@@ -20,6 +20,7 @@
CreatedDirs map[string]bool
OpenedFiles map[string]*MockFile
StatCalledWith []string
+ TempFiles []string
FailOn map[string]error // Map of function name to error to simulate failures
}
@@ -28,6 +29,7 @@
Files: make(map[string][]byte),
CreatedDirs: make(map[string]bool),
OpenedFiles: make(map[string]*MockFile),
+ TempFiles: []string{},
FailOn: make(map[string]error),
}
}
@@ -133,6 +135,54 @@
return tmpFile, nil
}
+func (m *MockFileSystem) TempFile(dir, pattern string) (*os.File, error) {
+ if err, ok := m.FailOn["TempFile"]; ok {
+ return nil, err
+ }
+
+ // Create an actual temporary file for testing purposes
+ tmpFile, err := os.CreateTemp(dir, pattern)
+ if err != nil {
+ return nil, err
+ }
+
+ // Record the temp file path
+ m.TempFiles = append(m.TempFiles, tmpFile.Name())
+
+ return tmpFile, nil
+}
+
+func (m *MockFileSystem) Rename(oldpath, newpath string) error {
+ if err, ok := m.FailOn["Rename"]; ok {
+ return err
+ }
+
+ // If the old path exists in our mock file system, move its contents
+ if data, exists := m.Files[oldpath]; exists {
+ m.Files[newpath] = data
+ delete(m.Files, oldpath)
+ }
+
+ return nil
+}
+
+func (m *MockFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
+ if err, ok := m.FailOn["SafeWriteFile"]; ok {
+ return err
+ }
+
+ // For the mock, we'll create a backup if the file exists
+ if existingData, exists := m.Files[name]; exists {
+ backupName := name + ".bak"
+ m.Files[backupName] = existingData
+ }
+
+ // Write the new data
+ m.Files[name] = data
+
+ return nil
+}
+
// MockKeyGenerator implements KeyGenerator interface for testing
type MockKeyGenerator struct {
privateKey *rsa.PrivateKey
@@ -192,6 +242,12 @@
sketchDir := filepath.Join(homePath, ".config/sketch")
mockFS.CreatedDirs[sketchDir] = true
+ // Create empty files so the tests don't fail
+ sketchConfigPath := filepath.Join(sketchDir, "ssh_config")
+ mockFS.Files[sketchConfigPath] = []byte("")
+ knownHostsPath := filepath.Join(sketchDir, "known_hosts")
+ mockFS.Files[knownHostsPath] = []byte("")
+
// Set HOME environment variable for the test
oldHome := os.Getenv("HOME")
os.Setenv("HOME", homePath)
@@ -214,6 +270,13 @@
os.Setenv("HOME", "/home/testuser")
defer func() { os.Setenv("HOME", oldHome) }()
+ // Create empty files so the test doesn't fail
+ sketchDir := "/home/testuser/.config/sketch"
+ sketchConfigPath := filepath.Join(sketchDir, "ssh_config")
+ mockFS.Files[sketchConfigPath] = []byte("")
+ knownHostsPath := filepath.Join(sketchDir, "known_hosts")
+ mockFS.Files[knownHostsPath] = []byte("")
+
// Create theater
_, err := newSSHTheatherWithDeps("test-container", "localhost", "2222", mockFS, mockKG)
if err != nil {