blob: 39efb522e66f47f1845967c814dddda20ea709ab [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package skabandclient
2
3import (
4 "bufio"
5 "context"
6 "crypto/ed25519"
7 crand "crypto/rand"
8 "crypto/tls"
9 "crypto/x509"
10 "encoding/hex"
11 "encoding/json"
12 "encoding/pem"
13 "errors"
14 "fmt"
15 "io"
16 "log/slog"
David Crawshaw0ead54d2025-05-16 13:58:36 -070017 "math/rand/v2"
Earl Lee2e463fb2025-04-17 11:22:22 -070018 "net"
19 "net/http"
20 "net/url"
21 "os"
22 "path/filepath"
23 "strings"
24 "sync/atomic"
25 "time"
26
David Crawshaw0ead54d2025-05-16 13:58:36 -070027 "github.com/richardlehane/crock32"
Earl Lee2e463fb2025-04-17 11:22:22 -070028 "golang.org/x/net/http2"
29)
30
31// DialAndServeLoop is a redial loop around DialAndServe.
32func DialAndServeLoop(ctx context.Context, skabandAddr, sessionID, clientPubKey string, srv http.Handler, connectFn func(connected bool)) {
33 if _, err := os.Stat("/.dockerenv"); err == nil { // inDocker
34 if addr, err := LocalhostToDockerInternal(skabandAddr); err == nil {
35 skabandAddr = addr
36 }
37 }
38
39 var skabandConnected atomic.Bool
40 skabandHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
41 if r.URL.Path == "/skabandinit" {
42 b, err := io.ReadAll(r.Body)
43 if err != nil {
44 fmt.Printf("skabandinit failed: %v\n", err)
45 return
46 }
47 m := map[string]string{}
48 if err := json.Unmarshal(b, &m); err != nil {
49 fmt.Printf("skabandinit failed: %v\n", err)
50 return
51 }
52 skabandConnected.Store(true)
53 if connectFn != nil {
54 connectFn(true)
55 }
56 return
57 }
58 srv.ServeHTTP(w, r)
59 })
60
61 var lastErrLog time.Time
62 for {
63 if err := DialAndServe(ctx, skabandAddr, sessionID, clientPubKey, skabandHandler); err != nil {
64 // NOTE: *just* backoff the logging. Backing off dialing
65 // is bad UX. Doing so saves negligble CPU and doing so
66 // without huring UX requires interrupting the backoff with
67 // wake-from-sleep and network-up events from the OS,
68 // which are a pain to plumb.
69 if time.Since(lastErrLog) > 1*time.Minute {
70 slog.DebugContext(ctx, "skaband connection failed", "err", err)
71 lastErrLog = time.Now()
72 }
73 }
74 if skabandConnected.CompareAndSwap(true, false) {
75 if connectFn != nil {
76 connectFn(false)
77 }
78 }
79 time.Sleep(200 * time.Millisecond)
80 }
81}
82
83func DialAndServe(ctx context.Context, hostURL, sessionID, clientPubKey string, h http.Handler) (err error) {
84 // Connect to the server.
85 var conn net.Conn
86 if strings.HasPrefix(hostURL, "https://") {
87 u, err := url.Parse(hostURL)
88 if err != nil {
89 return err
90 }
91 port := u.Port()
92 if port == "" {
93 port = "443"
94 }
95 dialer := tls.Dialer{}
96 conn, err = dialer.DialContext(ctx, "tcp4", u.Host+":"+port)
97 } else if strings.HasPrefix(hostURL, "http://") {
98 dialer := net.Dialer{}
99 conn, err = dialer.DialContext(ctx, "tcp4", strings.TrimPrefix(hostURL, "http://"))
100 } else {
101 return fmt.Errorf("skabandclient.Dial: bad url, needs to be http or https: %s", hostURL)
102 }
103 if err != nil {
104 return fmt.Errorf("skabandclient: %w", err)
105 }
Philip Zeyligerfe3e9f72025-04-24 09:02:05 -0700106 if conn == nil {
107 return fmt.Errorf("skabandclient: nil connection")
108 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700109 defer conn.Close()
110
111 // "Upgrade" our connection, like a WebSocket does.
112 req, err := http.NewRequest("POST", hostURL+"/attach", nil)
113 if err != nil {
114 return fmt.Errorf("skabandclient.Dial: /attach: %w", err)
115 }
116 req.Header.Set("Connection", "Upgrade")
117 req.Header.Set("Upgrade", "ska")
118 req.Header.Set("Session-ID", sessionID)
119 req.Header.Set("Public-Key", clientPubKey)
120
121 if err := req.Write(conn); err != nil {
122 return fmt.Errorf("skabandclient.Dial: write upgrade request: %w", err)
123 }
124 reader := bufio.NewReader(conn)
125 resp, err := http.ReadResponse(reader, req)
126 if err != nil {
Philip Zeyligerfe3e9f72025-04-24 09:02:05 -0700127 if resp != nil {
128 b, _ := io.ReadAll(resp.Body)
129 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w: %s", err, b)
130 } else {
131 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w", err)
132 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700133 }
134 defer resp.Body.Close()
135 if resp.StatusCode != http.StatusSwitchingProtocols {
136 b, _ := io.ReadAll(resp.Body)
137 return fmt.Errorf("skabandclient.Dial: unexpected status code: %d: %s", resp.StatusCode, b)
138 }
139 if !strings.Contains(resp.Header.Get("Upgrade"), "ska") {
140 return errors.New("skabandclient.Dial: server did not upgrade to ska protocol")
141 }
142 if buf := reader.Buffered(); buf > 0 {
143 peek, _ := reader.Peek(buf)
144 return fmt.Errorf("skabandclient.Dial: buffered read after upgrade response: %d: %q", buf, string(peek))
145 }
146
147 // Send Magic.
148 const magic = "skaband\n"
149 if _, err := conn.Write([]byte(magic)); err != nil {
150 return fmt.Errorf("skabandclient.Dial: failed to send upgrade init message: %w", err)
151 }
152
153 // We have a TCP connection to the server and have been through the upgrade dance.
154 // Now we can run an HTTP server over that connection ("inverting" the HTTP flow).
155 server := &http2.Server{}
156 server.ServeConn(conn, &http2.ServeConnOpts{
157 Handler: h,
158 })
159
160 return nil
161}
162
163func decodePrivKey(privData []byte) (ed25519.PrivateKey, error) {
164 privBlock, _ := pem.Decode(privData)
165 if privBlock == nil || privBlock.Type != "PRIVATE KEY" {
166 return nil, fmt.Errorf("no valid private key block found")
167 }
168 parsedPriv, err := x509.ParsePKCS8PrivateKey(privBlock.Bytes)
169 if err != nil {
170 return nil, err
171 }
172 return parsedPriv.(ed25519.PrivateKey), nil
173}
174
175func encodePrivateKey(privKey ed25519.PrivateKey) ([]byte, error) {
176 privBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
177 if err != nil {
178 return nil, err
179 }
180 return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}), nil
181}
182
183func LoadOrCreatePrivateKey(path string) (ed25519.PrivateKey, error) {
184 privData, err := os.ReadFile(path)
185 if os.IsNotExist(err) {
186 _, privKey, err := ed25519.GenerateKey(crand.Reader)
187 if err != nil {
188 return nil, err
189 }
190 b, err := encodePrivateKey(privKey)
David Crawshaw961cc9e2025-05-05 14:33:33 -0700191 if err != nil {
192 return nil, err
193 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700194 if err := os.WriteFile(path, b, 0o600); err != nil {
195 return nil, err
196 }
197 return privKey, nil
198 } else if err != nil {
199 return nil, fmt.Errorf("read key failed: %w", err)
200 }
201 key, err := decodePrivKey(privData)
202 if err != nil {
203 return nil, fmt.Errorf("%s: %w", path, err)
204 }
205 return key, nil
206}
207
David Crawshaw961cc9e2025-05-05 14:33:33 -0700208func Login(stdout io.Writer, privKey ed25519.PrivateKey, skabandAddr, sessionID, model string) (pubKey, apiURL, apiKey string, err error) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700209 sig := ed25519.Sign(privKey, []byte(sessionID))
210
211 req, err := http.NewRequest("POST", skabandAddr+"/authclient", nil)
212 if err != nil {
213 return "", "", "", err
214 }
215 pubKey = hex.EncodeToString(privKey.Public().(ed25519.PublicKey))
216 req.Header.Set("Public-Key", pubKey)
217 req.Header.Set("Session-ID", sessionID)
218 req.Header.Set("Session-ID-Sig", hex.EncodeToString(sig))
David Crawshaw961cc9e2025-05-05 14:33:33 -0700219 req.Header.Set("X-Model", model)
Earl Lee2e463fb2025-04-17 11:22:22 -0700220 resp, err := http.DefaultClient.Do(req)
221 if err != nil {
222 return "", "", "", fmt.Errorf("skaband login: %w", err)
223 }
224 apiURL = resp.Header.Get("X-API-URL")
225 apiKey = resp.Header.Get("X-API-Key")
226 defer resp.Body.Close()
227 _, err = io.Copy(stdout, resp.Body)
228 if err != nil {
229 return "", "", "", fmt.Errorf("skaband login: %w", err)
230 }
231 if resp.StatusCode != 200 {
232 return "", "", "", fmt.Errorf("skaband login failed: %d", resp.StatusCode)
233 }
234 if apiURL == "" {
235 return "", "", "", fmt.Errorf("skaband returned no api url")
236 }
237 if apiKey == "" {
238 return "", "", "", fmt.Errorf("skaband returned no api key")
239 }
240 return pubKey, apiURL, apiKey, nil
241}
242
243func DefaultKeyPath() string {
244 homeDir, err := os.UserHomeDir()
245 if err != nil {
246 panic(err)
247 }
248 cacheDir := filepath.Join(homeDir, ".cache", "sketch")
249 os.MkdirAll(cacheDir, 0o777)
250 return filepath.Join(cacheDir, "sketch.ed25519")
251}
252
253func LocalhostToDockerInternal(skabandURL string) (string, error) {
254 u, err := url.Parse(skabandURL)
255 if err != nil {
256 return "", fmt.Errorf("localhostToDockerInternal: %w", err)
257 }
258 switch u.Hostname() {
259 case "localhost", "127.0.0.1":
260 host := "host.docker.internal"
261 if port := u.Port(); port != "" {
262 host += ":" + port
263 }
264 u.Host = host
265 return u.String(), nil
266 }
267 return skabandURL, nil
268}
David Crawshaw0ead54d2025-05-16 13:58:36 -0700269
270// NewSessionID generates a new 10-byte random Session ID.
271func NewSessionID() string {
272 u1, u2 := rand.Uint64(), rand.Uint64N(1<<16)
273 s := crock32.Encode(u1) + crock32.Encode(uint64(u2))
274 if len(s) < 16 {
275 s += strings.Repeat("0", 16-len(s))
276 }
277 return s[0:4] + "-" + s[4:8] + "-" + s[8:12] + "-" + s[12:16]
278}