blob: ad612750f1345a2344f7397c859342871b42d8b9 [file] [log] [blame]
Sean McCullough4854c652025-04-24 18:37:02 -07001package dockerimg
2
3import (
4 "bufio"
5 "crypto/rand"
6 "crypto/rsa"
7 "crypto/x509"
Sean McCullough4854c652025-04-24 18:37:02 -07008 "encoding/pem"
9 "fmt"
Sean McCullough2cba6952025-04-25 20:32:10 +000010 "io/fs"
Sean McCullough4854c652025-04-24 18:37:02 -070011 "os"
12 "path/filepath"
13 "strings"
14
15 "github.com/kevinburke/ssh_config"
16 "golang.org/x/crypto/ssh"
Sean McCullough7d5a6302025-04-24 21:27:51 -070017 "golang.org/x/crypto/ssh/knownhosts"
Sean McCullough4854c652025-04-24 18:37:02 -070018)
19
20const keyBitSize = 2048
21
22// SSHTheater does the necessary key pair generation, known_hosts updates, ssh_config file updates etc steps
23// so that ssh can connect to a locally running sketch container to other local processes like vscode without
24// the user having to run the usual ssh obstacle course.
25//
26// SSHTheater does not modify your default .ssh/config, or known_hosts files. However, in order for you
27// to be able to use it properly you will have to make a one-time edit to your ~/.ssh/config file.
28//
29// In your ~/.ssh/config file, add the following line:
30//
31// Include $HOME/.sketch/ssh_config
32//
33// where $HOME is your home directory.
34type SSHTheater struct {
35 cntrName string
36 sshHost string
37 sshPort string
38
39 knownHostsPath string
40 userIdentityPath string
41 sshConfigPath string
42 serverIdentityPath string
43
44 serverPublicKey ssh.PublicKey
45 serverIdentity []byte
46 userIdentity []byte
Sean McCullough2cba6952025-04-25 20:32:10 +000047
48 fs FileSystem
49 kg KeyGenerator
Sean McCullough4854c652025-04-24 18:37:02 -070050}
51
52// NewSSHTheather will set up everything so that you can use ssh on localhost to connect to
53// the sketch container. Call #Clean when you are done with the container to remove the
54// various entries it created in its known_hosts and ssh_config files. Also note that
55// this will generate key pairs for both the ssh server identity and the user identity, if
56// these files do not already exist. These key pair files are not deleted by #Cleanup,
57// so they can be re-used across invocations of sketch. This means every sketch container
58// that runs on this host will use the same ssh server identity.
59//
60// If this doesn't return an error, you should be able to run "ssh <cntrName>"
61// in a terminal on your host machine to open a shell into the container without having
62// to manually accept changes to your known_hosts file etc.
63func NewSSHTheather(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
Sean McCullough2cba6952025-04-25 20:32:10 +000064 return newSSHTheatherWithDeps(cntrName, sshHost, sshPort, &RealFileSystem{}, &RealKeyGenerator{})
65}
66
67// newSSHTheatherWithDeps creates a new SSHTheater with the specified dependencies
68func newSSHTheatherWithDeps(cntrName, sshHost, sshPort string, fs FileSystem, kg KeyGenerator) (*SSHTheater, error) {
Sean McCullough4854c652025-04-24 18:37:02 -070069 base := filepath.Join(os.Getenv("HOME"), ".sketch")
Sean McCullough2cba6952025-04-25 20:32:10 +000070 if _, err := fs.Stat(base); err != nil {
71 if err := fs.Mkdir(base, 0o777); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070072 return nil, fmt.Errorf("couldn't create %s: %w", base, err)
73 }
74 }
75
Sean McCullough4854c652025-04-24 18:37:02 -070076 cst := &SSHTheater{
77 cntrName: cntrName,
78 sshHost: sshHost,
79 sshPort: sshPort,
80 knownHostsPath: filepath.Join(base, "known_hosts"),
81 userIdentityPath: filepath.Join(base, "container_user_identity"),
82 serverIdentityPath: filepath.Join(base, "container_server_identity"),
83 sshConfigPath: filepath.Join(base, "ssh_config"),
Sean McCullough2cba6952025-04-25 20:32:10 +000084 fs: fs,
85 kg: kg,
Sean McCullough4854c652025-04-24 18:37:02 -070086 }
Sean McCullough2cba6952025-04-25 20:32:10 +000087 if _, err := cst.createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070088 return nil, fmt.Errorf("couldn't create server identity: %w", err)
89 }
Sean McCullough2cba6952025-04-25 20:32:10 +000090 if _, err := cst.createKeyPairIfMissing(cst.userIdentityPath); err != nil {
Sean McCullough7d5a6302025-04-24 21:27:51 -070091 return nil, fmt.Errorf("couldn't create user identity: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -070092 }
93
Sean McCullough2cba6952025-04-25 20:32:10 +000094 serverIdentity, err := fs.ReadFile(cst.serverIdentityPath)
Sean McCullough4854c652025-04-24 18:37:02 -070095 if err != nil {
96 return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
97 }
98 cst.serverIdentity = serverIdentity
99
Sean McCullough2cba6952025-04-25 20:32:10 +0000100 serverPubKeyBytes, err := fs.ReadFile(cst.serverIdentityPath + ".pub")
101 if err != nil {
102 return nil, fmt.Errorf("couldn't read ssh server public key file: %w", err)
103 }
Sean McCullough4854c652025-04-24 18:37:02 -0700104 serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
105 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000106 return nil, fmt.Errorf("couldn't parse ssh server public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700107 }
108 cst.serverPublicKey = serverPubKey
109
Sean McCullough2cba6952025-04-25 20:32:10 +0000110 userIdentity, err := fs.ReadFile(cst.userIdentityPath + ".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700111 if err != nil {
112 return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
113 }
114 cst.userIdentity = userIdentity
115
Sean McCullough7d5a6302025-04-24 21:27:51 -0700116 if err := cst.addContainerToSSHConfig(); err != nil {
117 return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
118 }
119
120 if err := cst.addContainerToKnownHosts(); err != nil {
121 return nil, fmt.Errorf("couldn't update known hosts: %w", err)
122 }
123
Sean McCullough4854c652025-04-24 18:37:02 -0700124 return cst, nil
125}
126
Sean McCullough2cba6952025-04-25 20:32:10 +0000127func CheckForIncludeWithFS(fs FileSystem) error {
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700128 sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".sketch", "ssh_config")
129 defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
Sean McCullough2cba6952025-04-25 20:32:10 +0000130 f, _ := fs.OpenFile(filepath.Join(os.Getenv("HOME"), ".ssh", "config"), os.O_RDONLY, 0)
131 if f == nil {
132 return fmt.Errorf("⚠️ SSH connections are disabled. cannot open SSH config file: %s", defaultSSHPath)
133 }
134 defer f.Close()
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700135 cfg, _ := ssh_config.Decode(f)
136 var sketchInludePos *ssh_config.Position
137 var firstNonIncludePos *ssh_config.Position
138 for _, host := range cfg.Hosts {
139 for _, node := range host.Nodes {
140 inc, ok := node.(*ssh_config.Include)
141 if ok {
142 if strings.TrimSpace(inc.String()) == sketchSSHPathInclude {
143 pos := inc.Pos()
144 sketchInludePos = &pos
145 }
146 } else if firstNonIncludePos == nil && !strings.HasPrefix(strings.TrimSpace(node.String()), "#") {
147 pos := node.Pos()
148 firstNonIncludePos = &pos
149 }
150 }
151 }
152
153 if sketchInludePos == nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000154 return fmt.Errorf("⚠️ SSH connections are disabled. to enable them, add the line %q to the top of %s before any 'Host' lines", sketchSSHPathInclude, defaultSSHPath)
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700155 }
156
157 if firstNonIncludePos != nil && firstNonIncludePos.Line < sketchInludePos.Line {
Sean McCullough2cba6952025-04-25 20:32:10 +0000158 fmt.Printf("⚠️ SSH confg warning: the location of the Include statement for sketch's ssh config on line %d of %s may prevent ssh from working with sketch containers. try moving it to the top of the file (before any 'Host' lines) if ssh isn't working for you.\n", sketchInludePos.Line, defaultSSHPath)
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700159 }
160 return nil
161}
162
Sean McCullough4854c652025-04-24 18:37:02 -0700163func removeFromHosts(cntrName string, cfgHosts []*ssh_config.Host) []*ssh_config.Host {
164 hosts := []*ssh_config.Host{}
165 for _, host := range cfgHosts {
166 if host.Matches(cntrName) || strings.Contains(host.String(), cntrName) {
167 continue
168 }
169 patMatch := false
170 for _, pat := range host.Patterns {
171 if strings.Contains(pat.String(), cntrName) {
172 patMatch = true
173 }
174 }
175 if patMatch {
176 continue
177 }
178
179 hosts = append(hosts, host)
180 }
181 return hosts
182}
183
Sean McCullough4854c652025-04-24 18:37:02 -0700184func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
185 pemBlock := &pem.Block{
186 Type: "RSA PRIVATE KEY",
187 Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
188 }
189 pemBytes := pem.EncodeToMemory(pemBlock)
190 return pemBytes
191}
192
Sean McCullough2cba6952025-04-25 20:32:10 +0000193func (c *SSHTheater) writeKeyToFile(keyBytes []byte, filename string) error {
194 err := c.fs.WriteFile(filename, keyBytes, 0o600)
Sean McCullough4854c652025-04-24 18:37:02 -0700195 return err
196}
197
Sean McCullough2cba6952025-04-25 20:32:10 +0000198func (c *SSHTheater) createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
199 if _, err := c.fs.Stat(idPath); err == nil {
Sean McCullough4854c652025-04-24 18:37:02 -0700200 return nil, nil
201 }
202
Sean McCullough2cba6952025-04-25 20:32:10 +0000203 privateKey, err := c.kg.GeneratePrivateKey(keyBitSize)
Sean McCullough4854c652025-04-24 18:37:02 -0700204 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000205 return nil, fmt.Errorf("error generating private key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700206 }
207
Sean McCullough2cba6952025-04-25 20:32:10 +0000208 publicRsaKey, err := c.kg.GeneratePublicKey(&privateKey.PublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700209 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000210 return nil, fmt.Errorf("error generating public key: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700211 }
212
213 privateKeyPEM := encodePrivateKeyToPEM(privateKey)
214
Sean McCullough2cba6952025-04-25 20:32:10 +0000215 err = c.writeKeyToFile(privateKeyPEM, idPath)
Sean McCullough4854c652025-04-24 18:37:02 -0700216 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000217 return nil, fmt.Errorf("error writing private key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700218 }
219 pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
220
Sean McCullough2cba6952025-04-25 20:32:10 +0000221 err = c.writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
Sean McCullough4854c652025-04-24 18:37:02 -0700222 if err != nil {
Sean McCullough2cba6952025-04-25 20:32:10 +0000223 return nil, fmt.Errorf("error writing public key to file %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -0700224 }
225 return publicRsaKey, nil
226}
227
228func (c *SSHTheater) addSketchHostMatchIfMissing(cfg *ssh_config.Config) error {
229 found := false
230 for _, host := range cfg.Hosts {
231 if strings.Contains(host.String(), "host=\"sketch-*\"") {
232 found = true
233 break
234 }
235 }
236 if !found {
237 hostPattern, err := ssh_config.NewPattern("host=\"sketch-*\"")
238 if err != nil {
239 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
240 }
241
242 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
243 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
244
Sean McCullough4854c652025-04-24 18:37:02 -0700245 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
Sean McCullough4854c652025-04-24 18:37:02 -0700246 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
247
248 cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
249 }
250 return nil
251}
252
253func (c *SSHTheater) addContainerToSSHConfig() error {
Sean McCullough2cba6952025-04-25 20:32:10 +0000254 f, err := c.fs.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
Sean McCullough4854c652025-04-24 18:37:02 -0700255 if err != nil {
256 return fmt.Errorf("couldn't open ssh_config: %w", err)
257 }
258 defer f.Close()
259
260 cfg, err := ssh_config.Decode(f)
261 if err != nil {
262 return fmt.Errorf("couldn't decode ssh_config: %w", err)
263 }
264 cntrPattern, err := ssh_config.NewPattern(c.cntrName)
265 if err != nil {
266 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
267 }
268
269 // Remove any matches for this container if they already exist.
270 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
271
272 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{cntrPattern}}
273 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "HostName", Value: c.sshHost})
274 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "User", Value: "root"})
275 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "Port", Value: c.sshPort})
276 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
277 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
278
279 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
280 cfg.Hosts = append(cfg.Hosts, hostCfg)
281
282 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
283 return fmt.Errorf("couldn't add missing host match: %w", err)
284 }
285
286 cfgBytes, err := cfg.MarshalText()
287 if err != nil {
288 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
289 }
290 if err := f.Truncate(0); err != nil {
291 return fmt.Errorf("couldn't truncate ssh_config: %w", err)
292 }
293 if _, err := f.Seek(0, 0); err != nil {
294 return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
295 }
296 if _, err := f.Write(cfgBytes); err != nil {
297 return fmt.Errorf("couldn't write ssh_config: %w", err)
298 }
299
300 return nil
301}
302
303func (c *SSHTheater) addContainerToKnownHosts() error {
Sean McCullough2cba6952025-04-25 20:32:10 +0000304 f, err := c.fs.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
Sean McCullough4854c652025-04-24 18:37:02 -0700305 if err != nil {
306 return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
307 }
308 defer f.Close()
Sean McCullough7d5a6302025-04-24 21:27:51 -0700309 pkBytes := c.serverPublicKey.Marshal()
310 if len(pkBytes) == 0 {
Sean McCullough2cba6952025-04-25 20:32:10 +0000311 return fmt.Errorf("empty serverPublicKey, this is a bug")
Sean McCullough7d5a6302025-04-24 21:27:51 -0700312 }
313 newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700314
Sean McCullough7d5a6302025-04-24 21:27:51 -0700315 outputLines := []string{}
316 scanner := bufio.NewScanner(f)
317 for scanner.Scan() {
318 outputLines = append(outputLines, scanner.Text())
319 }
320 outputLines = append(outputLines, newHostLine)
321 if err := f.Truncate(0); err != nil {
322 return fmt.Errorf("couldn't truncate known_hosts: %w", err)
323 }
324 if _, err := f.Seek(0, 0); err != nil {
325 return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
326 }
327 if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
328 return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700329 }
330
331 return nil
332}
333
334func (c *SSHTheater) removeContainerFromKnownHosts() error {
Sean McCullough2cba6952025-04-25 20:32:10 +0000335 f, err := c.fs.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
Sean McCullough4854c652025-04-24 18:37:02 -0700336 if err != nil {
337 return fmt.Errorf("couldn't open ssh_config: %w", err)
338 }
339 defer f.Close()
340 scanner := bufio.NewScanner(f)
Sean McCullough7d5a6302025-04-24 21:27:51 -0700341 lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700342 outputLines := []string{}
343 for scanner.Scan() {
344 if scanner.Text() == lineToRemove {
345 continue
346 }
347 outputLines = append(outputLines, scanner.Text())
348 }
349 if err := f.Truncate(0); err != nil {
350 return fmt.Errorf("couldn't truncate known_hosts: %w", err)
351 }
352 if _, err := f.Seek(0, 0); err != nil {
353 return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
354 }
355 if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
356 return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
357 }
358
359 return nil
360}
361
362func (c *SSHTheater) Cleanup() error {
363 if err := c.removeContainerFromSSHConfig(); err != nil {
364 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
365 }
366 if err := c.removeContainerFromKnownHosts(); err != nil {
367 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
368 }
369
370 return nil
371}
372
373func (c *SSHTheater) removeContainerFromSSHConfig() error {
Sean McCullough2cba6952025-04-25 20:32:10 +0000374 f, err := c.fs.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
Sean McCullough4854c652025-04-24 18:37:02 -0700375 if err != nil {
376 return fmt.Errorf("couldn't open ssh_config: %w", err)
377 }
378 defer f.Close()
379
380 cfg, err := ssh_config.Decode(f)
381 if err != nil {
382 return fmt.Errorf("couldn't decode ssh_config: %w", err)
383 }
384 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
385
386 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
387 return fmt.Errorf("couldn't add missing host match: %w", err)
388 }
389
390 cfgBytes, err := cfg.MarshalText()
391 if err != nil {
392 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
393 }
394 if err := f.Truncate(0); err != nil {
395 return fmt.Errorf("couldn't truncate ssh_config: %w", err)
396 }
397 if _, err := f.Seek(0, 0); err != nil {
398 return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
399 }
400 if _, err := f.Write(cfgBytes); err != nil {
401 return fmt.Errorf("couldn't write ssh_config: %w", err)
402 }
403 return nil
404}
Sean McCullough2cba6952025-04-25 20:32:10 +0000405
406// FileSystem represents a filesystem interface for testability
407type FileSystem interface {
408 Stat(name string) (fs.FileInfo, error)
409 Mkdir(name string, perm fs.FileMode) error
410 ReadFile(name string) ([]byte, error)
411 WriteFile(name string, data []byte, perm fs.FileMode) error
412 OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error)
413}
414
415// RealFileSystem is the default implementation of FileSystem that uses the OS
416type RealFileSystem struct{}
417
418func (fs *RealFileSystem) Stat(name string) (fs.FileInfo, error) {
419 return os.Stat(name)
420}
421
422func (fs *RealFileSystem) Mkdir(name string, perm fs.FileMode) error {
423 return os.Mkdir(name, perm)
424}
425
426func (fs *RealFileSystem) ReadFile(name string) ([]byte, error) {
427 return os.ReadFile(name)
428}
429
430func (fs *RealFileSystem) WriteFile(name string, data []byte, perm fs.FileMode) error {
431 return os.WriteFile(name, data, perm)
432}
433
434func (fs *RealFileSystem) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
435 return os.OpenFile(name, flag, perm)
436}
437
438// KeyGenerator represents an interface for generating SSH keys for testability
439type KeyGenerator interface {
440 GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error)
441 GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error)
442}
443
444// RealKeyGenerator is the default implementation of KeyGenerator
445type RealKeyGenerator struct{}
446
447func (kg *RealKeyGenerator) GeneratePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
448 return rsa.GenerateKey(rand.Reader, bitSize)
449}
450
451func (kg *RealKeyGenerator) GeneratePublicKey(privateKey *rsa.PublicKey) (ssh.PublicKey, error) {
452 return ssh.NewPublicKey(privateKey)
453}
454
455// CheckForInclude checks if the user's SSH config includes the Sketch SSH config file
456func CheckForInclude() error {
457 return CheckForIncludeWithFS(&RealFileSystem{})
458}