| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 1 | package ant |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "cmp" |
| 6 | "context" |
| 7 | "encoding/json" |
| 8 | "fmt" |
| 9 | "io" |
| 10 | "log/slog" |
| 11 | "math/rand/v2" |
| 12 | "net/http" |
| 13 | "strings" |
| 14 | "testing" |
| 15 | "time" |
| 16 | |
| 17 | "sketch.dev/llm" |
| 18 | ) |
| 19 | |
| 20 | const ( |
| 21 | DefaultModel = Claude37Sonnet |
| 22 | // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for |
| 23 | // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19) |
| 24 | DefaultMaxTokens = 8192 |
| 25 | DefaultURL = "https://api.anthropic.com/v1/messages" |
| 26 | ) |
| 27 | |
| 28 | const ( |
| 29 | Claude35Sonnet = "claude-3-5-sonnet-20241022" |
| 30 | Claude35Haiku = "claude-3-5-haiku-20241022" |
| 31 | Claude37Sonnet = "claude-3-7-sonnet-20250219" |
| 32 | ) |
| 33 | |
| 34 | // Service provides Claude completions. |
| 35 | // Fields should not be altered concurrently with calling any method on Service. |
| 36 | type Service struct { |
| 37 | HTTPC *http.Client // defaults to http.DefaultClient if nil |
| 38 | URL string // defaults to DefaultURL if empty |
| 39 | APIKey string // must be non-empty |
| 40 | Model string // defaults to DefaultModel if empty |
| 41 | MaxTokens int // defaults to DefaultMaxTokens if zero |
| 42 | } |
| 43 | |
| 44 | var _ llm.Service = (*Service)(nil) |
| 45 | |
| 46 | type content struct { |
| 47 | // TODO: image support? |
| 48 | // https://docs.anthropic.com/en/api/messages |
| 49 | ID string `json:"id,omitempty"` |
| 50 | Type string `json:"type,omitempty"` |
| 51 | Text string `json:"text,omitempty"` |
| 52 | |
| 53 | // for thinking |
| 54 | Thinking string `json:"thinking,omitempty"` |
| 55 | Data string `json:"data,omitempty"` // for redacted_thinking |
| 56 | Signature string `json:"signature,omitempty"` // for thinking |
| 57 | |
| 58 | // for tool_use |
| 59 | ToolName string `json:"name,omitempty"` |
| 60 | ToolInput json.RawMessage `json:"input,omitempty"` |
| 61 | |
| 62 | // for tool_result |
| 63 | ToolUseID string `json:"tool_use_id,omitempty"` |
| 64 | ToolError bool `json:"is_error,omitempty"` |
| 65 | ToolResult string `json:"content,omitempty"` |
| 66 | |
| 67 | // timing information for tool_result; not sent to Claude |
| 68 | StartTime *time.Time `json:"-"` |
| 69 | EndTime *time.Time `json:"-"` |
| 70 | |
| 71 | CacheControl json.RawMessage `json:"cache_control,omitempty"` |
| 72 | } |
| 73 | |
| 74 | // message represents a message in the conversation. |
| 75 | type message struct { |
| 76 | Role string `json:"role"` |
| 77 | Content []content `json:"content"` |
| 78 | ToolUse *toolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use |
| 79 | } |
| 80 | |
| 81 | // toolUse represents a tool use in the message content. |
| 82 | type toolUse struct { |
| 83 | ID string `json:"id"` |
| 84 | Name string `json:"name"` |
| 85 | } |
| 86 | |
| 87 | // tool represents a tool available to Claude. |
| 88 | type tool struct { |
| 89 | Name string `json:"name"` |
| 90 | // Type is used by the text editor tool; see |
| 91 | // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool |
| 92 | Type string `json:"type,omitempty"` |
| 93 | Description string `json:"description,omitempty"` |
| 94 | InputSchema json.RawMessage `json:"input_schema,omitempty"` |
| 95 | } |
| 96 | |
| 97 | // usage represents the billing and rate-limit usage. |
| 98 | type usage struct { |
| 99 | InputTokens uint64 `json:"input_tokens"` |
| 100 | CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"` |
| 101 | CacheReadInputTokens uint64 `json:"cache_read_input_tokens"` |
| 102 | OutputTokens uint64 `json:"output_tokens"` |
| 103 | CostUSD float64 `json:"cost_usd"` |
| 104 | } |
| 105 | |
| 106 | func (u *usage) Add(other usage) { |
| 107 | u.InputTokens += other.InputTokens |
| 108 | u.CacheCreationInputTokens += other.CacheCreationInputTokens |
| 109 | u.CacheReadInputTokens += other.CacheReadInputTokens |
| 110 | u.OutputTokens += other.OutputTokens |
| 111 | u.CostUSD += other.CostUSD |
| 112 | } |
| 113 | |
| 114 | type errorResponse struct { |
| 115 | Type string `json:"type"` |
| 116 | Message string `json:"message"` |
| 117 | } |
| 118 | |
| 119 | // response represents the response from the message API. |
| 120 | type response struct { |
| 121 | ID string `json:"id"` |
| 122 | Type string `json:"type"` |
| 123 | Role string `json:"role"` |
| 124 | Model string `json:"model"` |
| 125 | Content []content `json:"content"` |
| 126 | StopReason string `json:"stop_reason"` |
| 127 | StopSequence *string `json:"stop_sequence,omitempty"` |
| 128 | Usage usage `json:"usage"` |
| 129 | } |
| 130 | |
| 131 | type toolChoice struct { |
| 132 | Type string `json:"type"` |
| 133 | Name string `json:"name,omitempty"` |
| 134 | } |
| 135 | |
| 136 | // https://docs.anthropic.com/en/api/messages#body-system |
| 137 | type systemContent struct { |
| 138 | Text string `json:"text,omitempty"` |
| 139 | Type string `json:"type,omitempty"` |
| 140 | CacheControl json.RawMessage `json:"cache_control,omitempty"` |
| 141 | } |
| 142 | |
| 143 | // request represents the request payload for creating a message. |
| 144 | type request struct { |
| 145 | Model string `json:"model"` |
| 146 | Messages []message `json:"messages"` |
| 147 | ToolChoice *toolChoice `json:"tool_choice,omitempty"` |
| 148 | MaxTokens int `json:"max_tokens"` |
| 149 | Tools []*tool `json:"tools,omitempty"` |
| 150 | Stream bool `json:"stream,omitempty"` |
| 151 | System []systemContent `json:"system,omitempty"` |
| 152 | Temperature float64 `json:"temperature,omitempty"` |
| 153 | TopK int `json:"top_k,omitempty"` |
| 154 | TopP float64 `json:"top_p,omitempty"` |
| 155 | StopSequences []string `json:"stop_sequences,omitempty"` |
| 156 | |
| 157 | TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28 |
| 158 | } |
| 159 | |
| 160 | const dumpText = false // debugging toggle to see raw communications with Claude |
| 161 | |
| 162 | func mapped[Slice ~[]E, E, T any](s Slice, f func(E) T) []T { |
| 163 | out := make([]T, len(s)) |
| 164 | for i, v := range s { |
| 165 | out[i] = f(v) |
| 166 | } |
| 167 | return out |
| 168 | } |
| 169 | |
| 170 | func inverted[K, V cmp.Ordered](m map[K]V) map[V]K { |
| 171 | inv := make(map[V]K) |
| 172 | for k, v := range m { |
| 173 | if _, ok := inv[v]; ok { |
| 174 | panic(fmt.Errorf("inverted map has multiple keys for value %v", v)) |
| 175 | } |
| 176 | inv[v] = k |
| 177 | } |
| 178 | return inv |
| 179 | } |
| 180 | |
| 181 | var ( |
| 182 | fromLLMRole = map[llm.MessageRole]string{ |
| 183 | llm.MessageRoleAssistant: "assistant", |
| 184 | llm.MessageRoleUser: "user", |
| 185 | } |
| 186 | toLLMRole = inverted(fromLLMRole) |
| 187 | |
| 188 | fromLLMContentType = map[llm.ContentType]string{ |
| 189 | llm.ContentTypeText: "text", |
| 190 | llm.ContentTypeThinking: "thinking", |
| 191 | llm.ContentTypeRedactedThinking: "redacted_thinking", |
| 192 | llm.ContentTypeToolUse: "tool_use", |
| 193 | llm.ContentTypeToolResult: "tool_result", |
| 194 | } |
| 195 | toLLMContentType = inverted(fromLLMContentType) |
| 196 | |
| 197 | fromLLMToolChoiceType = map[llm.ToolChoiceType]string{ |
| 198 | llm.ToolChoiceTypeAuto: "auto", |
| 199 | llm.ToolChoiceTypeAny: "any", |
| 200 | llm.ToolChoiceTypeNone: "none", |
| 201 | llm.ToolChoiceTypeTool: "tool", |
| 202 | } |
| 203 | |
| 204 | toLLMStopReason = map[string]llm.StopReason{ |
| 205 | "stop_sequence": llm.StopReasonStopSequence, |
| 206 | "max_tokens": llm.StopReasonMaxTokens, |
| 207 | "end_turn": llm.StopReasonEndTurn, |
| 208 | "tool_use": llm.StopReasonToolUse, |
| 209 | } |
| 210 | ) |
| 211 | |
| 212 | func fromLLMCache(c bool) json.RawMessage { |
| 213 | if !c { |
| 214 | return nil |
| 215 | } |
| 216 | return json.RawMessage(`{"type":"ephemeral"}`) |
| 217 | } |
| 218 | |
| 219 | func fromLLMContent(c llm.Content) content { |
| 220 | return content{ |
| 221 | ID: c.ID, |
| 222 | Type: fromLLMContentType[c.Type], |
| 223 | Text: c.Text, |
| 224 | Thinking: c.Thinking, |
| 225 | Data: c.Data, |
| 226 | Signature: c.Signature, |
| 227 | ToolName: c.ToolName, |
| 228 | ToolInput: c.ToolInput, |
| 229 | ToolUseID: c.ToolUseID, |
| 230 | ToolError: c.ToolError, |
| 231 | ToolResult: c.ToolResult, |
| 232 | CacheControl: fromLLMCache(c.Cache), |
| 233 | } |
| 234 | } |
| 235 | |
| 236 | func fromLLMToolUse(tu *llm.ToolUse) *toolUse { |
| 237 | if tu == nil { |
| 238 | return nil |
| 239 | } |
| 240 | return &toolUse{ |
| 241 | ID: tu.ID, |
| 242 | Name: tu.Name, |
| 243 | } |
| 244 | } |
| 245 | |
| 246 | func fromLLMMessage(msg llm.Message) message { |
| 247 | return message{ |
| 248 | Role: fromLLMRole[msg.Role], |
| 249 | Content: mapped(msg.Content, fromLLMContent), |
| 250 | ToolUse: fromLLMToolUse(msg.ToolUse), |
| 251 | } |
| 252 | } |
| 253 | |
| 254 | func fromLLMToolChoice(tc *llm.ToolChoice) *toolChoice { |
| 255 | if tc == nil { |
| 256 | return nil |
| 257 | } |
| 258 | return &toolChoice{ |
| 259 | Type: fromLLMToolChoiceType[tc.Type], |
| 260 | Name: tc.Name, |
| 261 | } |
| 262 | } |
| 263 | |
| 264 | func fromLLMTool(t *llm.Tool) *tool { |
| 265 | return &tool{ |
| 266 | Name: t.Name, |
| 267 | Type: t.Type, |
| 268 | Description: t.Description, |
| 269 | InputSchema: t.InputSchema, |
| 270 | } |
| 271 | } |
| 272 | |
| 273 | func fromLLMSystem(s llm.SystemContent) systemContent { |
| 274 | return systemContent{ |
| 275 | Text: s.Text, |
| 276 | Type: s.Type, |
| 277 | CacheControl: fromLLMCache(s.Cache), |
| 278 | } |
| 279 | } |
| 280 | |
| 281 | func (s *Service) fromLLMRequest(r *llm.Request) *request { |
| 282 | return &request{ |
| 283 | Model: cmp.Or(s.Model, DefaultModel), |
| 284 | Messages: mapped(r.Messages, fromLLMMessage), |
| 285 | MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens), |
| 286 | ToolChoice: fromLLMToolChoice(r.ToolChoice), |
| 287 | Tools: mapped(r.Tools, fromLLMTool), |
| 288 | System: mapped(r.System, fromLLMSystem), |
| 289 | } |
| 290 | } |
| 291 | |
| 292 | func toLLMUsage(u usage) llm.Usage { |
| 293 | return llm.Usage{ |
| 294 | InputTokens: u.InputTokens, |
| 295 | CacheCreationInputTokens: u.CacheCreationInputTokens, |
| 296 | CacheReadInputTokens: u.CacheReadInputTokens, |
| 297 | OutputTokens: u.OutputTokens, |
| 298 | CostUSD: u.CostUSD, |
| 299 | } |
| 300 | } |
| 301 | |
| 302 | func toLLMContent(c content) llm.Content { |
| 303 | return llm.Content{ |
| 304 | ID: c.ID, |
| 305 | Type: toLLMContentType[c.Type], |
| 306 | Text: c.Text, |
| 307 | Thinking: c.Thinking, |
| 308 | Data: c.Data, |
| 309 | Signature: c.Signature, |
| 310 | ToolName: c.ToolName, |
| 311 | ToolInput: c.ToolInput, |
| 312 | ToolUseID: c.ToolUseID, |
| 313 | ToolError: c.ToolError, |
| 314 | ToolResult: c.ToolResult, |
| 315 | } |
| 316 | } |
| 317 | |
| 318 | func toLLMResponse(r *response) *llm.Response { |
| 319 | return &llm.Response{ |
| 320 | ID: r.ID, |
| 321 | Type: r.Type, |
| 322 | Role: toLLMRole[r.Role], |
| 323 | Model: r.Model, |
| 324 | Content: mapped(r.Content, toLLMContent), |
| 325 | StopReason: toLLMStopReason[r.StopReason], |
| 326 | StopSequence: r.StopSequence, |
| 327 | Usage: toLLMUsage(r.Usage), |
| 328 | } |
| 329 | } |
| 330 | |
| 331 | // Do sends a request to Anthropic. |
| 332 | func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) { |
| 333 | request := s.fromLLMRequest(ir) |
| 334 | |
| 335 | var payload []byte |
| 336 | var err error |
| 337 | if dumpText || testing.Testing() { |
| 338 | payload, err = json.MarshalIndent(request, "", " ") |
| 339 | } else { |
| 340 | payload, err = json.Marshal(request) |
| 341 | payload = append(payload, '\n') |
| 342 | } |
| 343 | if err != nil { |
| 344 | return nil, err |
| 345 | } |
| 346 | |
| 347 | if false { |
| 348 | fmt.Printf("claude request payload:\n%s\n", payload) |
| 349 | } |
| 350 | |
| 351 | backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute} |
| 352 | largerMaxTokens := false |
| 353 | var partialUsage usage |
| 354 | |
| 355 | url := cmp.Or(s.URL, DefaultURL) |
| 356 | httpc := cmp.Or(s.HTTPC, http.DefaultClient) |
| 357 | |
| 358 | // retry loop |
| 359 | for attempts := 0; ; attempts++ { |
| 360 | if dumpText { |
| 361 | fmt.Printf("RAW REQUEST:\n%s\n\n", payload) |
| 362 | } |
| 363 | req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload)) |
| 364 | if err != nil { |
| 365 | return nil, err |
| 366 | } |
| 367 | |
| 368 | req.Header.Set("Content-Type", "application/json") |
| 369 | req.Header.Set("X-API-Key", s.APIKey) |
| 370 | req.Header.Set("Anthropic-Version", "2023-06-01") |
| 371 | |
| 372 | var features []string |
| 373 | if request.TokenEfficientToolUse { |
| 374 | features = append(features, "token-efficient-tool-use-2025-02-19") |
| 375 | } |
| 376 | if largerMaxTokens { |
| 377 | features = append(features, "output-128k-2025-02-19") |
| 378 | request.MaxTokens = 128 * 1024 |
| 379 | } |
| 380 | if len(features) > 0 { |
| 381 | req.Header.Set("anthropic-beta", strings.Join(features, ",")) |
| 382 | } |
| 383 | |
| 384 | resp, err := httpc.Do(req) |
| 385 | if err != nil { |
| 386 | return nil, err |
| 387 | } |
| 388 | buf, _ := io.ReadAll(resp.Body) |
| 389 | resp.Body.Close() |
| 390 | |
| 391 | switch { |
| 392 | case resp.StatusCode == http.StatusOK: |
| 393 | if dumpText { |
| 394 | fmt.Printf("RAW RESPONSE:\n%s\n\n", buf) |
| 395 | } |
| 396 | var response response |
| 397 | err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response) |
| 398 | if err != nil { |
| 399 | return nil, err |
| 400 | } |
| 401 | if response.StopReason == "max_tokens" && !largerMaxTokens { |
| Josh Bleecher Snyder | 29fea84 | 2025-05-06 01:51:09 +0000 | [diff] [blame] | 402 | slog.InfoContext(ctx, "anthropic_retrying_with_larger_tokens", "message", "Retrying Anthropic API call with larger max tokens size") |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 403 | // Retry with more output tokens. |
| 404 | largerMaxTokens = true |
| 405 | response.Usage.CostUSD = response.TotalDollars() |
| 406 | partialUsage = response.Usage |
| 407 | continue |
| 408 | } |
| 409 | |
| 410 | // Calculate and set the cost_usd field |
| 411 | if largerMaxTokens { |
| 412 | response.Usage.Add(partialUsage) |
| 413 | } |
| 414 | response.Usage.CostUSD = response.TotalDollars() |
| 415 | |
| 416 | return toLLMResponse(&response), nil |
| 417 | case resp.StatusCode >= 500 && resp.StatusCode < 600: |
| 418 | // overloaded or unhappy, in one form or another |
| 419 | sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second))) |
| 420 | slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep) |
| 421 | time.Sleep(sleep) |
| 422 | case resp.StatusCode == 429: |
| 423 | // rate limited. wait 1 minute as a starting point, because that's the rate limiting window. |
| 424 | // and then add some additional time for backoff. |
| 425 | sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second))) |
| 426 | slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep) |
| 427 | time.Sleep(sleep) |
| 428 | // case resp.StatusCode == 400: |
| 429 | // TODO: parse ErrorResponse, make (*ErrorResponse) implement error |
| 430 | default: |
| 431 | return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf) |
| 432 | } |
| 433 | } |
| 434 | } |
| 435 | |
| 436 | // cents per million tokens |
| 437 | // (not dollars because i'm twitchy about using floats for money) |
| 438 | type centsPer1MTokens struct { |
| 439 | Input uint64 |
| 440 | Output uint64 |
| 441 | CacheRead uint64 |
| 442 | CacheCreation uint64 |
| 443 | } |
| 444 | |
| 445 | // https://www.anthropic.com/pricing#anthropic-api |
| 446 | var modelCost = map[string]centsPer1MTokens{ |
| 447 | Claude37Sonnet: { |
| 448 | Input: 300, // $3 |
| 449 | Output: 1500, // $15 |
| 450 | CacheRead: 30, // $0.30 |
| 451 | CacheCreation: 375, // $3.75 |
| 452 | }, |
| 453 | Claude35Haiku: { |
| 454 | Input: 80, // $0.80 |
| 455 | Output: 400, // $4.00 |
| 456 | CacheRead: 8, // $0.08 |
| 457 | CacheCreation: 100, // $1.00 |
| 458 | }, |
| 459 | Claude35Sonnet: { |
| 460 | Input: 300, // $3 |
| 461 | Output: 1500, // $15 |
| 462 | CacheRead: 30, // $0.30 |
| 463 | CacheCreation: 375, // $3.75 |
| 464 | }, |
| 465 | } |
| 466 | |
| 467 | // TotalDollars returns the total cost to obtain this response, in dollars. |
| 468 | func (mr *response) TotalDollars() float64 { |
| 469 | cpm, ok := modelCost[mr.Model] |
| 470 | if !ok { |
| 471 | panic(fmt.Sprintf("no pricing info for model: %s", mr.Model)) |
| 472 | } |
| 473 | use := mr.Usage |
| 474 | megaCents := use.InputTokens*cpm.Input + |
| 475 | use.OutputTokens*cpm.Output + |
| 476 | use.CacheReadInputTokens*cpm.CacheRead + |
| 477 | use.CacheCreationInputTokens*cpm.CacheCreation |
| 478 | cents := float64(megaCents) / 1_000_000.0 |
| 479 | return cents / 100.0 |
| 480 | } |