blob: 18837a5f04fa6034f53435114d39fc4324446d94 [file] [log] [blame]
Sean McCulloughbaa2b592025-04-23 10:40:08 -07001package server
2
3import (
Sean McCulloughae3480f2025-04-23 15:28:20 -07004 "bytes"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07005 "context"
6 "fmt"
7 "io"
Sean McCullough1d061322025-04-24 09:52:56 -07008 "log/slog"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07009 "os"
10 "os/exec"
11 "syscall"
12 "unsafe"
13
14 "github.com/creack/pty"
15 "github.com/gliderlabs/ssh"
Sean McCulloughcf291fa2025-05-03 17:55:48 -070016 "github.com/pkg/sftp"
Sean McCullough1d061322025-04-24 09:52:56 -070017 gossh "golang.org/x/crypto/ssh"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070018)
19
20func setWinsize(f *os.File, w, h int) {
21 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
22 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
23}
24
25func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte) error {
Sean McCulloughae3480f2025-04-23 15:28:20 -070026 // Parse all authorized keys
27 allowedKeys := make([]ssh.PublicKey, 0)
28 rest := authorizedKeys
29 var err error
30
31 // Continue parsing as long as there are bytes left
32 for len(rest) > 0 {
33 var key ssh.PublicKey
34 key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
35 if err != nil {
36 // If we hit an error, check if we have more lines to try
37 if i := bytes.IndexByte(rest, '\n'); i >= 0 {
38 // Skip to the next line and continue
39 rest = rest[i+1:]
40 continue
41 }
42 // No more lines and we hit an error, so stop parsing
43 break
44 }
45 allowedKeys = append(allowedKeys, key)
Sean McCulloughbaa2b592025-04-23 10:40:08 -070046 }
Sean McCulloughae3480f2025-04-23 15:28:20 -070047 if len(allowedKeys) == 0 {
48 return fmt.Errorf("ServeSSH: no valid authorized keys found")
49 }
50
Sean McCullough1d061322025-04-24 09:52:56 -070051 signer, err := gossh.ParsePrivateKey(hostKey)
52 if err != nil {
53 return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
54 }
Sean McCullough01ed5be2025-04-24 22:46:53 -070055 forwardHandler := &ssh.ForwardedTCPHandler{}
Sean McCullough1d061322025-04-24 09:52:56 -070056
57 server := ssh.Server{
58 LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
59 slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport))
60 return true
61 }),
62 Addr: ":22",
63 ChannelHandlers: ssh.DefaultChannelHandlers,
64 Handler: ssh.Handler(func(s ssh.Session) {
65 ptyReq, winCh, isPty := s.Pty()
66 if isPty {
67 handlePTYSession(ctx, s, ptyReq, winCh)
68 } else {
69 handleSession(ctx, s)
70 }
71 }),
Sean McCullough01ed5be2025-04-24 22:46:53 -070072 RequestHandlers: map[string]ssh.RequestHandler{
73 "tcpip-forward": forwardHandler.HandleSSHRequest,
74 "cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
75 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -070076 SubsystemHandlers: map[string]ssh.SubsystemHandler{
Sean McCulloughbdfb1262025-05-03 20:15:41 -070077 "sftp": func(s ssh.Session) {
78 handleSftp(ctx, s)
79 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -070080 },
Sean McCullough1d061322025-04-24 09:52:56 -070081 HostSigners: []ssh.Signer{signer},
82 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
Sean McCulloughae3480f2025-04-23 15:28:20 -070083 // Check if the provided key matches any of our allowed keys
84 for _, allowedKey := range allowedKeys {
85 if ssh.KeysEqual(key, allowedKey) {
Sean McCullough1d061322025-04-24 09:52:56 -070086 slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
Sean McCulloughae3480f2025-04-23 15:28:20 -070087 return true
88 }
89 }
90 return false
Sean McCullough1d061322025-04-24 09:52:56 -070091 },
92 }
93
94 // This ChannelHandler is necessary for vscode's Remote-SSH connections to work.
95 // Without it the new VSC window will open, but you'll get an error that says something
96 // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server."
97 server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
98
99 return server.ListenAndServe()
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700100}
101
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700102func handleSftp(ctx context.Context, sess ssh.Session) {
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700103 debugStream := io.Discard
104 serverOptions := []sftp.ServerOption{
105 sftp.WithDebug(debugStream),
106 }
107 server, err := sftp.NewServer(
108 sess,
109 serverOptions...,
110 )
111 if err != nil {
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700112 slog.ErrorContext(ctx, "sftp server init error", slog.Any("err", err))
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700113 return
114 }
115 if err := server.Serve(); err == io.EOF {
116 server.Close()
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700117 } else if err != nil {
Sean McCulloughbdfb1262025-05-03 20:15:41 -0700118 slog.ErrorContext(ctx, "sftp server completed with error", slog.Any("err", err))
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700119 }
120}
121
Sean McCullough1d061322025-04-24 09:52:56 -0700122func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700123 cmd := exec.CommandContext(ctx, "/bin/bash")
Sean McCullough1d061322025-04-24 09:52:56 -0700124 slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq))
125
Sean McCullough22bd8eb2025-04-28 10:36:37 -0700126 cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
Sean McCullough1d061322025-04-24 09:52:56 -0700127 f, err := pty.Start(cmd)
128 if err != nil {
129 fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err)
130 s.Exit(1)
131 return
132 }
133
134 go func() {
135 for win := range winCh {
136 setWinsize(f, win.Width, win.Height)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700137 }
Sean McCullough1d061322025-04-24 09:52:56 -0700138 }()
139 go func() {
140 io.Copy(f, s) // stdin
141 }()
142 io.Copy(s, f) // stdout
143
144 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
145 // the pipe I/O before we call cmd.Wait?
146 if err := cmd.Wait(); err != nil {
147 slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error()))
148 s.Exit(1)
149 }
150}
151
152func handleSession(ctx context.Context, s ssh.Session) {
153 var cmd *exec.Cmd
154 slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command()))
155 if len(s.Command()) == 0 {
156 cmd = exec.CommandContext(ctx, "/bin/bash")
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700157 } else {
Sean McCullough1d061322025-04-24 09:52:56 -0700158 cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...)
159 }
160 stdinPipe, err := cmd.StdinPipe()
161 if err != nil {
162 slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error()))
163 fmt.Fprintf(s, "cmd.StdinPipe error: %v", err)
164
165 s.Exit(1)
166 return
167 }
168 defer stdinPipe.Close()
169
170 stdoutPipe, err := cmd.StdoutPipe()
171 if err != nil {
172 slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error()))
173 fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err)
174 s.Exit(1)
175 return
176 }
177 defer stdoutPipe.Close()
178
179 stderrPipe, err := cmd.StderrPipe()
180 if err != nil {
181 slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error()))
182 fmt.Fprintf(s, "cmd.StderrPipe error: %v", err)
183 s.Exit(1)
184 return
185 }
186 defer stderrPipe.Close()
187
188 if err := cmd.Start(); err != nil {
189 slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error()))
190 fmt.Fprintf(s, "cmd.Start error: %v", err)
191 s.Exit(1)
192 return
193 }
194
195 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
196 // the pipe I/O before we call cmd.Wait?
197 go func() {
198 io.Copy(s, stderrPipe)
199 }()
200 go func() {
201 io.Copy(s, stdoutPipe)
202 }()
203 io.Copy(stdinPipe, s)
204
205 if err := cmd.Wait(); err != nil {
206 slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error()))
207 fmt.Fprintf(s, "cmd.Wait error: %v", err)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700208 s.Exit(1)
209 }
210}