blob: cea4fbe746ce6bad2a30fef413a28c07810a1fa3 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package skabandclient
2
3import (
4 "bufio"
5 "context"
6 "crypto/ed25519"
7 crand "crypto/rand"
8 "crypto/tls"
9 "crypto/x509"
10 "encoding/hex"
11 "encoding/json"
12 "encoding/pem"
13 "errors"
14 "fmt"
15 "io"
16 "log/slog"
17 "net"
18 "net/http"
19 "net/url"
20 "os"
21 "path/filepath"
22 "strings"
23 "sync/atomic"
24 "time"
25
26 "golang.org/x/net/http2"
27)
28
29// DialAndServeLoop is a redial loop around DialAndServe.
30func DialAndServeLoop(ctx context.Context, skabandAddr, sessionID, clientPubKey string, srv http.Handler, connectFn func(connected bool)) {
31 if _, err := os.Stat("/.dockerenv"); err == nil { // inDocker
32 if addr, err := LocalhostToDockerInternal(skabandAddr); err == nil {
33 skabandAddr = addr
34 }
35 }
36
37 var skabandConnected atomic.Bool
38 skabandHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
39 if r.URL.Path == "/skabandinit" {
40 b, err := io.ReadAll(r.Body)
41 if err != nil {
42 fmt.Printf("skabandinit failed: %v\n", err)
43 return
44 }
45 m := map[string]string{}
46 if err := json.Unmarshal(b, &m); err != nil {
47 fmt.Printf("skabandinit failed: %v\n", err)
48 return
49 }
50 skabandConnected.Store(true)
51 if connectFn != nil {
52 connectFn(true)
53 }
54 return
55 }
56 srv.ServeHTTP(w, r)
57 })
58
59 var lastErrLog time.Time
60 for {
61 if err := DialAndServe(ctx, skabandAddr, sessionID, clientPubKey, skabandHandler); err != nil {
62 // NOTE: *just* backoff the logging. Backing off dialing
63 // is bad UX. Doing so saves negligble CPU and doing so
64 // without huring UX requires interrupting the backoff with
65 // wake-from-sleep and network-up events from the OS,
66 // which are a pain to plumb.
67 if time.Since(lastErrLog) > 1*time.Minute {
68 slog.DebugContext(ctx, "skaband connection failed", "err", err)
69 lastErrLog = time.Now()
70 }
71 }
72 if skabandConnected.CompareAndSwap(true, false) {
73 if connectFn != nil {
74 connectFn(false)
75 }
76 }
77 time.Sleep(200 * time.Millisecond)
78 }
79}
80
81func DialAndServe(ctx context.Context, hostURL, sessionID, clientPubKey string, h http.Handler) (err error) {
82 // Connect to the server.
83 var conn net.Conn
84 if strings.HasPrefix(hostURL, "https://") {
85 u, err := url.Parse(hostURL)
86 if err != nil {
87 return err
88 }
89 port := u.Port()
90 if port == "" {
91 port = "443"
92 }
93 dialer := tls.Dialer{}
94 conn, err = dialer.DialContext(ctx, "tcp4", u.Host+":"+port)
95 } else if strings.HasPrefix(hostURL, "http://") {
96 dialer := net.Dialer{}
97 conn, err = dialer.DialContext(ctx, "tcp4", strings.TrimPrefix(hostURL, "http://"))
98 } else {
99 return fmt.Errorf("skabandclient.Dial: bad url, needs to be http or https: %s", hostURL)
100 }
101 if err != nil {
102 return fmt.Errorf("skabandclient: %w", err)
103 }
Philip Zeyligerfe3e9f72025-04-24 09:02:05 -0700104 if conn == nil {
105 return fmt.Errorf("skabandclient: nil connection")
106 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700107 defer conn.Close()
108
109 // "Upgrade" our connection, like a WebSocket does.
110 req, err := http.NewRequest("POST", hostURL+"/attach", nil)
111 if err != nil {
112 return fmt.Errorf("skabandclient.Dial: /attach: %w", err)
113 }
114 req.Header.Set("Connection", "Upgrade")
115 req.Header.Set("Upgrade", "ska")
116 req.Header.Set("Session-ID", sessionID)
117 req.Header.Set("Public-Key", clientPubKey)
118
119 if err := req.Write(conn); err != nil {
120 return fmt.Errorf("skabandclient.Dial: write upgrade request: %w", err)
121 }
122 reader := bufio.NewReader(conn)
123 resp, err := http.ReadResponse(reader, req)
124 if err != nil {
Philip Zeyligerfe3e9f72025-04-24 09:02:05 -0700125 if resp != nil {
126 b, _ := io.ReadAll(resp.Body)
127 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w: %s", err, b)
128 } else {
129 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w", err)
130 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700131 }
132 defer resp.Body.Close()
133 if resp.StatusCode != http.StatusSwitchingProtocols {
134 b, _ := io.ReadAll(resp.Body)
135 return fmt.Errorf("skabandclient.Dial: unexpected status code: %d: %s", resp.StatusCode, b)
136 }
137 if !strings.Contains(resp.Header.Get("Upgrade"), "ska") {
138 return errors.New("skabandclient.Dial: server did not upgrade to ska protocol")
139 }
140 if buf := reader.Buffered(); buf > 0 {
141 peek, _ := reader.Peek(buf)
142 return fmt.Errorf("skabandclient.Dial: buffered read after upgrade response: %d: %q", buf, string(peek))
143 }
144
145 // Send Magic.
146 const magic = "skaband\n"
147 if _, err := conn.Write([]byte(magic)); err != nil {
148 return fmt.Errorf("skabandclient.Dial: failed to send upgrade init message: %w", err)
149 }
150
151 // We have a TCP connection to the server and have been through the upgrade dance.
152 // Now we can run an HTTP server over that connection ("inverting" the HTTP flow).
153 server := &http2.Server{}
154 server.ServeConn(conn, &http2.ServeConnOpts{
155 Handler: h,
156 })
157
158 return nil
159}
160
161func decodePrivKey(privData []byte) (ed25519.PrivateKey, error) {
162 privBlock, _ := pem.Decode(privData)
163 if privBlock == nil || privBlock.Type != "PRIVATE KEY" {
164 return nil, fmt.Errorf("no valid private key block found")
165 }
166 parsedPriv, err := x509.ParsePKCS8PrivateKey(privBlock.Bytes)
167 if err != nil {
168 return nil, err
169 }
170 return parsedPriv.(ed25519.PrivateKey), nil
171}
172
173func encodePrivateKey(privKey ed25519.PrivateKey) ([]byte, error) {
174 privBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
175 if err != nil {
176 return nil, err
177 }
178 return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}), nil
179}
180
181func LoadOrCreatePrivateKey(path string) (ed25519.PrivateKey, error) {
182 privData, err := os.ReadFile(path)
183 if os.IsNotExist(err) {
184 _, privKey, err := ed25519.GenerateKey(crand.Reader)
185 if err != nil {
186 return nil, err
187 }
188 b, err := encodePrivateKey(privKey)
David Crawshaw961cc9e2025-05-05 14:33:33 -0700189 if err != nil {
190 return nil, err
191 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700192 if err := os.WriteFile(path, b, 0o600); err != nil {
193 return nil, err
194 }
195 return privKey, nil
196 } else if err != nil {
197 return nil, fmt.Errorf("read key failed: %w", err)
198 }
199 key, err := decodePrivKey(privData)
200 if err != nil {
201 return nil, fmt.Errorf("%s: %w", path, err)
202 }
203 return key, nil
204}
205
David Crawshaw961cc9e2025-05-05 14:33:33 -0700206func Login(stdout io.Writer, privKey ed25519.PrivateKey, skabandAddr, sessionID, model string) (pubKey, apiURL, apiKey string, err error) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700207 sig := ed25519.Sign(privKey, []byte(sessionID))
208
209 req, err := http.NewRequest("POST", skabandAddr+"/authclient", nil)
210 if err != nil {
211 return "", "", "", err
212 }
213 pubKey = hex.EncodeToString(privKey.Public().(ed25519.PublicKey))
214 req.Header.Set("Public-Key", pubKey)
215 req.Header.Set("Session-ID", sessionID)
216 req.Header.Set("Session-ID-Sig", hex.EncodeToString(sig))
David Crawshaw961cc9e2025-05-05 14:33:33 -0700217 req.Header.Set("X-Model", model)
Earl Lee2e463fb2025-04-17 11:22:22 -0700218 resp, err := http.DefaultClient.Do(req)
219 if err != nil {
220 return "", "", "", fmt.Errorf("skaband login: %w", err)
221 }
222 apiURL = resp.Header.Get("X-API-URL")
223 apiKey = resp.Header.Get("X-API-Key")
224 defer resp.Body.Close()
225 _, err = io.Copy(stdout, resp.Body)
226 if err != nil {
227 return "", "", "", fmt.Errorf("skaband login: %w", err)
228 }
229 if resp.StatusCode != 200 {
230 return "", "", "", fmt.Errorf("skaband login failed: %d", resp.StatusCode)
231 }
232 if apiURL == "" {
233 return "", "", "", fmt.Errorf("skaband returned no api url")
234 }
235 if apiKey == "" {
236 return "", "", "", fmt.Errorf("skaband returned no api key")
237 }
238 return pubKey, apiURL, apiKey, nil
239}
240
241func DefaultKeyPath() string {
242 homeDir, err := os.UserHomeDir()
243 if err != nil {
244 panic(err)
245 }
246 cacheDir := filepath.Join(homeDir, ".cache", "sketch")
247 os.MkdirAll(cacheDir, 0o777)
248 return filepath.Join(cacheDir, "sketch.ed25519")
249}
250
251func LocalhostToDockerInternal(skabandURL string) (string, error) {
252 u, err := url.Parse(skabandURL)
253 if err != nil {
254 return "", fmt.Errorf("localhostToDockerInternal: %w", err)
255 }
256 switch u.Hostname() {
257 case "localhost", "127.0.0.1":
258 host := "host.docker.internal"
259 if port := u.Port(); port != "" {
260 host += ":" + port
261 }
262 u.Host = host
263 return u.String(), nil
264 }
265 return skabandURL, nil
266}