blob: 56fd67928a1693f1733bbed0749d70d0352ed68a [file] [log] [blame]
Sean McCulloughbaa2b592025-04-23 10:40:08 -07001package server
2
3import (
Sean McCulloughae3480f2025-04-23 15:28:20 -07004 "bytes"
Sean McCulloughbaa2b592025-04-23 10:40:08 -07005 "context"
6 "fmt"
7 "io"
8 "os"
9 "os/exec"
10 "syscall"
11 "unsafe"
12
13 "github.com/creack/pty"
14 "github.com/gliderlabs/ssh"
15)
16
17func setWinsize(f *os.File, w, h int) {
18 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
19 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
20}
21
22func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte) error {
Sean McCulloughae3480f2025-04-23 15:28:20 -070023 // Parse all authorized keys
24 allowedKeys := make([]ssh.PublicKey, 0)
25 rest := authorizedKeys
26 var err error
27
28 // Continue parsing as long as there are bytes left
29 for len(rest) > 0 {
30 var key ssh.PublicKey
31 key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
32 if err != nil {
33 // If we hit an error, check if we have more lines to try
34 if i := bytes.IndexByte(rest, '\n'); i >= 0 {
35 // Skip to the next line and continue
36 rest = rest[i+1:]
37 continue
38 }
39 // No more lines and we hit an error, so stop parsing
40 break
41 }
42 allowedKeys = append(allowedKeys, key)
Sean McCulloughbaa2b592025-04-23 10:40:08 -070043 }
Sean McCulloughae3480f2025-04-23 15:28:20 -070044 if len(allowedKeys) == 0 {
45 return fmt.Errorf("ServeSSH: no valid authorized keys found")
46 }
47
48 return ssh.ListenAndServe(":22",
Sean McCulloughbaa2b592025-04-23 10:40:08 -070049 func(s ssh.Session) {
50 handleSessionfunc(ctx, s)
51 },
52 ssh.HostKeyPEM(hostKey),
53 ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
Sean McCulloughae3480f2025-04-23 15:28:20 -070054 // Check if the provided key matches any of our allowed keys
55 for _, allowedKey := range allowedKeys {
56 if ssh.KeysEqual(key, allowedKey) {
57 return true
58 }
59 }
60 return false
Sean McCulloughbaa2b592025-04-23 10:40:08 -070061 }),
62 )
63}
64
65func handleSessionfunc(ctx context.Context, s ssh.Session) {
66 cmd := exec.CommandContext(ctx, "/bin/bash")
67 ptyReq, winCh, isPty := s.Pty()
68 if isPty {
69 cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
70 f, err := pty.Start(cmd)
71 if err != nil {
72 panic(err)
73 }
74 go func() {
75 for win := range winCh {
76 setWinsize(f, win.Width, win.Height)
77 }
78 }()
79 go func() {
80 io.Copy(f, s) // stdin
81 }()
82 io.Copy(s, f) // stdout
83 cmd.Wait()
84 } else {
85 io.WriteString(s, "No PTY requested.\n")
86 s.Exit(1)
87 }
88}