ssh_theater: ask the user before editing config
diff --git a/dockerimg/dockerimg.go b/dockerimg/dockerimg.go
index b485f8d..834a769 100644
--- a/dockerimg/dockerimg.go
+++ b/dockerimg/dockerimg.go
@@ -283,15 +283,19 @@
var sshServerIdentity, sshUserIdentity []byte
- if err := CheckForInclude(); err != nil {
- fmt.Println(err.Error())
+ sshErr := CheckForInclude()
+ sshAvailable := false
+ sshErrMsg := ""
+ if sshErr != nil {
+ fmt.Println(sshErr.Error())
+ sshErrMsg = sshErr.Error()
// continue - ssh config is not required for the rest of sketch to function locally.
} else {
cst, err := NewSSHTheater(cntrName, sshHost, sshPort)
if err != nil {
return appendInternalErr(fmt.Errorf("NewContainerSSHTheather: %w", err))
}
-
+ sshAvailable = true
// Note: The vscode: link uses an undocumented request parameter that I really had to dig to find:
// https://github.com/microsoft/vscode/blob/2b9486161abaca59b5132ce3c59544f3cc7000f6/src/vs/code/electron-main/app.ts#L878
fmt.Printf(`Connect to this container via any of these methods:
@@ -315,7 +319,7 @@
// the scrollback (which is not good, but also not fatal). I can't see why it does this
// though, since none of the calls in postContainerInitConfig obviously write to stdout
// or stderr.
- if err := postContainerInitConfig(ctx, localAddr, commit, gitSrv.gitPort, gitSrv.pass, sshServerIdentity, sshUserIdentity); err != nil {
+ if err := postContainerInitConfig(ctx, localAddr, commit, gitSrv.gitPort, gitSrv.pass, sshAvailable, sshErrMsg, sshServerIdentity, sshUserIdentity); err != nil {
slog.ErrorContext(ctx, "LaunchContainer.postContainerInitConfig", slog.String("err", err.Error()))
errCh <- appendInternalErr(err)
}
@@ -566,17 +570,9 @@
}
// Contact the container and configure it.
-func postContainerInitConfig(ctx context.Context, localAddr, commit, gitPort, gitPass string, sshServerIdentity, sshAuthorizedKeys []byte) error {
+func postContainerInitConfig(ctx context.Context, localAddr, commit, gitPort, gitPass string, sshAvailable bool, sshError string, sshServerIdentity, sshAuthorizedKeys []byte) error {
localURL := "http://" + localAddr
- // Check if SSH is available by checking for the Include directive in ~/.ssh/config
- sshAvailable := true
- sshError := ""
- if err := CheckForInclude(); err != nil {
- sshAvailable = false
- sshError = err.Error()
- }
-
initMsg, err := json.Marshal(
server.InitRequest{
Commit: commit,
diff --git a/dockerimg/ssh_theater.go b/dockerimg/ssh_theater.go
index 69e63de..6afcfc9 100644
--- a/dockerimg/ssh_theater.go
+++ b/dockerimg/ssh_theater.go
@@ -125,7 +125,7 @@
return cst, nil
}
-func CheckForIncludeWithFS(fs FileSystem) error {
+func CheckForIncludeWithFS(fs FileSystem, stdinReader bufio.Reader) error {
sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".config", "sketch", "ssh_config")
defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
@@ -163,6 +163,14 @@
}
if sketchInludePos == nil {
+ fmt.Printf("\nTo enable you to use ssh to connect to local sketch containers: \nAdd %q to the top of %s [y/N]? ", sketchSSHPathInclude, defaultSSHPath)
+ char, _, err := stdinReader.ReadRune()
+ if err != nil {
+ return fmt.Errorf("couldn't read from stdin: %w", err)
+ }
+ if char != 'y' && char != 'Y' {
+ return fmt.Errorf("User declined to edit ssh config file")
+ }
// Include line not found, add it to the top of the file
cfgBytes, err := cfg.MarshalText()
if err != nil {
@@ -557,5 +565,5 @@
// CheckForInclude checks if the user's SSH config includes the Sketch SSH config file
func CheckForInclude() error {
- return CheckForIncludeWithFS(&RealFileSystem{})
+ return CheckForIncludeWithFS(&RealFileSystem{}, *bufio.NewReader(os.Stdin))
}
diff --git a/dockerimg/ssh_theater_test.go b/dockerimg/ssh_theater_test.go
index 5674d91..a4e41a4 100644
--- a/dockerimg/ssh_theater_test.go
+++ b/dockerimg/ssh_theater_test.go
@@ -1,6 +1,7 @@
package dockerimg
import (
+ "bufio"
"bytes"
"crypto/rand"
"crypto/rsa"
@@ -597,7 +598,7 @@
}
}
-func TestCheckForInclude(t *testing.T) {
+func TestCheckForInclude_userAccepts(t *testing.T) {
mockFS := NewMockFileSystem()
// Set HOME environment variable for the test
@@ -612,9 +613,9 @@
// Add the config to the mock filesystem
sshConfigPath := "/home/testuser/.ssh/config"
mockFS.Files[sshConfigPath] = []byte(initialConfig)
-
+ stdinReader := bufio.NewReader(strings.NewReader("y\n"))
// Test the function with our mock
- err := CheckForIncludeWithFS(mockFS)
+ err := CheckForIncludeWithFS(mockFS, *stdinReader)
if err != nil {
t.Fatalf("CheckForInclude failed with proper include: %v", err)
}
@@ -622,12 +623,49 @@
// Now test with config missing the include
mockFS.Files[sshConfigPath] = []byte("Host example\n HostName example.com\n")
- err = CheckForIncludeWithFS(mockFS)
+ stdinReader = bufio.NewReader(strings.NewReader("y\n"))
+ err = CheckForIncludeWithFS(mockFS, *stdinReader)
if err != nil {
t.Fatalf("CheckForInclude should have created the Include line without an error")
}
}
+func TestCheckForInclude_userDeclines(t *testing.T) {
+ mockFS := NewMockFileSystem()
+
+ // Set HOME environment variable for the test
+ oldHome := os.Getenv("HOME")
+ os.Setenv("HOME", "/home/testuser")
+ defer func() { os.Setenv("HOME", oldHome) }()
+
+ // Create a mock ssh config with the expected include
+ includeLine := "Include /home/testuser/.config/sketch/ssh_config"
+ initialConfig := fmt.Sprintf("%s\nHost example\n HostName example.com\n", includeLine)
+
+ // Add the config to the mock filesystem
+ sshConfigPath := "/home/testuser/.ssh/config"
+ mockFS.Files[sshConfigPath] = []byte(initialConfig)
+ stdinReader := bufio.NewReader(strings.NewReader("n\n"))
+ // Test the function with our mock
+ err := CheckForIncludeWithFS(mockFS, *stdinReader)
+ if err != nil {
+ t.Fatalf("CheckForInclude failed with proper include: %v", err)
+ }
+
+ // Now test with config missing the include
+ missingInclude := []byte("Host example\n HostName example.com\n")
+ mockFS.Files[sshConfigPath] = missingInclude
+
+ stdinReader = bufio.NewReader(strings.NewReader("n\n"))
+ err = CheckForIncludeWithFS(mockFS, *stdinReader)
+ if err == nil {
+ t.Errorf("CheckForInclude should have returned an error")
+ }
+ if !bytes.Equal(mockFS.Files[sshConfigPath], missingInclude) {
+ t.Errorf("ssh config should not have been edited")
+ }
+}
+
func TestSSHTheaterWithErrors(t *testing.T) {
// Test directory creation failure
mockFS := NewMockFileSystem()