blob: cfef14b840808c06f85f6592dc8dfbfbb26aceb7 [file] [log] [blame]
Sean McCulloughbaa2b592025-04-23 10:40:08 -07001package server
2
3import (
Sean McCulloughae3480f2025-04-23 15:28:20 -07004 "bytes"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07005 "context"
6 "fmt"
7 "io"
Sean McCulloughcf291fa2025-05-03 17:55:48 -07008 "log"
Sean McCullough1d061322025-04-24 09:52:56 -07009 "log/slog"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070010 "os"
11 "os/exec"
12 "syscall"
13 "unsafe"
14
15 "github.com/creack/pty"
16 "github.com/gliderlabs/ssh"
Sean McCulloughcf291fa2025-05-03 17:55:48 -070017 "github.com/pkg/sftp"
Sean McCullough1d061322025-04-24 09:52:56 -070018 gossh "golang.org/x/crypto/ssh"
Sean McCulloughbaa2b592025-04-23 10:40:08 -070019)
20
21func setWinsize(f *os.File, w, h int) {
22 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
23 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
24}
25
26func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte) error {
Sean McCulloughae3480f2025-04-23 15:28:20 -070027 // Parse all authorized keys
28 allowedKeys := make([]ssh.PublicKey, 0)
29 rest := authorizedKeys
30 var err error
31
32 // Continue parsing as long as there are bytes left
33 for len(rest) > 0 {
34 var key ssh.PublicKey
35 key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
36 if err != nil {
37 // If we hit an error, check if we have more lines to try
38 if i := bytes.IndexByte(rest, '\n'); i >= 0 {
39 // Skip to the next line and continue
40 rest = rest[i+1:]
41 continue
42 }
43 // No more lines and we hit an error, so stop parsing
44 break
45 }
46 allowedKeys = append(allowedKeys, key)
Sean McCulloughbaa2b592025-04-23 10:40:08 -070047 }
Sean McCulloughae3480f2025-04-23 15:28:20 -070048 if len(allowedKeys) == 0 {
49 return fmt.Errorf("ServeSSH: no valid authorized keys found")
50 }
51
Sean McCullough1d061322025-04-24 09:52:56 -070052 signer, err := gossh.ParsePrivateKey(hostKey)
53 if err != nil {
54 return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
55 }
Sean McCullough01ed5be2025-04-24 22:46:53 -070056 forwardHandler := &ssh.ForwardedTCPHandler{}
Sean McCullough1d061322025-04-24 09:52:56 -070057
58 server := ssh.Server{
59 LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
60 slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport))
61 return true
62 }),
63 Addr: ":22",
64 ChannelHandlers: ssh.DefaultChannelHandlers,
65 Handler: ssh.Handler(func(s ssh.Session) {
66 ptyReq, winCh, isPty := s.Pty()
67 if isPty {
68 handlePTYSession(ctx, s, ptyReq, winCh)
69 } else {
70 handleSession(ctx, s)
71 }
72 }),
Sean McCullough01ed5be2025-04-24 22:46:53 -070073 RequestHandlers: map[string]ssh.RequestHandler{
74 "tcpip-forward": forwardHandler.HandleSSHRequest,
75 "cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
76 },
Sean McCulloughcf291fa2025-05-03 17:55:48 -070077 SubsystemHandlers: map[string]ssh.SubsystemHandler{
78 "sftp": handleSftp,
79 },
Sean McCullough1d061322025-04-24 09:52:56 -070080 HostSigners: []ssh.Signer{signer},
81 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
Sean McCulloughae3480f2025-04-23 15:28:20 -070082 // Check if the provided key matches any of our allowed keys
83 for _, allowedKey := range allowedKeys {
84 if ssh.KeysEqual(key, allowedKey) {
Sean McCullough1d061322025-04-24 09:52:56 -070085 slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
Sean McCulloughae3480f2025-04-23 15:28:20 -070086 return true
87 }
88 }
89 return false
Sean McCullough1d061322025-04-24 09:52:56 -070090 },
91 }
92
93 // This ChannelHandler is necessary for vscode's Remote-SSH connections to work.
94 // Without it the new VSC window will open, but you'll get an error that says something
95 // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server."
96 server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
97
98 return server.ListenAndServe()
Sean McCulloughbaa2b592025-04-23 10:40:08 -070099}
100
Sean McCulloughcf291fa2025-05-03 17:55:48 -0700101func handleSftp(sess ssh.Session) {
102 debugStream := io.Discard
103 serverOptions := []sftp.ServerOption{
104 sftp.WithDebug(debugStream),
105 }
106 server, err := sftp.NewServer(
107 sess,
108 serverOptions...,
109 )
110 if err != nil {
111 log.Printf("sftp server init error: %s\n", err)
112 return
113 }
114 if err := server.Serve(); err == io.EOF {
115 server.Close()
116 fmt.Println("sftp client exited session.")
117 } else if err != nil {
118 fmt.Println("sftp server completed with error:", err)
119 }
120}
121
Sean McCullough1d061322025-04-24 09:52:56 -0700122func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700123 cmd := exec.CommandContext(ctx, "/bin/bash")
Sean McCullough1d061322025-04-24 09:52:56 -0700124 slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq))
125
Sean McCullough22bd8eb2025-04-28 10:36:37 -0700126 cmd.Env = append(os.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
Sean McCullough1d061322025-04-24 09:52:56 -0700127 f, err := pty.Start(cmd)
128 if err != nil {
129 fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err)
130 s.Exit(1)
131 return
132 }
133
134 go func() {
135 for win := range winCh {
136 setWinsize(f, win.Width, win.Height)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700137 }
Sean McCullough1d061322025-04-24 09:52:56 -0700138 }()
139 go func() {
140 io.Copy(f, s) // stdin
141 }()
142 io.Copy(s, f) // stdout
143
144 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
145 // the pipe I/O before we call cmd.Wait?
146 if err := cmd.Wait(); err != nil {
147 slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error()))
148 s.Exit(1)
149 }
150}
151
152func handleSession(ctx context.Context, s ssh.Session) {
153 var cmd *exec.Cmd
154 slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command()))
155 if len(s.Command()) == 0 {
156 cmd = exec.CommandContext(ctx, "/bin/bash")
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700157 } else {
Sean McCullough1d061322025-04-24 09:52:56 -0700158 cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...)
159 }
160 stdinPipe, err := cmd.StdinPipe()
161 if err != nil {
162 slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error()))
163 fmt.Fprintf(s, "cmd.StdinPipe error: %v", err)
164
165 s.Exit(1)
166 return
167 }
168 defer stdinPipe.Close()
169
170 stdoutPipe, err := cmd.StdoutPipe()
171 if err != nil {
172 slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error()))
173 fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err)
174 s.Exit(1)
175 return
176 }
177 defer stdoutPipe.Close()
178
179 stderrPipe, err := cmd.StderrPipe()
180 if err != nil {
181 slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error()))
182 fmt.Fprintf(s, "cmd.StderrPipe error: %v", err)
183 s.Exit(1)
184 return
185 }
186 defer stderrPipe.Close()
187
188 if err := cmd.Start(); err != nil {
189 slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error()))
190 fmt.Fprintf(s, "cmd.Start error: %v", err)
191 s.Exit(1)
192 return
193 }
194
195 // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
196 // the pipe I/O before we call cmd.Wait?
197 go func() {
198 io.Copy(s, stderrPipe)
199 }()
200 go func() {
201 io.Copy(s, stdoutPipe)
202 }()
203 io.Copy(stdinPipe, s)
204
205 if err := cmd.Wait(); err != nil {
206 slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error()))
207 fmt.Fprintf(s, "cmd.Wait error: %v", err)
Sean McCulloughbaa2b592025-04-23 10:40:08 -0700208 s.Exit(1)
209 }
210}