blob: 6afcfc9fffd48a565e75ee26368e46a7e53ed10c [file] [log] [blame]
Sean McCullough4854c652025-04-24 18:37:02 -07001package dockerimg
2
3import (
4 "bufio"
Sean McCullough0d95d3a2025-04-30 16:22:28 +00005 "bytes"
Sean McCullough4854c652025-04-24 18:37:02 -07006 "crypto/rand"
7 "crypto/rsa"
8 "crypto/x509"
Sean McCullough4854c652025-04-24 18:37:02 -07009 "encoding/pem"
10 "fmt"
Sean McCullough2cba6952025-04-25 20:32:10 +000011 "io/fs"
Sean McCullough4854c652025-04-24 18:37:02 -070012 "os"
13 "path/filepath"
14 "strings"
15
16 "github.com/kevinburke/ssh_config"
17 "golang.org/x/crypto/ssh"
Sean McCullough7d5a6302025-04-24 21:27:51 -070018 "golang.org/x/crypto/ssh/knownhosts"
Sean McCullough4854c652025-04-24 18:37:02 -070019)
20
21const keyBitSize = 2048
22
23// SSHTheater does the necessary key pair generation, known_hosts updates, ssh_config file updates etc steps
24// so that ssh can connect to a locally running sketch container to other local processes like vscode without
25// the user having to run the usual ssh obstacle course.
26//
27// SSHTheater does not modify your default .ssh/config, or known_hosts files. However, in order for you
28// to be able to use it properly you will have to make a one-time edit to your ~/.ssh/config file.
29//
30// In your ~/.ssh/config file, add the following line:
31//
Sean McCullough74b01212025-04-29 18:40:53 -070032// Include $HOME/.config/sketch/ssh_config
Sean McCullough4854c652025-04-24 18:37:02 -070033//
34// where $HOME is your home directory.
35type SSHTheater struct {
36 cntrName string
37 sshHost string
38 sshPort string
39
40 knownHostsPath string
41 userIdentityPath string
42 sshConfigPath string
43 serverIdentityPath string
44
45 serverPublicKey ssh.PublicKey
46 serverIdentity []byte
47 userIdentity []byte
Sean McCullough2cba6952025-04-25 20:32:10 +000048
49 fs FileSystem
50 kg KeyGenerator
Sean McCullough4854c652025-04-24 18:37:02 -070051}
52
Josh Bleecher Snyder50608b12025-05-03 22:55:49 +000053// NewSSHTheater will set up everything so that you can use ssh on localhost to connect to
Sean McCullough4854c652025-04-24 18:37:02 -070054// the sketch container. Call #Clean when you are done with the container to remove the
55// various entries it created in its known_hosts and ssh_config files. Also note that
56// this will generate key pairs for both the ssh server identity and the user identity, if
57// these files do not already exist. These key pair files are not deleted by #Cleanup,
58// so they can be re-used across invocations of sketch. This means every sketch container
59// that runs on this host will use the same ssh server identity.
60//
61// If this doesn't return an error, you should be able to run "ssh <cntrName>"
62// in a terminal on your host machine to open a shell into the container without having
63// to manually accept changes to your known_hosts file etc.
Josh Bleecher Snyder50608b12025-05-03 22:55:49 +000064func NewSSHTheater(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
Sean McCullough2cba6952025-04-25 20:32:10 +000065 return newSSHTheatherWithDeps(cntrName, sshHost, sshPort, &RealFileSystem{}, &RealKeyGenerator{})
66}
67
68// newSSHTheatherWithDeps creates a new SSHTheater with the specified dependencies
69func newSSHTheatherWithDeps(cntrName, sshHost, sshPort string, fs FileSystem, kg KeyGenerator) (*SSHTheater, error) {
Sean McCullough74b01212025-04-29 18:40:53 -070070 base := filepath.Join(os.Getenv("HOME"), ".config", "sketch")
Sean McCullough2cba6952025-04-25 20:32:10 +000071 if _, err := fs.Stat(base); err != nil {
Sean McCulloughc796e7f2025-04-30 08:44:06 -070072 if err := fs.MkdirAll(base, 0o777); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070073 return nil, fmt.Errorf("couldn't create %s: %w", base, err)
74 }
75 }
76
Sean McCullough4854c652025-04-24 18:37:02 -070077 cst := &SSHTheater{
78 cntrName: cntrName,
79 sshHost: sshHost,
80 sshPort: sshPort,
81 knownHostsPath: filepath.Join(base, "known_hosts"),
82 userIdentityPath: filepath.Join(base, "container_user_identity"),
83 serverIdentityPath: filepath.Join(base, "container_server_identity"),
84 sshConfigPath: filepath.Join(base, "ssh_config"),
Sean McCullough2cba6952025-04-25 20:32:10 +000085 fs: fs,
86 kg: kg,
Sean McCullough4854c652025-04-24 18:37:02 -070087 }
Sean McCullough2cba6952025-04-25 20:32:10 +000088 if _, err := cst.createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070089 return nil, fmt.Errorf("couldn't create server identity: %w", err)
90 }
Sean McCullough2cba6952025-04-25 20:32:10 +000091 if _, err := cst.createKeyPairIfMissing(cst.userIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070092 return nil, fmt.Errorf("couldn't create user identity: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -070093 }
94
Sean McCullough2cba6952025-04-25 20:32:10 +000095 serverIdentity, err := fs.ReadFile(cst.serverIdentityPath)
Sean McCullough4854c652025-04-24 18:37:02 -070096 if err != nil {
97 return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
98 }
99 cst.serverIdentity = serverIdentity
100
Sean McCullough2cba6952025-04-25 20:32:10 +0000101 serverPubKeyBytes, err := fs.ReadFile(cst.serverIdentityPath + ".pub")
102 if err != nil {
103 return nil, fmt.Errorf("couldn't read ssh server public key file: %w", err)
104 }
Sean McCullough4854c652025-04-24 18:37:02 -0700105 serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
106 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000107 return nil, fmt.Errorf("couldn't parse ssh server public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700108 }
109 cst.serverPublicKey = serverPubKey
110
Sean McCullough2cba6952025-04-25 20:32:10 +0000111 userIdentity, err := fs.ReadFile(cst.userIdentityPath + ".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700112 if err != nil {
113 return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
114 }
115 cst.userIdentity = userIdentity
116
Sean McCullough7d5a6302025-04-24 21:27:51 -0700117 if err := cst.addContainerToSSHConfig(); err != nil {
118 return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
119 }
120
121 if err := cst.addContainerToKnownHosts(); err != nil {
122 return nil, fmt.Errorf("couldn't update known hosts: %w", err)
123 }
124
Sean McCullough4854c652025-04-24 18:37:02 -0700125 return cst, nil
126}
127
Sean McCullough15c95282025-05-08 16:48:38 -0700128func CheckForIncludeWithFS(fs FileSystem, stdinReader bufio.Reader) error {
Sean McCullough74b01212025-04-29 18:40:53 -0700129 sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".config", "sketch", "ssh_config")
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700130 defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000131
132 // Read the existing SSH config file
133 existingContent, err := fs.ReadFile(defaultSSHPath)
134 if err != nil {
135 // If the file doesn't exist, create a new one with just the include line
136 if os.IsNotExist(err) {
137 return fs.SafeWriteFile(defaultSSHPath, []byte(sketchSSHPathInclude+"\n"), 0o644)
138 }
139 return fmt.Errorf("⚠️ SSH connections are disabled. cannot open SSH config file: %s: %w", defaultSSHPath, err)
Sean McCullough2cba6952025-04-25 20:32:10 +0000140 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000141
142 // Parse the config file
143 cfg, err := ssh_config.Decode(bytes.NewReader(existingContent))
144 if err != nil {
145 return fmt.Errorf("couldn't decode ssh_config: %w", err)
146 }
147
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700148 var sketchInludePos *ssh_config.Position
149 var firstNonIncludePos *ssh_config.Position
150 for _, host := range cfg.Hosts {
151 for _, node := range host.Nodes {
152 inc, ok := node.(*ssh_config.Include)
153 if ok {
154 if strings.TrimSpace(inc.String()) == sketchSSHPathInclude {
155 pos := inc.Pos()
156 sketchInludePos = &pos
157 }
158 } else if firstNonIncludePos == nil && !strings.HasPrefix(strings.TrimSpace(node.String()), "#") {
159 pos := node.Pos()
160 firstNonIncludePos = &pos
161 }
162 }
163 }
164
165 if sketchInludePos == nil {
Sean McCullough15c95282025-05-08 16:48:38 -0700166 fmt.Printf("\nTo enable you to use ssh to connect to local sketch containers: \nAdd %q to the top of %s [y/N]? ", sketchSSHPathInclude, defaultSSHPath)
167 char, _, err := stdinReader.ReadRune()
168 if err != nil {
169 return fmt.Errorf("couldn't read from stdin: %w", err)
170 }
171 if char != 'y' && char != 'Y' {
172 return fmt.Errorf("User declined to edit ssh config file")
173 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000174 // Include line not found, add it to the top of the file
Sean McCullough3b0795b2025-04-29 19:09:23 -0700175 cfgBytes, err := cfg.MarshalText()
176 if err != nil {
177 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
178 }
Sean McCullough3b0795b2025-04-29 19:09:23 -0700179
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000180 // Add the include line to the beginning
181 cfgBytes = append([]byte(sketchSSHPathInclude+"\n"), cfgBytes...)
182
183 // Safely write the updated config back to the file
184 if err := fs.SafeWriteFile(defaultSSHPath, cfgBytes, 0o644); err != nil {
185 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
186 }
Sean McCullough3b0795b2025-04-29 19:09:23 -0700187 return nil
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700188 }
189
190 if firstNonIncludePos != nil && firstNonIncludePos.Line < sketchInludePos.Line {
Sean McCullough2cba6952025-04-25 20:32:10 +0000191 fmt.Printf("⚠️ SSH confg warning: the location of the Include statement for sketch's ssh config on line %d of %s may prevent ssh from working with sketch containers. try moving it to the top of the file (before any 'Host' lines) if ssh isn't working for you.\n", sketchInludePos.Line, defaultSSHPath)
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700192 }
193 return nil
194}
195
Sean McCullough4854c652025-04-24 18:37:02 -0700196func removeFromHosts(cntrName string, cfgHosts []*ssh_config.Host) []*ssh_config.Host {
197 hosts := []*ssh_config.Host{}
198 for _, host := range cfgHosts {
199 if host.Matches(cntrName) || strings.Contains(host.String(), cntrName) {
200 continue
201 }
202 patMatch := false
203 for _, pat := range host.Patterns {
204 if strings.Contains(pat.String(), cntrName) {
205 patMatch = true
206 }
207 }
208 if patMatch {
209 continue
210 }
211
212 hosts = append(hosts, host)
213 }
214 return hosts
215}
216
Sean McCullough4854c652025-04-24 18:37:02 -0700217func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
218 pemBlock := &pem.Block{
219 Type: "RSA PRIVATE KEY",
220 Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
221 }
222 pemBytes := pem.EncodeToMemory(pemBlock)
223 return pemBytes
224}
225
Sean McCullough2cba6952025-04-25 20:32:10 +0000226func (c *SSHTheater) writeKeyToFile(keyBytes []byte, filename string) error {
227 err := c.fs.WriteFile(filename, keyBytes, 0o600)
Sean McCullough4854c652025-04-24 18:37:02 -0700228 return err
229}
230
Sean McCullough2cba6952025-04-25 20:32:10 +0000231func (c *SSHTheater) createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
232 if _, err := c.fs.Stat(idPath); err == nil {
Sean McCullough4854c652025-04-24 18:37:02 -0700233 return nil, nil
234 }
235
Sean McCullough2cba6952025-04-25 20:32:10 +0000236 privateKey, err := c.kg.GeneratePrivateKey(keyBitSize)
Sean McCullough4854c652025-04-24 18:37:02 -0700237 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000238 return nil, fmt.Errorf("error generating private key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700239 }
240
Sean McCullough2cba6952025-04-25 20:32:10 +0000241 publicRsaKey, err := c.kg.GeneratePublicKey(&privateKey.PublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700242 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000243 return nil, fmt.Errorf("error generating public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700244 }
245
246 privateKeyPEM := encodePrivateKeyToPEM(privateKey)
247
Sean McCullough2cba6952025-04-25 20:32:10 +0000248 err = c.writeKeyToFile(privateKeyPEM, idPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700249 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000250 return nil, fmt.Errorf("error writing private key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700251 }
252 pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
253
Sean McCullough2cba6952025-04-25 20:32:10 +0000254 err = c.writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700255 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000256 return nil, fmt.Errorf("error writing public key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700257 }
258 return publicRsaKey, nil
259}
260
261func (c *SSHTheater) addSketchHostMatchIfMissing(cfg *ssh_config.Config) error {
262 found := false
263 for _, host := range cfg.Hosts {
264 if strings.Contains(host.String(), "host=\"sketch-*\"") {
265 found = true
266 break
267 }
268 }
269 if !found {
270 hostPattern, err := ssh_config.NewPattern("host=\"sketch-*\"")
271 if err != nil {
272 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
273 }
274
275 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
276 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
277
Sean McCullough4854c652025-04-24 18:37:02 -0700278 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
Sean McCullough4854c652025-04-24 18:37:02 -0700279 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
280
281 cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
282 }
283 return nil
284}
285
286func (c *SSHTheater) addContainerToSSHConfig() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000287 // Read the existing file contents or start with an empty config if file doesn't exist
288 var configData []byte
289 var cfg *ssh_config.Config
290 var err error
Sean McCullough4854c652025-04-24 18:37:02 -0700291
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000292 configData, err = c.fs.ReadFile(c.sshConfigPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700293 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000294 // If the file doesn't exist, create an empty config
295 if os.IsNotExist(err) {
296 cfg = &ssh_config.Config{}
297 } else {
298 return fmt.Errorf("couldn't read ssh_config: %w", err)
299 }
300 } else {
301 // Parse the existing config
302 cfg, err = ssh_config.Decode(bytes.NewReader(configData))
303 if err != nil {
304 return fmt.Errorf("couldn't decode ssh_config: %w", err)
305 }
Sean McCullough4854c652025-04-24 18:37:02 -0700306 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000307
Sean McCullough4854c652025-04-24 18:37:02 -0700308 cntrPattern, err := ssh_config.NewPattern(c.cntrName)
309 if err != nil {
310 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
311 }
312
313 // Remove any matches for this container if they already exist.
314 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
315
316 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{cntrPattern}}
317 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "HostName", Value: c.sshHost})
318 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "User", Value: "root"})
319 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "Port", Value: c.sshPort})
320 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
321 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
322
323 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
324 cfg.Hosts = append(cfg.Hosts, hostCfg)
325
326 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
327 return fmt.Errorf("couldn't add missing host match: %w", err)
328 }
329
330 cfgBytes, err := cfg.MarshalText()
331 if err != nil {
332 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
333 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000334
335 // Safely write the updated configuration to file
336 if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
337 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700338 }
339
340 return nil
341}
342
343func (c *SSHTheater) addContainerToKnownHosts() error {
Sean McCullough7d5a6302025-04-24 21:27:51 -0700344 pkBytes := c.serverPublicKey.Marshal()
345 if len(pkBytes) == 0 {
Sean McCullough2cba6952025-04-25 20:32:10 +0000346 return fmt.Errorf("empty serverPublicKey, this is a bug")
Sean McCullough7d5a6302025-04-24 21:27:51 -0700347 }
348 newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700349
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000350 // Read existing known_hosts content or start with empty if the file doesn't exist
Sean McCullough7d5a6302025-04-24 21:27:51 -0700351 outputLines := []string{}
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000352 existingContent, err := c.fs.ReadFile(c.knownHostsPath)
353 if err == nil {
354 scanner := bufio.NewScanner(bytes.NewReader(existingContent))
355 for scanner.Scan() {
356 outputLines = append(outputLines, scanner.Text())
357 }
358 } else if !os.IsNotExist(err) {
359 return fmt.Errorf("couldn't read known_hosts file: %w", err)
Sean McCullough7d5a6302025-04-24 21:27:51 -0700360 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000361
362 // Add the new host line
Sean McCullough7d5a6302025-04-24 21:27:51 -0700363 outputLines = append(outputLines, newHostLine)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000364
365 // Safely write the updated content to the file
366 if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
367 return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700368 }
369
370 return nil
371}
372
373func (c *SSHTheater) removeContainerFromKnownHosts() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000374 // Read the existing known_hosts file
375 existingContent, err := c.fs.ReadFile(c.knownHostsPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700376 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000377 // If the file doesn't exist, there's nothing to do
378 if os.IsNotExist(err) {
379 return nil
380 }
381 return fmt.Errorf("couldn't read known_hosts file: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700382 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000383
384 // Line we want to remove
Sean McCullough7d5a6302025-04-24 21:27:51 -0700385 lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000386
387 // Filter out the line we want to remove
Sean McCullough4854c652025-04-24 18:37:02 -0700388 outputLines := []string{}
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000389 scanner := bufio.NewScanner(bytes.NewReader(existingContent))
Sean McCullough4854c652025-04-24 18:37:02 -0700390 for scanner.Scan() {
391 if scanner.Text() == lineToRemove {
392 continue
393 }
394 outputLines = append(outputLines, scanner.Text())
395 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000396
397 // Safely write the updated content back to the file
398 if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
399 return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700400 }
401
402 return nil
403}
404
405func (c *SSHTheater) Cleanup() error {
406 if err := c.removeContainerFromSSHConfig(); err != nil {
407 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
408 }
409 if err := c.removeContainerFromKnownHosts(); err != nil {
410 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
411 }
412
413 return nil
414}
415
416func (c *SSHTheater) removeContainerFromSSHConfig() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000417 // Read the existing file contents
418 configData, err := c.fs.ReadFile(c.sshConfigPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700419 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000420 return fmt.Errorf("couldn't read ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700421 }
Sean McCullough4854c652025-04-24 18:37:02 -0700422
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000423 cfg, err := ssh_config.Decode(bytes.NewReader(configData))
Sean McCullough4854c652025-04-24 18:37:02 -0700424 if err != nil {
425 return fmt.Errorf("couldn't decode ssh_config: %w", err)
426 }
427 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
428
429 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
430 return fmt.Errorf("couldn't add missing host match: %w", err)
431 }
432
433 cfgBytes, err := cfg.MarshalText()
434 if err != nil {
435 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
436 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000437
438 // Safely write the updated configuration to file
439 if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
440 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700441 }
442 return nil
443}
Sean McCullough2cba6952025-04-25 20:32:10 +0000444
445// FileSystem represents a filesystem interface for testability
446type FileSystem interface {
447 Stat(name string) (fs.FileInfo, error)
448 Mkdir(name string, perm fs.FileMode) error
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700449 MkdirAll(name string, perm fs.FileMode) error
Sean McCullough2cba6952025-04-25 20:32:10 +0000450 ReadFile(name string) ([]byte, error)
451 WriteFile(name string, data []byte, perm fs.FileMode) error
452 OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000453 TempFile(dir, pattern string) (*os.File, error)
454 Rename(oldpath, newpath string) error
455 SafeWriteFile(name string, data []byte, perm fs.FileMode) error
Sean McCullough2cba6952025-04-25 20:32:10 +0000456}
457
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700458func (fs *RealFileSystem) MkdirAll(name string, perm fs.FileMode) error {
459 return os.MkdirAll(name, perm)
460}
461
Sean McCullough2cba6952025-04-25 20:32:10 +0000462// RealFileSystem is the default implementation of FileSystem that uses the OS
463type RealFileSystem struct{}
464
465func (fs *RealFileSystem) Stat(name string) (fs.FileInfo, error) {
466 return os.Stat(name)
467}
468
469func (fs *RealFileSystem) Mkdir(name string, perm fs.FileMode) error {
470 return os.Mkdir(name, perm)
471}
472
473func (fs *RealFileSystem) ReadFile(name string) ([]byte, error) {
474 return os.ReadFile(name)
475}
476
477func (fs *RealFileSystem) WriteFile(name string, data []byte, perm fs.FileMode) error {
478 return os.WriteFile(name, data, perm)
479}
480
481func (fs *RealFileSystem) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
482 return os.OpenFile(name, flag, perm)
483}
484
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000485func (fs *RealFileSystem) TempFile(dir, pattern string) (*os.File, error) {
486 return os.CreateTemp(dir, pattern)
487}
488
489func (fs *RealFileSystem) Rename(oldpath, newpath string) error {
490 return os.Rename(oldpath, newpath)
491}
492
493// SafeWriteFile writes data to a temporary file, syncs to disk, creates a backup of the existing file if it exists,
494// and then renames the temporary file to the target file name.
495func (fs *RealFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
496 // Get the directory from the target filename
497 dir := filepath.Dir(name)
498
499 // Create a temporary file in the same directory
500 tmpFile, err := fs.TempFile(dir, filepath.Base(name)+".*.tmp")
501 if err != nil {
502 return fmt.Errorf("couldn't create temporary file: %w", err)
503 }
504 tmpFilename := tmpFile.Name()
505 defer os.Remove(tmpFilename) // Clean up if we fail
506
507 // Write data to the temporary file
508 if _, err := tmpFile.Write(data); err != nil {
509 tmpFile.Close()
510 return fmt.Errorf("couldn't write to temporary file: %w", err)
511 }
512
513 // Sync to disk to ensure data is written
514 if err := tmpFile.Sync(); err != nil {
515 tmpFile.Close()
516 return fmt.Errorf("couldn't sync temporary file: %w", err)
517 }
518
519 // Close the temporary file
520 if err := tmpFile.Close(); err != nil {
521 return fmt.Errorf("couldn't close temporary file: %w", err)
522 }
523
524 // If the original file exists, create a backup
525 if _, err := fs.Stat(name); err == nil {
526 backupName := name + ".bak"
527 // Remove any existing backup
528 _ = os.Remove(backupName) // Ignore errors if the backup doesn't exist
529
530 // Create the backup
531 if err := fs.Rename(name, backupName); err != nil {
532 return fmt.Errorf("couldn't create backup file: %w", err)
533 }
534 }
535
536 // Rename the temporary file to the target file
537 if err := fs.Rename(tmpFilename, name); err != nil {
538 return fmt.Errorf("couldn't rename temporary file to target: %w", err)
539 }
540
541 // Set permissions on the new file
542 if err := os.Chmod(name, perm); err != nil {
543 return fmt.Errorf("couldn't set permissions on file: %w", err)
544 }
545
546 return nil
547}
548
Sean McCullough2cba6952025-04-25 20:32:10 +0000549// KeyGenerator represents an interface for generating SSH keys for testability
550type KeyGenerator interface {
551 GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error)
552 GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error)
553}
554
555// RealKeyGenerator is the default implementation of KeyGenerator
556type RealKeyGenerator struct{}
557
558func (kg *RealKeyGenerator) GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
559 return rsa.GenerateKey(rand.Reader, bitSize)
560}
561
562func (kg *RealKeyGenerator) GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error) {
563 return ssh.NewPublicKey(privateKey)
564}
565
566// CheckForInclude checks if the user's SSH config includes the Sketch SSH config file
567func CheckForInclude() error {
Sean McCullough15c95282025-05-08 16:48:38 -0700568 return CheckForIncludeWithFS(&RealFileSystem{}, *bufio.NewReader(os.Stdin))
Sean McCullough2cba6952025-04-25 20:32:10 +0000569}