sshserver: add direct-tcpip channel, allow non-pty
This change enables VS Code to connect to local sketch containers
over SSH.
The VSC docs describe how to use this feature:
https://code.visualstudio.com/docs/remote/ssh
diff --git a/loop/server/sshserver.go b/loop/server/sshserver.go
index 56fd679..9946ac9 100644
--- a/loop/server/sshserver.go
+++ b/loop/server/sshserver.go
@@ -5,6 +5,7 @@
"context"
"fmt"
"io"
+ "log/slog"
"os"
"os/exec"
"syscall"
@@ -12,6 +13,7 @@
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
+ gossh "golang.org/x/crypto/ssh"
)
func setWinsize(f *os.File, w, h int) {
@@ -45,44 +47,133 @@
return fmt.Errorf("ServeSSH: no valid authorized keys found")
}
- return ssh.ListenAndServe(":22",
- func(s ssh.Session) {
- handleSessionfunc(ctx, s)
- },
- ssh.HostKeyPEM(hostKey),
- ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
+ signer, err := gossh.ParsePrivateKey(hostKey)
+ if err != nil {
+ return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
+ }
+
+ server := ssh.Server{
+ LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
+ slog.DebugContext(ctx, "Accepted forward", slog.Any("dhost", dhost), slog.Any("dport", dport))
+ return true
+ }),
+ Addr: ":22",
+ ChannelHandlers: ssh.DefaultChannelHandlers,
+ Handler: ssh.Handler(func(s ssh.Session) {
+ ptyReq, winCh, isPty := s.Pty()
+ if isPty {
+ handlePTYSession(ctx, s, ptyReq, winCh)
+ } else {
+ handleSession(ctx, s)
+ }
+ }),
+ HostSigners: []ssh.Signer{signer},
+ PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
// Check if the provided key matches any of our allowed keys
for _, allowedKey := range allowedKeys {
if ssh.KeysEqual(key, allowedKey) {
+ slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
return true
}
}
return false
- }),
- )
+ },
+ }
+
+ // This ChannelHandler is necessary for vscode's Remote-SSH connections to work.
+ // Without it the new VSC window will open, but you'll get an error that says something
+ // like "Failed to set up dynamic port forwarding connection over SSH to the VS Code Server."
+ server.ChannelHandlers["direct-tcpip"] = ssh.DirectTCPIPHandler
+
+ return server.ListenAndServe()
}
-func handleSessionfunc(ctx context.Context, s ssh.Session) {
+func handlePTYSession(ctx context.Context, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
cmd := exec.CommandContext(ctx, "/bin/bash")
- ptyReq, winCh, isPty := s.Pty()
- if isPty {
- cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
- f, err := pty.Start(cmd)
- if err != nil {
- panic(err)
+ slog.DebugContext(ctx, "handlePTYSession", slog.Any("ptyReq", ptyReq))
+
+ cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
+ f, err := pty.Start(cmd)
+ if err != nil {
+ fmt.Fprintf(s, "PTY requested, but unable to start due to error: %v", err)
+ s.Exit(1)
+ return
+ }
+
+ go func() {
+ for win := range winCh {
+ setWinsize(f, win.Width, win.Height)
}
- go func() {
- for win := range winCh {
- setWinsize(f, win.Width, win.Height)
- }
- }()
- go func() {
- io.Copy(f, s) // stdin
- }()
- io.Copy(s, f) // stdout
- cmd.Wait()
+ }()
+ go func() {
+ io.Copy(f, s) // stdin
+ }()
+ io.Copy(s, f) // stdout
+
+ // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
+ // the pipe I/O before we call cmd.Wait?
+ if err := cmd.Wait(); err != nil {
+ slog.ErrorContext(ctx, "handlePTYSession: cmd.Wait", slog.String("err", err.Error()))
+ s.Exit(1)
+ }
+}
+
+func handleSession(ctx context.Context, s ssh.Session) {
+ var cmd *exec.Cmd
+ slog.DebugContext(ctx, "handleSession", slog.Any("s.Command", s.Command()))
+ if len(s.Command()) == 0 {
+ cmd = exec.CommandContext(ctx, "/bin/bash")
} else {
- io.WriteString(s, "No PTY requested.\n")
+ cmd = exec.CommandContext(ctx, s.Command()[0], s.Command()[1:]...)
+ }
+ stdinPipe, err := cmd.StdinPipe()
+ if err != nil {
+ slog.ErrorContext(ctx, "handleSession: cmd.StdinPipe", slog.Any("err", err.Error()))
+ fmt.Fprintf(s, "cmd.StdinPipe error: %v", err)
+
+ s.Exit(1)
+ return
+ }
+ defer stdinPipe.Close()
+
+ stdoutPipe, err := cmd.StdoutPipe()
+ if err != nil {
+ slog.ErrorContext(ctx, "handleSession: cmd.StdoutPipe", slog.Any("err", err.Error()))
+ fmt.Fprintf(s, "cmd.StdoutPipe error: %v", err)
+ s.Exit(1)
+ return
+ }
+ defer stdoutPipe.Close()
+
+ stderrPipe, err := cmd.StderrPipe()
+ if err != nil {
+ slog.ErrorContext(ctx, "handleSession: cmd.StderrPipe", slog.Any("err", err.Error()))
+ fmt.Fprintf(s, "cmd.StderrPipe error: %v", err)
+ s.Exit(1)
+ return
+ }
+ defer stderrPipe.Close()
+
+ if err := cmd.Start(); err != nil {
+ slog.ErrorContext(ctx, "handleSession: cmd.Start", slog.Any("err", err.Error()))
+ fmt.Fprintf(s, "cmd.Start error: %v", err)
+ s.Exit(1)
+ return
+ }
+
+ // TODO: double check, do we need a sync.WaitGroup here, to make sure we finish
+ // the pipe I/O before we call cmd.Wait?
+ go func() {
+ io.Copy(s, stderrPipe)
+ }()
+ go func() {
+ io.Copy(s, stdoutPipe)
+ }()
+ io.Copy(stdinPipe, s)
+
+ if err := cmd.Wait(); err != nil {
+ slog.ErrorContext(ctx, "handleSession: cmd.Wait", slog.Any("err", err.Error()))
+ fmt.Fprintf(s, "cmd.Wait error: %v", err)
s.Exit(1)
}
}