blob: 9946ac989666a699888efed420024bfffcbb49df [file] [log] [blame]
Sean McCulloughbaa2b592025-04-23 10:40:08 -07001package server
2
3import (
Sean McCulloughae3480f2025-04-23 15:28:20 -07004 "bytes"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07005 "context"
6 "fmt"
7 "io"
Sean McCullough1d061322025-04-24 09:52:56 -07008 "log/slog"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07009 "os"
10 "os/exec"
11 "syscall"
12 "unsafe"
13
14 "github.com/creack/pty"
15 "github.com/gliderlabs/ssh"
Sean McCullough1d061322025-04-24 09:52:56 -070016 gossh "golang.org/x/crypto/ssh"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070017)
18
19func setWinsize(f *os.File, w, h int) {
20 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
21 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
22}
23
24func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte) error {
Sean McCulloughae3480f2025-04-23 15:28:20 -070025 // Parse all authorized keys
26 allowedKeys := make([]ssh.PublicKey, 0)
27 rest := authorizedKeys
28 var err error
29
30 // Continue parsing as long as there are bytes left
31 for len(rest) > 0 {
32 var key ssh.PublicKey
33 key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
34 if err != nil {
35 // If we hit an error, check if we have more lines to try
36 if i := bytes.IndexByte(rest, '\n'); i >= 0 {
37 // Skip to the next line and continue
38 rest = rest[i+1:]
39 continue
40 }
41 // No more lines and we hit an error, so stop parsing
42 break
43 }
44 allowedKeys = append(allowedKeys, key)
Sean McCulloughbaa2b592025-04-23 10:40:08 -070045 }
Sean McCulloughae3480f2025-04-23 15:28:20 -070046 if len(allowedKeys) == 0 {
47 return fmt.Errorf("ServeSSH: no valid authorized keys found")
48 }
49
Sean McCullough1d061322025-04-24 09:52:56 -070050 signer, err := gossh.ParsePrivateKey(hostKey)
51 if err != nil {
52 return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
53 }
54
55 server := ssh.Server{
56 LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
57 slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport))
58 return true
59 }),
60 Addr: ":22",
61 ChannelHandlers: ssh.DefaultChannelHandlers,
62 Handler: ssh.Handler(func(s ssh.Session) {
63 ptyReq, winCh, isPty := s.Pty()
64 if isPty {
65 handlePTYSession(ctx, s, ptyReq, winCh)
66 } else {
67 handleSession(ctx, s)
68 }
69 }),
70 HostSigners: []ssh.Signer{signer},
71 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
Sean McCulloughae3480f2025-04-23 15:28:20 -070072 // Check if the provided key matches any of our allowed keys
73 for _, allowedKey := range allowedKeys {
74 if ssh.KeysEqual(key, allowedKey) {
Sean McCullough1d061322025-04-24 09:52:56 -070075 slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
Sean McCulloughae3480f2025-04-23 15:28:20 -070076 return true
77 }
78 }
79 return false
Sean McCullough1d061322025-04-24 09:52:56 -070080 },
81 }
82
83 // This ChannelHandler is necessary for vscode's Remote-SSH connections to work.
84 // Without it the new VSC window will open, but you'll get an error that says something
85 // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server."
86 server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
87
88 return server.ListenAndServe()
Sean McCulloughbaa2b592025-04-23 10:40:08 -070089}
90
Sean McCullough1d061322025-04-24 09:52:56 -070091func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
Sean McCulloughbaa2b592025-04-23 10:40:08 -070092 cmd := exec.CommandContext(ctx, "/bin/bash")
Sean McCullough1d061322025-04-24 09:52:56 -070093 slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq))
94
95 cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
96 f, err := pty.Start(cmd)
97 if err != nil {
98 fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err)
99 s.Exit(1)
100 return
101 }
102
103 go func() {
104 for win := range winCh {
105 setWinsize(f, win.Width, win.Height)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700106 }
Sean McCullough1d061322025-04-24 09:52:56 -0700107 }()
108 go func() {
109 io.Copy(f, s) // stdin
110 }()
111 io.Copy(s, f) // stdout
112
113 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
114 // the pipe I/O before we call cmd.Wait?
115 if err := cmd.Wait(); err != nil {
116 slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error()))
117 s.Exit(1)
118 }
119}
120
121func handleSession(ctx context.Context, s ssh.Session) {
122 var cmd *exec.Cmd
123 slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command()))
124 if len(s.Command()) == 0 {
125 cmd = exec.CommandContext(ctx, "/bin/bash")
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700126 } else {
Sean McCullough1d061322025-04-24 09:52:56 -0700127 cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...)
128 }
129 stdinPipe, err := cmd.StdinPipe()
130 if err != nil {
131 slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error()))
132 fmt.Fprintf(s, "cmd.StdinPipe error: %v", err)
133
134 s.Exit(1)
135 return
136 }
137 defer stdinPipe.Close()
138
139 stdoutPipe, err := cmd.StdoutPipe()
140 if err != nil {
141 slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error()))
142 fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err)
143 s.Exit(1)
144 return
145 }
146 defer stdoutPipe.Close()
147
148 stderrPipe, err := cmd.StderrPipe()
149 if err != nil {
150 slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error()))
151 fmt.Fprintf(s, "cmd.StderrPipe error: %v", err)
152 s.Exit(1)
153 return
154 }
155 defer stderrPipe.Close()
156
157 if err := cmd.Start(); err != nil {
158 slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error()))
159 fmt.Fprintf(s, "cmd.Start error: %v", err)
160 s.Exit(1)
161 return
162 }
163
164 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
165 // the pipe I/O before we call cmd.Wait?
166 go func() {
167 io.Copy(s, stderrPipe)
168 }()
169 go func() {
170 io.Copy(s, stdoutPipe)
171 }()
172 io.Copy(stdinPipe, s)
173
174 if err := cmd.Wait(); err != nil {
175 slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error()))
176 fmt.Fprintf(s, "cmd.Wait error: %v", err)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700177 s.Exit(1)
178 }
179}