blob: b6c40bf068c69a0e0bc5a902a96d4106b63bd68a [file] [log] [blame]
Sean McCulloughbaa2b592025-04-23 10:40:08 -07001package server
2
3import (
Sean McCulloughae3480f2025-04-23 15:28:20 -07004 "bytes"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07005 "context"
6 "fmt"
7 "io"
Sean McCullough1d061322025-04-24 09:52:56 -07008 "log/slog"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07009 "os"
10 "os/exec"
Sean McCullough7013e9e2025-05-14 02:03:58 +000011 "slices"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070012 "syscall"
Sean McCullough7013e9e2025-05-14 02:03:58 +000013 "time"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070014 "unsafe"
15
16 "github.com/creack/pty"
17 "github.com/gliderlabs/ssh"
Sean McCulloughcf291fa2025-05-03 17:55:48 -070018 "github.com/pkg/sftp"
Sean McCullough1d061322025-04-24 09:52:56 -070019 gossh "golang.org/x/crypto/ssh"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070020)
21
22func setWinsize(f *os.File, w, h int) {
23 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
24 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
25}
26
Sean McCullough7013e9e2025-05-14 02:03:58 +000027func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte, containerCAKey, hostCertificate []byte) error {
Sean McCulloughae3480f2025-04-23 15:28:20 -070028 // Parse all authorized keys
29 allowedKeys := make([]ssh.PublicKey, 0)
30 rest := authorizedKeys
31 var err error
32
33 // Continue parsing as long as there are bytes left
34 for len(rest) > 0 {
35 var key ssh.PublicKey
36 key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
37 if err != nil {
38 // If we hit an error, check if we have more lines to try
39 if i := bytes.IndexByte(rest, '\n'); i >= 0 {
40 // Skip to the next line and continue
41 rest = rest[i+1:]
42 continue
43 }
44 // No more lines and we hit an error, so stop parsing
45 break
46 }
47 allowedKeys = append(allowedKeys, key)
Sean McCulloughbaa2b592025-04-23 10:40:08 -070048 }
Sean McCulloughae3480f2025-04-23 15:28:20 -070049 if len(allowedKeys) == 0 {
50 return fmt.Errorf("ServeSSH: no valid authorized keys found")
51 }
52
Sean McCullough7013e9e2025-05-14 02:03:58 +000053 // Set up the certificate verifier if containerCAKey is provided
54 var certChecker *gossh.CertChecker
55 var containerCA gossh.PublicKey
56 hasMutualAuth := false
57
58 if len(containerCAKey) > 0 {
59 // Parse container CA public key
60 containerCA, _, _, _, err = ssh.ParseAuthorizedKey(containerCAKey)
61 if err != nil {
62 slog.WarnContext(ctx, "Failed to parse container CA key", slog.String("err", err.Error()))
63 } else {
64 certChecker = &gossh.CertChecker{
65 // Verify if the certificate was signed by our CA
66 IsUserAuthority: func(auth gossh.PublicKey) bool {
67 return bytes.Equal(auth.Marshal(), containerCA.Marshal())
68 },
69 // Check if a certificate has been revoked
70 IsRevoked: func(cert *gossh.Certificate) bool {
71 // We don't implement certificate revocation yet, so no certificates are revoked
72 return false
73 },
74 }
75 slog.InfoContext(ctx, "SSH server configured for mutual authentication with container CA")
76 hasMutualAuth = true
77 }
78 }
79
Sean McCullough1d061322025-04-24 09:52:56 -070080 signer, err := gossh.ParsePrivateKey(hostKey)
81 if err != nil {
82 return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
83 }
Sean McCullough01ed5be2025-04-24 22:46:53 -070084 forwardHandler := &ssh.ForwardedTCPHandler{}
Sean McCullough1d061322025-04-24 09:52:56 -070085
86 server := ssh.Server{
87 LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
88 slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport))
89 return true
90 }),
Philip Zeyligere84d5c72025-05-30 09:32:55 -070091 ReversePortForwardingCallback: ssh.ReversePortForwardingCallback(func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
92 slog.DebugContext(ctx, "Accepted reverse forward", slog.Any("bindHost", bindHost), slog.Any("bindPort", bindPort))
93 return true
94 }),
Sean McCullough1d061322025-04-24 09:52:56 -070095 Addr: ":22",
96 ChannelHandlers: ssh.DefaultChannelHandlers,
97 Handler: ssh.Handler(func(s ssh.Session) {
98 ptyReq, winCh, isPty := s.Pty()
99 if isPty {
100 handlePTYSession(ctx, s, ptyReq, winCh)
101 } else {
102 handleSession(ctx, s)
103 }
104 }),
Sean McCullough01ed5be2025-04-24 22:46:53 -0700105 RequestHandlers: map[string]ssh.RequestHandler{
106 "tcpip-forward": forwardHandler.HandleSSHRequest,
107 "cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
108 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700109 SubsystemHandlers: map[string]ssh.SubsystemHandler{
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700110 "sftp": func(s ssh.Session) {
111 handleSftp(ctx, s)
112 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700113 },
Sean McCullough1d061322025-04-24 09:52:56 -0700114 HostSigners: []ssh.Signer{signer},
115 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
Sean McCullough7013e9e2025-05-14 02:03:58 +0000116 // First, check if it's a certificate signed by our container CA
117 if hasMutualAuth {
118 if cert, ok := key.(*gossh.Certificate); ok {
119 // It's a certificate, check if it's valid and signed by our CA
120 if certChecker.IsUserAuthority(cert.SignatureKey) {
121 // Verify the certificate validity time
122 now := time.Now()
123 if now.Before(time.Unix(int64(cert.ValidBefore), 0)) &&
124 now.After(time.Unix(int64(cert.ValidAfter), 0)) {
125 // Check if the certificate has the right principal
126 if sliceContains(cert.ValidPrincipals, "root") {
127 slog.InfoContext(ctx, "SSH client authenticated with valid certificate")
128 return true
129 }
130 slog.WarnContext(ctx, "Certificate lacks root principal",
131 slog.Any("principals", cert.ValidPrincipals))
132 } else {
133 slog.WarnContext(ctx, "Certificate time validation failed",
134 slog.Time("now", now),
135 slog.Time("valid_after", time.Unix(int64(cert.ValidAfter), 0)),
136 slog.Time("valid_before", time.Unix(int64(cert.ValidBefore), 0)))
137 }
138 }
139 }
140 }
141
142 // Standard key-based authentication fallback
Sean McCulloughae3480f2025-04-23 15:28:20 -0700143 for _, allowedKey := range allowedKeys {
144 if ssh.KeysEqual(key, allowedKey) {
Sean McCullough1d061322025-04-24 09:52:56 -0700145 slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
Sean McCulloughae3480f2025-04-23 15:28:20 -0700146 return true
147 }
148 }
149 return false
Sean McCullough1d061322025-04-24 09:52:56 -0700150 },
151 }
152
153 // This ChannelHandler is necessary for vscode's Remote-SSH connections to work.
154 // Without it the new VSC window will open, but you'll get an error that says something
155 // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server."
156 server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
157
158 return server.ListenAndServe()
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700159}
160
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700161func handleSftp(ctx context.Context, sess ssh.Session) {
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700162 debugStream := io.Discard
163 serverOptions := []sftp.ServerOption{
164 sftp.WithDebug(debugStream),
165 }
166 server, err := sftp.NewServer(
167 sess,
168 serverOptions...,
169 )
170 if err != nil {
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700171 slog.ErrorContext(ctx, "sftp server init error", slog.Any("err", err))
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700172 return
173 }
174 if err := server.Serve(); err == io.EOF {
175 server.Close()
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700176 } else if err != nil {
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700177 slog.ErrorContext(ctx, "sftp server completed with error", slog.Any("err", err))
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700178 }
179}
180
Sean McCullough1d061322025-04-24 09:52:56 -0700181func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700182 cmd := exec.CommandContext(ctx, "/bin/bash")
Sean McCullough1d061322025-04-24 09:52:56 -0700183 slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq))
184
Sean McCullough22bd8eb2025-04-28 10:36:37 -0700185 cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
Sean McCullough1d061322025-04-24 09:52:56 -0700186 f, err := pty.Start(cmd)
187 if err != nil {
188 fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err)
189 s.Exit(1)
190 return
191 }
192
193 go func() {
194 for win := range winCh {
195 setWinsize(f, win.Width, win.Height)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700196 }
Sean McCullough1d061322025-04-24 09:52:56 -0700197 }()
198 go func() {
199 io.Copy(f, s) // stdin
200 }()
201 io.Copy(s, f) // stdout
202
203 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
204 // the pipe I/O before we call cmd.Wait?
205 if err := cmd.Wait(); err != nil {
206 slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error()))
207 s.Exit(1)
208 }
209}
210
211func handleSession(ctx context.Context, s ssh.Session) {
212 var cmd *exec.Cmd
213 slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command()))
214 if len(s.Command()) == 0 {
215 cmd = exec.CommandContext(ctx, "/bin/bash")
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700216 } else {
Sean McCullough1d061322025-04-24 09:52:56 -0700217 cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...)
218 }
219 stdinPipe, err := cmd.StdinPipe()
220 if err != nil {
221 slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error()))
222 fmt.Fprintf(s, "cmd.StdinPipe error: %v", err)
223
224 s.Exit(1)
225 return
226 }
227 defer stdinPipe.Close()
228
229 stdoutPipe, err := cmd.StdoutPipe()
230 if err != nil {
231 slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error()))
232 fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err)
233 s.Exit(1)
234 return
235 }
236 defer stdoutPipe.Close()
237
238 stderrPipe, err := cmd.StderrPipe()
239 if err != nil {
240 slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error()))
241 fmt.Fprintf(s, "cmd.StderrPipe error: %v", err)
242 s.Exit(1)
243 return
244 }
245 defer stderrPipe.Close()
246
247 if err := cmd.Start(); err != nil {
248 slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error()))
249 fmt.Fprintf(s, "cmd.Start error: %v", err)
250 s.Exit(1)
251 return
252 }
253
254 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
255 // the pipe I/O before we call cmd.Wait?
256 go func() {
257 io.Copy(s, stderrPipe)
258 }()
259 go func() {
260 io.Copy(s, stdoutPipe)
261 }()
262 io.Copy(stdinPipe, s)
263
264 if err := cmd.Wait(); err != nil {
265 slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error()))
266 fmt.Fprintf(s, "cmd.Wait error: %v", err)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700267 s.Exit(1)
268 }
269}
Sean McCullough7013e9e2025-05-14 02:03:58 +0000270
271// sliceContains checks if a string slice contains a specific string
272func sliceContains(slice []string, value string) bool {
273 return slices.Contains(slice, value)
274}