blob: 1f1e92b3189820b654957b50e633f0e7ed3277f4 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package skabandclient
2
3import (
4 "bufio"
5 "context"
6 "crypto/ed25519"
7 crand "crypto/rand"
8 "crypto/tls"
9 "crypto/x509"
10 "encoding/hex"
11 "encoding/json"
12 "encoding/pem"
13 "errors"
14 "fmt"
15 "io"
16 "log/slog"
17 "net"
18 "net/http"
19 "net/url"
20 "os"
21 "path/filepath"
22 "strings"
23 "sync/atomic"
24 "time"
25
26 "golang.org/x/net/http2"
27)
28
29// DialAndServeLoop is a redial loop around DialAndServe.
30func DialAndServeLoop(ctx context.Context, skabandAddr, sessionID, clientPubKey string, srv http.Handler, connectFn func(connected bool)) {
31 if _, err := os.Stat("/.dockerenv"); err == nil { // inDocker
32 if addr, err := LocalhostToDockerInternal(skabandAddr); err == nil {
33 skabandAddr = addr
34 }
35 }
36
37 var skabandConnected atomic.Bool
38 skabandHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
39 if r.URL.Path == "/skabandinit" {
40 b, err := io.ReadAll(r.Body)
41 if err != nil {
42 fmt.Printf("skabandinit failed: %v\n", err)
43 return
44 }
45 m := map[string]string{}
46 if err := json.Unmarshal(b, &m); err != nil {
47 fmt.Printf("skabandinit failed: %v\n", err)
48 return
49 }
50 skabandConnected.Store(true)
51 if connectFn != nil {
52 connectFn(true)
53 }
54 return
55 }
56 srv.ServeHTTP(w, r)
57 })
58
59 var lastErrLog time.Time
60 for {
61 if err := DialAndServe(ctx, skabandAddr, sessionID, clientPubKey, skabandHandler); err != nil {
62 // NOTE: *just* backoff the logging. Backing off dialing
63 // is bad UX. Doing so saves negligble CPU and doing so
64 // without huring UX requires interrupting the backoff with
65 // wake-from-sleep and network-up events from the OS,
66 // which are a pain to plumb.
67 if time.Since(lastErrLog) > 1*time.Minute {
68 slog.DebugContext(ctx, "skaband connection failed", "err", err)
69 lastErrLog = time.Now()
70 }
71 }
72 if skabandConnected.CompareAndSwap(true, false) {
73 if connectFn != nil {
74 connectFn(false)
75 }
76 }
77 time.Sleep(200 * time.Millisecond)
78 }
79}
80
81func DialAndServe(ctx context.Context, hostURL, sessionID, clientPubKey string, h http.Handler) (err error) {
82 // Connect to the server.
83 var conn net.Conn
84 if strings.HasPrefix(hostURL, "https://") {
85 u, err := url.Parse(hostURL)
86 if err != nil {
87 return err
88 }
89 port := u.Port()
90 if port == "" {
91 port = "443"
92 }
93 dialer := tls.Dialer{}
94 conn, err = dialer.DialContext(ctx, "tcp4", u.Host+":"+port)
95 } else if strings.HasPrefix(hostURL, "http://") {
96 dialer := net.Dialer{}
97 conn, err = dialer.DialContext(ctx, "tcp4", strings.TrimPrefix(hostURL, "http://"))
98 } else {
99 return fmt.Errorf("skabandclient.Dial: bad url, needs to be http or https: %s", hostURL)
100 }
101 if err != nil {
102 return fmt.Errorf("skabandclient: %w", err)
103 }
104 defer conn.Close()
105
106 // "Upgrade" our connection, like a WebSocket does.
107 req, err := http.NewRequest("POST", hostURL+"/attach", nil)
108 if err != nil {
109 return fmt.Errorf("skabandclient.Dial: /attach: %w", err)
110 }
111 req.Header.Set("Connection", "Upgrade")
112 req.Header.Set("Upgrade", "ska")
113 req.Header.Set("Session-ID", sessionID)
114 req.Header.Set("Public-Key", clientPubKey)
115
116 if err := req.Write(conn); err != nil {
117 return fmt.Errorf("skabandclient.Dial: write upgrade request: %w", err)
118 }
119 reader := bufio.NewReader(conn)
120 resp, err := http.ReadResponse(reader, req)
121 if err != nil {
122 b, _ := io.ReadAll(resp.Body)
123 return fmt.Errorf("skabandclient.Dial: read upgrade response: %w: %s", err, b)
124 }
125 defer resp.Body.Close()
126 if resp.StatusCode != http.StatusSwitchingProtocols {
127 b, _ := io.ReadAll(resp.Body)
128 return fmt.Errorf("skabandclient.Dial: unexpected status code: %d: %s", resp.StatusCode, b)
129 }
130 if !strings.Contains(resp.Header.Get("Upgrade"), "ska") {
131 return errors.New("skabandclient.Dial: server did not upgrade to ska protocol")
132 }
133 if buf := reader.Buffered(); buf > 0 {
134 peek, _ := reader.Peek(buf)
135 return fmt.Errorf("skabandclient.Dial: buffered read after upgrade response: %d: %q", buf, string(peek))
136 }
137
138 // Send Magic.
139 const magic = "skaband\n"
140 if _, err := conn.Write([]byte(magic)); err != nil {
141 return fmt.Errorf("skabandclient.Dial: failed to send upgrade init message: %w", err)
142 }
143
144 // We have a TCP connection to the server and have been through the upgrade dance.
145 // Now we can run an HTTP server over that connection ("inverting" the HTTP flow).
146 server := &http2.Server{}
147 server.ServeConn(conn, &http2.ServeConnOpts{
148 Handler: h,
149 })
150
151 return nil
152}
153
154func decodePrivKey(privData []byte) (ed25519.PrivateKey, error) {
155 privBlock, _ := pem.Decode(privData)
156 if privBlock == nil || privBlock.Type != "PRIVATE KEY" {
157 return nil, fmt.Errorf("no valid private key block found")
158 }
159 parsedPriv, err := x509.ParsePKCS8PrivateKey(privBlock.Bytes)
160 if err != nil {
161 return nil, err
162 }
163 return parsedPriv.(ed25519.PrivateKey), nil
164}
165
166func encodePrivateKey(privKey ed25519.PrivateKey) ([]byte, error) {
167 privBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
168 if err != nil {
169 return nil, err
170 }
171 return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}), nil
172}
173
174func LoadOrCreatePrivateKey(path string) (ed25519.PrivateKey, error) {
175 privData, err := os.ReadFile(path)
176 if os.IsNotExist(err) {
177 _, privKey, err := ed25519.GenerateKey(crand.Reader)
178 if err != nil {
179 return nil, err
180 }
181 b, err := encodePrivateKey(privKey)
182 if err := os.WriteFile(path, b, 0o600); err != nil {
183 return nil, err
184 }
185 return privKey, nil
186 } else if err != nil {
187 return nil, fmt.Errorf("read key failed: %w", err)
188 }
189 key, err := decodePrivKey(privData)
190 if err != nil {
191 return nil, fmt.Errorf("%s: %w", path, err)
192 }
193 return key, nil
194}
195
196func Login(stdout io.Writer, privKey ed25519.PrivateKey, skabandAddr, sessionID string) (pubKey, apiURL, apiKey string, err error) {
197 sig := ed25519.Sign(privKey, []byte(sessionID))
198
199 req, err := http.NewRequest("POST", skabandAddr+"/authclient", nil)
200 if err != nil {
201 return "", "", "", err
202 }
203 pubKey = hex.EncodeToString(privKey.Public().(ed25519.PublicKey))
204 req.Header.Set("Public-Key", pubKey)
205 req.Header.Set("Session-ID", sessionID)
206 req.Header.Set("Session-ID-Sig", hex.EncodeToString(sig))
207 resp, err := http.DefaultClient.Do(req)
208 if err != nil {
209 return "", "", "", fmt.Errorf("skaband login: %w", err)
210 }
211 apiURL = resp.Header.Get("X-API-URL")
212 apiKey = resp.Header.Get("X-API-Key")
213 defer resp.Body.Close()
214 _, err = io.Copy(stdout, resp.Body)
215 if err != nil {
216 return "", "", "", fmt.Errorf("skaband login: %w", err)
217 }
218 if resp.StatusCode != 200 {
219 return "", "", "", fmt.Errorf("skaband login failed: %d", resp.StatusCode)
220 }
221 if apiURL == "" {
222 return "", "", "", fmt.Errorf("skaband returned no api url")
223 }
224 if apiKey == "" {
225 return "", "", "", fmt.Errorf("skaband returned no api key")
226 }
227 return pubKey, apiURL, apiKey, nil
228}
229
230func DefaultKeyPath() string {
231 homeDir, err := os.UserHomeDir()
232 if err != nil {
233 panic(err)
234 }
235 cacheDir := filepath.Join(homeDir, ".cache", "sketch")
236 os.MkdirAll(cacheDir, 0o777)
237 return filepath.Join(cacheDir, "sketch.ed25519")
238}
239
240func LocalhostToDockerInternal(skabandURL string) (string, error) {
241 u, err := url.Parse(skabandURL)
242 if err != nil {
243 return "", fmt.Errorf("localhostToDockerInternal: %w", err)
244 }
245 switch u.Hostname() {
246 case "localhost", "127.0.0.1":
247 host := "host.docker.internal"
248 if port := u.Port(); port != "" {
249 host += ":" + port
250 }
251 u.Host = host
252 return u.String(), nil
253 }
254 return skabandURL, nil
255}