blob: e54124c492301463e546ff2d8e027c304de35206 [file] [log] [blame]
Sean McCullough4854c652025-04-24 18:37:02 -07001package dockerimg
2
3import (
4 "bufio"
5 "crypto/rand"
6 "crypto/rsa"
7 "crypto/x509"
Sean McCullough4854c652025-04-24 18:37:02 -07008 "encoding/pem"
9 "fmt"
10 "os"
11 "path/filepath"
12 "strings"
13
14 "github.com/kevinburke/ssh_config"
15 "golang.org/x/crypto/ssh"
Sean McCullough7d5a6302025-04-24 21:27:51 -070016 "golang.org/x/crypto/ssh/knownhosts"
Sean McCullough4854c652025-04-24 18:37:02 -070017)
18
19const keyBitSize = 2048
20
21// SSHTheater does the necessary key pair generation, known_hosts updates, ssh_config file updates etc steps
22// so that ssh can connect to a locally running sketch container to other local processes like vscode without
23// the user having to run the usual ssh obstacle course.
24//
25// SSHTheater does not modify your default .ssh/config, or known_hosts files. However, in order for you
26// to be able to use it properly you will have to make a one-time edit to your ~/.ssh/config file.
27//
28// In your ~/.ssh/config file, add the following line:
29//
30// Include $HOME/.sketch/ssh_config
31//
32// where $HOME is your home directory.
33type SSHTheater struct {
34 cntrName string
35 sshHost string
36 sshPort string
37
38 knownHostsPath string
39 userIdentityPath string
40 sshConfigPath string
41 serverIdentityPath string
42
43 serverPublicKey ssh.PublicKey
44 serverIdentity []byte
45 userIdentity []byte
46}
47
48// NewSSHTheather will set up everything so that you can use ssh on localhost to connect to
49// the sketch container. Call #Clean when you are done with the container to remove the
50// various entries it created in its known_hosts and ssh_config files. Also note that
51// this will generate key pairs for both the ssh server identity and the user identity, if
52// these files do not already exist. These key pair files are not deleted by #Cleanup,
53// so they can be re-used across invocations of sketch. This means every sketch container
54// that runs on this host will use the same ssh server identity.
55//
56// If this doesn't return an error, you should be able to run "ssh <cntrName>"
57// in a terminal on your host machine to open a shell into the container without having
58// to manually accept changes to your known_hosts file etc.
59func NewSSHTheather(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
60 base := filepath.Join(os.Getenv("HOME"), ".sketch")
Sean McCullough7d5a6302025-04-24 21:27:51 -070061 if _, err := os.Stat(base); err != nil {
62 if err := os.Mkdir(base, 0o777); err != nil {
63 return nil, fmt.Errorf("couldn't create %s: %w", base, err)
64 }
65 }
66
Sean McCullough4854c652025-04-24 18:37:02 -070067 cst := &SSHTheater{
68 cntrName: cntrName,
69 sshHost: sshHost,
70 sshPort: sshPort,
71 knownHostsPath: filepath.Join(base, "known_hosts"),
72 userIdentityPath: filepath.Join(base, "container_user_identity"),
73 serverIdentityPath: filepath.Join(base, "container_server_identity"),
74 sshConfigPath: filepath.Join(base, "ssh_config"),
75 }
Sean McCullough7d5a6302025-04-24 21:27:51 -070076 if _, err := createKeyPairIfMissing(cst.serverIdentityPath); err != nil {
77 return nil, fmt.Errorf("couldn't create server identity: %w", err)
78 }
79 if _, err := createKeyPairIfMissing(cst.userIdentityPath); err != nil {
80 return nil, fmt.Errorf("couldn't create user identity: %w", err)
Sean McCullough4854c652025-04-24 18:37:02 -070081 }
82
83 serverIdentity, err := os.ReadFile(cst.serverIdentityPath)
84 if err != nil {
85 return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
86 }
87 cst.serverIdentity = serverIdentity
88
89 serverPubKeyBytes, err := os.ReadFile(cst.serverIdentityPath + ".pub")
90 serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
91 if err != nil {
92 return nil, fmt.Errorf("couldn't read ssh server public key: %w", err)
93 }
94 cst.serverPublicKey = serverPubKey
95
96 userIdentity, err := os.ReadFile(cst.userIdentityPath + ".pub")
97 if err != nil {
98 return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
99 }
100 cst.userIdentity = userIdentity
101
Sean McCullough7d5a6302025-04-24 21:27:51 -0700102 if err := cst.addContainerToSSHConfig(); err != nil {
103 return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
104 }
105
106 if err := cst.addContainerToKnownHosts(); err != nil {
107 return nil, fmt.Errorf("couldn't update known hosts: %w", err)
108 }
109
Sean McCullough4854c652025-04-24 18:37:02 -0700110 return cst, nil
111}
112
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700113func CheckForInclude() error {
114 sketchSSHPathInclude := "Include " + filepath.Join(os.Getenv("HOME"), ".sketch", "ssh_config")
115 defaultSSHPath := filepath.Join(os.Getenv("HOME"), ".ssh", "config")
116 f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
117 cfg, _ := ssh_config.Decode(f)
118 var sketchInludePos *ssh_config.Position
119 var firstNonIncludePos *ssh_config.Position
120 for _, host := range cfg.Hosts {
121 for _, node := range host.Nodes {
122 inc, ok := node.(*ssh_config.Include)
123 if ok {
124 if strings.TrimSpace(inc.String()) == sketchSSHPathInclude {
125 pos := inc.Pos()
126 sketchInludePos = &pos
127 }
128 } else if firstNonIncludePos == nil && !strings.HasPrefix(strings.TrimSpace(node.String()), "#") {
129 pos := node.Pos()
130 firstNonIncludePos = &pos
131 }
132 }
133 }
134
135 if sketchInludePos == nil {
Philip Zeyliger6f2bf8a2025-04-25 15:29:46 -0700136 return fmt.Errorf("⚠️ SSH connections are disabled. To enable them, add the line %q to the top of %s before any 'Host' lines", sketchSSHPathInclude, defaultSSHPath)
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700137 }
138
139 if firstNonIncludePos != nil && firstNonIncludePos.Line < sketchInludePos.Line {
Philip Zeyliger6f2bf8a2025-04-25 15:29:46 -0700140 fmt.Printf("⚠️ SSH confg warning: The location of the Include statement for sketch's ssh config on line %d of %s may prevent ssh from working with sketch containers. Try moving it to the top of the file (before any 'Host' lines) if ssh isn't working for you.\n", sketchInludePos.Line, defaultSSHPath)
Sean McCulloughf5e28f62025-04-25 10:48:00 -0700141 }
142 return nil
143}
144
Sean McCullough4854c652025-04-24 18:37:02 -0700145func removeFromHosts(cntrName string, cfgHosts []*ssh_config.Host) []*ssh_config.Host {
146 hosts := []*ssh_config.Host{}
147 for _, host := range cfgHosts {
148 if host.Matches(cntrName) || strings.Contains(host.String(), cntrName) {
149 continue
150 }
151 patMatch := false
152 for _, pat := range host.Patterns {
153 if strings.Contains(pat.String(), cntrName) {
154 patMatch = true
155 }
156 }
157 if patMatch {
158 continue
159 }
160
161 hosts = append(hosts, host)
162 }
163 return hosts
164}
165
166func generatePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
167 privateKey, err := rsa.GenerateKey(rand.Reader, bitSize)
168 if err != nil {
169 return nil, err
170 }
171 return privateKey, nil
172}
173
174// generatePublicKey take a rsa.PublicKey and return bytes suitable for writing to .pub file
175// returns in the format "ssh-rsa ..."
176func generatePublicKey(privatekey *rsa.PublicKey) (ssh.PublicKey, error) {
177 publicRsaKey, err := ssh.NewPublicKey(privatekey)
178 if err != nil {
179 return nil, err
180 }
181
182 return publicRsaKey, nil
183}
184
185func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
186 pemBlock := &pem.Block{
187 Type: "RSA PRIVATE KEY",
188 Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
189 }
190 pemBytes := pem.EncodeToMemory(pemBlock)
191 return pemBytes
192}
193
194func writeKeyToFile(keyBytes []byte, filename string) error {
195 err := os.WriteFile(filename, keyBytes, 0o600)
196 return err
197}
198
199func createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
200 if _, err := os.Stat(idPath); err == nil {
201 return nil, nil
202 }
203
204 privateKey, err := generatePrivateKey(keyBitSize)
205 if err != nil {
206 return nil, fmt.Errorf("Error generating private key: %w", err)
207 }
208
209 publicRsaKey, err := generatePublicKey(&privateKey.PublicKey)
210 if err != nil {
211 return nil, fmt.Errorf("Error generating public key: %w", err)
212 }
213
214 privateKeyPEM := encodePrivateKeyToPEM(privateKey)
215
216 err = writeKeyToFile(privateKeyPEM, idPath)
217 if err != nil {
218 return nil, fmt.Errorf("Error writing private key to file %w", err)
219 }
220 pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
221
222 err = writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
223 if err != nil {
224 return nil, fmt.Errorf("Error writing public key to file %w", err)
225 }
226 return publicRsaKey, nil
227}
228
229func (c *SSHTheater) addSketchHostMatchIfMissing(cfg *ssh_config.Config) error {
230 found := false
231 for _, host := range cfg.Hosts {
232 if strings.Contains(host.String(), "host=\"sketch-*\"") {
233 found = true
234 break
235 }
236 }
237 if !found {
238 hostPattern, err := ssh_config.NewPattern("host=\"sketch-*\"")
239 if err != nil {
240 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
241 }
242
243 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
244 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
245
Sean McCullough4854c652025-04-24 18:37:02 -0700246 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
Sean McCullough4854c652025-04-24 18:37:02 -0700247 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
248
249 cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
250 }
251 return nil
252}
253
254func (c *SSHTheater) addContainerToSSHConfig() error {
Sean McCullough4854c652025-04-24 18:37:02 -0700255 f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
256 if err != nil {
257 return fmt.Errorf("couldn't open ssh_config: %w", err)
258 }
259 defer f.Close()
260
261 cfg, err := ssh_config.Decode(f)
262 if err != nil {
263 return fmt.Errorf("couldn't decode ssh_config: %w", err)
264 }
265 cntrPattern, err := ssh_config.NewPattern(c.cntrName)
266 if err != nil {
267 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
268 }
269
270 // Remove any matches for this container if they already exist.
271 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
272
273 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{cntrPattern}}
274 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "HostName", Value: c.sshHost})
275 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "User", Value: "root"})
276 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "Port", Value: c.sshPort})
277 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
278 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
279
280 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
281 cfg.Hosts = append(cfg.Hosts, hostCfg)
282
283 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
284 return fmt.Errorf("couldn't add missing host match: %w", err)
285 }
286
287 cfgBytes, err := cfg.MarshalText()
288 if err != nil {
289 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
290 }
291 if err := f.Truncate(0); err != nil {
292 return fmt.Errorf("couldn't truncate ssh_config: %w", err)
293 }
294 if _, err := f.Seek(0, 0); err != nil {
295 return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
296 }
297 if _, err := f.Write(cfgBytes); err != nil {
298 return fmt.Errorf("couldn't write ssh_config: %w", err)
299 }
300
301 return nil
302}
303
304func (c *SSHTheater) addContainerToKnownHosts() error {
305 f, err := os.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
306 if err != nil {
307 return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
308 }
309 defer f.Close()
Sean McCullough7d5a6302025-04-24 21:27:51 -0700310 pkBytes := c.serverPublicKey.Marshal()
311 if len(pkBytes) == 0 {
312 return fmt.Errorf("empty serverPublicKey. This is a bug")
313 }
314 newHostLine := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700315
Sean McCullough7d5a6302025-04-24 21:27:51 -0700316 outputLines := []string{}
317 scanner := bufio.NewScanner(f)
318 for scanner.Scan() {
319 outputLines = append(outputLines, scanner.Text())
320 }
321 outputLines = append(outputLines, newHostLine)
322 if err := f.Truncate(0); err != nil {
323 return fmt.Errorf("couldn't truncate known_hosts: %w", err)
324 }
325 if _, err := f.Seek(0, 0); err != nil {
326 return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
327 }
328 if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
329 return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
Sean McCullough4854c652025-04-24 18:37:02 -0700330 }
331
332 return nil
333}
334
335func (c *SSHTheater) removeContainerFromKnownHosts() error {
336 f, err := os.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
337 if err != nil {
338 return fmt.Errorf("couldn't open ssh_config: %w", err)
339 }
340 defer f.Close()
341 scanner := bufio.NewScanner(f)
Sean McCullough7d5a6302025-04-24 21:27:51 -0700342 lineToRemove := knownhosts.Line([]string{c.sshHost + ":" + c.sshPort}, c.serverPublicKey)
Sean McCullough4854c652025-04-24 18:37:02 -0700343 outputLines := []string{}
344 for scanner.Scan() {
345 if scanner.Text() == lineToRemove {
346 continue
347 }
348 outputLines = append(outputLines, scanner.Text())
349 }
350 if err := f.Truncate(0); err != nil {
351 return fmt.Errorf("couldn't truncate known_hosts: %w", err)
352 }
353 if _, err := f.Seek(0, 0); err != nil {
354 return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
355 }
356 if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
357 return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
358 }
359
360 return nil
361}
362
363func (c *SSHTheater) Cleanup() error {
364 if err := c.removeContainerFromSSHConfig(); err != nil {
365 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
366 }
367 if err := c.removeContainerFromKnownHosts(); err != nil {
368 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
369 }
370
371 return nil
372}
373
374func (c *SSHTheater) removeContainerFromSSHConfig() error {
375 f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
376 if err != nil {
377 return fmt.Errorf("couldn't open ssh_config: %w", err)
378 }
379 defer f.Close()
380
381 cfg, err := ssh_config.Decode(f)
382 if err != nil {
383 return fmt.Errorf("couldn't decode ssh_config: %w", err)
384 }
385 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
386
387 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
388 return fmt.Errorf("couldn't add missing host match: %w", err)
389 }
390
391 cfgBytes, err := cfg.MarshalText()
392 if err != nil {
393 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
394 }
395 if err := f.Truncate(0); err != nil {
396 return fmt.Errorf("couldn't truncate ssh_config: %w", err)
397 }
398 if _, err := f.Seek(0, 0); err != nil {
399 return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
400 }
401 if _, err := f.Write(cfgBytes); err != nil {
402 return fmt.Errorf("couldn't write ssh_config: %w", err)
403 }
404 return nil
405}