blob: d13fb777fa5bd826bb3f594d33d21f2a3efd7f6b [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package skabandclient
2
3import (
4 "bufio"
5 "context"
6 "crypto/ed25519"
7 crand "crypto/rand"
8 "crypto/tls"
9 "crypto/x509"
10 "encoding/hex"
11 "encoding/json"
12 "encoding/pem"
13 "errors"
14 "fmt"
15 "io"
16 "log/slog"
David Crawshaw0ead54d2025-05-16 13:58:36 -070017 "math/rand/v2"
Earl Lee2e463fb2025-04-17 11:22:22 -070018 "net"
19 "net/http"
20 "net/url"
21 "os"
22 "path/filepath"
23 "strings"
Philip Zeyligere9eaf6c2025-05-19 16:14:39 -070024 "sync"
Earl Lee2e463fb2025-04-17 11:22:22 -070025 "sync/atomic"
26 "time"
27
David Crawshaw0ead54d2025-05-16 13:58:36 -070028 "github.com/richardlehane/crock32"
Earl Lee2e463fb2025-04-17 11:22:22 -070029 "golang.org/x/net/http2"
30)
31
32// DialAndServeLoop is a redial loop around DialAndServe.
33func DialAndServeLoop(ctx context.Context, skabandAddr, sessionID, clientPubKey string, srv http.Handler, connectFn func(connected bool)) {
34 if _, err := os.Stat("/.dockerenv"); err == nil { // inDocker
35 if addr, err := LocalhostToDockerInternal(skabandAddr); err == nil {
36 skabandAddr = addr
37 }
38 }
39
40 var skabandConnected atomic.Bool
41 skabandHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
42 if r.URL.Path == "/skabandinit" {
43 b, err := io.ReadAll(r.Body)
44 if err != nil {
45 fmt.Printf("skabandinit failed: %v\n", err)
46 return
47 }
48 m := map[string]string{}
49 if err := json.Unmarshal(b, &m); err != nil {
50 fmt.Printf("skabandinit failed: %v\n", err)
51 return
52 }
53 skabandConnected.Store(true)
54 if connectFn != nil {
55 connectFn(true)
56 }
57 return
58 }
59 srv.ServeHTTP(w, r)
60 })
61
62 var lastErrLog time.Time
63 for {
64 if err := DialAndServe(ctx, skabandAddr, sessionID, clientPubKey, skabandHandler); err != nil {
65 // NOTE: *just* backoff the logging. Backing off dialing
Josh Bleecher Snyder574eda82025-05-28 10:30:43 -070066 // is bad UX. Doing so saves negligible CPU and doing so
67 // without hurting UX requires interrupting the backoff with
Earl Lee2e463fb2025-04-17 11:22:22 -070068 // wake-from-sleep and network-up events from the OS,
69 // which are a pain to plumb.
70 if time.Since(lastErrLog) > 1*time.Minute {
71 slog.DebugContext(ctx, "skaband connection failed", "err", err)
72 lastErrLog = time.Now()
73 }
74 }
75 if skabandConnected.CompareAndSwap(true, false) {
76 if connectFn != nil {
77 connectFn(false)
78 }
79 }
80 time.Sleep(200 * time.Millisecond)
81 }
82}
83
84func DialAndServe(ctx context.Context, hostURL, sessionID, clientPubKey string, h http.Handler) (err error) {
85 // Connect to the server.
86 var conn net.Conn
87 if strings.HasPrefix(hostURL, "https://") {
88 u, err := url.Parse(hostURL)
89 if err != nil {
90 return err
91 }
92 port := u.Port()
93 if port == "" {
94 port = "443"
95 }
96 dialer := tls.Dialer{}
97 conn, err = dialer.DialContext(ctx, "tcp4", u.Host+":"+port)
98 } else if strings.HasPrefix(hostURL, "http://") {
99 dialer := net.Dialer{}
100 conn, err = dialer.DialContext(ctx, "tcp4", strings.TrimPrefix(hostURL, "http://"))
101 } else {
102 return fmt.Errorf("skabandclient.Dial: bad url, needs to be http or https: %s", hostURL)
103 }
104 if err != nil {
105 return fmt.Errorf("skabandclient: %w", err)
106 }
Philip Zeyligerfe3e9f72025-04-24 09:02:05 -0700107 if conn == nil {
108 return fmt.Errorf("skabandclient: nil connection")
109 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700110 defer conn.Close()
111
112 // "Upgrade" our connection, like a WebSocket does.
113 req, err := http.NewRequest("POST", hostURL+"/attach", nil)
114 if err != nil {
115 return fmt.Errorf("skabandclient.Dial: /attach: %w", err)
116 }
117 req.Header.Set("Connection", "Upgrade")
118 req.Header.Set("Upgrade", "ska")
119 req.Header.Set("Session-ID", sessionID)
120 req.Header.Set("Public-Key", clientPubKey)
121
122 if err := req.Write(conn); err != nil {
123 return fmt.Errorf("skabandclient.Dial: write upgrade request: %w", err)
124 }
125 reader := bufio.NewReader(conn)
126 resp, err := http.ReadResponse(reader, req)
127 if err != nil {
Philip Zeyligerfe3e9f72025-04-24 09:02:05 -0700128 if resp != nil {
129 b, _ := io.ReadAll(resp.Body)
130 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w: %s", err, b)
131 } else {
132 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w", err)
133 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700134 }
135 defer resp.Body.Close()
136 if resp.StatusCode != http.StatusSwitchingProtocols {
137 b, _ := io.ReadAll(resp.Body)
138 return fmt.Errorf("skabandclient.Dial: unexpected status code: %d: %s", resp.StatusCode, b)
139 }
140 if !strings.Contains(resp.Header.Get("Upgrade"), "ska") {
141 return errors.New("skabandclient.Dial: server did not upgrade to ska protocol")
142 }
143 if buf := reader.Buffered(); buf > 0 {
144 peek, _ := reader.Peek(buf)
145 return fmt.Errorf("skabandclient.Dial: buffered read after upgrade response: %d: %q", buf, string(peek))
146 }
147
148 // Send Magic.
149 const magic = "skaband\n"
150 if _, err := conn.Write([]byte(magic)); err != nil {
151 return fmt.Errorf("skabandclient.Dial: failed to send upgrade init message: %w", err)
152 }
153
154 // We have a TCP connection to the server and have been through the upgrade dance.
155 // Now we can run an HTTP server over that connection ("inverting" the HTTP flow).
Philip Zeyligere9eaf6c2025-05-19 16:14:39 -0700156 // Skaband is expected to heartbeat within 60 seconds.
157 lastHeartbeat := time.Now()
158 mu := sync.Mutex{}
159 go func() {
160 for {
161 time.Sleep(5 * time.Second)
162 mu.Lock()
163 if time.Since(lastHeartbeat) > 60*time.Second {
164 mu.Unlock()
165 conn.Close()
166 slog.Info("skaband heartbeat timeout")
167 return
168 }
169 mu.Unlock()
170 }
171 }()
Earl Lee2e463fb2025-04-17 11:22:22 -0700172 server := &http2.Server{}
Philip Zeyligere9eaf6c2025-05-19 16:14:39 -0700173 h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
174 if r.URL.Path == "/skabandheartbeat" {
175 w.WriteHeader(http.StatusOK)
176 mu.Lock()
177 defer mu.Unlock()
178 lastHeartbeat = time.Now()
179 }
180 h.ServeHTTP(w, r)
181 })
Earl Lee2e463fb2025-04-17 11:22:22 -0700182 server.ServeConn(conn, &http2.ServeConnOpts{
Philip Zeyligere9eaf6c2025-05-19 16:14:39 -0700183 Handler: h2,
Earl Lee2e463fb2025-04-17 11:22:22 -0700184 })
185
186 return nil
187}
188
189func decodePrivKey(privData []byte) (ed25519.PrivateKey, error) {
190 privBlock, _ := pem.Decode(privData)
191 if privBlock == nil || privBlock.Type != "PRIVATE KEY" {
192 return nil, fmt.Errorf("no valid private key block found")
193 }
194 parsedPriv, err := x509.ParsePKCS8PrivateKey(privBlock.Bytes)
195 if err != nil {
196 return nil, err
197 }
198 return parsedPriv.(ed25519.PrivateKey), nil
199}
200
201func encodePrivateKey(privKey ed25519.PrivateKey) ([]byte, error) {
202 privBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
203 if err != nil {
204 return nil, err
205 }
206 return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}), nil
207}
208
209func LoadOrCreatePrivateKey(path string) (ed25519.PrivateKey, error) {
210 privData, err := os.ReadFile(path)
211 if os.IsNotExist(err) {
212 _, privKey, err := ed25519.GenerateKey(crand.Reader)
213 if err != nil {
214 return nil, err
215 }
216 b, err := encodePrivateKey(privKey)
David Crawshaw961cc9e2025-05-05 14:33:33 -0700217 if err != nil {
218 return nil, err
219 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700220 if err := os.WriteFile(path, b, 0o600); err != nil {
221 return nil, err
222 }
223 return privKey, nil
224 } else if err != nil {
225 return nil, fmt.Errorf("read key failed: %w", err)
226 }
227 key, err := decodePrivKey(privData)
228 if err != nil {
229 return nil, fmt.Errorf("%s: %w", path, err)
230 }
231 return key, nil
232}
233
David Crawshaw961cc9e2025-05-05 14:33:33 -0700234func Login(stdout io.Writer, privKey ed25519.PrivateKey, skabandAddr, sessionID, model string) (pubKey, apiURL, apiKey string, err error) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700235 sig := ed25519.Sign(privKey, []byte(sessionID))
236
237 req, err := http.NewRequest("POST", skabandAddr+"/authclient", nil)
238 if err != nil {
239 return "", "", "", err
240 }
241 pubKey = hex.EncodeToString(privKey.Public().(ed25519.PublicKey))
242 req.Header.Set("Public-Key", pubKey)
243 req.Header.Set("Session-ID", sessionID)
244 req.Header.Set("Session-ID-Sig", hex.EncodeToString(sig))
David Crawshaw961cc9e2025-05-05 14:33:33 -0700245 req.Header.Set("X-Model", model)
Earl Lee2e463fb2025-04-17 11:22:22 -0700246 resp, err := http.DefaultClient.Do(req)
247 if err != nil {
248 return "", "", "", fmt.Errorf("skaband login: %w", err)
249 }
250 apiURL = resp.Header.Get("X-API-URL")
251 apiKey = resp.Header.Get("X-API-Key")
252 defer resp.Body.Close()
253 _, err = io.Copy(stdout, resp.Body)
254 if err != nil {
255 return "", "", "", fmt.Errorf("skaband login: %w", err)
256 }
257 if resp.StatusCode != 200 {
258 return "", "", "", fmt.Errorf("skaband login failed: %d", resp.StatusCode)
259 }
260 if apiURL == "" {
261 return "", "", "", fmt.Errorf("skaband returned no api url")
262 }
263 if apiKey == "" {
264 return "", "", "", fmt.Errorf("skaband returned no api key")
265 }
266 return pubKey, apiURL, apiKey, nil
267}
268
269func DefaultKeyPath() string {
270 homeDir, err := os.UserHomeDir()
271 if err != nil {
272 panic(err)
273 }
274 cacheDir := filepath.Join(homeDir, ".cache", "sketch")
275 os.MkdirAll(cacheDir, 0o777)
276 return filepath.Join(cacheDir, "sketch.ed25519")
277}
278
279func LocalhostToDockerInternal(skabandURL string) (string, error) {
280 u, err := url.Parse(skabandURL)
281 if err != nil {
282 return "", fmt.Errorf("localhostToDockerInternal: %w", err)
283 }
284 switch u.Hostname() {
285 case "localhost", "127.0.0.1":
286 host := "host.docker.internal"
287 if port := u.Port(); port != "" {
288 host += ":" + port
289 }
290 u.Host = host
291 return u.String(), nil
292 }
293 return skabandURL, nil
294}
David Crawshaw0ead54d2025-05-16 13:58:36 -0700295
296// NewSessionID generates a new 10-byte random Session ID.
297func NewSessionID() string {
298 u1, u2 := rand.Uint64(), rand.Uint64N(1<<16)
299 s := crock32.Encode(u1) + crock32.Encode(uint64(u2))
300 if len(s) < 16 {
301 s += strings.Repeat("0", 16-len(s))
302 }
303 return s[0:4] + "-" + s[4:8] + "-" + s[8:12] + "-" + s[12:16]
304}