blob: d0b6bb43515617d97a08978161484551605edb68 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package skabandclient
2
3import (
4 "bufio"
5 "context"
6 "crypto/ed25519"
7 crand "crypto/rand"
Josh Bleecher Snyder75b45f52025-07-17 15:47:32 -07008 "crypto/sha256"
Earl Lee2e463fb2025-04-17 11:22:22 -07009 "crypto/tls"
10 "crypto/x509"
11 "encoding/hex"
12 "encoding/json"
13 "encoding/pem"
14 "errors"
15 "fmt"
16 "io"
17 "log/slog"
David Crawshaw0ead54d2025-05-16 13:58:36 -070018 "math/rand/v2"
Earl Lee2e463fb2025-04-17 11:22:22 -070019 "net"
20 "net/http"
21 "net/url"
22 "os"
23 "path/filepath"
Philip Zeyliger59789952025-06-28 20:02:23 -070024 "regexp"
Earl Lee2e463fb2025-04-17 11:22:22 -070025 "strings"
Philip Zeyligere9eaf6c2025-05-19 16:14:39 -070026 "sync"
Earl Lee2e463fb2025-04-17 11:22:22 -070027 "sync/atomic"
28 "time"
29
David Crawshaw0ead54d2025-05-16 13:58:36 -070030 "github.com/richardlehane/crock32"
Earl Lee2e463fb2025-04-17 11:22:22 -070031 "golang.org/x/net/http2"
32)
33
Philip Zeyligerc17ffe32025-06-05 19:49:13 -070034// SkabandClient provides HTTP client functionality for skaband server
35type SkabandClient struct {
36 addr string
37 publicKey string
38 client *http.Client
Earl Lee2e463fb2025-04-17 11:22:22 -070039}
40
Philip Zeyligerf2814ea2025-06-30 10:16:50 -070041func DialAndServe(ctx context.Context, hostURL, sessionID, clientPubKey string, sessionSecret string, h http.Handler) (err error) {
Earl Lee2e463fb2025-04-17 11:22:22 -070042 // Connect to the server.
43 var conn net.Conn
44 if strings.HasPrefix(hostURL, "https://") {
45 u, err := url.Parse(hostURL)
46 if err != nil {
47 return err
48 }
49 port := u.Port()
50 if port == "" {
51 port = "443"
52 }
53 dialer := tls.Dialer{}
54 conn, err = dialer.DialContext(ctx, "tcp4", u.Host+":"+port)
55 } else if strings.HasPrefix(hostURL, "http://") {
56 dialer := net.Dialer{}
57 conn, err = dialer.DialContext(ctx, "tcp4", strings.TrimPrefix(hostURL, "http://"))
58 } else {
59 return fmt.Errorf("skabandclient.Dial: bad url, needs to be http or https: %s", hostURL)
60 }
61 if err != nil {
62 return fmt.Errorf("skabandclient: %w", err)
63 }
Philip Zeyligerfe3e9f72025-04-24 09:02:05 -070064 if conn == nil {
65 return fmt.Errorf("skabandclient: nil connection")
66 }
Earl Lee2e463fb2025-04-17 11:22:22 -070067 defer conn.Close()
68
69 // "Upgrade" our connection, like a WebSocket does.
70 req, err := http.NewRequest("POST", hostURL+"/attach", nil)
71 if err != nil {
72 return fmt.Errorf("skabandclient.Dial: /attach: %w", err)
73 }
74 req.Header.Set("Connection", "Upgrade")
75 req.Header.Set("Upgrade", "ska")
76 req.Header.Set("Session-ID", sessionID)
77 req.Header.Set("Public-Key", clientPubKey)
Philip Zeyligerf2814ea2025-06-30 10:16:50 -070078 req.Header.Set("Session-Secret", sessionSecret)
Earl Lee2e463fb2025-04-17 11:22:22 -070079
80 if err := req.Write(conn); err != nil {
81 return fmt.Errorf("skabandclient.Dial: write upgrade request: %w", err)
82 }
83 reader := bufio.NewReader(conn)
84 resp, err := http.ReadResponse(reader, req)
85 if err != nil {
Philip Zeyligerfe3e9f72025-04-24 09:02:05 -070086 if resp != nil {
87 b, _ := io.ReadAll(resp.Body)
88 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w: %s", err, b)
89 } else {
90 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w", err)
91 }
Earl Lee2e463fb2025-04-17 11:22:22 -070092 }
93 defer resp.Body.Close()
94 if resp.StatusCode != http.StatusSwitchingProtocols {
95 b, _ := io.ReadAll(resp.Body)
96 return fmt.Errorf("skabandclient.Dial: unexpected status code: %d: %s", resp.StatusCode, b)
97 }
98 if !strings.Contains(resp.Header.Get("Upgrade"), "ska") {
99 return errors.New("skabandclient.Dial: server did not upgrade to ska protocol")
100 }
101 if buf := reader.Buffered(); buf > 0 {
102 peek, _ := reader.Peek(buf)
103 return fmt.Errorf("skabandclient.Dial: buffered read after upgrade response: %d: %q", buf, string(peek))
104 }
105
106 // Send Magic.
107 const magic = "skaband\n"
108 if _, err := conn.Write([]byte(magic)); err != nil {
109 return fmt.Errorf("skabandclient.Dial: failed to send upgrade init message: %w", err)
110 }
111
112 // We have a TCP connection to the server and have been through the upgrade dance.
113 // Now we can run an HTTP server over that connection ("inverting" the HTTP flow).
Philip Zeyligere9eaf6c2025-05-19 16:14:39 -0700114 // Skaband is expected to heartbeat within 60 seconds.
115 lastHeartbeat := time.Now()
116 mu := sync.Mutex{}
117 go func() {
118 for {
119 time.Sleep(5 * time.Second)
120 mu.Lock()
121 if time.Since(lastHeartbeat) > 60*time.Second {
122 mu.Unlock()
123 conn.Close()
124 slog.Info("skaband heartbeat timeout")
125 return
126 }
127 mu.Unlock()
128 }
129 }()
Earl Lee2e463fb2025-04-17 11:22:22 -0700130 server := &http2.Server{}
Philip Zeyligere9eaf6c2025-05-19 16:14:39 -0700131 h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132 if r.URL.Path == "/skabandheartbeat" {
133 w.WriteHeader(http.StatusOK)
134 mu.Lock()
135 defer mu.Unlock()
136 lastHeartbeat = time.Now()
137 }
138 h.ServeHTTP(w, r)
139 })
Earl Lee2e463fb2025-04-17 11:22:22 -0700140 server.ServeConn(conn, &http2.ServeConnOpts{
Philip Zeyligere9eaf6c2025-05-19 16:14:39 -0700141 Handler: h2,
Earl Lee2e463fb2025-04-17 11:22:22 -0700142 })
143
144 return nil
145}
146
147func decodePrivKey(privData []byte) (ed25519.PrivateKey, error) {
148 privBlock, _ := pem.Decode(privData)
149 if privBlock == nil || privBlock.Type != "PRIVATE KEY" {
150 return nil, fmt.Errorf("no valid private key block found")
151 }
152 parsedPriv, err := x509.ParsePKCS8PrivateKey(privBlock.Bytes)
153 if err != nil {
154 return nil, err
155 }
156 return parsedPriv.(ed25519.PrivateKey), nil
157}
158
159func encodePrivateKey(privKey ed25519.PrivateKey) ([]byte, error) {
160 privBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
161 if err != nil {
162 return nil, err
163 }
164 return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}), nil
165}
166
167func LoadOrCreatePrivateKey(path string) (ed25519.PrivateKey, error) {
168 privData, err := os.ReadFile(path)
169 if os.IsNotExist(err) {
170 _, privKey, err := ed25519.GenerateKey(crand.Reader)
171 if err != nil {
172 return nil, err
173 }
174 b, err := encodePrivateKey(privKey)
David Crawshaw961cc9e2025-05-05 14:33:33 -0700175 if err != nil {
176 return nil, err
177 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700178 if err := os.WriteFile(path, b, 0o600); err != nil {
179 return nil, err
180 }
181 return privKey, nil
182 } else if err != nil {
183 return nil, fmt.Errorf("read key failed: %w", err)
184 }
185 key, err := decodePrivKey(privData)
186 if err != nil {
187 return nil, fmt.Errorf("%s: %w", path, err)
188 }
189 return key, nil
190}
191
Josh Bleecher Snyder75b45f52025-07-17 15:47:32 -0700192// Login connects to skaband and authenticates the user.
193// If skabandAddr is empty, it returns the public key without contacting a server.
194// It is the caller's responsibility to set the API URL and key in this case.
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000195func Login(stdout io.Writer, privKey ed25519.PrivateKey, skabandAddr, sessionID, model string) (pubKey, apiURL, oaiModelName, apiKey string, err error) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700196 sig := ed25519.Sign(privKey, []byte(sessionID))
Josh Bleecher Snyder75b45f52025-07-17 15:47:32 -0700197 pubKey = hex.EncodeToString(privKey.Public().(ed25519.PublicKey))
198 if skabandAddr == "" {
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000199 return pubKey, "", "", "", nil
Josh Bleecher Snyder75b45f52025-07-17 15:47:32 -0700200 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700201
202 req, err := http.NewRequest("POST", skabandAddr+"/authclient", nil)
203 if err != nil {
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000204 return "", "", "", "", err
Earl Lee2e463fb2025-04-17 11:22:22 -0700205 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700206 req.Header.Set("Public-Key", pubKey)
207 req.Header.Set("Session-ID", sessionID)
208 req.Header.Set("Session-ID-Sig", hex.EncodeToString(sig))
David Crawshaw961cc9e2025-05-05 14:33:33 -0700209 req.Header.Set("X-Model", model)
Earl Lee2e463fb2025-04-17 11:22:22 -0700210 resp, err := http.DefaultClient.Do(req)
211 if err != nil {
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000212 return "", "", "", "", fmt.Errorf("skaband login: %w", err)
Earl Lee2e463fb2025-04-17 11:22:22 -0700213 }
214 apiURL = resp.Header.Get("X-API-URL")
215 apiKey = resp.Header.Get("X-API-Key")
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000216 oaiModelName = resp.Header.Get("X-OAI-Model")
Earl Lee2e463fb2025-04-17 11:22:22 -0700217 defer resp.Body.Close()
218 _, err = io.Copy(stdout, resp.Body)
219 if err != nil {
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000220 return "", "", "", "", fmt.Errorf("skaband login: %w", err)
Earl Lee2e463fb2025-04-17 11:22:22 -0700221 }
222 if resp.StatusCode != 200 {
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000223 return "", "", "", "", fmt.Errorf("skaband login failed: %d", resp.StatusCode)
Earl Lee2e463fb2025-04-17 11:22:22 -0700224 }
225 if apiURL == "" {
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000226 return "", "", "", "", fmt.Errorf("skaband returned no api url")
Earl Lee2e463fb2025-04-17 11:22:22 -0700227 }
228 if apiKey == "" {
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000229 return "", "", "", "", fmt.Errorf("skaband returned no api key")
Earl Lee2e463fb2025-04-17 11:22:22 -0700230 }
Josh Bleecher Snyderd1c1ace2025-07-29 00:16:27 +0000231 return pubKey, apiURL, oaiModelName, apiKey, nil
Earl Lee2e463fb2025-04-17 11:22:22 -0700232}
233
Josh Bleecher Snyder75b45f52025-07-17 15:47:32 -0700234func DefaultKeyPath(skabandAddr string) string {
Earl Lee2e463fb2025-04-17 11:22:22 -0700235 homeDir, err := os.UserHomeDir()
236 if err != nil {
237 panic(err)
238 }
239 cacheDir := filepath.Join(homeDir, ".cache", "sketch")
Josh Bleecher Snyder75b45f52025-07-17 15:47:32 -0700240 if skabandAddr != "https://sketch.dev" { // main server gets "root" cache dir, for backwards compatibility
241 h := sha256.Sum256([]byte(skabandAddr))
242 cacheDir = filepath.Join(cacheDir, hex.EncodeToString(h[:8]))
243 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700244 os.MkdirAll(cacheDir, 0o777)
245 return filepath.Join(cacheDir, "sketch.ed25519")
246}
247
248func LocalhostToDockerInternal(skabandURL string) (string, error) {
249 u, err := url.Parse(skabandURL)
250 if err != nil {
251 return "", fmt.Errorf("localhostToDockerInternal: %w", err)
252 }
253 switch u.Hostname() {
254 case "localhost", "127.0.0.1":
255 host := "host.docker.internal"
256 if port := u.Port(); port != "" {
257 host += ":" + port
258 }
259 u.Host = host
260 return u.String(), nil
261 }
262 return skabandURL, nil
263}
David Crawshaw0ead54d2025-05-16 13:58:36 -0700264
265// NewSessionID generates a new 10-byte random Session ID.
266func NewSessionID() string {
267 u1, u2 := rand.Uint64(), rand.Uint64N(1<<16)
268 s := crock32.Encode(u1) + crock32.Encode(uint64(u2))
269 if len(s) < 16 {
270 s += strings.Repeat("0", 16-len(s))
271 }
272 return s[0:4] + "-" + s[4:8] + "-" + s[8:12] + "-" + s[12:16]
273}
Philip Zeyligerc17ffe32025-06-05 19:49:13 -0700274
Philip Zeyliger59789952025-06-28 20:02:23 -0700275// Regex pattern for SessionID format: xxxx-xxxx-xxxx-xxxx
276// Where x is a valid Crockford Base32 character (0-9, A-H, J-N, P-Z)
277// Case-insensitive match
278var sessionIdRegexp = regexp.MustCompile(
279 "^[0-9A-HJ-NP-Za-hj-np-z]{4}-[0-9A-HJ-NP-Za-hj-np-z]{4}-[0-9A-HJ-NP-Za-hj-np-z]{4}-[0-9A-HJ-NP-Za-hj-np-z]{4}")
280
281func ValidateSessionID(sessionID string) bool {
282 return sessionIdRegexp.MatchString(sessionID)
283}
284
Philip Zeyliger0113be52025-06-07 23:53:41 +0000285// Addr returns the skaband server address
286func (c *SkabandClient) Addr() string {
287 if c == nil {
288 return ""
289 }
290 return c.addr
291}
292
Philip Zeyligerc17ffe32025-06-05 19:49:13 -0700293// NewSkabandClient creates a new skaband client
294func NewSkabandClient(addr, publicKey string) *SkabandClient {
295 // Apply localhost-to-docker-internal transformation if needed
296 if _, err := os.Stat("/.dockerenv"); err == nil { // inDocker
297 if newAddr, err := LocalhostToDockerInternal(addr); err == nil {
298 addr = newAddr
299 }
300 }
301
302 return &SkabandClient{
303 addr: addr,
304 publicKey: publicKey,
305 client: &http.Client{Timeout: 30 * time.Second},
306 }
307}
308
Philip Zeyligerc17ffe32025-06-05 19:49:13 -0700309// DialAndServeLoop is a redial loop around DialAndServe.
Philip Zeyligerf2814ea2025-06-30 10:16:50 -0700310func (c *SkabandClient) DialAndServeLoop(ctx context.Context, sessionID string, sessionSecret string, srv http.Handler, connectFn func(connected bool)) {
Philip Zeyligerc17ffe32025-06-05 19:49:13 -0700311 skabandAddr := c.addr
312 clientPubKey := c.publicKey
313
314 if _, err := os.Stat("/.dockerenv"); err == nil { // inDocker
315 if addr, err := LocalhostToDockerInternal(skabandAddr); err == nil {
316 skabandAddr = addr
317 }
318 }
319
320 var skabandConnected atomic.Bool
321 skabandHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
322 if r.URL.Path == "/skabandinit" {
323 b, err := io.ReadAll(r.Body)
324 if err != nil {
325 fmt.Printf("skabandinit failed: %v\n", err)
326 return
327 }
328 m := map[string]string{}
329 if err := json.Unmarshal(b, &m); err != nil {
330 fmt.Printf("skabandinit failed: %v\n", err)
331 return
332 }
333 skabandConnected.Store(true)
334 if connectFn != nil {
335 connectFn(true)
336 }
337 return
338 }
339 srv.ServeHTTP(w, r)
340 })
341
342 var lastErrLog time.Time
343 for {
Philip Zeyligerf2814ea2025-06-30 10:16:50 -0700344 if err := DialAndServe(ctx, skabandAddr, sessionID, clientPubKey, sessionSecret, skabandHandler); err != nil {
Philip Zeyligerc17ffe32025-06-05 19:49:13 -0700345 // NOTE: *just* backoff the logging. Backing off dialing
346 // is bad UX. Doing so saves negligible CPU and doing so
347 // without hurting UX requires interrupting the backoff with
348 // wake-from-sleep and network-up events from the OS,
349 // which are a pain to plumb.
350 if time.Since(lastErrLog) > 1*time.Minute {
351 slog.DebugContext(ctx, "skaband connection failed", "err", err)
352 lastErrLog = time.Now()
353 }
354 }
355 if skabandConnected.CompareAndSwap(true, false) {
356 if connectFn != nil {
357 connectFn(false)
358 }
359 }
360 time.Sleep(200 * time.Millisecond)
361 }
362}