blob: 85f9cbf34032145e509c32d0f27746088b097aa5 [file] [log] [blame]
package skabandclient
import (
"bufio"
"context"
"crypto/ed25519"
crand "crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"log/slog"
"math/rand/v2"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/richardlehane/crock32"
"golang.org/x/net/http2"
)
// SketchSession represents a sketch session with metadata
type SketchSession struct {
ID string `json:"id"`
Title string `json:"title"`
FirstMessage string `json:"first_message"`
LastMessage string `json:"last_message"`
FirstMessageDate time.Time `json:"first_message_date"`
LastMessageDate time.Time `json:"last_message_date"`
}
// SkabandClient provides HTTP client functionality for skaband server
type SkabandClient struct {
addr string
publicKey string
client *http.Client
}
func DialAndServe(ctx context.Context, hostURL, sessionID, clientPubKey string, sessionSecret string, h http.Handler) (err error) {
// Connect to the server.
var conn net.Conn
if strings.HasPrefix(hostURL, "https://") {
u, err := url.Parse(hostURL)
if err != nil {
return err
}
port := u.Port()
if port == "" {
port = "443"
}
dialer := tls.Dialer{}
conn, err = dialer.DialContext(ctx, "tcp4", u.Host+":"+port)
} else if strings.HasPrefix(hostURL, "http://") {
dialer := net.Dialer{}
conn, err = dialer.DialContext(ctx, "tcp4", strings.TrimPrefix(hostURL, "http://"))
} else {
return fmt.Errorf("skabandclient.Dial: bad url, needs to be http or https: %s", hostURL)
}
if err != nil {
return fmt.Errorf("skabandclient: %w", err)
}
if conn == nil {
return fmt.Errorf("skabandclient: nil connection")
}
defer conn.Close()
// "Upgrade" our connection, like a WebSocket does.
req, err := http.NewRequest("POST", hostURL+"/attach", nil)
if err != nil {
return fmt.Errorf("skabandclient.Dial: /attach: %w", err)
}
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "ska")
req.Header.Set("Session-ID", sessionID)
req.Header.Set("Public-Key", clientPubKey)
req.Header.Set("Session-Secret", sessionSecret)
if err := req.Write(conn); err != nil {
return fmt.Errorf("skabandclient.Dial: write upgrade request: %w", err)
}
reader := bufio.NewReader(conn)
resp, err := http.ReadResponse(reader, req)
if err != nil {
if resp != nil {
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("skabandclient.Dial: read upgrade response: %w: %s", err, b)
} else {
return fmt.Errorf("skabandclient.Dial: read upgrade response: %w", err)
}
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
b, _ := io.ReadAll(resp.Body)
return fmt.Errorf("skabandclient.Dial: unexpected status code: %d: %s", resp.StatusCode, b)
}
if !strings.Contains(resp.Header.Get("Upgrade"), "ska") {
return errors.New("skabandclient.Dial: server did not upgrade to ska protocol")
}
if buf := reader.Buffered(); buf > 0 {
peek, _ := reader.Peek(buf)
return fmt.Errorf("skabandclient.Dial: buffered read after upgrade response: %d: %q", buf, string(peek))
}
// Send Magic.
const magic = "skaband\n"
if _, err := conn.Write([]byte(magic)); err != nil {
return fmt.Errorf("skabandclient.Dial: failed to send upgrade init message: %w", err)
}
// We have a TCP connection to the server and have been through the upgrade dance.
// Now we can run an HTTP server over that connection ("inverting" the HTTP flow).
// Skaband is expected to heartbeat within 60 seconds.
lastHeartbeat := time.Now()
mu := sync.Mutex{}
go func() {
for {
time.Sleep(5 * time.Second)
mu.Lock()
if time.Since(lastHeartbeat) > 60*time.Second {
mu.Unlock()
conn.Close()
slog.Info("skaband heartbeat timeout")
return
}
mu.Unlock()
}
}()
server := &http2.Server{}
h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/skabandheartbeat" {
w.WriteHeader(http.StatusOK)
mu.Lock()
defer mu.Unlock()
lastHeartbeat = time.Now()
}
h.ServeHTTP(w, r)
})
server.ServeConn(conn, &http2.ServeConnOpts{
Handler: h2,
})
return nil
}
func decodePrivKey(privData []byte) (ed25519.PrivateKey, error) {
privBlock, _ := pem.Decode(privData)
if privBlock == nil || privBlock.Type != "PRIVATE KEY" {
return nil, fmt.Errorf("no valid private key block found")
}
parsedPriv, err := x509.ParsePKCS8PrivateKey(privBlock.Bytes)
if err != nil {
return nil, err
}
return parsedPriv.(ed25519.PrivateKey), nil
}
func encodePrivateKey(privKey ed25519.PrivateKey) ([]byte, error) {
privBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}), nil
}
func LoadOrCreatePrivateKey(path string) (ed25519.PrivateKey, error) {
privData, err := os.ReadFile(path)
if os.IsNotExist(err) {
_, privKey, err := ed25519.GenerateKey(crand.Reader)
if err != nil {
return nil, err
}
b, err := encodePrivateKey(privKey)
if err != nil {
return nil, err
}
if err := os.WriteFile(path, b, 0o600); err != nil {
return nil, err
}
return privKey, nil
} else if err != nil {
return nil, fmt.Errorf("read key failed: %w", err)
}
key, err := decodePrivKey(privData)
if err != nil {
return nil, fmt.Errorf("%s: %w", path, err)
}
return key, nil
}
// Login connects to skaband and authenticates the user.
// If skabandAddr is empty, it returns the public key without contacting a server.
// It is the caller's responsibility to set the API URL and key in this case.
func Login(stdout io.Writer, privKey ed25519.PrivateKey, skabandAddr, sessionID, model string) (pubKey, apiURL, apiKey string, err error) {
sig := ed25519.Sign(privKey, []byte(sessionID))
pubKey = hex.EncodeToString(privKey.Public().(ed25519.PublicKey))
if skabandAddr == "" {
return pubKey, "", "", nil
}
req, err := http.NewRequest("POST", skabandAddr+"/authclient", nil)
if err != nil {
return "", "", "", err
}
req.Header.Set("Public-Key", pubKey)
req.Header.Set("Session-ID", sessionID)
req.Header.Set("Session-ID-Sig", hex.EncodeToString(sig))
req.Header.Set("X-Model", model)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", "", fmt.Errorf("skaband login: %w", err)
}
apiURL = resp.Header.Get("X-API-URL")
apiKey = resp.Header.Get("X-API-Key")
defer resp.Body.Close()
_, err = io.Copy(stdout, resp.Body)
if err != nil {
return "", "", "", fmt.Errorf("skaband login: %w", err)
}
if resp.StatusCode != 200 {
return "", "", "", fmt.Errorf("skaband login failed: %d", resp.StatusCode)
}
if apiURL == "" {
return "", "", "", fmt.Errorf("skaband returned no api url")
}
if apiKey == "" {
return "", "", "", fmt.Errorf("skaband returned no api key")
}
return pubKey, apiURL, apiKey, nil
}
func DefaultKeyPath(skabandAddr string) string {
homeDir, err := os.UserHomeDir()
if err != nil {
panic(err)
}
cacheDir := filepath.Join(homeDir, ".cache", "sketch")
if skabandAddr != "https://sketch.dev" { // main server gets "root" cache dir, for backwards compatibility
h := sha256.Sum256([]byte(skabandAddr))
cacheDir = filepath.Join(cacheDir, hex.EncodeToString(h[:8]))
}
os.MkdirAll(cacheDir, 0o777)
return filepath.Join(cacheDir, "sketch.ed25519")
}
func LocalhostToDockerInternal(skabandURL string) (string, error) {
u, err := url.Parse(skabandURL)
if err != nil {
return "", fmt.Errorf("localhostToDockerInternal: %w", err)
}
switch u.Hostname() {
case "localhost", "127.0.0.1":
host := "host.docker.internal"
if port := u.Port(); port != "" {
host += ":" + port
}
u.Host = host
return u.String(), nil
}
return skabandURL, nil
}
// NewSessionID generates a new 10-byte random Session ID.
func NewSessionID() string {
u1, u2 := rand.Uint64(), rand.Uint64N(1<<16)
s := crock32.Encode(u1) + crock32.Encode(uint64(u2))
if len(s) < 16 {
s += strings.Repeat("0", 16-len(s))
}
return s[0:4] + "-" + s[4:8] + "-" + s[8:12] + "-" + s[12:16]
}
// Regex pattern for SessionID format: xxxx-xxxx-xxxx-xxxx
// Where x is a valid Crockford Base32 character (0-9, A-H, J-N, P-Z)
// Case-insensitive match
var sessionIdRegexp = regexp.MustCompile(
"^[0-9A-HJ-NP-Za-hj-np-z]{4}-[0-9A-HJ-NP-Za-hj-np-z]{4}-[0-9A-HJ-NP-Za-hj-np-z]{4}-[0-9A-HJ-NP-Za-hj-np-z]{4}")
func ValidateSessionID(sessionID string) bool {
return sessionIdRegexp.MatchString(sessionID)
}
// Addr returns the skaband server address
func (c *SkabandClient) Addr() string {
if c == nil {
return ""
}
return c.addr
}
// NewSkabandClient creates a new skaband client
func NewSkabandClient(addr, publicKey string) *SkabandClient {
// Apply localhost-to-docker-internal transformation if needed
if _, err := os.Stat("/.dockerenv"); err == nil { // inDocker
if newAddr, err := LocalhostToDockerInternal(addr); err == nil {
addr = newAddr
}
}
return &SkabandClient{
addr: addr,
publicKey: publicKey,
client: &http.Client{Timeout: 30 * time.Second},
}
}
// ListRecentSketchSessionsMarkdown returns recent sessions as a markdown table
func (c *SkabandClient) ListRecentSketchSessionsMarkdown(ctx context.Context, currentRepo, sessionID string) (string, error) {
if c == nil {
return "", fmt.Errorf("SkabandClient is nil")
}
// Build URL with query parameters
baseURL := c.addr + "/api/sessions/recent"
if currentRepo != "" {
baseURL += "?repo=" + url.QueryEscape(currentRepo)
}
req, err := http.NewRequestWithContext(ctx, "GET", baseURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Add headers
req.Header.Set("Public-Key", c.publicKey)
req.Header.Set("Session-ID", sessionID)
resp, err := c.client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
return string(body), nil
}
// ReadSketchSession reads the full details of a specific session and returns formatted text
func (c *SkabandClient) ReadSketchSession(ctx context.Context, targetSessionID, originSessionID string) (*string, error) {
if c == nil {
return nil, fmt.Errorf("SkabandClient is nil")
}
req, err := http.NewRequestWithContext(ctx, "GET", c.addr+"/api/sessions/"+targetSessionID, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Add headers
req.Header.Set("Public-Key", c.publicKey)
req.Header.Set("Session-ID", originSessionID)
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
response := string(body)
return &response, nil
}
// DialAndServeLoop is a redial loop around DialAndServe.
func (c *SkabandClient) DialAndServeLoop(ctx context.Context, sessionID string, sessionSecret string, srv http.Handler, connectFn func(connected bool)) {
skabandAddr := c.addr
clientPubKey := c.publicKey
if _, err := os.Stat("/.dockerenv"); err == nil { // inDocker
if addr, err := LocalhostToDockerInternal(skabandAddr); err == nil {
skabandAddr = addr
}
}
var skabandConnected atomic.Bool
skabandHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/skabandinit" {
b, err := io.ReadAll(r.Body)
if err != nil {
fmt.Printf("skabandinit failed: %v\n", err)
return
}
m := map[string]string{}
if err := json.Unmarshal(b, &m); err != nil {
fmt.Printf("skabandinit failed: %v\n", err)
return
}
skabandConnected.Store(true)
if connectFn != nil {
connectFn(true)
}
return
}
srv.ServeHTTP(w, r)
})
var lastErrLog time.Time
for {
if err := DialAndServe(ctx, skabandAddr, sessionID, clientPubKey, sessionSecret, skabandHandler); err != nil {
// NOTE: *just* backoff the logging. Backing off dialing
// is bad UX. Doing so saves negligible CPU and doing so
// without hurting UX requires interrupting the backoff with
// wake-from-sleep and network-up events from the OS,
// which are a pain to plumb.
if time.Since(lastErrLog) > 1*time.Minute {
slog.DebugContext(ctx, "skaband connection failed", "err", err)
lastErrLog = time.Now()
}
}
if skabandConnected.CompareAndSwap(true, false) {
if connectFn != nil {
connectFn(false)
}
}
time.Sleep(200 * time.Millisecond)
}
}