blob: 6714af6e3c28b014e4d4650b09ce332511835ba9 [file] [log] [blame]
Sean McCulloughbaa2b592025-04-23 10:40:08 -07001package server
2
3import (
Sean McCulloughae3480f2025-04-23 15:28:20 -07004 "bytes"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07005 "context"
6 "fmt"
7 "io"
Sean McCullough1d061322025-04-24 09:52:56 -07008 "log/slog"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07009 "os"
10 "os/exec"
Sean McCullough7013e9e2025-05-14 02:03:58 +000011 "slices"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070012 "syscall"
Sean McCullough7013e9e2025-05-14 02:03:58 +000013 "time"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070014 "unsafe"
15
16 "github.com/creack/pty"
17 "github.com/gliderlabs/ssh"
Sean McCulloughcf291fa2025-05-03 17:55:48 -070018 "github.com/pkg/sftp"
Sean McCullough1d061322025-04-24 09:52:56 -070019 gossh "golang.org/x/crypto/ssh"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070020)
21
22func setWinsize(f *os.File, w, h int) {
23 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
24 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
25}
26
Sean McCullough7013e9e2025-05-14 02:03:58 +000027func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte, containerCAKey, hostCertificate []byte) error {
Sean McCulloughae3480f2025-04-23 15:28:20 -070028 // Parse all authorized keys
29 allowedKeys := make([]ssh.PublicKey, 0)
30 rest := authorizedKeys
31 var err error
32
33 // Continue parsing as long as there are bytes left
34 for len(rest) > 0 {
35 var key ssh.PublicKey
36 key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
37 if err != nil {
38 // If we hit an error, check if we have more lines to try
39 if i := bytes.IndexByte(rest, '\n'); i >= 0 {
40 // Skip to the next line and continue
41 rest = rest[i+1:]
42 continue
43 }
44 // No more lines and we hit an error, so stop parsing
45 break
46 }
47 allowedKeys = append(allowedKeys, key)
Sean McCulloughbaa2b592025-04-23 10:40:08 -070048 }
Sean McCulloughae3480f2025-04-23 15:28:20 -070049 if len(allowedKeys) == 0 {
50 return fmt.Errorf("ServeSSH: no valid authorized keys found")
51 }
52
Sean McCullough7013e9e2025-05-14 02:03:58 +000053 // Set up the certificate verifier if containerCAKey is provided
54 var certChecker *gossh.CertChecker
55 var containerCA gossh.PublicKey
56 hasMutualAuth := false
57
58 if len(containerCAKey) > 0 {
59 // Parse container CA public key
60 containerCA, _, _, _, err = ssh.ParseAuthorizedKey(containerCAKey)
61 if err != nil {
62 slog.WarnContext(ctx, "Failed to parse container CA key", slog.String("err", err.Error()))
63 } else {
64 certChecker = &gossh.CertChecker{
65 // Verify if the certificate was signed by our CA
66 IsUserAuthority: func(auth gossh.PublicKey) bool {
67 return bytes.Equal(auth.Marshal(), containerCA.Marshal())
68 },
69 // Check if a certificate has been revoked
70 IsRevoked: func(cert *gossh.Certificate) bool {
71 // We don't implement certificate revocation yet, so no certificates are revoked
72 return false
73 },
74 }
75 slog.InfoContext(ctx, "SSH server configured for mutual authentication with container CA")
76 hasMutualAuth = true
77 }
78 }
79
Sean McCullough1d061322025-04-24 09:52:56 -070080 signer, err := gossh.ParsePrivateKey(hostKey)
81 if err != nil {
82 return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
83 }
Sean McCullough01ed5be2025-04-24 22:46:53 -070084 forwardHandler := &ssh.ForwardedTCPHandler{}
Sean McCullough1d061322025-04-24 09:52:56 -070085
86 server := ssh.Server{
87 LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
Sean McCullough1d061322025-04-24 09:52:56 -070088 return true
89 }),
Philip Zeyligere84d5c72025-05-30 09:32:55 -070090 ReversePortForwardingCallback: ssh.ReversePortForwardingCallback(func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
91 slog.DebugContext(ctx, "Accepted reverse forward", slog.Any("bindHost", bindHost), slog.Any("bindPort", bindPort))
92 return true
93 }),
Sean McCullough1d061322025-04-24 09:52:56 -070094 Addr: ":22",
95 ChannelHandlers: ssh.DefaultChannelHandlers,
96 Handler: ssh.Handler(func(s ssh.Session) {
97 ptyReq, winCh, isPty := s.Pty()
98 if isPty {
99 handlePTYSession(ctx, s, ptyReq, winCh)
100 } else {
101 handleSession(ctx, s)
102 }
103 }),
Sean McCullough01ed5be2025-04-24 22:46:53 -0700104 RequestHandlers: map[string]ssh.RequestHandler{
105 "tcpip-forward": forwardHandler.HandleSSHRequest,
106 "cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
107 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700108 SubsystemHandlers: map[string]ssh.SubsystemHandler{
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700109 "sftp": func(s ssh.Session) {
110 handleSftp(ctx, s)
111 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700112 },
Sean McCullough1d061322025-04-24 09:52:56 -0700113 HostSigners: []ssh.Signer{signer},
114 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
Sean McCullough7013e9e2025-05-14 02:03:58 +0000115 // First, check if it's a certificate signed by our container CA
116 if hasMutualAuth {
117 if cert, ok := key.(*gossh.Certificate); ok {
118 // It's a certificate, check if it's valid and signed by our CA
119 if certChecker.IsUserAuthority(cert.SignatureKey) {
120 // Verify the certificate validity time
121 now := time.Now()
122 if now.Before(time.Unix(int64(cert.ValidBefore), 0)) &&
123 now.After(time.Unix(int64(cert.ValidAfter), 0)) {
124 // Check if the certificate has the right principal
125 if sliceContains(cert.ValidPrincipals, "root") {
126 slog.InfoContext(ctx, "SSH client authenticated with valid certificate")
127 return true
128 }
129 slog.WarnContext(ctx, "Certificate lacks root principal",
130 slog.Any("principals", cert.ValidPrincipals))
131 } else {
132 slog.WarnContext(ctx, "Certificate time validation failed",
133 slog.Time("now", now),
134 slog.Time("valid_after", time.Unix(int64(cert.ValidAfter), 0)),
135 slog.Time("valid_before", time.Unix(int64(cert.ValidBefore), 0)))
136 }
137 }
138 }
139 }
140
141 // Standard key-based authentication fallback
Sean McCulloughae3480f2025-04-23 15:28:20 -0700142 for _, allowedKey := range allowedKeys {
143 if ssh.KeysEqual(key, allowedKey) {
Sean McCullough1d061322025-04-24 09:52:56 -0700144 slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
Sean McCulloughae3480f2025-04-23 15:28:20 -0700145 return true
146 }
147 }
148 return false
Sean McCullough1d061322025-04-24 09:52:56 -0700149 },
150 }
151
152 // This ChannelHandler is necessary for vscode's Remote-SSH connections to work.
153 // Without it the new VSC window will open, but you'll get an error that says something
154 // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server."
155 server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
156
157 return server.ListenAndServe()
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700158}
159
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700160func handleSftp(ctx context.Context, sess ssh.Session) {
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700161 debugStream := io.Discard
162 serverOptions := []sftp.ServerOption{
163 sftp.WithDebug(debugStream),
164 }
165 server, err := sftp.NewServer(
166 sess,
167 serverOptions...,
168 )
169 if err != nil {
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700170 slog.ErrorContext(ctx, "sftp server init error", slog.Any("err", err))
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700171 return
172 }
173 if err := server.Serve(); err == io.EOF {
174 server.Close()
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700175 } else if err != nil {
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700176 slog.ErrorContext(ctx, "sftp server completed with error", slog.Any("err", err))
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700177 }
178}
179
Sean McCullough1d061322025-04-24 09:52:56 -0700180func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700181 cmd := exec.CommandContext(ctx, "/bin/bash")
Sean McCullough1d061322025-04-24 09:52:56 -0700182 slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq))
183
Sean McCullough22bd8eb2025-04-28 10:36:37 -0700184 cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
Sean McCullough1d061322025-04-24 09:52:56 -0700185 f, err := pty.Start(cmd)
186 if err != nil {
187 fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err)
188 s.Exit(1)
189 return
190 }
191
192 go func() {
193 for win := range winCh {
194 setWinsize(f, win.Width, win.Height)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700195 }
Sean McCullough1d061322025-04-24 09:52:56 -0700196 }()
197 go func() {
198 io.Copy(f, s) // stdin
199 }()
200 io.Copy(s, f) // stdout
201
202 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
203 // the pipe I/O before we call cmd.Wait?
204 if err := cmd.Wait(); err != nil {
205 slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error()))
206 s.Exit(1)
207 }
208}
209
210func handleSession(ctx context.Context, s ssh.Session) {
211 var cmd *exec.Cmd
212 slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command()))
213 if len(s.Command()) == 0 {
214 cmd = exec.CommandContext(ctx, "/bin/bash")
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700215 } else {
Sean McCullough1d061322025-04-24 09:52:56 -0700216 cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...)
217 }
218 stdinPipe, err := cmd.StdinPipe()
219 if err != nil {
220 slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error()))
221 fmt.Fprintf(s, "cmd.StdinPipe error: %v", err)
222
223 s.Exit(1)
224 return
225 }
226 defer stdinPipe.Close()
227
228 stdoutPipe, err := cmd.StdoutPipe()
229 if err != nil {
230 slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error()))
231 fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err)
232 s.Exit(1)
233 return
234 }
235 defer stdoutPipe.Close()
236
237 stderrPipe, err := cmd.StderrPipe()
238 if err != nil {
239 slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error()))
240 fmt.Fprintf(s, "cmd.StderrPipe error: %v", err)
241 s.Exit(1)
242 return
243 }
244 defer stderrPipe.Close()
245
246 if err := cmd.Start(); err != nil {
247 slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error()))
248 fmt.Fprintf(s, "cmd.Start error: %v", err)
249 s.Exit(1)
250 return
251 }
252
253 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
254 // the pipe I/O before we call cmd.Wait?
255 go func() {
256 io.Copy(s, stderrPipe)
257 }()
258 go func() {
259 io.Copy(s, stdoutPipe)
260 }()
261 io.Copy(stdinPipe, s)
262
263 if err := cmd.Wait(); err != nil {
264 slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error()))
265 fmt.Fprintf(s, "cmd.Wait error: %v", err)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700266 s.Exit(1)
267 }
268}
Sean McCullough7013e9e2025-05-14 02:03:58 +0000269
270// sliceContains checks if a string slice contains a specific string
271func sliceContains(slice []string, value string) bool {
272 return slices.Contains(slice, value)
273}