blob: c006487788f10f2d97cc80c3f99f44ca2c8ae39e [file] [log] [blame]
Sean McCullough4854c652025-04-24 18:37:02 -07001package dockerimg
2
3import (
4 "bufio"
Sean McCullough0d95d3a2025-04-30 16:22:28 +00005 "bytes"
Sean McCullough4854c652025-04-24 18:37:02 -07006 "crypto/rand"
7 "crypto/rsa"
8 "crypto/x509"
Sean McCullough4854c652025-04-24 18:37:02 -07009 "encoding/pem"
10 "fmt"
Sean McCullough2cba6952025-04-25 20:32:10 +000011 "io/fs"
Sean McCullough4854c652025-04-24 18:37:02 -070012 "os"
13 "path/filepath"
14 "strings"
15
16 "github.com/kevinburke/ssh_config"
17 "golang.org/x/crypto/ssh"
Sean McCullough7d5a6302025-04-24 21:27:51 -070018 "golang.org/x/crypto/ssh/knownhosts"
Sean McCullough4854c652025-04-24 18:37:02 -070019)
20
21const keyBitSize = 2048
22
23// SSHTheater does the necessary key pair generation, known_hosts updates, ssh_config file updates etc steps
24// so that ssh can connect to a locally running sketch container to other local processes like vscode without
25// the user having to run the usual ssh obstacle course.
26//
27// SSHTheater does not modify your default .ssh/config, or known_hosts files. However, in order for you
28// to be able to use it properly you will have to make a one-time edit to your ~/.ssh/config file.
29//
30// In your ~/.ssh/config file, add the following line:
31//
Sean McCullough74b01212025-04-29 18:40:53 -070032// Include $HOME/.config/sketch/ssh_config
Sean McCullough4854c652025-04-24 18:37:02 -070033//
34// where $HOME is your home directory.
35type SSHTheater struct {
36 cntrName string
37 sshHost string
38 sshPort string
39
40 knownHostsPath string
41 userIdentityPath string
42 sshConfigPath string
43 serverIdentityPath string
44
45 serverPublicKey ssh.PublicKey
46 serverIdentity []byte
47 userIdentity []byte
Sean McCullough2cba6952025-04-25 20:32:10 +000048
49 fs FileSystem
50 kg KeyGenerator
Sean McCullough4854c652025-04-24 18:37:02 -070051}
52
53// NewSSHTheather will set up everything so that you can use ssh on localhost to connect to
54// the sketch container. Call #Clean when you are done with the container to remove the
55// various entries it created in its known_hosts and ssh_config files. Also note that
56// this will generate key pairs for both the ssh server identity and the user identity, if
57// these files do not already exist. These key pair files are not deleted by #Cleanup,
58// so they can be re-used across invocations of sketch. This means every sketch container
59// that runs on this host will use the same ssh server identity.
60//
61// If this doesn't return an error, you should be able to run "ssh <cntrName>"
62// in a terminal on your host machine to open a shell into the container without having
63// to manually accept changes to your known_hosts file etc.
64func NewSSHTheather(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
Sean McCullough2cba6952025-04-25 20:32:10 +000065 return newSSHTheatherWithDeps(cntrName, sshHost, sshPort, &RealFileSystem{}, &RealKeyGenerator{})
66}
67
68// newSSHTheatherWithDeps creates a new SSHTheater with the specified dependencies
69func newSSHTheatherWithDeps(cntrName, sshHost, sshPort string, fs FileSystem, kg KeyGenerator) (*SSHTheater, error) {
Sean McCullough74b01212025-04-29 18:40:53 -070070 base := filepath.Join(os.Getenv("HOME"), ".config", "sketch")
Sean McCullough2cba6952025-04-25 20:32:10 +000071 if _, err := fs.Stat(base); err != nil {
Sean McCulloughc796e7f2025-04-30 08:44:06 -070072 if err := fs.MkdirAll(base, 0o777); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070073 return nil, fmt.Errorf("couldn't create %s: %w", base, err)
74 }
75 }
76
Sean McCullough4854c652025-04-24 18:37:02 -070077 cst := &SSHTheater{
78 cntrName: cntrName,
79 sshHost: sshHost,
80 sshPort: sshPort,
81 knownHostsPath: filepath.Join(base, "known_hosts"),
82 userIdentityPath: filepath.Join(base, "container_user_identity"),
83 serverIdentityPath: filepath.Join(base, "container_server_identity"),
84 sshConfigPath: filepath.Join(base, "ssh_config"),
Sean McCullough2cba6952025-04-25 20:32:10 +000085 fs: fs,
86 kg: kg,
Sean McCullough4854c652025-04-24 18:37:02 -070087 }
Sean McCullough2cba6952025-04-25 20:32:10 +000088 if _, err := cst.createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070089 return nil, fmt.Errorf("couldn't create server identity: %w", err)
90 }
Sean McCullough2cba6952025-04-25 20:32:10 +000091 if _, err := cst.createKeyPairIfMissing(cst.userIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070092 return nil, fmt.Errorf("couldn't create user identity: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -070093 }
94
Sean McCullough2cba6952025-04-25 20:32:10 +000095 serverIdentity, err := fs.ReadFile(cst.serverIdentityPath)
Sean McCullough4854c652025-04-24 18:37:02 -070096 if err != nil {
97 return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
98 }
99 cst.serverIdentity = serverIdentity
100
Sean McCullough2cba6952025-04-25 20:32:10 +0000101 serverPubKeyBytes, err := fs.ReadFile(cst.serverIdentityPath + ".pub")
102 if err != nil {
103 return nil, fmt.Errorf("couldn't read ssh server public key file: %w", err)
104 }
Sean McCullough4854c652025-04-24 18:37:02 -0700105 serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
106 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000107 return nil, fmt.Errorf("couldn't parse ssh server public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700108 }
109 cst.serverPublicKey = serverPubKey
110
Sean McCullough2cba6952025-04-25 20:32:10 +0000111 userIdentity, err := fs.ReadFile(cst.userIdentityPath + ".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700112 if err != nil {
113 return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
114 }
115 cst.userIdentity = userIdentity
116
Sean McCullough7d5a6302025-04-24 21:27:51 -0700117 if err := cst.addContainerToSSHConfig(); err != nil {
118 return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
119 }
120
121 if err := cst.addContainerToKnownHosts(); err != nil {
122 return nil, fmt.Errorf("couldn't update known hosts: %w", err)
123 }
124
Sean McCullough4854c652025-04-24 18:37:02 -0700125 return cst, nil
126}
127
Sean McCullough2cba6952025-04-25 20:32:10 +0000128func CheckForIncludeWithFS(fs FileSystem) error {
Sean McCullough74b01212025-04-29 18:40:53 -0700129 sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".config", "sketch", "ssh_config")
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700130 defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000131
132 // Read the existing SSH config file
133 existingContent, err := fs.ReadFile(defaultSSHPath)
134 if err != nil {
135 // If the file doesn't exist, create a new one with just the include line
136 if os.IsNotExist(err) {
137 return fs.SafeWriteFile(defaultSSHPath, []byte(sketchSSHPathInclude+"\n"), 0o644)
138 }
139 return fmt.Errorf("⚠️ SSH connections are disabled. cannot open SSH config file: %s: %w", defaultSSHPath, err)
Sean McCullough2cba6952025-04-25 20:32:10 +0000140 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000141
142 // Parse the config file
143 cfg, err := ssh_config.Decode(bytes.NewReader(existingContent))
144 if err != nil {
145 return fmt.Errorf("couldn't decode ssh_config: %w", err)
146 }
147
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700148 var sketchInludePos *ssh_config.Position
149 var firstNonIncludePos *ssh_config.Position
150 for _, host := range cfg.Hosts {
151 for _, node := range host.Nodes {
152 inc, ok := node.(*ssh_config.Include)
153 if ok {
154 if strings.TrimSpace(inc.String()) == sketchSSHPathInclude {
155 pos := inc.Pos()
156 sketchInludePos = &pos
157 }
158 } else if firstNonIncludePos == nil && !strings.HasPrefix(strings.TrimSpace(node.String()), "#") {
159 pos := node.Pos()
160 firstNonIncludePos = &pos
161 }
162 }
163 }
164
165 if sketchInludePos == nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000166 // Include line not found, add it to the top of the file
Sean McCullough3b0795b2025-04-29 19:09:23 -0700167 cfgBytes, err := cfg.MarshalText()
168 if err != nil {
169 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
170 }
Sean McCullough3b0795b2025-04-29 19:09:23 -0700171
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000172 // Add the include line to the beginning
173 cfgBytes = append([]byte(sketchSSHPathInclude+"\n"), cfgBytes...)
174
175 // Safely write the updated config back to the file
176 if err := fs.SafeWriteFile(defaultSSHPath, cfgBytes, 0o644); err != nil {
177 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
178 }
Sean McCullough3b0795b2025-04-29 19:09:23 -0700179 return nil
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700180 }
181
182 if firstNonIncludePos != nil && firstNonIncludePos.Line < sketchInludePos.Line {
Sean McCullough2cba6952025-04-25 20:32:10 +0000183 fmt.Printf("⚠️ SSH confg warning: the location of the Include statement for sketch's ssh config on line %d of %s may prevent ssh from working with sketch containers. try moving it to the top of the file (before any 'Host' lines) if ssh isn't working for you.\n", sketchInludePos.Line, defaultSSHPath)
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700184 }
185 return nil
186}
187
Sean McCullough4854c652025-04-24 18:37:02 -0700188func removeFromHosts(cntrName string, cfgHosts []*ssh_config.Host) []*ssh_config.Host {
189 hosts := []*ssh_config.Host{}
190 for _, host := range cfgHosts {
191 if host.Matches(cntrName) || strings.Contains(host.String(), cntrName) {
192 continue
193 }
194 patMatch := false
195 for _, pat := range host.Patterns {
196 if strings.Contains(pat.String(), cntrName) {
197 patMatch = true
198 }
199 }
200 if patMatch {
201 continue
202 }
203
204 hosts = append(hosts, host)
205 }
206 return hosts
207}
208
Sean McCullough4854c652025-04-24 18:37:02 -0700209func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
210 pemBlock := &pem.Block{
211 Type: "RSA PRIVATE KEY",
212 Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
213 }
214 pemBytes := pem.EncodeToMemory(pemBlock)
215 return pemBytes
216}
217
Sean McCullough2cba6952025-04-25 20:32:10 +0000218func (c *SSHTheater) writeKeyToFile(keyBytes []byte, filename string) error {
219 err := c.fs.WriteFile(filename, keyBytes, 0o600)
Sean McCullough4854c652025-04-24 18:37:02 -0700220 return err
221}
222
Sean McCullough2cba6952025-04-25 20:32:10 +0000223func (c *SSHTheater) createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
224 if _, err := c.fs.Stat(idPath); err == nil {
Sean McCullough4854c652025-04-24 18:37:02 -0700225 return nil, nil
226 }
227
Sean McCullough2cba6952025-04-25 20:32:10 +0000228 privateKey, err := c.kg.GeneratePrivateKey(keyBitSize)
Sean McCullough4854c652025-04-24 18:37:02 -0700229 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000230 return nil, fmt.Errorf("error generating private key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700231 }
232
Sean McCullough2cba6952025-04-25 20:32:10 +0000233 publicRsaKey, err := c.kg.GeneratePublicKey(&privateKey.PublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700234 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000235 return nil, fmt.Errorf("error generating public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700236 }
237
238 privateKeyPEM := encodePrivateKeyToPEM(privateKey)
239
Sean McCullough2cba6952025-04-25 20:32:10 +0000240 err = c.writeKeyToFile(privateKeyPEM, idPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700241 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000242 return nil, fmt.Errorf("error writing private key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700243 }
244 pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
245
Sean McCullough2cba6952025-04-25 20:32:10 +0000246 err = c.writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700247 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000248 return nil, fmt.Errorf("error writing public key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700249 }
250 return publicRsaKey, nil
251}
252
253func (c *SSHTheater) addSketchHostMatchIfMissing(cfg *ssh_config.Config) error {
254 found := false
255 for _, host := range cfg.Hosts {
256 if strings.Contains(host.String(), "host=\"sketch-*\"") {
257 found = true
258 break
259 }
260 }
261 if !found {
262 hostPattern, err := ssh_config.NewPattern("host=\"sketch-*\"")
263 if err != nil {
264 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
265 }
266
267 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
268 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
269
Sean McCullough4854c652025-04-24 18:37:02 -0700270 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
Sean McCullough4854c652025-04-24 18:37:02 -0700271 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
272
273 cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
274 }
275 return nil
276}
277
278func (c *SSHTheater) addContainerToSSHConfig() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000279 // Read the existing file contents or start with an empty config if file doesn't exist
280 var configData []byte
281 var cfg *ssh_config.Config
282 var err error
Sean McCullough4854c652025-04-24 18:37:02 -0700283
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000284 configData, err = c.fs.ReadFile(c.sshConfigPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700285 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000286 // If the file doesn't exist, create an empty config
287 if os.IsNotExist(err) {
288 cfg = &ssh_config.Config{}
289 } else {
290 return fmt.Errorf("couldn't read ssh_config: %w", err)
291 }
292 } else {
293 // Parse the existing config
294 cfg, err = ssh_config.Decode(bytes.NewReader(configData))
295 if err != nil {
296 return fmt.Errorf("couldn't decode ssh_config: %w", err)
297 }
Sean McCullough4854c652025-04-24 18:37:02 -0700298 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000299
Sean McCullough4854c652025-04-24 18:37:02 -0700300 cntrPattern, err := ssh_config.NewPattern(c.cntrName)
301 if err != nil {
302 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
303 }
304
305 // Remove any matches for this container if they already exist.
306 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
307
308 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{cntrPattern}}
309 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "HostName", Value: c.sshHost})
310 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "User", Value: "root"})
311 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "Port", Value: c.sshPort})
312 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
313 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
314
315 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
316 cfg.Hosts = append(cfg.Hosts, hostCfg)
317
318 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
319 return fmt.Errorf("couldn't add missing host match: %w", err)
320 }
321
322 cfgBytes, err := cfg.MarshalText()
323 if err != nil {
324 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
325 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000326
327 // Safely write the updated configuration to file
328 if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
329 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700330 }
331
332 return nil
333}
334
335func (c *SSHTheater) addContainerToKnownHosts() error {
Sean McCullough7d5a6302025-04-24 21:27:51 -0700336 pkBytes := c.serverPublicKey.Marshal()
337 if len(pkBytes) == 0 {
Sean McCullough2cba6952025-04-25 20:32:10 +0000338 return fmt.Errorf("empty serverPublicKey, this is a bug")
Sean McCullough7d5a6302025-04-24 21:27:51 -0700339 }
340 newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700341
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000342 // Read existing known_hosts content or start with empty if the file doesn't exist
Sean McCullough7d5a6302025-04-24 21:27:51 -0700343 outputLines := []string{}
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000344 existingContent, err := c.fs.ReadFile(c.knownHostsPath)
345 if err == nil {
346 scanner := bufio.NewScanner(bytes.NewReader(existingContent))
347 for scanner.Scan() {
348 outputLines = append(outputLines, scanner.Text())
349 }
350 } else if !os.IsNotExist(err) {
351 return fmt.Errorf("couldn't read known_hosts file: %w", err)
Sean McCullough7d5a6302025-04-24 21:27:51 -0700352 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000353
354 // Add the new host line
Sean McCullough7d5a6302025-04-24 21:27:51 -0700355 outputLines = append(outputLines, newHostLine)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000356
357 // Safely write the updated content to the file
358 if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
359 return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700360 }
361
362 return nil
363}
364
365func (c *SSHTheater) removeContainerFromKnownHosts() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000366 // Read the existing known_hosts file
367 existingContent, err := c.fs.ReadFile(c.knownHostsPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700368 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000369 // If the file doesn't exist, there's nothing to do
370 if os.IsNotExist(err) {
371 return nil
372 }
373 return fmt.Errorf("couldn't read known_hosts file: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700374 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000375
376 // Line we want to remove
Sean McCullough7d5a6302025-04-24 21:27:51 -0700377 lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000378
379 // Filter out the line we want to remove
Sean McCullough4854c652025-04-24 18:37:02 -0700380 outputLines := []string{}
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000381 scanner := bufio.NewScanner(bytes.NewReader(existingContent))
Sean McCullough4854c652025-04-24 18:37:02 -0700382 for scanner.Scan() {
383 if scanner.Text() == lineToRemove {
384 continue
385 }
386 outputLines = append(outputLines, scanner.Text())
387 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000388
389 // Safely write the updated content back to the file
390 if err := c.fs.SafeWriteFile(c.knownHostsPath, []byte(strings.Join(outputLines, "\n")), 0o644); err != nil {
391 return fmt.Errorf("couldn't safely write updated known_hosts to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700392 }
393
394 return nil
395}
396
397func (c *SSHTheater) Cleanup() error {
398 if err := c.removeContainerFromSSHConfig(); err != nil {
399 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
400 }
401 if err := c.removeContainerFromKnownHosts(); err != nil {
402 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
403 }
404
405 return nil
406}
407
408func (c *SSHTheater) removeContainerFromSSHConfig() error {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000409 // Read the existing file contents
410 configData, err := c.fs.ReadFile(c.sshConfigPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700411 if err != nil {
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000412 return fmt.Errorf("couldn't read ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700413 }
Sean McCullough4854c652025-04-24 18:37:02 -0700414
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000415 cfg, err := ssh_config.Decode(bytes.NewReader(configData))
Sean McCullough4854c652025-04-24 18:37:02 -0700416 if err != nil {
417 return fmt.Errorf("couldn't decode ssh_config: %w", err)
418 }
419 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
420
421 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
422 return fmt.Errorf("couldn't add missing host match: %w", err)
423 }
424
425 cfgBytes, err := cfg.MarshalText()
426 if err != nil {
427 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
428 }
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000429
430 // Safely write the updated configuration to file
431 if err := c.fs.SafeWriteFile(c.sshConfigPath, cfgBytes, 0o644); err != nil {
432 return fmt.Errorf("couldn't safely write ssh_config: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700433 }
434 return nil
435}
Sean McCullough2cba6952025-04-25 20:32:10 +0000436
437// FileSystem represents a filesystem interface for testability
438type FileSystem interface {
439 Stat(name string) (fs.FileInfo, error)
440 Mkdir(name string, perm fs.FileMode) error
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700441 MkdirAll(name string, perm fs.FileMode) error
Sean McCullough2cba6952025-04-25 20:32:10 +0000442 ReadFile(name string) ([]byte, error)
443 WriteFile(name string, data []byte, perm fs.FileMode) error
444 OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error)
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000445 TempFile(dir, pattern string) (*os.File, error)
446 Rename(oldpath, newpath string) error
447 SafeWriteFile(name string, data []byte, perm fs.FileMode) error
Sean McCullough2cba6952025-04-25 20:32:10 +0000448}
449
Sean McCulloughc796e7f2025-04-30 08:44:06 -0700450func (fs *RealFileSystem) MkdirAll(name string, perm fs.FileMode) error {
451 return os.MkdirAll(name, perm)
452}
453
Sean McCullough2cba6952025-04-25 20:32:10 +0000454// RealFileSystem is the default implementation of FileSystem that uses the OS
455type RealFileSystem struct{}
456
457func (fs *RealFileSystem) Stat(name string) (fs.FileInfo, error) {
458 return os.Stat(name)
459}
460
461func (fs *RealFileSystem) Mkdir(name string, perm fs.FileMode) error {
462 return os.Mkdir(name, perm)
463}
464
465func (fs *RealFileSystem) ReadFile(name string) ([]byte, error) {
466 return os.ReadFile(name)
467}
468
469func (fs *RealFileSystem) WriteFile(name string, data []byte, perm fs.FileMode) error {
470 return os.WriteFile(name, data, perm)
471}
472
473func (fs *RealFileSystem) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
474 return os.OpenFile(name, flag, perm)
475}
476
Sean McCullough0d95d3a2025-04-30 16:22:28 +0000477func (fs *RealFileSystem) TempFile(dir, pattern string) (*os.File, error) {
478 return os.CreateTemp(dir, pattern)
479}
480
481func (fs *RealFileSystem) Rename(oldpath, newpath string) error {
482 return os.Rename(oldpath, newpath)
483}
484
485// SafeWriteFile writes data to a temporary file, syncs to disk, creates a backup of the existing file if it exists,
486// and then renames the temporary file to the target file name.
487func (fs *RealFileSystem) SafeWriteFile(name string, data []byte, perm fs.FileMode) error {
488 // Get the directory from the target filename
489 dir := filepath.Dir(name)
490
491 // Create a temporary file in the same directory
492 tmpFile, err := fs.TempFile(dir, filepath.Base(name)+".*.tmp")
493 if err != nil {
494 return fmt.Errorf("couldn't create temporary file: %w", err)
495 }
496 tmpFilename := tmpFile.Name()
497 defer os.Remove(tmpFilename) // Clean up if we fail
498
499 // Write data to the temporary file
500 if _, err := tmpFile.Write(data); err != nil {
501 tmpFile.Close()
502 return fmt.Errorf("couldn't write to temporary file: %w", err)
503 }
504
505 // Sync to disk to ensure data is written
506 if err := tmpFile.Sync(); err != nil {
507 tmpFile.Close()
508 return fmt.Errorf("couldn't sync temporary file: %w", err)
509 }
510
511 // Close the temporary file
512 if err := tmpFile.Close(); err != nil {
513 return fmt.Errorf("couldn't close temporary file: %w", err)
514 }
515
516 // If the original file exists, create a backup
517 if _, err := fs.Stat(name); err == nil {
518 backupName := name + ".bak"
519 // Remove any existing backup
520 _ = os.Remove(backupName) // Ignore errors if the backup doesn't exist
521
522 // Create the backup
523 if err := fs.Rename(name, backupName); err != nil {
524 return fmt.Errorf("couldn't create backup file: %w", err)
525 }
526 }
527
528 // Rename the temporary file to the target file
529 if err := fs.Rename(tmpFilename, name); err != nil {
530 return fmt.Errorf("couldn't rename temporary file to target: %w", err)
531 }
532
533 // Set permissions on the new file
534 if err := os.Chmod(name, perm); err != nil {
535 return fmt.Errorf("couldn't set permissions on file: %w", err)
536 }
537
538 return nil
539}
540
Sean McCullough2cba6952025-04-25 20:32:10 +0000541// KeyGenerator represents an interface for generating SSH keys for testability
542type KeyGenerator interface {
543 GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error)
544 GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error)
545}
546
547// RealKeyGenerator is the default implementation of KeyGenerator
548type RealKeyGenerator struct{}
549
550func (kg *RealKeyGenerator) GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
551 return rsa.GenerateKey(rand.Reader, bitSize)
552}
553
554func (kg *RealKeyGenerator) GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error) {
555 return ssh.NewPublicKey(privateKey)
556}
557
558// CheckForInclude checks if the user's SSH config includes the Sketch SSH config file
559func CheckForInclude() error {
560 return CheckForIncludeWithFS(&RealFileSystem{})
561}