blob: 91d3c21c718691306a380e4df561e81f79aa8f62 [file] [log] [blame]
Sean McCullough4854c652025-04-24 18:37:02 -07001package dockerimg
2
3import (
4 "bufio"
Sean McCullough0d95d3a2025-04-30 16:22:28 +00005 "bytes"
Sean McCullough3e9d80c2025-05-13 23:35:23 +00006 "crypto/ed25519"
Sean McCullough4854c652025-04-24 18:37:02 -07007 "crypto/rand"
Sean McCullough4854c652025-04-24 18:37:02 -07008 "encoding/pem"
9 "fmt"
Sean McCullough2cba6952025-04-25 20:32:10 +000010 "io/fs"
Sean McCullough4854c652025-04-24 18:37:02 -070011 "os"
Sean McCullough078e85a2025-05-08 17:28:34 -070012 "os/exec"
Sean McCullough4854c652025-04-24 18:37:02 -070013 "path/filepath"
14 "strings"
15
16 "github.com/kevinburke/ssh_config"
17 "golang.org/x/crypto/ssh"
Sean McCullough7d5a6302025-04-24 21:27:51 -070018 "golang.org/x/crypto/ssh/knownhosts"
Sean McCullough4854c652025-04-24 18:37:02 -070019)
20
Sean McCullough3e9d80c2025-05-13 23:35:23 +000021// Ed25519 has a fixed key size, no bit size constant needed
Sean McCullough4854c652025-04-24 18:37:02 -070022
23// SSHTheater does the necessary key pair generation, known_hosts updates, ssh_config file updates etc steps
24// so that ssh can connect to a locally running sketch container to other local processes like vscode without
25// the user having to run the usual ssh obstacle course.
26//
27// SSHTheater does not modify your default .ssh/config, or known_hosts files. However, in order for you
28// to be able to use it properly you will have to make a one-time edit to your ~/.ssh/config file.
29//
30// In your ~/.ssh/config file, add the following line:
31//
Sean McCullough74b01212025-04-29 18:40:53 -070032// Include $HOME/.config/sketch/ssh_config
Sean McCullough4854c652025-04-24 18:37:02 -070033//
34// where $HOME is your home directory.
Sean McCullough3e9d80c2025-05-13 23:35:23 +000035//
36// SSHTheater uses Ed25519 keys for improved security and performance.
Sean McCullough4854c652025-04-24 18:37:02 -070037type SSHTheater struct {
38 cntrName string
39 sshHost string
40 sshPort string
41
42 knownHostsPath string
43 userIdentityPath string
44 sshConfigPath string
45 serverIdentityPath string
46
47 serverPublicKey ssh.PublicKey
48 serverIdentity []byte
49 userIdentity []byte
Sean McCullough2cba6952025-04-25 20:32:10 +000050
51 fs FileSystem
52 kg KeyGenerator
Sean McCullough4854c652025-04-24 18:37:02 -070053}
54
Josh Bleecher Snyder50608b12025-05-03 22:55:49 +000055// NewSSHTheater will set up everything so that you can use ssh on localhost to connect to
Sean McCullough4854c652025-04-24 18:37:02 -070056// the sketch container. Call #Clean when you are done with the container to remove the
57// various entries it created in its known_hosts and ssh_config files. Also note that
58// this will generate key pairs for both the ssh server identity and the user identity, if
59// these files do not already exist. These key pair files are not deleted by #Cleanup,
60// so they can be re-used across invocations of sketch. This means every sketch container
61// that runs on this host will use the same ssh server identity.
Sean McCullough3e9d80c2025-05-13 23:35:23 +000062// The system uses Ed25519 keys for better security and performance.
Sean McCullough4854c652025-04-24 18:37:02 -070063//
64// If this doesn't return an error, you should be able to run "ssh <cntrName>"
65// in a terminal on your host machine to open a shell into the container without having
66// to manually accept changes to your known_hosts file etc.
Josh Bleecher Snyder50608b12025-05-03 22:55:49 +000067func NewSSHTheater(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
Sean McCullough2cba6952025-04-25 20:32:10 +000068 return newSSHTheatherWithDeps(cntrName, sshHost, sshPort, &RealFileSystem{}, &RealKeyGenerator{})
69}
70
71// newSSHTheatherWithDeps creates a new SSHTheater with the specified dependencies
72func newSSHTheatherWithDeps(cntrName, sshHost, sshPort string, fs FileSystem, kg KeyGenerator) (*SSHTheater, error) {
Sean McCullough74b01212025-04-29 18:40:53 -070073 base := filepath.Join(os.Getenv("HOME"), ".config", "sketch")
Sean McCullough2cba6952025-04-25 20:32:10 +000074 if _, err := fs.Stat(base); err != nil {
Sean McCulloughc796e7f2025-04-30 08:44:06 -070075 if err := fs.MkdirAll(base, 0o777); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070076 return nil, fmt.Errorf("couldn't create %s: %w", base, err)
77 }
78 }
79
Sean McCullough4854c652025-04-24 18:37:02 -070080 cst := &SSHTheater{
81 cntrName: cntrName,
82 sshHost: sshHost,
83 sshPort: sshPort,
84 knownHostsPath: filepath.Join(base, "known_hosts"),
85 userIdentityPath: filepath.Join(base, "container_user_identity"),
86 serverIdentityPath: filepath.Join(base, "container_server_identity"),
87 sshConfigPath: filepath.Join(base, "ssh_config"),
Sean McCullough2cba6952025-04-25 20:32:10 +000088 fs: fs,
89 kg: kg,
Sean McCullough4854c652025-04-24 18:37:02 -070090 }
Sean McCullough2cba6952025-04-25 20:32:10 +000091 if _, err := cst.createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070092 return nil, fmt.Errorf("couldn't create server identity: %w", err)
93 }
Sean McCullough2cba6952025-04-25 20:32:10 +000094 if _, err := cst.createKeyPairIfMissing(cst.userIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070095 return nil, fmt.Errorf("couldn't create user identity: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -070096 }
97
Sean McCullough2cba6952025-04-25 20:32:10 +000098 serverIdentity, err := fs.ReadFile(cst.serverIdentityPath)
Sean McCullough4854c652025-04-24 18:37:02 -070099 if err != nil {
100 return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
101 }
102 cst.serverIdentity = serverIdentity
103
Sean McCullough2cba6952025-04-25 20:32:10 +0000104 serverPubKeyBytes, err := fs.ReadFile(cst.serverIdentityPath + ".pub")
105 if err != nil {
106 return nil, fmt.Errorf("couldn't read ssh server public key file: %w", err)
107 }
Sean McCullough4854c652025-04-24 18:37:02 -0700108 serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
109 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000110 return nil, fmt.Errorf("couldn't parse ssh server public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700111 }
112 cst.serverPublicKey = serverPubKey
113
Sean McCullough2cba6952025-04-25 20:32:10 +0000114 userIdentity, err := fs.ReadFile(cst.userIdentityPath + ".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700115 if err != nil {
116 return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
117 }
118 cst.userIdentity = userIdentity
119
Sean McCullough7d5a6302025-04-24 21:27:51 -0700120 if err := cst.addContainerToSSHConfig(); err != nil {
121 return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
122 }
123
124 if err := cst.addContainerToKnownHosts(); err != nil {
125 return nil, fmt.Errorf("couldn't update known hosts: %w", err)
126 }
127
Sean McCullough4854c652025-04-24 18:37:02 -0700128 return cst, nil
129}
130
Sean McCullough078e85a2025-05-08 17:28:34 -0700131func checkSSHResolve(hostname string) error {
132 cmd := exec.Command("ssh", "-T", hostname)
133 out, err := cmd.CombinedOutput()
134 if strings.HasPrefix(string(out), "ssh: Could not resolve hostname") {
135 return err
136 }
137 return nil
138}
139
Sean McCullough15c95282025-05-08 16:48:38 -0700140func CheckForIncludeWithFS(fs FileSystem, stdinReader bufio.Reader) error {
Sean McCullough74b01212025-04-29 18:40:53 -0700141 sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".config", "sketch", "ssh_config")
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700142 defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000143
144 // Read the existing SSH config file
145 existingContent, err := fs.ReadFile(defaultSSHPath)
146 if err != nil {
147 // If the file doesn't exist, create a new one with just the include line
148 if os.IsNotExist(err) {
149 return fs.SafeWriteFile(defaultSSHPath, []byte(sketchSSHPathInclude+"\n"), 0o644)
150 }
151 return fmt.Errorf("⚠️ SSH connections are disabled. cannot open SSH config file: %s: %w", defaultSSHPath, err)
Sean McCullough2cba6952025-04-25 20:32:10 +0000152 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000153
154 // Parse the config file
155 cfg, err := ssh_config.Decode(bytes.NewReader(existingContent))
156 if err != nil {
157 return fmt.Errorf("couldn't decode ssh_config: %w", err)
158 }
159
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700160 var sketchInludePos *ssh_config.Position
161 var firstNonIncludePos *ssh_config.Position
162 for _, host := range cfg.Hosts {
163 for _, node := range host.Nodes {
164 inc, ok := node.(*ssh_config.Include)
165 if ok {
166 if strings.TrimSpace(inc.String()) == sketchSSHPathInclude {
167 pos := inc.Pos()
168 sketchInludePos = &pos
169 }
170 } else if firstNonIncludePos == nil && !strings.HasPrefix(strings.TrimSpace(node.String()), "#") {
171 pos := node.Pos()
172 firstNonIncludePos = &pos
173 }
174 }
175 }
176
177 if sketchInludePos == nil {
Sean McCullough15c95282025-05-08 16:48:38 -0700178 fmt.Printf("\nTo enable you to use ssh to connect to local sketch containers: \nAdd %q to the top of %s [y/N]? ", sketchSSHPathInclude, defaultSSHPath)
179 char, _, err := stdinReader.ReadRune()
180 if err != nil {
181 return fmt.Errorf("couldn't read from stdin: %w", err)
182 }
183 if char != 'y' && char != 'Y' {
184 return fmt.Errorf("User declined to edit ssh config file")
185 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000186 // Include line not found, add it to the top of the file
Sean McCullough3b0795b2025-04-29 19:09:23 -0700187 cfgBytes, err := cfg.MarshalText()
188 if err != nil {
189 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
190 }
Sean McCullough3b0795b2025-04-29 19:09:23 -0700191
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000192 // Add the include line to the beginning
193 cfgBytes = append([]byte(sketchSSHPathInclude+"\n"), cfgBytes...)
194
195 // Safely write the updated config back to the file
196 if err := fs.SafeWriteFile(defaultSSHPath, cfgBytes, 0o644); err != nil {
197 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
198 }
Sean McCullough3b0795b2025-04-29 19:09:23 -0700199 return nil
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700200 }
201
202 if firstNonIncludePos != nil && firstNonIncludePos.Line < sketchInludePos.Line {
Sean McCullough2cba6952025-04-25 20:32:10 +0000203 fmt.Printf("⚠️ SSH confg warning: the location of the Include statement for sketch's ssh config on line %d of %s may prevent ssh from working with sketch containers. try moving it to the top of the file (before any 'Host' lines) if ssh isn't working for you.\n", sketchInludePos.Line, defaultSSHPath)
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700204 }
205 return nil
206}
207
Sean McCullough4854c652025-04-24 18:37:02 -0700208func removeFromHosts(cntrName string, cfgHosts []*ssh_config.Host) []*ssh_config.Host {
209 hosts := []*ssh_config.Host{}
210 for _, host := range cfgHosts {
211 if host.Matches(cntrName) || strings.Contains(host.String(), cntrName) {
212 continue
213 }
214 patMatch := false
215 for _, pat := range host.Patterns {
216 if strings.Contains(pat.String(), cntrName) {
217 patMatch = true
218 }
219 }
220 if patMatch {
221 continue
222 }
223
224 hosts = append(hosts, host)
225 }
226 return hosts
227}
228
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000229func encodePrivateKeyToPEM(privateKey ed25519.PrivateKey) []byte {
Sean McCullough4854c652025-04-24 18:37:02 -0700230 pemBlock := &pem.Block{
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000231 Type: "OPENSSH PRIVATE KEY",
232 Bytes: MarshalED25519PrivateKey(privateKey),
Sean McCullough4854c652025-04-24 18:37:02 -0700233 }
234 pemBytes := pem.EncodeToMemory(pemBlock)
235 return pemBytes
236}
237
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000238// MarshalED25519PrivateKey encodes an Ed25519 private key in the OpenSSH private key format
239func MarshalED25519PrivateKey(key ed25519.PrivateKey) []byte {
240 // Marshal the private key using the SSH library
241 pkBytes, err := ssh.MarshalPrivateKey(key, "")
242 if err != nil {
243 panic(fmt.Sprintf("failed to marshal private key: %v", err))
244 }
245 return pem.EncodeToMemory(pkBytes)
246}
247
Sean McCullough2cba6952025-04-25 20:32:10 +0000248func (c *SSHTheater) writeKeyToFile(keyBytes []byte, filename string) error {
249 err := c.fs.WriteFile(filename, keyBytes, 0o600)
Sean McCullough4854c652025-04-24 18:37:02 -0700250 return err
251}
252
Sean McCullough2cba6952025-04-25 20:32:10 +0000253func (c *SSHTheater) createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
254 if _, err := c.fs.Stat(idPath); err == nil {
Sean McCullough4854c652025-04-24 18:37:02 -0700255 return nil, nil
256 }
257
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000258 privateKey, publicKey, err := c.kg.GenerateKeyPair()
Sean McCullough4854c652025-04-24 18:37:02 -0700259 if err != nil {
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000260 return nil, fmt.Errorf("error generating key pair: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700261 }
262
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000263 sshPublicKey, err := c.kg.ConvertToSSHPublicKey(publicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700264 if err != nil {
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000265 return nil, fmt.Errorf("error converting to SSH public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700266 }
267
268 privateKeyPEM := encodePrivateKeyToPEM(privateKey)
269
Sean McCullough2cba6952025-04-25 20:32:10 +0000270 err = c.writeKeyToFile(privateKeyPEM, idPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700271 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000272 return nil, fmt.Errorf("error writing private key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700273 }
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000274 pubKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700275
Sean McCullough2cba6952025-04-25 20:32:10 +0000276 err = c.writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700277 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000278 return nil, fmt.Errorf("error writing public key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700279 }
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000280 return sshPublicKey, nil
Sean McCullough4854c652025-04-24 18:37:02 -0700281}
282
283func (c *SSHTheater) addSketchHostMatchIfMissing(cfg *ssh_config.Config) error {
284 found := false
285 for _, host := range cfg.Hosts {
286 if strings.Contains(host.String(), "host=\"sketch-*\"") {
287 found = true
288 break
289 }
290 }
291 if !found {
292 hostPattern, err := ssh_config.NewPattern("host=\"sketch-*\"")
293 if err != nil {
294 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
295 }
296
297 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
298 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
299
Sean McCullough4854c652025-04-24 18:37:02 -0700300 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
Sean McCullough4854c652025-04-24 18:37:02 -0700301 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
302
303 cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
304 }
305 return nil
306}
307
308func (c *SSHTheater) addContainerToSSHConfig() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000309 // Read the existing file contents or start with an empty config if file doesn't exist
310 var configData []byte
311 var cfg *ssh_config.Config
312 var err error
Sean McCullough4854c652025-04-24 18:37:02 -0700313
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000314 configData, err = c.fs.ReadFile(c.sshConfigPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700315 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000316 // If the file doesn't exist, create an empty config
317 if os.IsNotExist(err) {
318 cfg = &ssh_config.Config{}
319 } else {
320 return fmt.Errorf("couldn't read ssh_config: %w", err)
321 }
322 } else {
323 // Parse the existing config
324 cfg, err = ssh_config.Decode(bytes.NewReader(configData))
325 if err != nil {
326 return fmt.Errorf("couldn't decode ssh_config: %w", err)
327 }
Sean McCullough4854c652025-04-24 18:37:02 -0700328 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000329
Sean McCullough4854c652025-04-24 18:37:02 -0700330 cntrPattern, err := ssh_config.NewPattern(c.cntrName)
331 if err != nil {
332 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
333 }
334
335 // Remove any matches for this container if they already exist.
336 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
337
338 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{cntrPattern}}
339 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "HostName", Value: c.sshHost})
340 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "User", Value: "root"})
341 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "Port", Value: c.sshPort})
342 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
343 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
344
345 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
346 cfg.Hosts = append(cfg.Hosts, hostCfg)
347
348 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
349 return fmt.Errorf("couldn't add missing host match: %w", err)
350 }
351
352 cfgBytes, err := cfg.MarshalText()
353 if err != nil {
354 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
355 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000356
357 // Safely write the updated configuration to file
358 if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
359 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700360 }
361
362 return nil
363}
364
365func (c *SSHTheater) addContainerToKnownHosts() error {
Sean McCullough7d5a6302025-04-24 21:27:51 -0700366 pkBytes := c.serverPublicKey.Marshal()
367 if len(pkBytes) == 0 {
Sean McCullough2cba6952025-04-25 20:32:10 +0000368 return fmt.Errorf("empty serverPublicKey, this is a bug")
Sean McCullough7d5a6302025-04-24 21:27:51 -0700369 }
370 newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700371
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000372 // Read existing known_hosts content or start with empty if the file doesn't exist
Sean McCullough7d5a6302025-04-24 21:27:51 -0700373 outputLines := []string{}
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000374 existingContent, err := c.fs.ReadFile(c.knownHostsPath)
375 if err == nil {
376 scanner := bufio.NewScanner(bytes.NewReader(existingContent))
377 for scanner.Scan() {
378 outputLines = append(outputLines, scanner.Text())
379 }
380 } else if !os.IsNotExist(err) {
381 return fmt.Errorf("couldn't read known_hosts file: %w", err)
Sean McCullough7d5a6302025-04-24 21:27:51 -0700382 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000383
384 // Add the new host line
Sean McCullough7d5a6302025-04-24 21:27:51 -0700385 outputLines = append(outputLines, newHostLine)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000386
387 // Safely write the updated content to the file
388 if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
389 return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700390 }
391
392 return nil
393}
394
395func (c *SSHTheater) removeContainerFromKnownHosts() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000396 // Read the existing known_hosts file
397 existingContent, err := c.fs.ReadFile(c.knownHostsPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700398 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000399 // If the file doesn't exist, there's nothing to do
400 if os.IsNotExist(err) {
401 return nil
402 }
403 return fmt.Errorf("couldn't read known_hosts file: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700404 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000405
406 // Line we want to remove
Sean McCullough7d5a6302025-04-24 21:27:51 -0700407 lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000408
409 // Filter out the line we want to remove
Sean McCullough4854c652025-04-24 18:37:02 -0700410 outputLines := []string{}
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000411 scanner := bufio.NewScanner(bytes.NewReader(existingContent))
Sean McCullough4854c652025-04-24 18:37:02 -0700412 for scanner.Scan() {
413 if scanner.Text() == lineToRemove {
414 continue
415 }
416 outputLines = append(outputLines, scanner.Text())
417 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000418
419 // Safely write the updated content back to the file
420 if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
421 return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700422 }
423
424 return nil
425}
426
427func (c *SSHTheater) Cleanup() error {
428 if err := c.removeContainerFromSSHConfig(); err != nil {
429 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
430 }
431 if err := c.removeContainerFromKnownHosts(); err != nil {
432 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
433 }
434
435 return nil
436}
437
438func (c *SSHTheater) removeContainerFromSSHConfig() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000439 // Read the existing file contents
440 configData, err := c.fs.ReadFile(c.sshConfigPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700441 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000442 return fmt.Errorf("couldn't read ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700443 }
Sean McCullough4854c652025-04-24 18:37:02 -0700444
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000445 cfg, err := ssh_config.Decode(bytes.NewReader(configData))
Sean McCullough4854c652025-04-24 18:37:02 -0700446 if err != nil {
447 return fmt.Errorf("couldn't decode ssh_config: %w", err)
448 }
449 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
450
451 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
452 return fmt.Errorf("couldn't add missing host match: %w", err)
453 }
454
455 cfgBytes, err := cfg.MarshalText()
456 if err != nil {
457 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
458 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000459
460 // Safely write the updated configuration to file
461 if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
462 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700463 }
464 return nil
465}
Sean McCullough2cba6952025-04-25 20:32:10 +0000466
467// FileSystem represents a filesystem interface for testability
468type FileSystem interface {
469 Stat(name string) (fs.FileInfo, error)
470 Mkdir(name string, perm fs.FileMode) error
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700471 MkdirAll(name string, perm fs.FileMode) error
Sean McCullough2cba6952025-04-25 20:32:10 +0000472 ReadFile(name string) ([]byte, error)
473 WriteFile(name string, data []byte, perm fs.FileMode) error
474 OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000475 TempFile(dir, pattern string) (*os.File, error)
476 Rename(oldpath, newpath string) error
477 SafeWriteFile(name string, data []byte, perm fs.FileMode) error
Sean McCullough2cba6952025-04-25 20:32:10 +0000478}
479
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700480func (fs *RealFileSystem) MkdirAll(name string, perm fs.FileMode) error {
481 return os.MkdirAll(name, perm)
482}
483
Sean McCullough2cba6952025-04-25 20:32:10 +0000484// RealFileSystem is the default implementation of FileSystem that uses the OS
485type RealFileSystem struct{}
486
487func (fs *RealFileSystem) Stat(name string) (fs.FileInfo, error) {
488 return os.Stat(name)
489}
490
491func (fs *RealFileSystem) Mkdir(name string, perm fs.FileMode) error {
492 return os.Mkdir(name, perm)
493}
494
495func (fs *RealFileSystem) ReadFile(name string) ([]byte, error) {
496 return os.ReadFile(name)
497}
498
499func (fs *RealFileSystem) WriteFile(name string, data []byte, perm fs.FileMode) error {
500 return os.WriteFile(name, data, perm)
501}
502
503func (fs *RealFileSystem) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
504 return os.OpenFile(name, flag, perm)
505}
506
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000507func (fs *RealFileSystem) TempFile(dir, pattern string) (*os.File, error) {
508 return os.CreateTemp(dir, pattern)
509}
510
511func (fs *RealFileSystem) Rename(oldpath, newpath string) error {
512 return os.Rename(oldpath, newpath)
513}
514
515// SafeWriteFile writes data to a temporary file, syncs to disk, creates a backup of the existing file if it exists,
516// and then renames the temporary file to the target file name.
517func (fs *RealFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
518 // Get the directory from the target filename
519 dir := filepath.Dir(name)
520
521 // Create a temporary file in the same directory
522 tmpFile, err := fs.TempFile(dir, filepath.Base(name)+".*.tmp")
523 if err != nil {
524 return fmt.Errorf("couldn't create temporary file: %w", err)
525 }
526 tmpFilename := tmpFile.Name()
527 defer os.Remove(tmpFilename) // Clean up if we fail
528
529 // Write data to the temporary file
530 if _, err := tmpFile.Write(data); err != nil {
531 tmpFile.Close()
532 return fmt.Errorf("couldn't write to temporary file: %w", err)
533 }
534
535 // Sync to disk to ensure data is written
536 if err := tmpFile.Sync(); err != nil {
537 tmpFile.Close()
538 return fmt.Errorf("couldn't sync temporary file: %w", err)
539 }
540
541 // Close the temporary file
542 if err := tmpFile.Close(); err != nil {
543 return fmt.Errorf("couldn't close temporary file: %w", err)
544 }
545
546 // If the original file exists, create a backup
547 if _, err := fs.Stat(name); err == nil {
548 backupName := name + ".bak"
549 // Remove any existing backup
550 _ = os.Remove(backupName) // Ignore errors if the backup doesn't exist
551
552 // Create the backup
553 if err := fs.Rename(name, backupName); err != nil {
554 return fmt.Errorf("couldn't create backup file: %w", err)
555 }
556 }
557
558 // Rename the temporary file to the target file
559 if err := fs.Rename(tmpFilename, name); err != nil {
560 return fmt.Errorf("couldn't rename temporary file to target: %w", err)
561 }
562
563 // Set permissions on the new file
564 if err := os.Chmod(name, perm); err != nil {
565 return fmt.Errorf("couldn't set permissions on file: %w", err)
566 }
567
568 return nil
569}
570
Sean McCullough2cba6952025-04-25 20:32:10 +0000571// KeyGenerator represents an interface for generating SSH keys for testability
572type KeyGenerator interface {
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000573 GenerateKeyPair() (ed25519.PrivateKey, ed25519.PublicKey, error)
574 ConvertToSSHPublicKey(publicKey ed25519.PublicKey) (ssh.PublicKey, error)
Sean McCullough2cba6952025-04-25 20:32:10 +0000575}
576
577// RealKeyGenerator is the default implementation of KeyGenerator
578type RealKeyGenerator struct{}
579
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000580func (kg *RealKeyGenerator) GenerateKeyPair() (ed25519.PrivateKey, ed25519.PublicKey, error) {
581 publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
582 return privateKey, publicKey, err
Sean McCullough2cba6952025-04-25 20:32:10 +0000583}
584
Sean McCullough3e9d80c2025-05-13 23:35:23 +0000585func (kg *RealKeyGenerator) ConvertToSSHPublicKey(publicKey ed25519.PublicKey) (ssh.PublicKey, error) {
586 return ssh.NewPublicKey(publicKey)
Sean McCullough2cba6952025-04-25 20:32:10 +0000587}
588
Sean McCullough078e85a2025-05-08 17:28:34 -0700589// CheckSSHReachability checks if the user's SSH config includes the Sketch SSH config file
590func CheckSSHReachability(cntrName string) error {
591 if err := checkSSHResolve(cntrName); err != nil {
592 return CheckForIncludeWithFS(&RealFileSystem{}, *bufio.NewReader(os.Stdin))
593 }
594 return nil
Sean McCullough2cba6952025-04-25 20:32:10 +0000595}