| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 1 | package server |
| 2 | |
| 3 | import ( |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 4 | "bytes" |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 5 | "context" |
| 6 | "fmt" |
| 7 | "io" |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 8 | "log/slog" |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 9 | "os" |
| 10 | "os/exec" |
| Sean McCullough | 7013e9e | 2025-05-14 02:03:58 +0000 | [diff] [blame^] | 11 | "slices" |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 12 | "syscall" |
| Sean McCullough | 7013e9e | 2025-05-14 02:03:58 +0000 | [diff] [blame^] | 13 | "time" |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 14 | "unsafe" |
| 15 | |
| 16 | "github.com/creack/pty" |
| 17 | "github.com/gliderlabs/ssh" |
| Sean McCullough | cf291fa | 2025-05-03 17:55:48 -0700 | [diff] [blame] | 18 | "github.com/pkg/sftp" |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 19 | gossh "golang.org/x/crypto/ssh" |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 20 | ) |
| 21 | |
| 22 | func setWinsize(f *os.File, w, h int) { |
| 23 | syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), |
| 24 | uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) |
| 25 | } |
| 26 | |
| Sean McCullough | 7013e9e | 2025-05-14 02:03:58 +0000 | [diff] [blame^] | 27 | func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte, containerCAKey, hostCertificate []byte) error { |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 28 | // Parse all authorized keys |
| 29 | allowedKeys := make([]ssh.PublicKey, 0) |
| 30 | rest := authorizedKeys |
| 31 | var err error |
| 32 | |
| 33 | // Continue parsing as long as there are bytes left |
| 34 | for len(rest) > 0 { |
| 35 | var key ssh.PublicKey |
| 36 | key, _, _, rest, err = ssh.ParseAuthorizedKey(rest) |
| 37 | if err != nil { |
| 38 | // If we hit an error, check if we have more lines to try |
| 39 | if i := bytes.IndexByte(rest, '\n'); i >= 0 { |
| 40 | // Skip to the next line and continue |
| 41 | rest = rest[i+1:] |
| 42 | continue |
| 43 | } |
| 44 | // No more lines and we hit an error, so stop parsing |
| 45 | break |
| 46 | } |
| 47 | allowedKeys = append(allowedKeys, key) |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 48 | } |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 49 | if len(allowedKeys) == 0 { |
| 50 | return fmt.Errorf("ServeSSH: no valid authorized keys found") |
| 51 | } |
| 52 | |
| Sean McCullough | 7013e9e | 2025-05-14 02:03:58 +0000 | [diff] [blame^] | 53 | // Set up the certificate verifier if containerCAKey is provided |
| 54 | var certChecker *gossh.CertChecker |
| 55 | var containerCA gossh.PublicKey |
| 56 | hasMutualAuth := false |
| 57 | |
| 58 | if len(containerCAKey) > 0 { |
| 59 | // Parse container CA public key |
| 60 | containerCA, _, _, _, err = ssh.ParseAuthorizedKey(containerCAKey) |
| 61 | if err != nil { |
| 62 | slog.WarnContext(ctx, "Failed to parse container CA key", slog.String("err", err.Error())) |
| 63 | } else { |
| 64 | certChecker = &gossh.CertChecker{ |
| 65 | // Verify if the certificate was signed by our CA |
| 66 | IsUserAuthority: func(auth gossh.PublicKey) bool { |
| 67 | return bytes.Equal(auth.Marshal(), containerCA.Marshal()) |
| 68 | }, |
| 69 | // Check if a certificate has been revoked |
| 70 | IsRevoked: func(cert *gossh.Certificate) bool { |
| 71 | // We don't implement certificate revocation yet, so no certificates are revoked |
| 72 | return false |
| 73 | }, |
| 74 | } |
| 75 | slog.InfoContext(ctx, "SSH server configured for mutual authentication with container CA") |
| 76 | hasMutualAuth = true |
| 77 | } |
| 78 | } |
| 79 | |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 80 | signer, err := gossh.ParsePrivateKey(hostKey) |
| 81 | if err != nil { |
| 82 | return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err) |
| 83 | } |
| Sean McCullough | 01ed5be | 2025-04-24 22:46:53 -0700 | [diff] [blame] | 84 | forwardHandler := &ssh.ForwardedTCPHandler{} |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 85 | |
| 86 | server := ssh.Server{ |
| 87 | LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool { |
| 88 | slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport)) |
| 89 | return true |
| 90 | }), |
| 91 | Addr: ":22", |
| 92 | ChannelHandlers: ssh.DefaultChannelHandlers, |
| 93 | Handler: ssh.Handler(func(s ssh.Session) { |
| 94 | ptyReq, winCh, isPty := s.Pty() |
| 95 | if isPty { |
| 96 | handlePTYSession(ctx, s, ptyReq, winCh) |
| 97 | } else { |
| 98 | handleSession(ctx, s) |
| 99 | } |
| 100 | }), |
| Sean McCullough | 01ed5be | 2025-04-24 22:46:53 -0700 | [diff] [blame] | 101 | RequestHandlers: map[string]ssh.RequestHandler{ |
| 102 | "tcpip-forward": forwardHandler.HandleSSHRequest, |
| 103 | "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, |
| 104 | }, |
| Sean McCullough | cf291fa | 2025-05-03 17:55:48 -0700 | [diff] [blame] | 105 | SubsystemHandlers: map[string]ssh.SubsystemHandler{ |
| Sean McCullough | bdfb126 | 2025-05-03 20:15:41 -0700 | [diff] [blame] | 106 | "sftp": func(s ssh.Session) { |
| 107 | handleSftp(ctx, s) |
| 108 | }, |
| Sean McCullough | cf291fa | 2025-05-03 17:55:48 -0700 | [diff] [blame] | 109 | }, |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 110 | HostSigners: []ssh.Signer{signer}, |
| 111 | PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { |
| Sean McCullough | 7013e9e | 2025-05-14 02:03:58 +0000 | [diff] [blame^] | 112 | // First, check if it's a certificate signed by our container CA |
| 113 | if hasMutualAuth { |
| 114 | if cert, ok := key.(*gossh.Certificate); ok { |
| 115 | // It's a certificate, check if it's valid and signed by our CA |
| 116 | if certChecker.IsUserAuthority(cert.SignatureKey) { |
| 117 | // Verify the certificate validity time |
| 118 | now := time.Now() |
| 119 | if now.Before(time.Unix(int64(cert.ValidBefore), 0)) && |
| 120 | now.After(time.Unix(int64(cert.ValidAfter), 0)) { |
| 121 | // Check if the certificate has the right principal |
| 122 | if sliceContains(cert.ValidPrincipals, "root") { |
| 123 | slog.InfoContext(ctx, "SSH client authenticated with valid certificate") |
| 124 | return true |
| 125 | } |
| 126 | slog.WarnContext(ctx, "Certificate lacks root principal", |
| 127 | slog.Any("principals", cert.ValidPrincipals)) |
| 128 | } else { |
| 129 | slog.WarnContext(ctx, "Certificate time validation failed", |
| 130 | slog.Time("now", now), |
| 131 | slog.Time("valid_after", time.Unix(int64(cert.ValidAfter), 0)), |
| 132 | slog.Time("valid_before", time.Unix(int64(cert.ValidBefore), 0))) |
| 133 | } |
| 134 | } |
| 135 | } |
| 136 | } |
| 137 | |
| 138 | // Standard key-based authentication fallback |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 139 | for _, allowedKey := range allowedKeys { |
| 140 | if ssh.KeysEqual(key, allowedKey) { |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 141 | slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal()))) |
| Sean McCullough | ae3480f | 2025-04-23 15:28:20 -0700 | [diff] [blame] | 142 | return true |
| 143 | } |
| 144 | } |
| 145 | return false |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 146 | }, |
| 147 | } |
| 148 | |
| 149 | // This ChannelHandler is necessary for vscode's Remote-SSH connections to work. |
| 150 | // Without it the new VSC window will open, but you'll get an error that says something |
| 151 | // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server." |
| 152 | server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler |
| 153 | |
| 154 | return server.ListenAndServe() |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 155 | } |
| 156 | |
| Sean McCullough | bdfb126 | 2025-05-03 20:15:41 -0700 | [diff] [blame] | 157 | func handleSftp(ctx context.Context, sess ssh.Session) { |
| Sean McCullough | cf291fa | 2025-05-03 17:55:48 -0700 | [diff] [blame] | 158 | debugStream := io.Discard |
| 159 | serverOptions := []sftp.ServerOption{ |
| 160 | sftp.WithDebug(debugStream), |
| 161 | } |
| 162 | server, err := sftp.NewServer( |
| 163 | sess, |
| 164 | serverOptions..., |
| 165 | ) |
| 166 | if err != nil { |
| Sean McCullough | bdfb126 | 2025-05-03 20:15:41 -0700 | [diff] [blame] | 167 | slog.ErrorContext(ctx, "sftp server init error", slog.Any("err", err)) |
| Sean McCullough | cf291fa | 2025-05-03 17:55:48 -0700 | [diff] [blame] | 168 | return |
| 169 | } |
| 170 | if err := server.Serve(); err == io.EOF { |
| 171 | server.Close() |
| Sean McCullough | cf291fa | 2025-05-03 17:55:48 -0700 | [diff] [blame] | 172 | } else if err != nil { |
| Sean McCullough | bdfb126 | 2025-05-03 20:15:41 -0700 | [diff] [blame] | 173 | slog.ErrorContext(ctx, "sftp server completed with error", slog.Any("err", err)) |
| Sean McCullough | cf291fa | 2025-05-03 17:55:48 -0700 | [diff] [blame] | 174 | } |
| 175 | } |
| 176 | |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 177 | func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) { |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 178 | cmd := exec.CommandContext(ctx, "/bin/bash") |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 179 | slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq)) |
| 180 | |
| Sean McCullough | 22bd8eb | 2025-04-28 10:36:37 -0700 | [diff] [blame] | 181 | cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term)) |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 182 | f, err := pty.Start(cmd) |
| 183 | if err != nil { |
| 184 | fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err) |
| 185 | s.Exit(1) |
| 186 | return |
| 187 | } |
| 188 | |
| 189 | go func() { |
| 190 | for win := range winCh { |
| 191 | setWinsize(f, win.Width, win.Height) |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 192 | } |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 193 | }() |
| 194 | go func() { |
| 195 | io.Copy(f, s) // stdin |
| 196 | }() |
| 197 | io.Copy(s, f) // stdout |
| 198 | |
| 199 | // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish |
| 200 | // the pipe I/O before we call cmd.Wait? |
| 201 | if err := cmd.Wait(); err != nil { |
| 202 | slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error())) |
| 203 | s.Exit(1) |
| 204 | } |
| 205 | } |
| 206 | |
| 207 | func handleSession(ctx context.Context, s ssh.Session) { |
| 208 | var cmd *exec.Cmd |
| 209 | slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command())) |
| 210 | if len(s.Command()) == 0 { |
| 211 | cmd = exec.CommandContext(ctx, "/bin/bash") |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 212 | } else { |
| Sean McCullough | 1d06132 | 2025-04-24 09:52:56 -0700 | [diff] [blame] | 213 | cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...) |
| 214 | } |
| 215 | stdinPipe, err := cmd.StdinPipe() |
| 216 | if err != nil { |
| 217 | slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error())) |
| 218 | fmt.Fprintf(s, "cmd.StdinPipe error: %v", err) |
| 219 | |
| 220 | s.Exit(1) |
| 221 | return |
| 222 | } |
| 223 | defer stdinPipe.Close() |
| 224 | |
| 225 | stdoutPipe, err := cmd.StdoutPipe() |
| 226 | if err != nil { |
| 227 | slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error())) |
| 228 | fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err) |
| 229 | s.Exit(1) |
| 230 | return |
| 231 | } |
| 232 | defer stdoutPipe.Close() |
| 233 | |
| 234 | stderrPipe, err := cmd.StderrPipe() |
| 235 | if err != nil { |
| 236 | slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error())) |
| 237 | fmt.Fprintf(s, "cmd.StderrPipe error: %v", err) |
| 238 | s.Exit(1) |
| 239 | return |
| 240 | } |
| 241 | defer stderrPipe.Close() |
| 242 | |
| 243 | if err := cmd.Start(); err != nil { |
| 244 | slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error())) |
| 245 | fmt.Fprintf(s, "cmd.Start error: %v", err) |
| 246 | s.Exit(1) |
| 247 | return |
| 248 | } |
| 249 | |
| 250 | // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish |
| 251 | // the pipe I/O before we call cmd.Wait? |
| 252 | go func() { |
| 253 | io.Copy(s, stderrPipe) |
| 254 | }() |
| 255 | go func() { |
| 256 | io.Copy(s, stdoutPipe) |
| 257 | }() |
| 258 | io.Copy(stdinPipe, s) |
| 259 | |
| 260 | if err := cmd.Wait(); err != nil { |
| 261 | slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error())) |
| 262 | fmt.Fprintf(s, "cmd.Wait error: %v", err) |
| Sean McCullough | baa2b59 | 2025-04-23 10:40:08 -0700 | [diff] [blame] | 263 | s.Exit(1) |
| 264 | } |
| 265 | } |
| Sean McCullough | 7013e9e | 2025-05-14 02:03:58 +0000 | [diff] [blame^] | 266 | |
| 267 | // sliceContains checks if a string slice contains a specific string |
| 268 | func sliceContains(slice []string, value string) bool { |
| 269 | return slices.Contains(slice, value) |
| 270 | } |