ssh: use local CA, add mutual container/host auth
See loop/server/local_ssh.md for a detailed description of how sketch uses
now uses a local CA to sign each container certificate instead of adding
a new entry to known_hosts for each container.
This also adds another layer of security by having the container's ssh
server verify that incoming ssh connections have valid host certificates,
whereas prior to this change the authentication was only one-way (verifying
that the sketch container you think you're ssh'ing into really is the one
you think you're ssh'ing into).
This is somewhat inspired by https://github.com/FiloSottile/mkcert - which
plays a similar role as ssh_theater.go local for ssh connections, but mkcert
uses a local CA to address local development use cases for TLS/https rather
than for ssh.
Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: sc7b3928295277d5dk
diff --git a/loop/server/sshserver.go b/loop/server/sshserver.go
index 18837a5..65eb54b 100644
--- a/loop/server/sshserver.go
+++ b/loop/server/sshserver.go
@@ -8,7 +8,9 @@
"log/slog"
"os"
"os/exec"
+ "slices"
"syscall"
+ "time"
"unsafe"
"github.com/creack/pty"
@@ -22,7 +24,7 @@
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
}
-func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte) error {
+func (s *Server) ServeSSH(ctx context.Context, hostKey, authorizedKeys []byte, containerCAKey, hostCertificate []byte) error {
// Parse all authorized keys
allowedKeys := make([]ssh.PublicKey, 0)
rest := authorizedKeys
@@ -48,6 +50,33 @@
return fmt.Errorf("ServeSSH: no valid authorized keys found")
}
+ // Set up the certificate verifier if containerCAKey is provided
+ var certChecker *gossh.CertChecker
+ var containerCA gossh.PublicKey
+ hasMutualAuth := false
+
+ if len(containerCAKey) > 0 {
+ // Parse container CA public key
+ containerCA, _, _, _, err = ssh.ParseAuthorizedKey(containerCAKey)
+ if err != nil {
+ slog.WarnContext(ctx, "Failed to parse container CA key", slog.String("err", err.Error()))
+ } else {
+ certChecker = &gossh.CertChecker{
+ // Verify if the certificate was signed by our CA
+ IsUserAuthority: func(auth gossh.PublicKey) bool {
+ return bytes.Equal(auth.Marshal(), containerCA.Marshal())
+ },
+ // Check if a certificate has been revoked
+ IsRevoked: func(cert *gossh.Certificate) bool {
+ // We don't implement certificate revocation yet, so no certificates are revoked
+ return false
+ },
+ }
+ slog.InfoContext(ctx, "SSH server configured for mutual authentication with container CA")
+ hasMutualAuth = true
+ }
+ }
+
signer, err := gossh.ParsePrivateKey(hostKey)
if err != nil {
return fmt.Errorf("ServeSSH: failed to parse host private key, err: %w", err)
@@ -80,7 +109,33 @@
},
HostSigners: []ssh.Signer{signer},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
- // Check if the provided key matches any of our allowed keys
+ // First, check if it's a certificate signed by our container CA
+ if hasMutualAuth {
+ if cert, ok := key.(*gossh.Certificate); ok {
+ // It's a certificate, check if it's valid and signed by our CA
+ if certChecker.IsUserAuthority(cert.SignatureKey) {
+ // Verify the certificate validity time
+ now := time.Now()
+ if now.Before(time.Unix(int64(cert.ValidBefore), 0)) &&
+ now.After(time.Unix(int64(cert.ValidAfter), 0)) {
+ // Check if the certificate has the right principal
+ if sliceContains(cert.ValidPrincipals, "root") {
+ slog.InfoContext(ctx, "SSH client authenticated with valid certificate")
+ return true
+ }
+ slog.WarnContext(ctx, "Certificate lacks root principal",
+ slog.Any("principals", cert.ValidPrincipals))
+ } else {
+ slog.WarnContext(ctx, "Certificate time validation failed",
+ slog.Time("now", now),
+ slog.Time("valid_after", time.Unix(int64(cert.ValidAfter), 0)),
+ slog.Time("valid_before", time.Unix(int64(cert.ValidBefore), 0)))
+ }
+ }
+ }
+ }
+
+ // Standard key-based authentication fallback
for _, allowedKey := range allowedKeys {
if ssh.KeysEqual(key, allowedKey) {
slog.DebugContext(ctx, "ServeSSH: allow key", slog.String("key", string(key.Marshal())))
@@ -208,3 +263,8 @@
s.Exit(1)
}
}
+
+// sliceContains checks if a string slice contains a specific string
+func sliceContains(slice []string, value string) bool {
+ return slices.Contains(slice, value)
+}