blob: d1c366a9bd955a978cc4b1a639b363554a6c1694 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package ant
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log/slog"
11 "math/rand/v2"
12 "net/http"
13 "strings"
14 "testing"
15 "time"
16
17 "sketch.dev/llm"
18)
19
20const (
21 DefaultModel = Claude37Sonnet
22 // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
23 // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
24 DefaultMaxTokens = 8192
25 DefaultURL = "https://api.anthropic.com/v1/messages"
26)
27
28const (
29 Claude35Sonnet = "claude-3-5-sonnet-20241022"
30 Claude35Haiku = "claude-3-5-haiku-20241022"
31 Claude37Sonnet = "claude-3-7-sonnet-20250219"
32)
33
34// Service provides Claude completions.
35// Fields should not be altered concurrently with calling any method on Service.
36type Service struct {
37 HTTPC *http.Client // defaults to http.DefaultClient if nil
38 URL string // defaults to DefaultURL if empty
39 APIKey string // must be non-empty
40 Model string // defaults to DefaultModel if empty
41 MaxTokens int // defaults to DefaultMaxTokens if zero
42}
43
44var _ llm.Service = (*Service)(nil)
45
46type content struct {
47 // TODO: image support?
48 // https://docs.anthropic.com/en/api/messages
49 ID string `json:"id,omitempty"`
50 Type string `json:"type,omitempty"`
51 Text string `json:"text,omitempty"`
52
53 // for thinking
54 Thinking string `json:"thinking,omitempty"`
55 Data string `json:"data,omitempty"` // for redacted_thinking
56 Signature string `json:"signature,omitempty"` // for thinking
57
58 // for tool_use
59 ToolName string `json:"name,omitempty"`
60 ToolInput json.RawMessage `json:"input,omitempty"`
61
62 // for tool_result
63 ToolUseID string `json:"tool_use_id,omitempty"`
64 ToolError bool `json:"is_error,omitempty"`
65 ToolResult string `json:"content,omitempty"`
66
67 // timing information for tool_result; not sent to Claude
68 StartTime *time.Time `json:"-"`
69 EndTime *time.Time `json:"-"`
70
71 CacheControl json.RawMessage `json:"cache_control,omitempty"`
72}
73
74// message represents a message in the conversation.
75type message struct {
76 Role string `json:"role"`
77 Content []content `json:"content"`
78 ToolUse *toolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
79}
80
81// toolUse represents a tool use in the message content.
82type toolUse struct {
83 ID string `json:"id"`
84 Name string `json:"name"`
85}
86
87// tool represents a tool available to Claude.
88type tool struct {
89 Name string `json:"name"`
90 // Type is used by the text editor tool; see
91 // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
92 Type string `json:"type,omitempty"`
93 Description string `json:"description,omitempty"`
94 InputSchema json.RawMessage `json:"input_schema,omitempty"`
95}
96
97// usage represents the billing and rate-limit usage.
98type usage struct {
99 InputTokens uint64 `json:"input_tokens"`
100 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
101 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
102 OutputTokens uint64 `json:"output_tokens"`
103 CostUSD float64 `json:"cost_usd"`
104}
105
106func (u *usage) Add(other usage) {
107 u.InputTokens += other.InputTokens
108 u.CacheCreationInputTokens += other.CacheCreationInputTokens
109 u.CacheReadInputTokens += other.CacheReadInputTokens
110 u.OutputTokens += other.OutputTokens
111 u.CostUSD += other.CostUSD
112}
113
114type errorResponse struct {
115 Type string `json:"type"`
116 Message string `json:"message"`
117}
118
119// response represents the response from the message API.
120type response struct {
121 ID string `json:"id"`
122 Type string `json:"type"`
123 Role string `json:"role"`
124 Model string `json:"model"`
125 Content []content `json:"content"`
126 StopReason string `json:"stop_reason"`
127 StopSequence *string `json:"stop_sequence,omitempty"`
128 Usage usage `json:"usage"`
129}
130
131type toolChoice struct {
132 Type string `json:"type"`
133 Name string `json:"name,omitempty"`
134}
135
136// https://docs.anthropic.com/en/api/messages#body-system
137type systemContent struct {
138 Text string `json:"text,omitempty"`
139 Type string `json:"type,omitempty"`
140 CacheControl json.RawMessage `json:"cache_control,omitempty"`
141}
142
143// request represents the request payload for creating a message.
144type request struct {
145 Model string `json:"model"`
146 Messages []message `json:"messages"`
147 ToolChoice *toolChoice `json:"tool_choice,omitempty"`
148 MaxTokens int `json:"max_tokens"`
149 Tools []*tool `json:"tools,omitempty"`
150 Stream bool `json:"stream,omitempty"`
151 System []systemContent `json:"system,omitempty"`
152 Temperature float64 `json:"temperature,omitempty"`
153 TopK int `json:"top_k,omitempty"`
154 TopP float64 `json:"top_p,omitempty"`
155 StopSequences []string `json:"stop_sequences,omitempty"`
156
157 TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
158}
159
160const dumpText = false // debugging toggle to see raw communications with Claude
161
162func mapped[Slice ~[]E, E, T any](s Slice, f func(E) T) []T {
163 out := make([]T, len(s))
164 for i, v := range s {
165 out[i] = f(v)
166 }
167 return out
168}
169
170func inverted[K, V cmp.Ordered](m map[K]V) map[V]K {
171 inv := make(map[V]K)
172 for k, v := range m {
173 if _, ok := inv[v]; ok {
174 panic(fmt.Errorf("inverted map has multiple keys for value %v", v))
175 }
176 inv[v] = k
177 }
178 return inv
179}
180
181var (
182 fromLLMRole = map[llm.MessageRole]string{
183 llm.MessageRoleAssistant: "assistant",
184 llm.MessageRoleUser: "user",
185 }
186 toLLMRole = inverted(fromLLMRole)
187
188 fromLLMContentType = map[llm.ContentType]string{
189 llm.ContentTypeText: "text",
190 llm.ContentTypeThinking: "thinking",
191 llm.ContentTypeRedactedThinking: "redacted_thinking",
192 llm.ContentTypeToolUse: "tool_use",
193 llm.ContentTypeToolResult: "tool_result",
194 }
195 toLLMContentType = inverted(fromLLMContentType)
196
197 fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
198 llm.ToolChoiceTypeAuto: "auto",
199 llm.ToolChoiceTypeAny: "any",
200 llm.ToolChoiceTypeNone: "none",
201 llm.ToolChoiceTypeTool: "tool",
202 }
203
204 toLLMStopReason = map[string]llm.StopReason{
205 "stop_sequence": llm.StopReasonStopSequence,
206 "max_tokens": llm.StopReasonMaxTokens,
207 "end_turn": llm.StopReasonEndTurn,
208 "tool_use": llm.StopReasonToolUse,
209 }
210)
211
212func fromLLMCache(c bool) json.RawMessage {
213 if !c {
214 return nil
215 }
216 return json.RawMessage(`{"type":"ephemeral"}`)
217}
218
219func fromLLMContent(c llm.Content) content {
220 return content{
221 ID: c.ID,
222 Type: fromLLMContentType[c.Type],
223 Text: c.Text,
224 Thinking: c.Thinking,
225 Data: c.Data,
226 Signature: c.Signature,
227 ToolName: c.ToolName,
228 ToolInput: c.ToolInput,
229 ToolUseID: c.ToolUseID,
230 ToolError: c.ToolError,
231 ToolResult: c.ToolResult,
232 CacheControl: fromLLMCache(c.Cache),
233 }
234}
235
236func fromLLMToolUse(tu *llm.ToolUse) *toolUse {
237 if tu == nil {
238 return nil
239 }
240 return &toolUse{
241 ID: tu.ID,
242 Name: tu.Name,
243 }
244}
245
246func fromLLMMessage(msg llm.Message) message {
247 return message{
248 Role: fromLLMRole[msg.Role],
249 Content: mapped(msg.Content, fromLLMContent),
250 ToolUse: fromLLMToolUse(msg.ToolUse),
251 }
252}
253
254func fromLLMToolChoice(tc *llm.ToolChoice) *toolChoice {
255 if tc == nil {
256 return nil
257 }
258 return &toolChoice{
259 Type: fromLLMToolChoiceType[tc.Type],
260 Name: tc.Name,
261 }
262}
263
264func fromLLMTool(t *llm.Tool) *tool {
265 return &tool{
266 Name: t.Name,
267 Type: t.Type,
268 Description: t.Description,
269 InputSchema: t.InputSchema,
270 }
271}
272
273func fromLLMSystem(s llm.SystemContent) systemContent {
274 return systemContent{
275 Text: s.Text,
276 Type: s.Type,
277 CacheControl: fromLLMCache(s.Cache),
278 }
279}
280
281func (s *Service) fromLLMRequest(r *llm.Request) *request {
282 return &request{
283 Model: cmp.Or(s.Model, DefaultModel),
284 Messages: mapped(r.Messages, fromLLMMessage),
285 MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
286 ToolChoice: fromLLMToolChoice(r.ToolChoice),
287 Tools: mapped(r.Tools, fromLLMTool),
288 System: mapped(r.System, fromLLMSystem),
289 }
290}
291
292func toLLMUsage(u usage) llm.Usage {
293 return llm.Usage{
294 InputTokens: u.InputTokens,
295 CacheCreationInputTokens: u.CacheCreationInputTokens,
296 CacheReadInputTokens: u.CacheReadInputTokens,
297 OutputTokens: u.OutputTokens,
298 CostUSD: u.CostUSD,
299 }
300}
301
302func toLLMContent(c content) llm.Content {
303 return llm.Content{
304 ID: c.ID,
305 Type: toLLMContentType[c.Type],
306 Text: c.Text,
307 Thinking: c.Thinking,
308 Data: c.Data,
309 Signature: c.Signature,
310 ToolName: c.ToolName,
311 ToolInput: c.ToolInput,
312 ToolUseID: c.ToolUseID,
313 ToolError: c.ToolError,
314 ToolResult: c.ToolResult,
315 }
316}
317
318func toLLMResponse(r *response) *llm.Response {
319 return &llm.Response{
320 ID: r.ID,
321 Type: r.Type,
322 Role: toLLMRole[r.Role],
323 Model: r.Model,
324 Content: mapped(r.Content, toLLMContent),
325 StopReason: toLLMStopReason[r.StopReason],
326 StopSequence: r.StopSequence,
327 Usage: toLLMUsage(r.Usage),
328 }
329}
330
331// Do sends a request to Anthropic.
332func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
333 request := s.fromLLMRequest(ir)
334
335 var payload []byte
336 var err error
337 if dumpText || testing.Testing() {
338 payload, err = json.MarshalIndent(request, "", " ")
339 } else {
340 payload, err = json.Marshal(request)
341 payload = append(payload, '\n')
342 }
343 if err != nil {
344 return nil, err
345 }
346
347 if false {
348 fmt.Printf("claude request payload:\n%s\n", payload)
349 }
350
351 backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
352 largerMaxTokens := false
353 var partialUsage usage
354
355 url := cmp.Or(s.URL, DefaultURL)
356 httpc := cmp.Or(s.HTTPC, http.DefaultClient)
357
358 // retry loop
359 for attempts := 0; ; attempts++ {
360 if dumpText {
361 fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
362 }
363 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
364 if err != nil {
365 return nil, err
366 }
367
368 req.Header.Set("Content-Type", "application/json")
369 req.Header.Set("X-API-Key", s.APIKey)
370 req.Header.Set("Anthropic-Version", "2023-06-01")
371
372 var features []string
373 if request.TokenEfficientToolUse {
374 features = append(features, "token-efficient-tool-use-2025-02-19")
375 }
376 if largerMaxTokens {
377 features = append(features, "output-128k-2025-02-19")
378 request.MaxTokens = 128 * 1024
379 }
380 if len(features) > 0 {
381 req.Header.Set("anthropic-beta", strings.Join(features, ","))
382 }
383
384 resp, err := httpc.Do(req)
385 if err != nil {
386 return nil, err
387 }
388 buf, _ := io.ReadAll(resp.Body)
389 resp.Body.Close()
390
391 switch {
392 case resp.StatusCode == http.StatusOK:
393 if dumpText {
394 fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
395 }
396 var response response
397 err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
398 if err != nil {
399 return nil, err
400 }
401 if response.StopReason == "max_tokens" && !largerMaxTokens {
Josh Bleecher Snyder29fea842025-05-06 01:51:09 +0000402 slog.InfoContext(ctx, "anthropic_retrying_with_larger_tokens", "message", "Retrying Anthropic API call with larger max tokens size")
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700403 // Retry with more output tokens.
404 largerMaxTokens = true
405 response.Usage.CostUSD = response.TotalDollars()
406 partialUsage = response.Usage
407 continue
408 }
409
410 // Calculate and set the cost_usd field
411 if largerMaxTokens {
412 response.Usage.Add(partialUsage)
413 }
414 response.Usage.CostUSD = response.TotalDollars()
415
416 return toLLMResponse(&response), nil
417 case resp.StatusCode >= 500 && resp.StatusCode < 600:
418 // overloaded or unhappy, in one form or another
419 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
420 slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
421 time.Sleep(sleep)
422 case resp.StatusCode == 429:
423 // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
424 // and then add some additional time for backoff.
425 sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
426 slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
427 time.Sleep(sleep)
428 // case resp.StatusCode == 400:
429 // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
430 default:
431 return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
432 }
433 }
434}
435
436// cents per million tokens
437// (not dollars because i'm twitchy about using floats for money)
438type centsPer1MTokens struct {
439 Input uint64
440 Output uint64
441 CacheRead uint64
442 CacheCreation uint64
443}
444
445// https://www.anthropic.com/pricing#anthropic-api
446var modelCost = map[string]centsPer1MTokens{
447 Claude37Sonnet: {
448 Input: 300, // $3
449 Output: 1500, // $15
450 CacheRead: 30, // $0.30
451 CacheCreation: 375, // $3.75
452 },
453 Claude35Haiku: {
454 Input: 80, // $0.80
455 Output: 400, // $4.00
456 CacheRead: 8, // $0.08
457 CacheCreation: 100, // $1.00
458 },
459 Claude35Sonnet: {
460 Input: 300, // $3
461 Output: 1500, // $15
462 CacheRead: 30, // $0.30
463 CacheCreation: 375, // $3.75
464 },
465}
466
467// TotalDollars returns the total cost to obtain this response, in dollars.
468func (mr *response) TotalDollars() float64 {
469 cpm, ok := modelCost[mr.Model]
470 if !ok {
471 panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
472 }
473 use := mr.Usage
474 megaCents := use.InputTokens*cpm.Input +
475 use.OutputTokens*cpm.Output +
476 use.CacheReadInputTokens*cpm.CacheRead +
477 use.CacheCreationInputTokens*cpm.CacheCreation
478 cents := float64(megaCents) / 1_000_000.0
479 return cents / 100.0
480}