blob: 1b060cc94987ab48eab1c4721d62b215c754da75 [file] [log] [blame]
Sean McCullough4854c652025-04-24 18:37:02 -07001package dockerimg
2
3import (
4 "bufio"
Sean McCullough0d95d3a2025-04-30 16:22:28 +00005 "bytes"
Sean McCullough4854c652025-04-24 18:37:02 -07006 "crypto/rand"
7 "crypto/rsa"
8 "crypto/x509"
Sean McCullough4854c652025-04-24 18:37:02 -07009 "encoding/pem"
10 "fmt"
Sean McCullough2cba6952025-04-25 20:32:10 +000011 "io/fs"
Sean McCullough4854c652025-04-24 18:37:02 -070012 "os"
Sean McCullough078e85a2025-05-08 17:28:34 -070013 "os/exec"
Sean McCullough4854c652025-04-24 18:37:02 -070014 "path/filepath"
15 "strings"
16
17 "github.com/kevinburke/ssh_config"
18 "golang.org/x/crypto/ssh"
Sean McCullough7d5a6302025-04-24 21:27:51 -070019 "golang.org/x/crypto/ssh/knownhosts"
Sean McCullough4854c652025-04-24 18:37:02 -070020)
21
22const keyBitSize = 2048
23
24// SSHTheater does the necessary key pair generation, known_hosts updates, ssh_config file updates etc steps
25// so that ssh can connect to a locally running sketch container to other local processes like vscode without
26// the user having to run the usual ssh obstacle course.
27//
28// SSHTheater does not modify your default .ssh/config, or known_hosts files. However, in order for you
29// to be able to use it properly you will have to make a one-time edit to your ~/.ssh/config file.
30//
31// In your ~/.ssh/config file, add the following line:
32//
Sean McCullough74b01212025-04-29 18:40:53 -070033// Include $HOME/.config/sketch/ssh_config
Sean McCullough4854c652025-04-24 18:37:02 -070034//
35// where $HOME is your home directory.
36type SSHTheater struct {
37 cntrName string
38 sshHost string
39 sshPort string
40
41 knownHostsPath string
42 userIdentityPath string
43 sshConfigPath string
44 serverIdentityPath string
45
46 serverPublicKey ssh.PublicKey
47 serverIdentity []byte
48 userIdentity []byte
Sean McCullough2cba6952025-04-25 20:32:10 +000049
50 fs FileSystem
51 kg KeyGenerator
Sean McCullough4854c652025-04-24 18:37:02 -070052}
53
Josh Bleecher Snyder50608b12025-05-03 22:55:49 +000054// NewSSHTheater will set up everything so that you can use ssh on localhost to connect to
Sean McCullough4854c652025-04-24 18:37:02 -070055// the sketch container. Call #Clean when you are done with the container to remove the
56// various entries it created in its known_hosts and ssh_config files. Also note that
57// this will generate key pairs for both the ssh server identity and the user identity, if
58// these files do not already exist. These key pair files are not deleted by #Cleanup,
59// so they can be re-used across invocations of sketch. This means every sketch container
60// that runs on this host will use the same ssh server identity.
61//
62// If this doesn't return an error, you should be able to run "ssh <cntrName>"
63// in a terminal on your host machine to open a shell into the container without having
64// to manually accept changes to your known_hosts file etc.
Josh Bleecher Snyder50608b12025-05-03 22:55:49 +000065func NewSSHTheater(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
Sean McCullough2cba6952025-04-25 20:32:10 +000066 return newSSHTheatherWithDeps(cntrName, sshHost, sshPort, &RealFileSystem{}, &RealKeyGenerator{})
67}
68
69// newSSHTheatherWithDeps creates a new SSHTheater with the specified dependencies
70func newSSHTheatherWithDeps(cntrName, sshHost, sshPort string, fs FileSystem, kg KeyGenerator) (*SSHTheater, error) {
Sean McCullough74b01212025-04-29 18:40:53 -070071 base := filepath.Join(os.Getenv("HOME"), ".config", "sketch")
Sean McCullough2cba6952025-04-25 20:32:10 +000072 if _, err := fs.Stat(base); err != nil {
Sean McCulloughc796e7f2025-04-30 08:44:06 -070073 if err := fs.MkdirAll(base, 0o777); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070074 return nil, fmt.Errorf("couldn't create %s: %w", base, err)
75 }
76 }
77
Sean McCullough4854c652025-04-24 18:37:02 -070078 cst := &SSHTheater{
79 cntrName: cntrName,
80 sshHost: sshHost,
81 sshPort: sshPort,
82 knownHostsPath: filepath.Join(base, "known_hosts"),
83 userIdentityPath: filepath.Join(base, "container_user_identity"),
84 serverIdentityPath: filepath.Join(base, "container_server_identity"),
85 sshConfigPath: filepath.Join(base, "ssh_config"),
Sean McCullough2cba6952025-04-25 20:32:10 +000086 fs: fs,
87 kg: kg,
Sean McCullough4854c652025-04-24 18:37:02 -070088 }
Sean McCullough2cba6952025-04-25 20:32:10 +000089 if _, err := cst.createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070090 return nil, fmt.Errorf("couldn't create server identity: %w", err)
91 }
Sean McCullough2cba6952025-04-25 20:32:10 +000092 if _, err := cst.createKeyPairIfMissing(cst.userIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070093 return nil, fmt.Errorf("couldn't create user identity: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -070094 }
95
Sean McCullough2cba6952025-04-25 20:32:10 +000096 serverIdentity, err := fs.ReadFile(cst.serverIdentityPath)
Sean McCullough4854c652025-04-24 18:37:02 -070097 if err != nil {
98 return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
99 }
100 cst.serverIdentity = serverIdentity
101
Sean McCullough2cba6952025-04-25 20:32:10 +0000102 serverPubKeyBytes, err := fs.ReadFile(cst.serverIdentityPath + ".pub")
103 if err != nil {
104 return nil, fmt.Errorf("couldn't read ssh server public key file: %w", err)
105 }
Sean McCullough4854c652025-04-24 18:37:02 -0700106 serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
107 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000108 return nil, fmt.Errorf("couldn't parse ssh server public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700109 }
110 cst.serverPublicKey = serverPubKey
111
Sean McCullough2cba6952025-04-25 20:32:10 +0000112 userIdentity, err := fs.ReadFile(cst.userIdentityPath + ".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700113 if err != nil {
114 return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
115 }
116 cst.userIdentity = userIdentity
117
Sean McCullough7d5a6302025-04-24 21:27:51 -0700118 if err := cst.addContainerToSSHConfig(); err != nil {
119 return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
120 }
121
122 if err := cst.addContainerToKnownHosts(); err != nil {
123 return nil, fmt.Errorf("couldn't update known hosts: %w", err)
124 }
125
Sean McCullough4854c652025-04-24 18:37:02 -0700126 return cst, nil
127}
128
Sean McCullough078e85a2025-05-08 17:28:34 -0700129func checkSSHResolve(hostname string) error {
130 cmd := exec.Command("ssh", "-T", hostname)
131 out, err := cmd.CombinedOutput()
132 if strings.HasPrefix(string(out), "ssh: Could not resolve hostname") {
133 return err
134 }
135 return nil
136}
137
Sean McCullough15c95282025-05-08 16:48:38 -0700138func CheckForIncludeWithFS(fs FileSystem, stdinReader bufio.Reader) error {
Sean McCullough74b01212025-04-29 18:40:53 -0700139 sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".config", "sketch", "ssh_config")
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700140 defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000141
142 // Read the existing SSH config file
143 existingContent, err := fs.ReadFile(defaultSSHPath)
144 if err != nil {
145 // If the file doesn't exist, create a new one with just the include line
146 if os.IsNotExist(err) {
147 return fs.SafeWriteFile(defaultSSHPath, []byte(sketchSSHPathInclude+"\n"), 0o644)
148 }
149 return fmt.Errorf("⚠️ SSH connections are disabled. cannot open SSH config file: %s: %w", defaultSSHPath, err)
Sean McCullough2cba6952025-04-25 20:32:10 +0000150 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000151
152 // Parse the config file
153 cfg, err := ssh_config.Decode(bytes.NewReader(existingContent))
154 if err != nil {
155 return fmt.Errorf("couldn't decode ssh_config: %w", err)
156 }
157
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700158 var sketchInludePos *ssh_config.Position
159 var firstNonIncludePos *ssh_config.Position
160 for _, host := range cfg.Hosts {
161 for _, node := range host.Nodes {
162 inc, ok := node.(*ssh_config.Include)
163 if ok {
164 if strings.TrimSpace(inc.String()) == sketchSSHPathInclude {
165 pos := inc.Pos()
166 sketchInludePos = &pos
167 }
168 } else if firstNonIncludePos == nil && !strings.HasPrefix(strings.TrimSpace(node.String()), "#") {
169 pos := node.Pos()
170 firstNonIncludePos = &pos
171 }
172 }
173 }
174
175 if sketchInludePos == nil {
Sean McCullough15c95282025-05-08 16:48:38 -0700176 fmt.Printf("\nTo enable you to use ssh to connect to local sketch containers: \nAdd %q to the top of %s [y/N]? ", sketchSSHPathInclude, defaultSSHPath)
177 char, _, err := stdinReader.ReadRune()
178 if err != nil {
179 return fmt.Errorf("couldn't read from stdin: %w", err)
180 }
181 if char != 'y' && char != 'Y' {
182 return fmt.Errorf("User declined to edit ssh config file")
183 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000184 // Include line not found, add it to the top of the file
Sean McCullough3b0795b2025-04-29 19:09:23 -0700185 cfgBytes, err := cfg.MarshalText()
186 if err != nil {
187 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
188 }
Sean McCullough3b0795b2025-04-29 19:09:23 -0700189
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000190 // Add the include line to the beginning
191 cfgBytes = append([]byte(sketchSSHPathInclude+"\n"), cfgBytes...)
192
193 // Safely write the updated config back to the file
194 if err := fs.SafeWriteFile(defaultSSHPath, cfgBytes, 0o644); err != nil {
195 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
196 }
Sean McCullough3b0795b2025-04-29 19:09:23 -0700197 return nil
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700198 }
199
200 if firstNonIncludePos != nil && firstNonIncludePos.Line < sketchInludePos.Line {
Sean McCullough2cba6952025-04-25 20:32:10 +0000201 fmt.Printf("⚠️ SSH confg warning: the location of the Include statement for sketch's ssh config on line %d of %s may prevent ssh from working with sketch containers. try moving it to the top of the file (before any 'Host' lines) if ssh isn't working for you.\n", sketchInludePos.Line, defaultSSHPath)
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700202 }
203 return nil
204}
205
Sean McCullough4854c652025-04-24 18:37:02 -0700206func removeFromHosts(cntrName string, cfgHosts []*ssh_config.Host) []*ssh_config.Host {
207 hosts := []*ssh_config.Host{}
208 for _, host := range cfgHosts {
209 if host.Matches(cntrName) || strings.Contains(host.String(), cntrName) {
210 continue
211 }
212 patMatch := false
213 for _, pat := range host.Patterns {
214 if strings.Contains(pat.String(), cntrName) {
215 patMatch = true
216 }
217 }
218 if patMatch {
219 continue
220 }
221
222 hosts = append(hosts, host)
223 }
224 return hosts
225}
226
Sean McCullough4854c652025-04-24 18:37:02 -0700227func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
228 pemBlock := &pem.Block{
229 Type: "RSA PRIVATE KEY",
230 Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
231 }
232 pemBytes := pem.EncodeToMemory(pemBlock)
233 return pemBytes
234}
235
Sean McCullough2cba6952025-04-25 20:32:10 +0000236func (c *SSHTheater) writeKeyToFile(keyBytes []byte, filename string) error {
237 err := c.fs.WriteFile(filename, keyBytes, 0o600)
Sean McCullough4854c652025-04-24 18:37:02 -0700238 return err
239}
240
Sean McCullough2cba6952025-04-25 20:32:10 +0000241func (c *SSHTheater) createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
242 if _, err := c.fs.Stat(idPath); err == nil {
Sean McCullough4854c652025-04-24 18:37:02 -0700243 return nil, nil
244 }
245
Sean McCullough2cba6952025-04-25 20:32:10 +0000246 privateKey, err := c.kg.GeneratePrivateKey(keyBitSize)
Sean McCullough4854c652025-04-24 18:37:02 -0700247 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000248 return nil, fmt.Errorf("error generating private key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700249 }
250
Sean McCullough2cba6952025-04-25 20:32:10 +0000251 publicRsaKey, err := c.kg.GeneratePublicKey(&privateKey.PublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700252 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000253 return nil, fmt.Errorf("error generating public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700254 }
255
256 privateKeyPEM := encodePrivateKeyToPEM(privateKey)
257
Sean McCullough2cba6952025-04-25 20:32:10 +0000258 err = c.writeKeyToFile(privateKeyPEM, idPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700259 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000260 return nil, fmt.Errorf("error writing private key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700261 }
262 pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
263
Sean McCullough2cba6952025-04-25 20:32:10 +0000264 err = c.writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700265 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000266 return nil, fmt.Errorf("error writing public key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700267 }
268 return publicRsaKey, nil
269}
270
271func (c *SSHTheater) addSketchHostMatchIfMissing(cfg *ssh_config.Config) error {
272 found := false
273 for _, host := range cfg.Hosts {
274 if strings.Contains(host.String(), "host=\"sketch-*\"") {
275 found = true
276 break
277 }
278 }
279 if !found {
280 hostPattern, err := ssh_config.NewPattern("host=\"sketch-*\"")
281 if err != nil {
282 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
283 }
284
285 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
286 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
287
Sean McCullough4854c652025-04-24 18:37:02 -0700288 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
Sean McCullough4854c652025-04-24 18:37:02 -0700289 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
290
291 cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
292 }
293 return nil
294}
295
296func (c *SSHTheater) addContainerToSSHConfig() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000297 // Read the existing file contents or start with an empty config if file doesn't exist
298 var configData []byte
299 var cfg *ssh_config.Config
300 var err error
Sean McCullough4854c652025-04-24 18:37:02 -0700301
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000302 configData, err = c.fs.ReadFile(c.sshConfigPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700303 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000304 // If the file doesn't exist, create an empty config
305 if os.IsNotExist(err) {
306 cfg = &ssh_config.Config{}
307 } else {
308 return fmt.Errorf("couldn't read ssh_config: %w", err)
309 }
310 } else {
311 // Parse the existing config
312 cfg, err = ssh_config.Decode(bytes.NewReader(configData))
313 if err != nil {
314 return fmt.Errorf("couldn't decode ssh_config: %w", err)
315 }
Sean McCullough4854c652025-04-24 18:37:02 -0700316 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000317
Sean McCullough4854c652025-04-24 18:37:02 -0700318 cntrPattern, err := ssh_config.NewPattern(c.cntrName)
319 if err != nil {
320 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
321 }
322
323 // Remove any matches for this container if they already exist.
324 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
325
326 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{cntrPattern}}
327 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "HostName", Value: c.sshHost})
328 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "User", Value: "root"})
329 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "Port", Value: c.sshPort})
330 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
331 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
332
333 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
334 cfg.Hosts = append(cfg.Hosts, hostCfg)
335
336 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
337 return fmt.Errorf("couldn't add missing host match: %w", err)
338 }
339
340 cfgBytes, err := cfg.MarshalText()
341 if err != nil {
342 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
343 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000344
345 // Safely write the updated configuration to file
346 if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
347 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700348 }
349
350 return nil
351}
352
353func (c *SSHTheater) addContainerToKnownHosts() error {
Sean McCullough7d5a6302025-04-24 21:27:51 -0700354 pkBytes := c.serverPublicKey.Marshal()
355 if len(pkBytes) == 0 {
Sean McCullough2cba6952025-04-25 20:32:10 +0000356 return fmt.Errorf("empty serverPublicKey, this is a bug")
Sean McCullough7d5a6302025-04-24 21:27:51 -0700357 }
358 newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700359
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000360 // Read existing known_hosts content or start with empty if the file doesn't exist
Sean McCullough7d5a6302025-04-24 21:27:51 -0700361 outputLines := []string{}
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000362 existingContent, err := c.fs.ReadFile(c.knownHostsPath)
363 if err == nil {
364 scanner := bufio.NewScanner(bytes.NewReader(existingContent))
365 for scanner.Scan() {
366 outputLines = append(outputLines, scanner.Text())
367 }
368 } else if !os.IsNotExist(err) {
369 return fmt.Errorf("couldn't read known_hosts file: %w", err)
Sean McCullough7d5a6302025-04-24 21:27:51 -0700370 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000371
372 // Add the new host line
Sean McCullough7d5a6302025-04-24 21:27:51 -0700373 outputLines = append(outputLines, newHostLine)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000374
375 // Safely write the updated content to the file
376 if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
377 return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700378 }
379
380 return nil
381}
382
383func (c *SSHTheater) removeContainerFromKnownHosts() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000384 // Read the existing known_hosts file
385 existingContent, err := c.fs.ReadFile(c.knownHostsPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700386 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000387 // If the file doesn't exist, there's nothing to do
388 if os.IsNotExist(err) {
389 return nil
390 }
391 return fmt.Errorf("couldn't read known_hosts file: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700392 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000393
394 // Line we want to remove
Sean McCullough7d5a6302025-04-24 21:27:51 -0700395 lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000396
397 // Filter out the line we want to remove
Sean McCullough4854c652025-04-24 18:37:02 -0700398 outputLines := []string{}
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000399 scanner := bufio.NewScanner(bytes.NewReader(existingContent))
Sean McCullough4854c652025-04-24 18:37:02 -0700400 for scanner.Scan() {
401 if scanner.Text() == lineToRemove {
402 continue
403 }
404 outputLines = append(outputLines, scanner.Text())
405 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000406
407 // Safely write the updated content back to the file
408 if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
409 return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700410 }
411
412 return nil
413}
414
415func (c *SSHTheater) Cleanup() error {
416 if err := c.removeContainerFromSSHConfig(); err != nil {
417 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
418 }
419 if err := c.removeContainerFromKnownHosts(); err != nil {
420 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
421 }
422
423 return nil
424}
425
426func (c *SSHTheater) removeContainerFromSSHConfig() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000427 // Read the existing file contents
428 configData, err := c.fs.ReadFile(c.sshConfigPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700429 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000430 return fmt.Errorf("couldn't read ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700431 }
Sean McCullough4854c652025-04-24 18:37:02 -0700432
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000433 cfg, err := ssh_config.Decode(bytes.NewReader(configData))
Sean McCullough4854c652025-04-24 18:37:02 -0700434 if err != nil {
435 return fmt.Errorf("couldn't decode ssh_config: %w", err)
436 }
437 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
438
439 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
440 return fmt.Errorf("couldn't add missing host match: %w", err)
441 }
442
443 cfgBytes, err := cfg.MarshalText()
444 if err != nil {
445 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
446 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000447
448 // Safely write the updated configuration to file
449 if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
450 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700451 }
452 return nil
453}
Sean McCullough2cba6952025-04-25 20:32:10 +0000454
455// FileSystem represents a filesystem interface for testability
456type FileSystem interface {
457 Stat(name string) (fs.FileInfo, error)
458 Mkdir(name string, perm fs.FileMode) error
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700459 MkdirAll(name string, perm fs.FileMode) error
Sean McCullough2cba6952025-04-25 20:32:10 +0000460 ReadFile(name string) ([]byte, error)
461 WriteFile(name string, data []byte, perm fs.FileMode) error
462 OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000463 TempFile(dir, pattern string) (*os.File, error)
464 Rename(oldpath, newpath string) error
465 SafeWriteFile(name string, data []byte, perm fs.FileMode) error
Sean McCullough2cba6952025-04-25 20:32:10 +0000466}
467
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700468func (fs *RealFileSystem) MkdirAll(name string, perm fs.FileMode) error {
469 return os.MkdirAll(name, perm)
470}
471
Sean McCullough2cba6952025-04-25 20:32:10 +0000472// RealFileSystem is the default implementation of FileSystem that uses the OS
473type RealFileSystem struct{}
474
475func (fs *RealFileSystem) Stat(name string) (fs.FileInfo, error) {
476 return os.Stat(name)
477}
478
479func (fs *RealFileSystem) Mkdir(name string, perm fs.FileMode) error {
480 return os.Mkdir(name, perm)
481}
482
483func (fs *RealFileSystem) ReadFile(name string) ([]byte, error) {
484 return os.ReadFile(name)
485}
486
487func (fs *RealFileSystem) WriteFile(name string, data []byte, perm fs.FileMode) error {
488 return os.WriteFile(name, data, perm)
489}
490
491func (fs *RealFileSystem) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
492 return os.OpenFile(name, flag, perm)
493}
494
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000495func (fs *RealFileSystem) TempFile(dir, pattern string) (*os.File, error) {
496 return os.CreateTemp(dir, pattern)
497}
498
499func (fs *RealFileSystem) Rename(oldpath, newpath string) error {
500 return os.Rename(oldpath, newpath)
501}
502
503// SafeWriteFile writes data to a temporary file, syncs to disk, creates a backup of the existing file if it exists,
504// and then renames the temporary file to the target file name.
505func (fs *RealFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
506 // Get the directory from the target filename
507 dir := filepath.Dir(name)
508
509 // Create a temporary file in the same directory
510 tmpFile, err := fs.TempFile(dir, filepath.Base(name)+".*.tmp")
511 if err != nil {
512 return fmt.Errorf("couldn't create temporary file: %w", err)
513 }
514 tmpFilename := tmpFile.Name()
515 defer os.Remove(tmpFilename) // Clean up if we fail
516
517 // Write data to the temporary file
518 if _, err := tmpFile.Write(data); err != nil {
519 tmpFile.Close()
520 return fmt.Errorf("couldn't write to temporary file: %w", err)
521 }
522
523 // Sync to disk to ensure data is written
524 if err := tmpFile.Sync(); err != nil {
525 tmpFile.Close()
526 return fmt.Errorf("couldn't sync temporary file: %w", err)
527 }
528
529 // Close the temporary file
530 if err := tmpFile.Close(); err != nil {
531 return fmt.Errorf("couldn't close temporary file: %w", err)
532 }
533
534 // If the original file exists, create a backup
535 if _, err := fs.Stat(name); err == nil {
536 backupName := name + ".bak"
537 // Remove any existing backup
538 _ = os.Remove(backupName) // Ignore errors if the backup doesn't exist
539
540 // Create the backup
541 if err := fs.Rename(name, backupName); err != nil {
542 return fmt.Errorf("couldn't create backup file: %w", err)
543 }
544 }
545
546 // Rename the temporary file to the target file
547 if err := fs.Rename(tmpFilename, name); err != nil {
548 return fmt.Errorf("couldn't rename temporary file to target: %w", err)
549 }
550
551 // Set permissions on the new file
552 if err := os.Chmod(name, perm); err != nil {
553 return fmt.Errorf("couldn't set permissions on file: %w", err)
554 }
555
556 return nil
557}
558
Sean McCullough2cba6952025-04-25 20:32:10 +0000559// KeyGenerator represents an interface for generating SSH keys for testability
560type KeyGenerator interface {
561 GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error)
562 GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error)
563}
564
565// RealKeyGenerator is the default implementation of KeyGenerator
566type RealKeyGenerator struct{}
567
568func (kg *RealKeyGenerator) GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
569 return rsa.GenerateKey(rand.Reader, bitSize)
570}
571
572func (kg *RealKeyGenerator) GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error) {
573 return ssh.NewPublicKey(privateKey)
574}
575
Sean McCullough078e85a2025-05-08 17:28:34 -0700576// CheckSSHReachability checks if the user's SSH config includes the Sketch SSH config file
577func CheckSSHReachability(cntrName string) error {
578 if err := checkSSHResolve(cntrName); err != nil {
579 return CheckForIncludeWithFS(&RealFileSystem{}, *bufio.NewReader(os.Stdin))
580 }
581 return nil
Sean McCullough2cba6952025-04-25 20:32:10 +0000582}