| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 1 | package server |
| 2 | |
| 3 | import ( |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 4 | "bytes" |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 5 | "context" |
| 6 | "fmt" |
| 7 | "io" |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 8 | "log/slog" |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 9 | "os" |
| 10 | "os/exec" |
| 11 | "syscall" |
| 12 | "unsafe" |
| 13 | |
| 14 | "github.com/creack/pty" |
| 15 | "github.com/gliderlabs/ssh" |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 16 | gossh "golang.org/x/crypto/ssh" |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 17 | ) |
| 18 | |
| 19 | func setWinsize(f *os.File, w, h int) { |
| 20 | syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), |
| 21 | uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) |
| 22 | } |
| 23 | |
| 24 | func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte) error { |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 25 | // Parse all authorized keys |
| 26 | allowedKeys := make([]ssh.PublicKey, 0) |
| 27 | rest := authorizedKeys |
| 28 | var err error |
| 29 | |
| 30 | // Continue parsing as long as there are bytes left |
| 31 | for len(rest) > 0 { |
| 32 | var key ssh.PublicKey |
| 33 | key, _, _, rest, err = ssh.ParseAuthorizedKey(rest) |
| 34 | if err != nil { |
| 35 | // If we hit an error, check if we have more lines to try |
| 36 | if i := bytes.IndexByte(rest, '\n'); i >= 0 { |
| 37 | // Skip to the next line and continue |
| 38 | rest = rest[i+1:] |
| 39 | continue |
| 40 | } |
| 41 | // No more lines and we hit an error, so stop parsing |
| 42 | break |
| 43 | } |
| 44 | allowedKeys = append(allowedKeys, key) |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 45 | } |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 46 | if len(allowedKeys) == 0 { |
| 47 | return fmt.Errorf("ServeSSH: no valid authorized keys found") |
| 48 | } |
| 49 | |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 50 | signer, err := gossh.ParsePrivateKey(hostKey) |
| 51 | if err != nil { |
| 52 | return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err) |
| 53 | } |
| 54 | |
| 55 | server := ssh.Server{ |
| 56 | LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool { |
| 57 | slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport)) |
| 58 | return true |
| 59 | }), |
| 60 | Addr: ":22", |
| 61 | ChannelHandlers: ssh.DefaultChannelHandlers, |
| 62 | Handler: ssh.Handler(func(s ssh.Session) { |
| 63 | ptyReq, winCh, isPty := s.Pty() |
| 64 | if isPty { |
| 65 | handlePTYSession(ctx, s, ptyReq, winCh) |
| 66 | } else { |
| 67 | handleSession(ctx, s) |
| 68 | } |
| 69 | }), |
| 70 | HostSigners: []ssh.Signer{signer}, |
| 71 | PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 72 | // Check if the provided key matches any of our allowed keys |
| 73 | for _, allowedKey := range allowedKeys { |
| 74 | if ssh.KeysEqual(key, allowedKey) { |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 75 | slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal()))) |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 76 | return true |
| 77 | } |
| 78 | } |
| 79 | return false |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 80 | }, |
| 81 | } |
| 82 | |
| 83 | // This ChannelHandler is necessary for vscode's Remote-SSH connections to work. |
| 84 | // Without it the new VSC window will open, but you'll get an error that says something |
| 85 | // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server." |
| 86 | server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler |
| 87 | |
| 88 | return server.ListenAndServe() |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 89 | } |
| 90 | |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 91 | func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) { |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 92 | cmd := exec.CommandContext(ctx, "/bin/bash") |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 93 | slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq)) |
| 94 | |
| 95 | cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) |
| 96 | f, err := pty.Start(cmd) |
| 97 | if err != nil { |
| 98 | fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err) |
| 99 | s.Exit(1) |
| 100 | return |
| 101 | } |
| 102 | |
| 103 | go func() { |
| 104 | for win := range winCh { |
| 105 | setWinsize(f, win.Width, win.Height) |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 106 | } |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 107 | }() |
| 108 | go func() { |
| 109 | io.Copy(f, s) // stdin |
| 110 | }() |
| 111 | io.Copy(s, f) // stdout |
| 112 | |
| 113 | // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish |
| 114 | // the pipe I/O before we call cmd.Wait? |
| 115 | if err := cmd.Wait(); err != nil { |
| 116 | slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error())) |
| 117 | s.Exit(1) |
| 118 | } |
| 119 | } |
| 120 | |
| 121 | func handleSession(ctx context.Context, s ssh.Session) { |
| 122 | var cmd *exec.Cmd |
| 123 | slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command())) |
| 124 | if len(s.Command()) == 0 { |
| 125 | cmd = exec.CommandContext(ctx, "/bin/bash") |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 126 | } else { |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 127 | cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...) |
| 128 | } |
| 129 | stdinPipe, err := cmd.StdinPipe() |
| 130 | if err != nil { |
| 131 | slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error())) |
| 132 | fmt.Fprintf(s, "cmd.StdinPipe error: %v", err) |
| 133 | |
| 134 | s.Exit(1) |
| 135 | return |
| 136 | } |
| 137 | defer stdinPipe.Close() |
| 138 | |
| 139 | stdoutPipe, err := cmd.StdoutPipe() |
| 140 | if err != nil { |
| 141 | slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error())) |
| 142 | fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err) |
| 143 | s.Exit(1) |
| 144 | return |
| 145 | } |
| 146 | defer stdoutPipe.Close() |
| 147 | |
| 148 | stderrPipe, err := cmd.StderrPipe() |
| 149 | if err != nil { |
| 150 | slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error())) |
| 151 | fmt.Fprintf(s, "cmd.StderrPipe error: %v", err) |
| 152 | s.Exit(1) |
| 153 | return |
| 154 | } |
| 155 | defer stderrPipe.Close() |
| 156 | |
| 157 | if err := cmd.Start(); err != nil { |
| 158 | slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error())) |
| 159 | fmt.Fprintf(s, "cmd.Start error: %v", err) |
| 160 | s.Exit(1) |
| 161 | return |
| 162 | } |
| 163 | |
| 164 | // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish |
| 165 | // the pipe I/O before we call cmd.Wait? |
| 166 | go func() { |
| 167 | io.Copy(s, stderrPipe) |
| 168 | }() |
| 169 | go func() { |
| 170 | io.Copy(s, stdoutPipe) |
| 171 | }() |
| 172 | io.Copy(stdinPipe, s) |
| 173 | |
| 174 | if err := cmd.Wait(); err != nil { |
| 175 | slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error())) |
| 176 | fmt.Fprintf(s, "cmd.Wait error: %v", err) |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 177 | s.Exit(1) |
| 178 | } |
| 179 | } |