blob: efebf26766ffa93f77a49e932bc1198f3cbe9cf6 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package ant
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log/slog"
11 "math/rand/v2"
12 "net/http"
13 "strings"
14 "testing"
15 "time"
16
17 "sketch.dev/llm"
18)
19
20const (
21 DefaultModel = Claude37Sonnet
22 // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
23 // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
24 DefaultMaxTokens = 8192
25 DefaultURL = "https://api.anthropic.com/v1/messages"
26)
27
28const (
29 Claude35Sonnet = "claude-3-5-sonnet-20241022"
30 Claude35Haiku = "claude-3-5-haiku-20241022"
31 Claude37Sonnet = "claude-3-7-sonnet-20250219"
32)
33
34// Service provides Claude completions.
35// Fields should not be altered concurrently with calling any method on Service.
36type Service struct {
37 HTTPC *http.Client // defaults to http.DefaultClient if nil
38 URL string // defaults to DefaultURL if empty
39 APIKey string // must be non-empty
40 Model string // defaults to DefaultModel if empty
41 MaxTokens int // defaults to DefaultMaxTokens if zero
42}
43
Philip Zeyliger022b3632025-05-10 06:14:21 -070044func (s *Service) ModelName() string {
45 return cmp.Or(s.Model, DefaultModel)
46}
47
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070048var _ llm.Service = (*Service)(nil)
49
50type content struct {
51 // TODO: image support?
52 // https://docs.anthropic.com/en/api/messages
53 ID string `json:"id,omitempty"`
54 Type string `json:"type,omitempty"`
55 Text string `json:"text,omitempty"`
56
57 // for thinking
58 Thinking string `json:"thinking,omitempty"`
59 Data string `json:"data,omitempty"` // for redacted_thinking
60 Signature string `json:"signature,omitempty"` // for thinking
61
62 // for tool_use
63 ToolName string `json:"name,omitempty"`
64 ToolInput json.RawMessage `json:"input,omitempty"`
65
66 // for tool_result
67 ToolUseID string `json:"tool_use_id,omitempty"`
68 ToolError bool `json:"is_error,omitempty"`
69 ToolResult string `json:"content,omitempty"`
70
71 // timing information for tool_result; not sent to Claude
72 StartTime *time.Time `json:"-"`
73 EndTime *time.Time `json:"-"`
74
75 CacheControl json.RawMessage `json:"cache_control,omitempty"`
76}
77
78// message represents a message in the conversation.
79type message struct {
80 Role string `json:"role"`
81 Content []content `json:"content"`
82 ToolUse *toolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
83}
84
85// toolUse represents a tool use in the message content.
86type toolUse struct {
87 ID string `json:"id"`
88 Name string `json:"name"`
89}
90
91// tool represents a tool available to Claude.
92type tool struct {
93 Name string `json:"name"`
94 // Type is used by the text editor tool; see
95 // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
96 Type string `json:"type,omitempty"`
97 Description string `json:"description,omitempty"`
98 InputSchema json.RawMessage `json:"input_schema,omitempty"`
99}
100
101// usage represents the billing and rate-limit usage.
102type usage struct {
103 InputTokens uint64 `json:"input_tokens"`
104 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
105 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
106 OutputTokens uint64 `json:"output_tokens"`
107 CostUSD float64 `json:"cost_usd"`
108}
109
110func (u *usage) Add(other usage) {
111 u.InputTokens += other.InputTokens
112 u.CacheCreationInputTokens += other.CacheCreationInputTokens
113 u.CacheReadInputTokens += other.CacheReadInputTokens
114 u.OutputTokens += other.OutputTokens
115 u.CostUSD += other.CostUSD
116}
117
118type errorResponse struct {
119 Type string `json:"type"`
120 Message string `json:"message"`
121}
122
123// response represents the response from the message API.
124type response struct {
125 ID string `json:"id"`
126 Type string `json:"type"`
127 Role string `json:"role"`
128 Model string `json:"model"`
129 Content []content `json:"content"`
130 StopReason string `json:"stop_reason"`
131 StopSequence *string `json:"stop_sequence,omitempty"`
132 Usage usage `json:"usage"`
133}
134
135type toolChoice struct {
136 Type string `json:"type"`
137 Name string `json:"name,omitempty"`
138}
139
140// https://docs.anthropic.com/en/api/messages#body-system
141type systemContent struct {
142 Text string `json:"text,omitempty"`
143 Type string `json:"type,omitempty"`
144 CacheControl json.RawMessage `json:"cache_control,omitempty"`
145}
146
147// request represents the request payload for creating a message.
148type request struct {
149 Model string `json:"model"`
150 Messages []message `json:"messages"`
151 ToolChoice *toolChoice `json:"tool_choice,omitempty"`
152 MaxTokens int `json:"max_tokens"`
153 Tools []*tool `json:"tools,omitempty"`
154 Stream bool `json:"stream,omitempty"`
155 System []systemContent `json:"system,omitempty"`
156 Temperature float64 `json:"temperature,omitempty"`
157 TopK int `json:"top_k,omitempty"`
158 TopP float64 `json:"top_p,omitempty"`
159 StopSequences []string `json:"stop_sequences,omitempty"`
160
161 TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
162}
163
164const dumpText = false // debugging toggle to see raw communications with Claude
165
166func mapped[Slice ~[]E, E, T any](s Slice, f func(E) T) []T {
167 out := make([]T, len(s))
168 for i, v := range s {
169 out[i] = f(v)
170 }
171 return out
172}
173
174func inverted[K, V cmp.Ordered](m map[K]V) map[V]K {
175 inv := make(map[V]K)
176 for k, v := range m {
177 if _, ok := inv[v]; ok {
178 panic(fmt.Errorf("inverted map has multiple keys for value %v", v))
179 }
180 inv[v] = k
181 }
182 return inv
183}
184
185var (
186 fromLLMRole = map[llm.MessageRole]string{
187 llm.MessageRoleAssistant: "assistant",
188 llm.MessageRoleUser: "user",
189 }
190 toLLMRole = inverted(fromLLMRole)
191
192 fromLLMContentType = map[llm.ContentType]string{
193 llm.ContentTypeText: "text",
194 llm.ContentTypeThinking: "thinking",
195 llm.ContentTypeRedactedThinking: "redacted_thinking",
196 llm.ContentTypeToolUse: "tool_use",
197 llm.ContentTypeToolResult: "tool_result",
198 }
199 toLLMContentType = inverted(fromLLMContentType)
200
201 fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
202 llm.ToolChoiceTypeAuto: "auto",
203 llm.ToolChoiceTypeAny: "any",
204 llm.ToolChoiceTypeNone: "none",
205 llm.ToolChoiceTypeTool: "tool",
206 }
207
208 toLLMStopReason = map[string]llm.StopReason{
209 "stop_sequence": llm.StopReasonStopSequence,
210 "max_tokens": llm.StopReasonMaxTokens,
211 "end_turn": llm.StopReasonEndTurn,
212 "tool_use": llm.StopReasonToolUse,
213 }
214)
215
216func fromLLMCache(c bool) json.RawMessage {
217 if !c {
218 return nil
219 }
220 return json.RawMessage(`{"type":"ephemeral"}`)
221}
222
223func fromLLMContent(c llm.Content) content {
224 return content{
225 ID: c.ID,
226 Type: fromLLMContentType[c.Type],
227 Text: c.Text,
228 Thinking: c.Thinking,
229 Data: c.Data,
230 Signature: c.Signature,
231 ToolName: c.ToolName,
232 ToolInput: c.ToolInput,
233 ToolUseID: c.ToolUseID,
234 ToolError: c.ToolError,
235 ToolResult: c.ToolResult,
236 CacheControl: fromLLMCache(c.Cache),
237 }
238}
239
240func fromLLMToolUse(tu *llm.ToolUse) *toolUse {
241 if tu == nil {
242 return nil
243 }
244 return &toolUse{
245 ID: tu.ID,
246 Name: tu.Name,
247 }
248}
249
250func fromLLMMessage(msg llm.Message) message {
251 return message{
252 Role: fromLLMRole[msg.Role],
253 Content: mapped(msg.Content, fromLLMContent),
254 ToolUse: fromLLMToolUse(msg.ToolUse),
255 }
256}
257
258func fromLLMToolChoice(tc *llm.ToolChoice) *toolChoice {
259 if tc == nil {
260 return nil
261 }
262 return &toolChoice{
263 Type: fromLLMToolChoiceType[tc.Type],
264 Name: tc.Name,
265 }
266}
267
268func fromLLMTool(t *llm.Tool) *tool {
269 return &tool{
270 Name: t.Name,
271 Type: t.Type,
272 Description: t.Description,
273 InputSchema: t.InputSchema,
274 }
275}
276
277func fromLLMSystem(s llm.SystemContent) systemContent {
278 return systemContent{
279 Text: s.Text,
280 Type: s.Type,
281 CacheControl: fromLLMCache(s.Cache),
282 }
283}
284
285func (s *Service) fromLLMRequest(r *llm.Request) *request {
286 return &request{
287 Model: cmp.Or(s.Model, DefaultModel),
288 Messages: mapped(r.Messages, fromLLMMessage),
289 MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
290 ToolChoice: fromLLMToolChoice(r.ToolChoice),
291 Tools: mapped(r.Tools, fromLLMTool),
292 System: mapped(r.System, fromLLMSystem),
293 }
294}
295
296func toLLMUsage(u usage) llm.Usage {
297 return llm.Usage{
298 InputTokens: u.InputTokens,
299 CacheCreationInputTokens: u.CacheCreationInputTokens,
300 CacheReadInputTokens: u.CacheReadInputTokens,
301 OutputTokens: u.OutputTokens,
302 CostUSD: u.CostUSD,
303 }
304}
305
306func toLLMContent(c content) llm.Content {
307 return llm.Content{
308 ID: c.ID,
309 Type: toLLMContentType[c.Type],
310 Text: c.Text,
311 Thinking: c.Thinking,
312 Data: c.Data,
313 Signature: c.Signature,
314 ToolName: c.ToolName,
315 ToolInput: c.ToolInput,
316 ToolUseID: c.ToolUseID,
317 ToolError: c.ToolError,
318 ToolResult: c.ToolResult,
319 }
320}
321
322func toLLMResponse(r *response) *llm.Response {
323 return &llm.Response{
324 ID: r.ID,
325 Type: r.Type,
326 Role: toLLMRole[r.Role],
327 Model: r.Model,
328 Content: mapped(r.Content, toLLMContent),
329 StopReason: toLLMStopReason[r.StopReason],
330 StopSequence: r.StopSequence,
331 Usage: toLLMUsage(r.Usage),
332 }
333}
334
335// Do sends a request to Anthropic.
336func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
337 request := s.fromLLMRequest(ir)
338
339 var payload []byte
340 var err error
341 if dumpText || testing.Testing() {
342 payload, err = json.MarshalIndent(request, "", " ")
343 } else {
344 payload, err = json.Marshal(request)
345 payload = append(payload, '\n')
346 }
347 if err != nil {
348 return nil, err
349 }
350
351 if false {
352 fmt.Printf("claude request payload:\n%s\n", payload)
353 }
354
355 backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
356 largerMaxTokens := false
357 var partialUsage usage
358
359 url := cmp.Or(s.URL, DefaultURL)
360 httpc := cmp.Or(s.HTTPC, http.DefaultClient)
361
362 // retry loop
363 for attempts := 0; ; attempts++ {
364 if dumpText {
365 fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
366 }
367 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
368 if err != nil {
369 return nil, err
370 }
371
372 req.Header.Set("Content-Type", "application/json")
373 req.Header.Set("X-API-Key", s.APIKey)
374 req.Header.Set("Anthropic-Version", "2023-06-01")
375
376 var features []string
377 if request.TokenEfficientToolUse {
378 features = append(features, "token-efficient-tool-use-2025-02-19")
379 }
380 if largerMaxTokens {
381 features = append(features, "output-128k-2025-02-19")
382 request.MaxTokens = 128 * 1024
383 }
384 if len(features) > 0 {
385 req.Header.Set("anthropic-beta", strings.Join(features, ","))
386 }
387
388 resp, err := httpc.Do(req)
389 if err != nil {
390 return nil, err
391 }
392 buf, _ := io.ReadAll(resp.Body)
393 resp.Body.Close()
394
395 switch {
396 case resp.StatusCode == http.StatusOK:
397 if dumpText {
398 fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
399 }
400 var response response
401 err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
402 if err != nil {
403 return nil, err
404 }
405 if response.StopReason == "max_tokens" && !largerMaxTokens {
Josh Bleecher Snyder29fea842025-05-06 01:51:09 +0000406 slog.InfoContext(ctx, "anthropic_retrying_with_larger_tokens", "message", "Retrying Anthropic API call with larger max tokens size")
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700407 // Retry with more output tokens.
408 largerMaxTokens = true
409 response.Usage.CostUSD = response.TotalDollars()
410 partialUsage = response.Usage
411 continue
412 }
413
414 // Calculate and set the cost_usd field
415 if largerMaxTokens {
416 response.Usage.Add(partialUsage)
417 }
418 response.Usage.CostUSD = response.TotalDollars()
419
420 return toLLMResponse(&response), nil
421 case resp.StatusCode >= 500 && resp.StatusCode < 600:
422 // overloaded or unhappy, in one form or another
423 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
424 slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
425 time.Sleep(sleep)
426 case resp.StatusCode == 429:
427 // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
428 // and then add some additional time for backoff.
429 sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
430 slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
431 time.Sleep(sleep)
432 // case resp.StatusCode == 400:
433 // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
434 default:
435 return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
436 }
437 }
438}
439
440// cents per million tokens
441// (not dollars because i'm twitchy about using floats for money)
442type centsPer1MTokens struct {
443 Input uint64
444 Output uint64
445 CacheRead uint64
446 CacheCreation uint64
447}
448
449// https://www.anthropic.com/pricing#anthropic-api
450var modelCost = map[string]centsPer1MTokens{
451 Claude37Sonnet: {
452 Input: 300, // $3
453 Output: 1500, // $15
454 CacheRead: 30, // $0.30
455 CacheCreation: 375, // $3.75
456 },
457 Claude35Haiku: {
458 Input: 80, // $0.80
459 Output: 400, // $4.00
460 CacheRead: 8, // $0.08
461 CacheCreation: 100, // $1.00
462 },
463 Claude35Sonnet: {
464 Input: 300, // $3
465 Output: 1500, // $15
466 CacheRead: 30, // $0.30
467 CacheCreation: 375, // $3.75
468 },
469}
470
471// TotalDollars returns the total cost to obtain this response, in dollars.
472func (mr *response) TotalDollars() float64 {
473 cpm, ok := modelCost[mr.Model]
474 if !ok {
475 panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
476 }
477 use := mr.Usage
478 megaCents := use.InputTokens*cpm.Input +
479 use.OutputTokens*cpm.Output +
480 use.CacheReadInputTokens*cpm.CacheRead +
481 use.CacheCreationInputTokens*cpm.CacheCreation
482 cents := float64(megaCents) / 1_000_000.0
483 return cents / 100.0
484}