blob: f1af3301060e58a401115261551e6de8dd98f63a [file] [log] [blame]
Sean McCulloughbaa2b592025-04-23 10:40:08 -07001package server
2
3import (
Sean McCulloughae3480f2025-04-23 15:28:20 -07004 "bytes"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07005 "context"
6 "fmt"
7 "io"
Sean McCullough1d061322025-04-24 09:52:56 -07008 "log/slog"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07009 "os"
10 "os/exec"
11 "syscall"
12 "unsafe"
13
14 "github.com/creack/pty"
15 "github.com/gliderlabs/ssh"
Sean McCullough1d061322025-04-24 09:52:56 -070016 gossh "golang.org/x/crypto/ssh"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070017)
18
19func setWinsize(f *os.File, w, h int) {
20 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
21 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
22}
23
24func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte) error {
Sean McCulloughae3480f2025-04-23 15:28:20 -070025 // Parse all authorized keys
26 allowedKeys := make([]ssh.PublicKey, 0)
27 rest := authorizedKeys
28 var err error
29
30 // Continue parsing as long as there are bytes left
31 for len(rest) > 0 {
32 var key ssh.PublicKey
33 key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
34 if err != nil {
35 // If we hit an error, check if we have more lines to try
36 if i := bytes.IndexByte(rest, '\n'); i >= 0 {
37 // Skip to the next line and continue
38 rest = rest[i+1:]
39 continue
40 }
41 // No more lines and we hit an error, so stop parsing
42 break
43 }
44 allowedKeys = append(allowedKeys, key)
Sean McCulloughbaa2b592025-04-23 10:40:08 -070045 }
Sean McCulloughae3480f2025-04-23 15:28:20 -070046 if len(allowedKeys) == 0 {
47 return fmt.Errorf("ServeSSH: no valid authorized keys found")
48 }
49
Sean McCullough1d061322025-04-24 09:52:56 -070050 signer, err := gossh.ParsePrivateKey(hostKey)
51 if err != nil {
52 return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
53 }
Sean McCullough01ed5be2025-04-24 22:46:53 -070054 forwardHandler := &ssh.ForwardedTCPHandler{}
Sean McCullough1d061322025-04-24 09:52:56 -070055
56 server := ssh.Server{
57 LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
58 slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport))
59 return true
60 }),
61 Addr: ":22",
62 ChannelHandlers: ssh.DefaultChannelHandlers,
63 Handler: ssh.Handler(func(s ssh.Session) {
64 ptyReq, winCh, isPty := s.Pty()
65 if isPty {
66 handlePTYSession(ctx, s, ptyReq, winCh)
67 } else {
68 handleSession(ctx, s)
69 }
70 }),
Sean McCullough01ed5be2025-04-24 22:46:53 -070071 RequestHandlers: map[string]ssh.RequestHandler{
72 "tcpip-forward": forwardHandler.HandleSSHRequest,
73 "cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
74 },
Sean McCullough1d061322025-04-24 09:52:56 -070075 HostSigners: []ssh.Signer{signer},
76 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
Sean McCulloughae3480f2025-04-23 15:28:20 -070077 // Check if the provided key matches any of our allowed keys
78 for _, allowedKey := range allowedKeys {
79 if ssh.KeysEqual(key, allowedKey) {
Sean McCullough1d061322025-04-24 09:52:56 -070080 slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
Sean McCulloughae3480f2025-04-23 15:28:20 -070081 return true
82 }
83 }
84 return false
Sean McCullough1d061322025-04-24 09:52:56 -070085 },
86 }
87
88 // This ChannelHandler is necessary for vscode's Remote-SSH connections to work.
89 // Without it the new VSC window will open, but you'll get an error that says something
90 // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server."
91 server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
92
93 return server.ListenAndServe()
Sean McCulloughbaa2b592025-04-23 10:40:08 -070094}
95
Sean McCullough1d061322025-04-24 09:52:56 -070096func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
Sean McCulloughbaa2b592025-04-23 10:40:08 -070097 cmd := exec.CommandContext(ctx, "/bin/bash")
Sean McCullough1d061322025-04-24 09:52:56 -070098 slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq))
99
Sean McCullough22bd8eb2025-04-28 10:36:37 -0700100 cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
Sean McCullough1d061322025-04-24 09:52:56 -0700101 f, err := pty.Start(cmd)
102 if err != nil {
103 fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err)
104 s.Exit(1)
105 return
106 }
107
108 go func() {
109 for win := range winCh {
110 setWinsize(f, win.Width, win.Height)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700111 }
Sean McCullough1d061322025-04-24 09:52:56 -0700112 }()
113 go func() {
114 io.Copy(f, s) // stdin
115 }()
116 io.Copy(s, f) // stdout
117
118 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
119 // the pipe I/O before we call cmd.Wait?
120 if err := cmd.Wait(); err != nil {
121 slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error()))
122 s.Exit(1)
123 }
124}
125
126func handleSession(ctx context.Context, s ssh.Session) {
127 var cmd *exec.Cmd
128 slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command()))
129 if len(s.Command()) == 0 {
130 cmd = exec.CommandContext(ctx, "/bin/bash")
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700131 } else {
Sean McCullough1d061322025-04-24 09:52:56 -0700132 cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...)
133 }
134 stdinPipe, err := cmd.StdinPipe()
135 if err != nil {
136 slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error()))
137 fmt.Fprintf(s, "cmd.StdinPipe error: %v", err)
138
139 s.Exit(1)
140 return
141 }
142 defer stdinPipe.Close()
143
144 stdoutPipe, err := cmd.StdoutPipe()
145 if err != nil {
146 slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error()))
147 fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err)
148 s.Exit(1)
149 return
150 }
151 defer stdoutPipe.Close()
152
153 stderrPipe, err := cmd.StderrPipe()
154 if err != nil {
155 slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error()))
156 fmt.Fprintf(s, "cmd.StderrPipe error: %v", err)
157 s.Exit(1)
158 return
159 }
160 defer stderrPipe.Close()
161
162 if err := cmd.Start(); err != nil {
163 slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error()))
164 fmt.Fprintf(s, "cmd.Start error: %v", err)
165 s.Exit(1)
166 return
167 }
168
169 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
170 // the pipe I/O before we call cmd.Wait?
171 go func() {
172 io.Copy(s, stderrPipe)
173 }()
174 go func() {
175 io.Copy(s, stdoutPipe)
176 }()
177 io.Copy(stdinPipe, s)
178
179 if err := cmd.Wait(); err != nil {
180 slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error()))
181 fmt.Fprintf(s, "cmd.Wait error: %v", err)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700182 s.Exit(1)
183 }
184}