ssh_theater: ask the user before editing config
diff --git a/dockerimg/ssh_theater_test.go b/dockerimg/ssh_theater_test.go
index 5674d91..a4e41a4 100644
--- a/dockerimg/ssh_theater_test.go
+++ b/dockerimg/ssh_theater_test.go
@@ -1,6 +1,7 @@
package dockerimg
import (
+ "bufio"
"bytes"
"crypto/rand"
"crypto/rsa"
@@ -597,7 +598,7 @@
}
}
-func TestCheckForInclude(t *testing.T) {
+func TestCheckForInclude_userAccepts(t *testing.T) {
mockFS := NewMockFileSystem()
// Set HOME environment variable for the test
@@ -612,9 +613,9 @@
// Add the config to the mock filesystem
sshConfigPath := "/home/testuser/.ssh/config"
mockFS.Files[sshConfigPath] = []byte(initialConfig)
-
+ stdinReader := bufio.NewReader(strings.NewReader("y\n"))
// Test the function with our mock
- err := CheckForIncludeWithFS(mockFS)
+ err := CheckForIncludeWithFS(mockFS, *stdinReader)
if err != nil {
t.Fatalf("CheckForInclude failed with proper include: %v", err)
}
@@ -622,12 +623,49 @@
// Now test with config missing the include
mockFS.Files[sshConfigPath] = []byte("Host example\n HostName example.com\n")
- err = CheckForIncludeWithFS(mockFS)
+ stdinReader = bufio.NewReader(strings.NewReader("y\n"))
+ err = CheckForIncludeWithFS(mockFS, *stdinReader)
if err != nil {
t.Fatalf("CheckForInclude should have created the Include line without an error")
}
}
+func TestCheckForInclude_userDeclines(t *testing.T) {
+ mockFS := NewMockFileSystem()
+
+ // Set HOME environment variable for the test
+ oldHome := os.Getenv("HOME")
+ os.Setenv("HOME", "/home/testuser")
+ defer func() { os.Setenv("HOME", oldHome) }()
+
+ // Create a mock ssh config with the expected include
+ includeLine := "Include /home/testuser/.config/sketch/ssh_config"
+ initialConfig := fmt.Sprintf("%s\nHost example\n HostName example.com\n", includeLine)
+
+ // Add the config to the mock filesystem
+ sshConfigPath := "/home/testuser/.ssh/config"
+ mockFS.Files[sshConfigPath] = []byte(initialConfig)
+ stdinReader := bufio.NewReader(strings.NewReader("n\n"))
+ // Test the function with our mock
+ err := CheckForIncludeWithFS(mockFS, *stdinReader)
+ if err != nil {
+ t.Fatalf("CheckForInclude failed with proper include: %v", err)
+ }
+
+ // Now test with config missing the include
+ missingInclude := []byte("Host example\n HostName example.com\n")
+ mockFS.Files[sshConfigPath] = missingInclude
+
+ stdinReader = bufio.NewReader(strings.NewReader("n\n"))
+ err = CheckForIncludeWithFS(mockFS, *stdinReader)
+ if err == nil {
+ t.Errorf("CheckForInclude should have returned an error")
+ }
+ if !bytes.Equal(mockFS.Files[sshConfigPath], missingInclude) {
+ t.Errorf("ssh config should not have been edited")
+ }
+}
+
func TestSSHTheaterWithErrors(t *testing.T) {
// Test directory creation failure
mockFS := NewMockFileSystem()