dockerimg: +SSHTheater, auto-configures local ssh
- removes the old cli flags for pointing sketch at existing ssh key files
- generates key pairs for both the container ssh server and the user if
these don't already exist in ~/.sketch/
- automatically edits ~/.sketch/known_hosts and ~/.sketch/ssh_config to
add entries when you start a sketch container and remove them when you're
done
- does not make any edits to default ~/.ssh/config files, so you'll have
to manually add line to you default ssh config file like:
`Include HOME_DIR/.sketch/ssh_config`
in order to have your ssh tools pick up what sketch is putting down.
diff --git a/dockerimg/dockerimg.go b/dockerimg/dockerimg.go
index d812987..a92fb92 100644
--- a/dockerimg/dockerimg.go
+++ b/dockerimg/dockerimg.go
@@ -72,12 +72,6 @@
// Host port for the container's ssh server
SSHPort int
- // Public keys authorized to connect to the container's ssh server
- SSHAuthorizedKeys []byte
-
- // Private key used to identify the container's ssh server
- SSHServerIdentity []byte
-
// Outside information to pass to the container
OutsideHostname string
OutsideOS string
@@ -262,9 +256,25 @@
}
sshHost, sshPort, err := net.SplitHostPort(localSSHAddr)
if err != nil {
- fmt.Println("Error splitting ssh host and port:", err)
+ return appendInternalErr(fmt.Errorf("Error splitting ssh host and port: %w", err))
}
- fmt.Printf("ssh into this container with: ssh root@%s -p %s\n", sshHost, sshPort)
+
+ cst, err := NewSSHTheather(cntrName, sshHost, sshPort)
+ if err != nil {
+ return appendInternalErr(fmt.Errorf("NewContainerSSHTheather: %w", err))
+ }
+
+ fmt.Printf(`Connect to this container via any of these methods:
+🖥️ ssh %s
+🖥️ code --remote ssh-remote+root@%s /app -n
+🔗 vscode://vscode-remote/ssh-remote+root@%s/app?n=true
+`, cntrName, cntrName, cntrName)
+
+ defer func() {
+ if err := cst.Cleanup(); err != nil {
+ appendInternalErr(err)
+ }
+ }()
// Tell the sketch container which git server port and commit to initialize with.
go func() {
@@ -273,7 +283,7 @@
// the scrollback (which is not good, but also not fatal). I can't see why it does this
// though, since none of the calls in postContainerInitConfig obviously write to stdout
// or stderr.
- if err := postContainerInitConfig(ctx, localAddr, commit, gitSrv.gitPort, gitSrv.pass, config.SSHServerIdentity, config.SSHAuthorizedKeys); err != nil {
+ if err := postContainerInitConfig(ctx, localAddr, commit, gitSrv.gitPort, gitSrv.pass, cst.serverIdentity, cst.userIdentity); err != nil {
slog.ErrorContext(ctx, "LaunchContainer.postContainerInitConfig", slog.String("err", err.Error()))
errCh <- appendInternalErr(err)
}
diff --git a/dockerimg/sshtheater.go b/dockerimg/sshtheater.go
new file mode 100644
index 0000000..31d1849
--- /dev/null
+++ b/dockerimg/sshtheater.go
@@ -0,0 +1,372 @@
+package dockerimg
+
+import (
+ "bufio"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/pem"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/kevinburke/ssh_config"
+ "golang.org/x/crypto/ssh"
+)
+
+const keyBitSize = 2048
+
+// SSHTheater does the necessary key pair generation, known_hosts updates, ssh_config file updates etc steps
+// so that ssh can connect to a locally running sketch container to other local processes like vscode without
+// the user having to run the usual ssh obstacle course.
+//
+// SSHTheater does not modify your default .ssh/config, or known_hosts files. However, in order for you
+// to be able to use it properly you will have to make a one-time edit to your ~/.ssh/config file.
+//
+// In your ~/.ssh/config file, add the following line:
+//
+// Include $HOME/.sketch/ssh_config
+//
+// where $HOME is your home directory.
+type SSHTheater struct {
+ cntrName string
+ sshHost string
+ sshPort string
+
+ knownHostsPath string
+ userIdentityPath string
+ sshConfigPath string
+ serverIdentityPath string
+
+ serverPublicKey ssh.PublicKey
+ serverIdentity []byte
+ userIdentity []byte
+}
+
+// NewSSHTheather will set up everything so that you can use ssh on localhost to connect to
+// the sketch container. Call #Clean when you are done with the container to remove the
+// various entries it created in its known_hosts and ssh_config files. Also note that
+// this will generate key pairs for both the ssh server identity and the user identity, if
+// these files do not already exist. These key pair files are not deleted by #Cleanup,
+// so they can be re-used across invocations of sketch. This means every sketch container
+// that runs on this host will use the same ssh server identity.
+//
+// If this doesn't return an error, you should be able to run "ssh <cntrName>"
+// in a terminal on your host machine to open a shell into the container without having
+// to manually accept changes to your known_hosts file etc.
+func NewSSHTheather(cntrName, sshHost, sshPort string) (*SSHTheater, error) {
+ base := filepath.Join(os.Getenv("HOME"), ".sketch")
+ cst := &SSHTheater{
+ cntrName: cntrName,
+ sshHost: sshHost,
+ sshPort: sshPort,
+ knownHostsPath: filepath.Join(base, "known_hosts"),
+ userIdentityPath: filepath.Join(base, "container_user_identity"),
+ serverIdentityPath: filepath.Join(base, "container_server_identity"),
+ sshConfigPath: filepath.Join(base, "ssh_config"),
+ }
+ if err := cst.addContainerToSSHConfig(); err != nil {
+ return nil, fmt.Errorf("couldn't add container to ssh_config: %w", err)
+ }
+
+ serverIdentity, err := os.ReadFile(cst.serverIdentityPath)
+ if err != nil {
+ return nil, fmt.Errorf("couldn't read container's ssh server identity: %w", err)
+ }
+ cst.serverIdentity = serverIdentity
+
+ serverPubKeyBytes, err := os.ReadFile(cst.serverIdentityPath + ".pub")
+ serverPubKey, _, _, _, err := ssh.ParseAuthorizedKey(serverPubKeyBytes)
+ if err != nil {
+ return nil, fmt.Errorf("couldn't read ssh server public key: %w", err)
+ }
+ cst.serverPublicKey = serverPubKey
+
+ userIdentity, err := os.ReadFile(cst.userIdentityPath + ".pub")
+ if err != nil {
+ return nil, fmt.Errorf("couldn't read ssh user identity: %w", err)
+ }
+ cst.userIdentity = userIdentity
+
+ return cst, nil
+}
+
+func removeFromHosts(cntrName string, cfgHosts []*ssh_config.Host) []*ssh_config.Host {
+ hosts := []*ssh_config.Host{}
+ for _, host := range cfgHosts {
+ if host.Matches(cntrName) || strings.Contains(host.String(), cntrName) {
+ continue
+ }
+ patMatch := false
+ for _, pat := range host.Patterns {
+ if strings.Contains(pat.String(), cntrName) {
+ patMatch = true
+ }
+ }
+ if patMatch {
+ continue
+ }
+
+ hosts = append(hosts, host)
+ }
+ return hosts
+}
+
+func generatePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, bitSize)
+ if err != nil {
+ return nil, err
+ }
+ return privateKey, nil
+}
+
+// generatePublicKey take a rsa.PublicKey and return bytes suitable for writing to .pub file
+// returns in the format "ssh-rsa ..."
+func generatePublicKey(privatekey *rsa.PublicKey) (ssh.PublicKey, error) {
+ publicRsaKey, err := ssh.NewPublicKey(privatekey)
+ if err != nil {
+ return nil, err
+ }
+
+ return publicRsaKey, nil
+}
+
+func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
+ pemBlock := &pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
+ }
+ pemBytes := pem.EncodeToMemory(pemBlock)
+ return pemBytes
+}
+
+func writeKeyToFile(keyBytes []byte, filename string) error {
+ err := os.WriteFile(filename, keyBytes, 0o600)
+ return err
+}
+
+func createKeyPairIfMissing(idPath string) (ssh.PublicKey, error) {
+ if _, err := os.Stat(idPath); err == nil {
+ return nil, nil
+ }
+
+ privateKey, err := generatePrivateKey(keyBitSize)
+ if err != nil {
+ return nil, fmt.Errorf("Error generating private key: %w", err)
+ }
+
+ publicRsaKey, err := generatePublicKey(&privateKey.PublicKey)
+ if err != nil {
+ return nil, fmt.Errorf("Error generating public key: %w", err)
+ }
+
+ privateKeyPEM := encodePrivateKeyToPEM(privateKey)
+
+ err = writeKeyToFile(privateKeyPEM, idPath)
+ if err != nil {
+ return nil, fmt.Errorf("Error writing private key to file %w", err)
+ }
+ pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
+
+ err = writeKeyToFile([]byte(pubKeyBytes), idPath+".pub")
+ if err != nil {
+ return nil, fmt.Errorf("Error writing public key to file %w", err)
+ }
+ return publicRsaKey, nil
+}
+
+func (c *SSHTheater) addSketchHostMatchIfMissing(cfg *ssh_config.Config) error {
+ found := false
+ for _, host := range cfg.Hosts {
+ if strings.Contains(host.String(), "host=\"sketch-*\"") {
+ found = true
+ break
+ }
+ }
+ if !found {
+ hostPattern, err := ssh_config.NewPattern("host=\"sketch-*\"")
+ if err != nil {
+ return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
+ }
+
+ hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{hostPattern}}
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
+
+ if _, err := createKeyPairIfMissing(c.userIdentityPath); err != nil {
+ return fmt.Errorf("couldn't create user identity: %w", err)
+ }
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
+
+ serverPubKey, err := createKeyPairIfMissing(c.serverIdentityPath)
+ if err != nil {
+ return fmt.Errorf("couldn't create server identity: %w", err)
+ }
+ if serverPubKey != nil {
+ c.serverPublicKey = serverPubKey
+ if err := c.addContainerToKnownHosts(); err != nil {
+ return fmt.Errorf("couldn't update known hosts: %w", err)
+ }
+ }
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
+
+ cfg.Hosts = append([]*ssh_config.Host{hostCfg}, cfg.Hosts...)
+ }
+ return nil
+}
+
+func (c *SSHTheater) addContainerToSSHConfig() error {
+ dotSketchPath := filepath.Join(os.Getenv("HOME"), ".sketch")
+ if _, err := os.Stat(dotSketchPath); err != nil {
+ if err := os.Mkdir(dotSketchPath, 0o777); err != nil {
+ return fmt.Errorf("couldn't create %s: %w", dotSketchPath, err)
+ }
+ }
+
+ f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
+ if err != nil {
+ return fmt.Errorf("couldn't open ssh_config: %w", err)
+ }
+ defer f.Close()
+
+ cfg, err := ssh_config.Decode(f)
+ if err != nil {
+ return fmt.Errorf("couldn't decode ssh_config: %w", err)
+ }
+ cntrPattern, err := ssh_config.NewPattern(c.cntrName)
+ if err != nil {
+ return fmt.Errorf("couldn't add pattern to ssh_config: %w", err)
+ }
+
+ // Remove any matches for this container if they already exist.
+ cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
+
+ hostCfg := &ssh_config.Host{Patterns: []*ssh_config.Pattern{cntrPattern}}
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "HostName", Value: c.sshHost})
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "User", Value: "root"})
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "Port", Value: c.sshPort})
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "IdentityFile", Value: c.userIdentityPath})
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.KV{Key: "UserKnownHostsFile", Value: c.knownHostsPath})
+
+ hostCfg.Nodes = append(hostCfg.Nodes, &ssh_config.Empty{})
+ cfg.Hosts = append(cfg.Hosts, hostCfg)
+
+ if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
+ return fmt.Errorf("couldn't add missing host match: %w", err)
+ }
+
+ cfgBytes, err := cfg.MarshalText()
+ if err != nil {
+ return fmt.Errorf("couldn't marshal ssh_config: %w", err)
+ }
+ if err := f.Truncate(0); err != nil {
+ return fmt.Errorf("couldn't truncate ssh_config: %w", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
+ }
+ if _, err := f.Write(cfgBytes); err != nil {
+ return fmt.Errorf("couldn't write ssh_config: %w", err)
+ }
+
+ return nil
+}
+
+func (c *SSHTheater) addContainerToKnownHosts() error {
+ f, err := os.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
+ if err != nil {
+ return fmt.Errorf("couldn't open %s: %w", c.knownHostsPath, err)
+ }
+ defer f.Close()
+
+ line := strings.Join(
+ []string{
+ fmt.Sprintf("[%s]:%s", c.sshHost, c.sshPort),
+ c.serverPublicKey.Type(),
+ base64.StdEncoding.EncodeToString(c.serverPublicKey.Marshal()),
+ }, " ")
+ if _, err := f.Write([]byte(line + "\n")); err != nil {
+ return fmt.Errorf("couldn't write new known_host entry to to %s: %w", c.knownHostsPath, err)
+ }
+
+ return nil
+}
+
+func (c *SSHTheater) removeContainerFromKnownHosts() error {
+ f, err := os.OpenFile(c.knownHostsPath, os.O_RDWR|os.O_CREATE, 0o644)
+ if err != nil {
+ return fmt.Errorf("couldn't open ssh_config: %w", err)
+ }
+ defer f.Close()
+ scanner := bufio.NewScanner(f)
+ lineToRemove := strings.Join(
+ []string{
+ fmt.Sprintf("[%s]:%s", c.sshHost, c.sshPort),
+ c.serverPublicKey.Type(),
+ base64.StdEncoding.EncodeToString(c.serverPublicKey.Marshal()),
+ }, " ")
+
+ outputLines := []string{}
+ for scanner.Scan() {
+ if scanner.Text() == lineToRemove {
+ continue
+ }
+ outputLines = append(outputLines, scanner.Text())
+ }
+ if err := f.Truncate(0); err != nil {
+ return fmt.Errorf("couldn't truncate known_hosts: %w", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return fmt.Errorf("couldn't seek to beginning of known_hosts: %w", err)
+ }
+ if _, err := f.Write([]byte(strings.Join(outputLines, "\n"))); err != nil {
+ return fmt.Errorf("couldn't write updated known_hosts to to %s: %w", c.knownHostsPath, err)
+ }
+
+ return nil
+}
+
+func (c *SSHTheater) Cleanup() error {
+ if err := c.removeContainerFromSSHConfig(); err != nil {
+ return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
+ }
+ if err := c.removeContainerFromKnownHosts(); err != nil {
+ return fmt.Errorf("couldn't remove container from ssh_config: %v\n", err)
+ }
+
+ return nil
+}
+
+func (c *SSHTheater) removeContainerFromSSHConfig() error {
+ f, err := os.OpenFile(c.sshConfigPath, os.O_RDWR|os.O_CREATE, 0o644)
+ if err != nil {
+ return fmt.Errorf("couldn't open ssh_config: %w", err)
+ }
+ defer f.Close()
+
+ cfg, err := ssh_config.Decode(f)
+ if err != nil {
+ return fmt.Errorf("couldn't decode ssh_config: %w", err)
+ }
+ cfg.Hosts = removeFromHosts(c.cntrName, cfg.Hosts)
+
+ if err := c.addSketchHostMatchIfMissing(cfg); err != nil {
+ return fmt.Errorf("couldn't add missing host match: %w", err)
+ }
+
+ cfgBytes, err := cfg.MarshalText()
+ if err != nil {
+ return fmt.Errorf("couldn't marshal ssh_config: %w", err)
+ }
+ if err := f.Truncate(0); err != nil {
+ return fmt.Errorf("couldn't truncate ssh_config: %w", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return fmt.Errorf("couldn't seek to beginning of ssh_config: %w", err)
+ }
+ if _, err := f.Write(cfgBytes); err != nil {
+ return fmt.Errorf("couldn't write ssh_config: %w", err)
+ }
+ return nil
+}