blob: 65eb54bb0b5e32b07ad6a9d83e1d783484ce7cd4 [file] [log] [blame]
Sean McCulloughbaa2b592025-04-23 10:40:08 -07001package server
2
3import (
Sean McCulloughae3480f2025-04-23 15:28:20 -07004 "bytes"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07005 "context"
6 "fmt"
7 "io"
Sean McCullough1d061322025-04-24 09:52:56 -07008 "log/slog"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07009 "os"
10 "os/exec"
Sean McCullough7013e9e2025-05-14 02:03:58 +000011 "slices"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070012 "syscall"
Sean McCullough7013e9e2025-05-14 02:03:58 +000013 "time"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070014 "unsafe"
15
16 "github.com/creack/pty"
17 "github.com/gliderlabs/ssh"
Sean McCulloughcf291fa2025-05-03 17:55:48 -070018 "github.com/pkg/sftp"
Sean McCullough1d061322025-04-24 09:52:56 -070019 gossh "golang.org/x/crypto/ssh"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070020)
21
22func setWinsize(f *os.File, w, h int) {
23 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
24 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
25}
26
Sean McCullough7013e9e2025-05-14 02:03:58 +000027func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte, containerCAKey, hostCertificate []byte) error {
Sean McCulloughae3480f2025-04-23 15:28:20 -070028 // Parse all authorized keys
29 allowedKeys := make([]ssh.PublicKey, 0)
30 rest := authorizedKeys
31 var err error
32
33 // Continue parsing as long as there are bytes left
34 for len(rest) > 0 {
35 var key ssh.PublicKey
36 key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
37 if err != nil {
38 // If we hit an error, check if we have more lines to try
39 if i := bytes.IndexByte(rest, '\n'); i >= 0 {
40 // Skip to the next line and continue
41 rest = rest[i+1:]
42 continue
43 }
44 // No more lines and we hit an error, so stop parsing
45 break
46 }
47 allowedKeys = append(allowedKeys, key)
Sean McCulloughbaa2b592025-04-23 10:40:08 -070048 }
Sean McCulloughae3480f2025-04-23 15:28:20 -070049 if len(allowedKeys) == 0 {
50 return fmt.Errorf("ServeSSH: no valid authorized keys found")
51 }
52
Sean McCullough7013e9e2025-05-14 02:03:58 +000053 // Set up the certificate verifier if containerCAKey is provided
54 var certChecker *gossh.CertChecker
55 var containerCA gossh.PublicKey
56 hasMutualAuth := false
57
58 if len(containerCAKey) > 0 {
59 // Parse container CA public key
60 containerCA, _, _, _, err = ssh.ParseAuthorizedKey(containerCAKey)
61 if err != nil {
62 slog.WarnContext(ctx, "Failed to parse container CA key", slog.String("err", err.Error()))
63 } else {
64 certChecker = &gossh.CertChecker{
65 // Verify if the certificate was signed by our CA
66 IsUserAuthority: func(auth gossh.PublicKey) bool {
67 return bytes.Equal(auth.Marshal(), containerCA.Marshal())
68 },
69 // Check if a certificate has been revoked
70 IsRevoked: func(cert *gossh.Certificate) bool {
71 // We don't implement certificate revocation yet, so no certificates are revoked
72 return false
73 },
74 }
75 slog.InfoContext(ctx, "SSH server configured for mutual authentication with container CA")
76 hasMutualAuth = true
77 }
78 }
79
Sean McCullough1d061322025-04-24 09:52:56 -070080 signer, err := gossh.ParsePrivateKey(hostKey)
81 if err != nil {
82 return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
83 }
Sean McCullough01ed5be2025-04-24 22:46:53 -070084 forwardHandler := &ssh.ForwardedTCPHandler{}
Sean McCullough1d061322025-04-24 09:52:56 -070085
86 server := ssh.Server{
87 LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
88 slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport))
89 return true
90 }),
91 Addr: ":22",
92 ChannelHandlers: ssh.DefaultChannelHandlers,
93 Handler: ssh.Handler(func(s ssh.Session) {
94 ptyReq, winCh, isPty := s.Pty()
95 if isPty {
96 handlePTYSession(ctx, s, ptyReq, winCh)
97 } else {
98 handleSession(ctx, s)
99 }
100 }),
Sean McCullough01ed5be2025-04-24 22:46:53 -0700101 RequestHandlers: map[string]ssh.RequestHandler{
102 "tcpip-forward": forwardHandler.HandleSSHRequest,
103 "cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
104 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700105 SubsystemHandlers: map[string]ssh.SubsystemHandler{
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700106 "sftp": func(s ssh.Session) {
107 handleSftp(ctx, s)
108 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700109 },
Sean McCullough1d061322025-04-24 09:52:56 -0700110 HostSigners: []ssh.Signer{signer},
111 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
Sean McCullough7013e9e2025-05-14 02:03:58 +0000112 // First, check if it's a certificate signed by our container CA
113 if hasMutualAuth {
114 if cert, ok := key.(*gossh.Certificate); ok {
115 // It's a certificate, check if it's valid and signed by our CA
116 if certChecker.IsUserAuthority(cert.SignatureKey) {
117 // Verify the certificate validity time
118 now := time.Now()
119 if now.Before(time.Unix(int64(cert.ValidBefore), 0)) &&
120 now.After(time.Unix(int64(cert.ValidAfter), 0)) {
121 // Check if the certificate has the right principal
122 if sliceContains(cert.ValidPrincipals, "root") {
123 slog.InfoContext(ctx, "SSH client authenticated with valid certificate")
124 return true
125 }
126 slog.WarnContext(ctx, "Certificate lacks root principal",
127 slog.Any("principals", cert.ValidPrincipals))
128 } else {
129 slog.WarnContext(ctx, "Certificate time validation failed",
130 slog.Time("now", now),
131 slog.Time("valid_after", time.Unix(int64(cert.ValidAfter), 0)),
132 slog.Time("valid_before", time.Unix(int64(cert.ValidBefore), 0)))
133 }
134 }
135 }
136 }
137
138 // Standard key-based authentication fallback
Sean McCulloughae3480f2025-04-23 15:28:20 -0700139 for _, allowedKey := range allowedKeys {
140 if ssh.KeysEqual(key, allowedKey) {
Sean McCullough1d061322025-04-24 09:52:56 -0700141 slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
Sean McCulloughae3480f2025-04-23 15:28:20 -0700142 return true
143 }
144 }
145 return false
Sean McCullough1d061322025-04-24 09:52:56 -0700146 },
147 }
148
149 // This ChannelHandler is necessary for vscode's Remote-SSH connections to work.
150 // Without it the new VSC window will open, but you'll get an error that says something
151 // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server."
152 server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
153
154 return server.ListenAndServe()
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700155}
156
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700157func handleSftp(ctx context.Context, sess ssh.Session) {
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700158 debugStream := io.Discard
159 serverOptions := []sftp.ServerOption{
160 sftp.WithDebug(debugStream),
161 }
162 server, err := sftp.NewServer(
163 sess,
164 serverOptions...,
165 )
166 if err != nil {
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700167 slog.ErrorContext(ctx, "sftp server init error", slog.Any("err", err))
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700168 return
169 }
170 if err := server.Serve(); err == io.EOF {
171 server.Close()
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700172 } else if err != nil {
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700173 slog.ErrorContext(ctx, "sftp server completed with error", slog.Any("err", err))
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700174 }
175}
176
Sean McCullough1d061322025-04-24 09:52:56 -0700177func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700178 cmd := exec.CommandContext(ctx, "/bin/bash")
Sean McCullough1d061322025-04-24 09:52:56 -0700179 slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq))
180
Sean McCullough22bd8eb2025-04-28 10:36:37 -0700181 cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
Sean McCullough1d061322025-04-24 09:52:56 -0700182 f, err := pty.Start(cmd)
183 if err != nil {
184 fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err)
185 s.Exit(1)
186 return
187 }
188
189 go func() {
190 for win := range winCh {
191 setWinsize(f, win.Width, win.Height)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700192 }
Sean McCullough1d061322025-04-24 09:52:56 -0700193 }()
194 go func() {
195 io.Copy(f, s) // stdin
196 }()
197 io.Copy(s, f) // stdout
198
199 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
200 // the pipe I/O before we call cmd.Wait?
201 if err := cmd.Wait(); err != nil {
202 slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error()))
203 s.Exit(1)
204 }
205}
206
207func handleSession(ctx context.Context, s ssh.Session) {
208 var cmd *exec.Cmd
209 slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command()))
210 if len(s.Command()) == 0 {
211 cmd = exec.CommandContext(ctx, "/bin/bash")
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700212 } else {
Sean McCullough1d061322025-04-24 09:52:56 -0700213 cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...)
214 }
215 stdinPipe, err := cmd.StdinPipe()
216 if err != nil {
217 slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error()))
218 fmt.Fprintf(s, "cmd.StdinPipe error: %v", err)
219
220 s.Exit(1)
221 return
222 }
223 defer stdinPipe.Close()
224
225 stdoutPipe, err := cmd.StdoutPipe()
226 if err != nil {
227 slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error()))
228 fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err)
229 s.Exit(1)
230 return
231 }
232 defer stdoutPipe.Close()
233
234 stderrPipe, err := cmd.StderrPipe()
235 if err != nil {
236 slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error()))
237 fmt.Fprintf(s, "cmd.StderrPipe error: %v", err)
238 s.Exit(1)
239 return
240 }
241 defer stderrPipe.Close()
242
243 if err := cmd.Start(); err != nil {
244 slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error()))
245 fmt.Fprintf(s, "cmd.Start error: %v", err)
246 s.Exit(1)
247 return
248 }
249
250 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
251 // the pipe I/O before we call cmd.Wait?
252 go func() {
253 io.Copy(s, stderrPipe)
254 }()
255 go func() {
256 io.Copy(s, stdoutPipe)
257 }()
258 io.Copy(stdinPipe, s)
259
260 if err := cmd.Wait(); err != nil {
261 slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error()))
262 fmt.Fprintf(s, "cmd.Wait error: %v", err)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700263 s.Exit(1)
264 }
265}
Sean McCullough7013e9e2025-05-14 02:03:58 +0000266
267// sliceContains checks if a string slice contains a specific string
268func sliceContains(slice []string, value string) bool {
269 return slices.Contains(slice, value)
270}