blob: 31d1849eae23d50e59c39e36480bfaf25ab7368f [file] [log] [blame]
Sean McCullough4854c652025-04-24 18:37:02 -07001package dockerimg
2
3import (
4 "bufio"
5 "crypto/rand"
6 "crypto/rsa"
7 "crypto/x509"
8 "encoding/base64"
9 "encoding/pem"
10 "fmt"
11 "os"
12 "path/filepath"
13 "strings"
14
15 "github.com/kevinburke/ssh_config"
16 "golang.org/x/crypto/ssh"
17)
18
19const keyBitSize = 2048
20
21// SSHTheater does the necessary key pair generation, known_hosts updates, ssh_config file updates etc steps
22// so that ssh can connect to a locally running sketch container to other local processes like vscode without
23// the user having to run the usual ssh obstacle course.
24//
25// SSHTheater does not modify your default .ssh/config, or known_hosts files. However, in order for you
26// to be able to use it properly you will have to make a one-time edit to your ~/.ssh/config file.
27//
28// In your ~/.ssh/config file, add the following line:
29//
30// Include $HOME/.sketch/ssh_config
31//
32// where $HOME is your home directory.
33type SSHTheater struct {
34 cntrName string
35 sshHost string
36 sshPort string
37
38 knownHostsPath string
39 userIdentityPath string
40 sshConfigPath string
41 serverIdentityPath string
42
43 serverPublicKey ssh.PublicKey
44 serverIdentity []byte
45 userIdentity []byte
46}
47
48// NewSSHTheather will set up everything so that you can use ssh on localhost to connect to
49// the sketch container. Call #Clean when you are done with the container to remove the
50// various entries it created in its known_hosts and ssh_config files. Also note that
51// this will generate key pairs for both the ssh server identity and the user identity, if
52// these files do not already exist. These key pair files are not deleted by #Cleanup,
53// so they can be re-used across invocations of sketch. This means every sketch container
54// that runs on this host will use the same ssh server identity.
55//
56// If this doesn't return an error, you should be able to run "ssh <cntrName>"
57// in a terminal on your host machine to open a shell into the container without having
58// to manually accept changes to your known_hosts file etc.
59func NewSSHTheather(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
60 base := filepath.Join(os.Getenv("HOME"), ".sketch")
61 cst := &SSHTheater{
62 cntrName: cntrName,
63 sshHost: sshHost,
64 sshPort: sshPort,
65 knownHostsPath: filepath.Join(base, "known_hosts"),
66 userIdentityPath: filepath.Join(base, "container_user_identity"),
67 serverIdentityPath: filepath.Join(base, "container_server_identity"),
68 sshConfigPath: filepath.Join(base, "ssh_config"),
69 }
70 if err := cst.addContainerToSSHConfig(); err != nil {
71 return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
72 }
73
74 serverIdentity, err := os.ReadFile(cst.serverIdentityPath)
75 if err != nil {
76 return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
77 }
78 cst.serverIdentity = serverIdentity
79
80 serverPubKeyBytes, err := os.ReadFile(cst.serverIdentityPath + ".pub")
81 serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
82 if err != nil {
83 return nil, fmt.Errorf("couldn't read ssh server public key: %w", err)
84 }
85 cst.serverPublicKey = serverPubKey
86
87 userIdentity, err := os.ReadFile(cst.userIdentityPath + ".pub")
88 if err != nil {
89 return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
90 }
91 cst.userIdentity = userIdentity
92
93 return cst, nil
94}
95
96func removeFromHosts(cntrName string, cfgHosts []*ssh_config.Host) []*ssh_config.Host {
97 hosts := []*ssh_config.Host{}
98 for _, host := range cfgHosts {
99 if host.Matches(cntrName) || strings.Contains(host.String(), cntrName) {
100 continue
101 }
102 patMatch := false
103 for _, pat := range host.Patterns {
104 if strings.Contains(pat.String(), cntrName) {
105 patMatch = true
106 }
107 }
108 if patMatch {
109 continue
110 }
111
112 hosts = append(hosts, host)
113 }
114 return hosts
115}
116
117func generatePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
118 privateKey, err := rsa.GenerateKey(rand.Reader, bitSize)
119 if err != nil {
120 return nil, err
121 }
122 return privateKey, nil
123}
124
125// generatePublicKey take a rsa.PublicKey and return bytes suitable for writing to .pub file
126// returns in the format "ssh-rsa ..."
127func generatePublicKey(privatekey *rsa.PublicKey) (ssh.PublicKey, error) {
128 publicRsaKey, err := ssh.NewPublicKey(privatekey)
129 if err != nil {
130 return nil, err
131 }
132
133 return publicRsaKey, nil
134}
135
136func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
137 pemBlock := &pem.Block{
138 Type: "RSA PRIVATE KEY",
139 Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
140 }
141 pemBytes := pem.EncodeToMemory(pemBlock)
142 return pemBytes
143}
144
145func writeKeyToFile(keyBytes []byte, filename string) error {
146 err := os.WriteFile(filename, keyBytes, 0o600)
147 return err
148}
149
150func createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
151 if _, err := os.Stat(idPath); err == nil {
152 return nil, nil
153 }
154
155 privateKey, err := generatePrivateKey(keyBitSize)
156 if err != nil {
157 return nil, fmt.Errorf("Error generating private key: %w", err)
158 }
159
160 publicRsaKey, err := generatePublicKey(&privateKey.PublicKey)
161 if err != nil {
162 return nil, fmt.Errorf("Error generating public key: %w", err)
163 }
164
165 privateKeyPEM := encodePrivateKeyToPEM(privateKey)
166
167 err = writeKeyToFile(privateKeyPEM, idPath)
168 if err != nil {
169 return nil, fmt.Errorf("Error writing private key to file %w", err)
170 }
171 pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
172
173 err = writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
174 if err != nil {
175 return nil, fmt.Errorf("Error writing public key to file %w", err)
176 }
177 return publicRsaKey, nil
178}
179
180func (c *SSHTheater) addSketchHostMatchIfMissing(cfg *ssh_config.Config) error {
181 found := false
182 for _, host := range cfg.Hosts {
183 if strings.Contains(host.String(), "host=\"sketch-*\"") {
184 found = true
185 break
186 }
187 }
188 if !found {
189 hostPattern, err := ssh_config.NewPattern("host=\"sketch-*\"")
190 if err != nil {
191 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
192 }
193
194 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
195 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
196
197 if _, err := createKeyPairIfMissing(c.userIdentityPath); err != nil {
198 return fmt.Errorf("couldn't create user identity: %w", err)
199 }
200 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
201
202 serverPubKey, err := createKeyPairIfMissing(c.serverIdentityPath)
203 if err != nil {
204 return fmt.Errorf("couldn't create server identity: %w", err)
205 }
206 if serverPubKey != nil {
207 c.serverPublicKey = serverPubKey
208 if err := c.addContainerToKnownHosts(); err != nil {
209 return fmt.Errorf("couldn't update known hosts: %w", err)
210 }
211 }
212 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
213
214 cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
215 }
216 return nil
217}
218
219func (c *SSHTheater) addContainerToSSHConfig() error {
220 dotSketchPath := filepath.Join(os.Getenv("HOME"), ".sketch")
221 if _, err := os.Stat(dotSketchPath); err != nil {
222 if err := os.Mkdir(dotSketchPath, 0o777); err != nil {
223 return fmt.Errorf("couldn't create %s: %w", dotSketchPath, err)
224 }
225 }
226
227 f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
228 if err != nil {
229 return fmt.Errorf("couldn't open ssh_config: %w", err)
230 }
231 defer f.Close()
232
233 cfg, err := ssh_config.Decode(f)
234 if err != nil {
235 return fmt.Errorf("couldn't decode ssh_config: %w", err)
236 }
237 cntrPattern, err := ssh_config.NewPattern(c.cntrName)
238 if err != nil {
239 return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
240 }
241
242 // Remove any matches for this container if they already exist.
243 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
244
245 hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{cntrPattern}}
246 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "HostName", Value: c.sshHost})
247 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "User", Value: "root"})
248 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "Port", Value: c.sshPort})
249 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
250 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
251
252 hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
253 cfg.Hosts = append(cfg.Hosts, hostCfg)
254
255 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
256 return fmt.Errorf("couldn't add missing host match: %w", err)
257 }
258
259 cfgBytes, err := cfg.MarshalText()
260 if err != nil {
261 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
262 }
263 if err := f.Truncate(0); err != nil {
264 return fmt.Errorf("couldn't truncate ssh_config: %w", err)
265 }
266 if _, err := f.Seek(0, 0); err != nil {
267 return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
268 }
269 if _, err := f.Write(cfgBytes); err != nil {
270 return fmt.Errorf("couldn't write ssh_config: %w", err)
271 }
272
273 return nil
274}
275
276func (c *SSHTheater) addContainerToKnownHosts() error {
277 f, err := os.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
278 if err != nil {
279 return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
280 }
281 defer f.Close()
282
283 line := strings.Join(
284 []string{
285 fmt.Sprintf("[%s]:%s", c.sshHost, c.sshPort),
286 c.serverPublicKey.Type(),
287 base64.StdEncoding.EncodeToString(c.serverPublicKey.Marshal()),
288 }, " ")
289 if _, err := f.Write([]byte(line + "\n")); err != nil {
290 return fmt.Errorf("couldn't write new known_host entry to to %s: %w", c.knownHostsPath, err)
291 }
292
293 return nil
294}
295
296func (c *SSHTheater) removeContainerFromKnownHosts() error {
297 f, err := os.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
298 if err != nil {
299 return fmt.Errorf("couldn't open ssh_config: %w", err)
300 }
301 defer f.Close()
302 scanner := bufio.NewScanner(f)
303 lineToRemove := strings.Join(
304 []string{
305 fmt.Sprintf("[%s]:%s", c.sshHost, c.sshPort),
306 c.serverPublicKey.Type(),
307 base64.StdEncoding.EncodeToString(c.serverPublicKey.Marshal()),
308 }, " ")
309
310 outputLines := []string{}
311 for scanner.Scan() {
312 if scanner.Text() == lineToRemove {
313 continue
314 }
315 outputLines = append(outputLines, scanner.Text())
316 }
317 if err := f.Truncate(0); err != nil {
318 return fmt.Errorf("couldn't truncate known_hosts: %w", err)
319 }
320 if _, err := f.Seek(0, 0); err != nil {
321 return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
322 }
323 if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
324 return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
325 }
326
327 return nil
328}
329
330func (c *SSHTheater) Cleanup() error {
331 if err := c.removeContainerFromSSHConfig(); err != nil {
332 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
333 }
334 if err := c.removeContainerFromKnownHosts(); err != nil {
335 return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
336 }
337
338 return nil
339}
340
341func (c *SSHTheater) removeContainerFromSSHConfig() error {
342 f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
343 if err != nil {
344 return fmt.Errorf("couldn't open ssh_config: %w", err)
345 }
346 defer f.Close()
347
348 cfg, err := ssh_config.Decode(f)
349 if err != nil {
350 return fmt.Errorf("couldn't decode ssh_config: %w", err)
351 }
352 cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
353
354 if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
355 return fmt.Errorf("couldn't add missing host match: %w", err)
356 }
357
358 cfgBytes, err := cfg.MarshalText()
359 if err != nil {
360 return fmt.Errorf("couldn't marshal ssh_config: %w", err)
361 }
362 if err := f.Truncate(0); err != nil {
363 return fmt.Errorf("couldn't truncate ssh_config: %w", err)
364 }
365 if _, err := f.Seek(0, 0); err != nil {
366 return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
367 }
368 if _, err := f.Write(cfgBytes); err != nil {
369 return fmt.Errorf("couldn't write ssh_config: %w", err)
370 }
371 return nil
372}