blob: 56fd67928a1693f1733bbed0749d70d0352ed68a [file] [log] [blame]
package server
import (
"bytes"
"context"
"fmt"
"io"
"os"
"os/exec"
"syscall"
"unsafe"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
)
func setWinsize(f *os.File, w, h int) {
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
}
func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte) error {
// Parse all authorized keys
allowedKeys := make([]ssh.PublicKey, 0)
rest := authorizedKeys
var err error
// Continue parsing as long as there are bytes left
for len(rest) > 0 {
var key ssh.PublicKey
key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
if err != nil {
// If we hit an error, check if we have more lines to try
if i := bytes.IndexByte(rest, '\n'); i >= 0 {
// Skip to the next line and continue
rest = rest[i+1:]
continue
}
// No more lines and we hit an error, so stop parsing
break
}
allowedKeys = append(allowedKeys, key)
}
if len(allowedKeys) == 0 {
return fmt.Errorf("ServeSSH: no valid authorized keys found")
}
return ssh.ListenAndServe(":22",
func(s ssh.Session) {
handleSessionfunc(ctx, s)
},
ssh.HostKeyPEM(hostKey),
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
// Check if the provided key matches any of our allowed keys
for _, allowedKey := range allowedKeys {
if ssh.KeysEqual(key, allowedKey) {
return true
}
}
return false
}),
)
}
func handleSessionfunc(ctx context.Context, s ssh.Session) {
cmd := exec.CommandContext(ctx, "/bin/bash")
ptyReq, winCh, isPty := s.Pty()
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
f, err := pty.Start(cmd)
if err != nil {
panic(err)
}
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
go func() {
io.Copy(f, s) // stdin
}()
io.Copy(s, f) // stdout
cmd.Wait()
} else {
io.WriteString(s, "No PTY requested.\n")
s.Exit(1)
}
}