blob: 463d5fb3cdc071f41ee1c5d7d33c16fb6336535e [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package ant
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "math/rand/v2"
13 "net/http"
14 "slices"
15 "strings"
16 "sync"
17 "testing"
18 "time"
19
20 "github.com/richardlehane/crock32"
21 "sketch.dev/skribe"
22)
23
24const (
25 DefaultModel = Claude37Sonnet
26 // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
27 // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
28 DefaultMaxTokens = 8192
29 DefaultURL = "https://api.anthropic.com/v1/messages"
30)
31
32const (
33 Claude35Sonnet = "claude-3-5-sonnet-20241022"
34 Claude35Haiku = "claude-3-5-haiku-20241022"
35 Claude37Sonnet = "claude-3-7-sonnet-20250219"
36)
37
38const (
39 MessageRoleUser = "user"
40 MessageRoleAssistant = "assistant"
41
42 ContentTypeText = "text"
43 ContentTypeThinking = "thinking"
44 ContentTypeRedactedThinking = "redacted_thinking"
45 ContentTypeToolUse = "tool_use"
46 ContentTypeToolResult = "tool_result"
47
48 StopReasonStopSequence = "stop_sequence"
49 StopReasonMaxTokens = "max_tokens"
50 StopReasonEndTurn = "end_turn"
51 StopReasonToolUse = "tool_use"
52)
53
54type Listener interface {
55 // TODO: Content is leaking an anthropic API; should we avoid it?
56 // TODO: Where should we include start/end time and usage?
57 OnToolResult(ctx context.Context, convo *Convo, toolName string, toolInput json.RawMessage, content Content, result *string, err error)
58 OnResponse(ctx context.Context, convo *Convo, msg *MessageResponse)
59 OnRequest(ctx context.Context, convo *Convo, msg *Message)
60}
61
62type NoopListener struct{}
63
64func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, toolName string, toolInput json.RawMessage, content Content, result *string, err error) {
65}
66func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, msg *MessageResponse) {}
67func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, msg *Message) {}
68
69type Content struct {
70 // TODO: image support?
71 // https://docs.anthropic.com/en/api/messages
72 ID string `json:"id,omitempty"`
73 Type string `json:"type,omitempty"`
74 Text string `json:"text,omitempty"`
75
76 // for thinking
77 Thinking string `json:"thinking,omitempty"`
78 Data string `json:"data,omitempty"` // for redacted_thinking
79 Signature string `json:"signature,omitempty"` // for thinking
80
81 // for tool_use
82 ToolName string `json:"name,omitempty"`
83 ToolInput json.RawMessage `json:"input,omitempty"`
84
85 // for tool_result
86 ToolUseID string `json:"tool_use_id,omitempty"`
87 ToolError bool `json:"is_error,omitempty"`
88 ToolResult string `json:"content,omitempty"`
89
90 // timing information for tool_result; not sent to Claude
91 StartTime *time.Time `json:"-"`
92 EndTime *time.Time `json:"-"`
93
94 CacheControl json.RawMessage `json:"cache_control,omitempty"`
95}
96
97func StringContent(s string) Content {
98 return Content{Type: ContentTypeText, Text: s}
99}
100
101// Message represents a message in the conversation.
102type Message struct {
103 Role string `json:"role"`
104 Content []Content `json:"content"`
105 ToolUse *ToolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
106}
107
108// ToolUse represents a tool use in the message content.
109type ToolUse struct {
110 ID string `json:"id"`
111 Name string `json:"name"`
112}
113
114// Tool represents a tool available to Claude.
115type Tool struct {
116 Name string `json:"name"`
117 // Type is used by the text editor tool; see
118 // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
119 Type string `json:"type,omitempty"`
120 Description string `json:"description,omitempty"`
121 InputSchema json.RawMessage `json:"input_schema,omitempty"`
122
123 // The Run function is automatically called when the tool is used.
124 // Run functions may be called concurrently with each other and themselves.
125 // The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
126 // The outputs from Run will be sent back to Claude.
127 // If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
128 // ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
129 Run func(ctx context.Context, input json.RawMessage) (string, error) `json:"-"`
130}
131
132var ErrDoNotRespond = errors.New("do not respond")
133
134// Usage represents the billing and rate-limit usage.
135type Usage struct {
136 InputTokens uint64 `json:"input_tokens"`
137 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
138 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
139 OutputTokens uint64 `json:"output_tokens"`
140 CostUSD float64 `json:"cost_usd"`
141}
142
143func (u *Usage) Add(other Usage) {
144 u.InputTokens += other.InputTokens
145 u.CacheCreationInputTokens += other.CacheCreationInputTokens
146 u.CacheReadInputTokens += other.CacheReadInputTokens
147 u.OutputTokens += other.OutputTokens
148 u.CostUSD += other.CostUSD
149}
150
151func (u *Usage) String() string {
152 return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
153}
154
155func (u *Usage) IsZero() bool {
156 return *u == Usage{}
157}
158
159func (u *Usage) Attr() slog.Attr {
160 return slog.Group("usage",
161 slog.Uint64("input_tokens", u.InputTokens),
162 slog.Uint64("output_tokens", u.OutputTokens),
163 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
164 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
165 )
166}
167
168type ErrorResponse struct {
169 Type string `json:"type"`
170 Message string `json:"message"`
171}
172
173// MessageResponse represents the response from the message API.
174type MessageResponse struct {
175 ID string `json:"id"`
176 Type string `json:"type"`
177 Role string `json:"role"`
178 Model string `json:"model"`
179 Content []Content `json:"content"`
180 StopReason string `json:"stop_reason"`
181 StopSequence *string `json:"stop_sequence,omitempty"`
182 Usage Usage `json:"usage"`
183 StartTime *time.Time `json:"start_time,omitempty"`
184 EndTime *time.Time `json:"end_time,omitempty"`
185}
186
187func (m *MessageResponse) ToMessage() Message {
188 return Message{
189 Role: m.Role,
190 Content: m.Content,
191 }
192}
193
194func (m *MessageResponse) StopSequenceString() string {
195 if m.StopSequence == nil {
196 return ""
197 }
198 return *m.StopSequence
199}
200
201const (
202 ToolChoiceTypeAuto = "auto" // default
203 ToolChoiceTypeAny = "any" // any tool, but must use one
204 ToolChoiceTypeNone = "none" // no tools allowed
205 ToolChoiceTypeTool = "tool" // must use the tool specified in the Name field
206)
207
208type ToolChoice struct {
209 Type string `json:"type"`
210 Name string `json:"name,omitempty"`
211}
212
213// https://docs.anthropic.com/en/api/messages#body-system
214type SystemContent struct {
215 Text string `json:"text,omitempty"`
216 Type string `json:"type,omitempty"`
217 CacheControl json.RawMessage `json:"cache_control,omitempty"`
218}
219
220// MessageRequest represents the request payload for creating a message.
221type MessageRequest struct {
222 Model string `json:"model"`
223 Messages []Message `json:"messages"`
224 ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
225 MaxTokens int `json:"max_tokens"`
226 Tools []*Tool `json:"tools,omitempty"`
227 Stream bool `json:"stream,omitempty"`
228 System []SystemContent `json:"system,omitempty"`
229 Temperature float64 `json:"temperature,omitempty"`
230 TopK int `json:"top_k,omitempty"`
231 TopP float64 `json:"top_p,omitempty"`
232 StopSequences []string `json:"stop_sequences,omitempty"`
233
234 TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
235}
236
237const dumpText = false // debugging toggle to see raw communications with Claude
238
239// createMessage sends a request to the Anthropic message API to create a message.
240func createMessage(ctx context.Context, httpc *http.Client, url, apiKey string, request *MessageRequest) (*MessageResponse, error) {
241 var payload []byte
242 var err error
243 if dumpText || testing.Testing() {
244 payload, err = json.MarshalIndent(request, "", " ")
245 } else {
246 payload, err = json.Marshal(request)
247 payload = append(payload, '\n')
248 }
249 if err != nil {
250 return nil, err
251 }
252
253 if false {
254 fmt.Printf("claude request payload:\n%s\n", payload)
255 }
256
257 backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
258 largerMaxTokens := false
259 var partialUsage Usage
260
261 // retry loop
262 for attempts := 0; ; attempts++ {
263 if dumpText {
264 fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
265 }
266 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
267 if err != nil {
268 return nil, err
269 }
270
271 req.Header.Set("Content-Type", "application/json")
272 req.Header.Set("X-API-Key", apiKey)
273 req.Header.Set("Anthropic-Version", "2023-06-01")
274
275 features := []string{}
276
277 if request.TokenEfficientToolUse {
278 features = append(features, "token-efficient-tool-use-2025-02-19")
279 }
280 if largerMaxTokens {
281 features = append(features, "output-128k-2025-02-19")
282 request.MaxTokens = 128 * 1024
283 }
284 if len(features) > 0 {
285 req.Header.Set("anthropic-beta", strings.Join(features, ","))
286 }
287
288 resp, err := httpc.Do(req)
289 if err != nil {
290 return nil, err
291 }
292 buf, _ := io.ReadAll(resp.Body)
293 resp.Body.Close()
294
295 switch {
296 case resp.StatusCode == http.StatusOK:
297 if dumpText {
298 fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
299 }
300 var response MessageResponse
301 err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
302 if err != nil {
303 return nil, err
304 }
305 if response.StopReason == StopReasonMaxTokens && !largerMaxTokens {
306 fmt.Printf("Retrying Anthropic API call with larger max tokens size.")
307 // Retry with more output tokens.
308 largerMaxTokens = true
309 response.Usage.CostUSD = response.TotalDollars()
310 partialUsage = response.Usage
311 continue
312 }
313
314 // Calculate and set the cost_usd field
315 if largerMaxTokens {
316 response.Usage.Add(partialUsage)
317 }
318 response.Usage.CostUSD = response.TotalDollars()
319
320 return &response, nil
321 case resp.StatusCode >= 500 && resp.StatusCode < 600:
322 // overloaded or unhappy, in one form or another
323 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
324 slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
325 time.Sleep(sleep)
326 case resp.StatusCode == 429:
327 // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
328 // and then add some additional time for backoff.
329 sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
330 slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
331 // case resp.StatusCode == 400:
332 // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
333 default:
334 return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
335 }
336 }
337}
338
339// A Convo is a managed conversation with Claude.
340// It automatically manages the state of the conversation,
341// including appending messages send/received,
342// calling tools and sending their results,
343// tracking usage, etc.
344//
345// Exported fields must not be altered concurrently with calling any method on Convo.
346// Typical usage is to configure a Convo once before using it.
347type Convo struct {
348 // ID is a unique ID for the conversation
349 ID string
350 // Ctx is the context for the entire conversation.
351 Ctx context.Context
352 // HTTPC is the HTTP client for the conversation.
353 HTTPC *http.Client
354 // URL is the remote messages URL to dial.
355 URL string
356 // APIKey is the API key for the conversation.
357 APIKey string
358 // Model is the model for the conversation.
359 Model string
360 // MaxTokens is the max tokens for each response in the conversation.
361 MaxTokens int
362 // Tools are the tools available during the conversation.
363 Tools []*Tool
364 // SystemPrompt is the system prompt for the conversation.
365 SystemPrompt string
366 // PromptCaching indicates whether to use Anthropic's prompt caching.
367 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
368 // for the documentation. At request send time, we set the cache_control field on the
369 // last message. We also cache the system prompt.
370 // Default: true.
371 PromptCaching bool
372 // ToolUseOnly indicates whether Claude may only use tools during this conversation.
373 // TODO: add more fine-grained control over tool use?
374 ToolUseOnly bool
375 // Parent is the parent conversation, if any.
376 // It is non-nil for "subagent" calls.
377 // It is set automatically when calling SubConvo,
378 // and usually should not be set manually.
379 Parent *Convo
380 // Budget is the budget for this conversation (and all sub-conversations).
381 // The Conversation DOES NOT automatically enforce the budget.
382 // It is up to the caller to call OverBudget() as appropriate.
383 Budget Budget
384
385 // messages tracks the messages so far in the conversation.
386 messages []Message
387
388 // Listener receives messages being sent.
389 Listener Listener
390
391 muToolUseCancel *sync.Mutex
392 toolUseCancel map[string]context.CancelCauseFunc
393
394 // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
395 mu *sync.Mutex
396 // usage tracks usage for this conversation and all sub-conversations.
397 usage *CumulativeUsage
398}
399
400// newConvoID generates a new 8-byte random id.
401// The uniqueness/collision requirements here are very low.
402// They are not global identifiers,
403// just enough to distinguish different convos in a single session.
404func newConvoID() string {
405 u1 := rand.Uint32()
406 s := crock32.Encode(uint64(u1))
407 if len(s) < 7 {
408 s += strings.Repeat("0", 7-len(s))
409 }
410 return s[:3] + "-" + s[3:]
411}
412
413// NewConvo creates a new conversation with Claude with sensible defaults.
414// ctx is the context for the entire conversation.
415func NewConvo(ctx context.Context, apiKey string) *Convo {
416 id := newConvoID()
417 return &Convo{
418 Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
419 HTTPC: http.DefaultClient,
420 URL: DefaultURL,
421 APIKey: apiKey,
422 Model: DefaultModel,
423 MaxTokens: DefaultMaxTokens,
424 PromptCaching: true,
425 usage: newUsage(),
426 Listener: &NoopListener{},
427 ID: id,
428 muToolUseCancel: &sync.Mutex{},
429 toolUseCancel: map[string]context.CancelCauseFunc{},
430 mu: &sync.Mutex{},
431 }
432}
433
434// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
435// (This propagates context for cancellation, HTTP client, API key, etc.)
436// The sub-conversation shares no messages with the parent conversation.
437// It does not inherit tools from the parent conversation.
438func (c *Convo) SubConvo() *Convo {
439 id := newConvoID()
440 return &Convo{
441 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
442 HTTPC: c.HTTPC,
Josh Bleecher Snyder6a50b182025-04-24 11:07:41 -0700443 URL: c.URL,
Earl Lee2e463fb2025-04-17 11:22:22 -0700444 APIKey: c.APIKey,
445 Model: c.Model,
446 MaxTokens: c.MaxTokens,
447 PromptCaching: c.PromptCaching,
448 Parent: c,
449 // For convenience, sub-convo usage shares tool uses map with parent,
450 // all other fields separate, propagated in AddResponse
451 usage: newUsageWithSharedToolUses(c.usage),
452 mu: c.mu,
453 Listener: c.Listener,
454 ID: id,
455 // Do not copy Budget. Each budget is independent,
456 // and OverBudget checks whether any ancestor is over budget.
457 }
458}
459
460// Depth reports how many "sub-conversations" deep this conversation is.
461// That it, it walks up parents until it finds a root.
462func (c *Convo) Depth() int {
463 x := c
464 var depth int
465 for x.Parent != nil {
466 x = x.Parent
467 depth++
468 }
469 return depth
470}
471
472// SendUserTextMessage sends a text message to Claude in this conversation.
473// otherContents contains additional contents to send with the message, usually tool results.
474func (c *Convo) SendUserTextMessage(s string, otherContents ...Content) (*MessageResponse, error) {
475 contents := slices.Clone(otherContents)
476 if s != "" {
477 contents = append(contents, Content{Type: ContentTypeText, Text: s})
478 }
479 msg := Message{
480 Role: MessageRoleUser,
481 Content: contents,
482 }
483 return c.SendMessage(msg)
484}
485
486func (c *Convo) messageRequest(msg Message) *MessageRequest {
487 system := []SystemContent{}
488 if c.SystemPrompt != "" {
489 var d SystemContent
490 d = SystemContent{Type: "text", Text: c.SystemPrompt}
491 if c.PromptCaching {
492 d.CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
493 }
494 system = []SystemContent{d}
495 }
496
497 // Claude is happy to return an empty response in response to our Done() call,
498 // and, if so, you'll see something like:
499 // API request failed with status 400 Bad Request
500 // {"type":"error","error": {"type":"invalid_request_error",
501 // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
502 // So, we filter out those empty messages.
503 var nonEmptyMessages []Message
504 for _, m := range c.messages {
505 if len(m.Content) > 0 {
506 nonEmptyMessages = append(nonEmptyMessages, m)
507 }
508 }
509
510 mr := &MessageRequest{
511 Model: c.Model,
512 Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
513 System: system,
514 Tools: c.Tools,
515 MaxTokens: c.MaxTokens,
516 }
517 if c.ToolUseOnly {
518 mr.ToolChoice = &ToolChoice{Type: ToolChoiceTypeAny}
519 }
520 return mr
521}
522
523func (c *Convo) findTool(name string) (*Tool, error) {
524 for _, tool := range c.Tools {
525 if tool.Name == name {
526 return tool, nil
527 }
528 }
529 return nil, fmt.Errorf("tool %q not found", name)
530}
531
532// insertMissingToolResults adds error results for tool uses that were requested
533// but not included in the message, which can happen in error paths like "out of budget."
534// We only insert these if there were no tool responses at all, since an incorrect
535// number of tool results would be a programmer error. Mutates inputs.
536func (c *Convo) insertMissingToolResults(mr *MessageRequest, msg *Message) {
537 if len(mr.Messages) < 2 {
538 return
539 }
540 prev := mr.Messages[len(mr.Messages)-2]
541 var toolUsePrev int
542 for _, c := range prev.Content {
543 if c.Type == ContentTypeToolUse {
544 toolUsePrev++
545 }
546 }
547 if toolUsePrev == 0 {
548 return
549 }
550 var toolUseCurrent int
551 for _, c := range msg.Content {
552 if c.Type == ContentTypeToolResult {
553 toolUseCurrent++
554 }
555 }
556 if toolUseCurrent != 0 {
557 return
558 }
559 var prefix []Content
560 for _, part := range prev.Content {
561 if part.Type != ContentTypeToolUse {
562 continue
563 }
564 content := Content{
565 Type: ContentTypeToolResult,
566 ToolUseID: part.ID,
567 ToolError: true,
568 ToolResult: "not executed; retry possible",
569 }
570 prefix = append(prefix, content)
571 msg.Content = append(prefix, msg.Content...)
572 mr.Messages[len(mr.Messages)-1].Content = msg.Content
573 }
574 slog.DebugContext(c.Ctx, "inserted missing tool results")
575}
576
577// SendMessage sends a message to Claude.
578// The conversation records (internally) all messages succesfully sent and received.
579func (c *Convo) SendMessage(msg Message) (*MessageResponse, error) {
580 mr := c.messageRequest(msg)
581 var lastMessage *Message
582 if c.PromptCaching {
583 lastMessage = &mr.Messages[len(mr.Messages)-1]
584 if len(lastMessage.Content) > 0 {
585 lastMessage.Content[len(lastMessage.Content)-1].CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
586 }
587 }
588 defer func() {
589 if lastMessage == nil {
590 return
591 }
592 if len(lastMessage.Content) > 0 {
593 lastMessage.Content[len(lastMessage.Content)-1].CacheControl = []byte{}
594 }
595 }()
596 c.insertMissingToolResults(mr, &msg)
597 c.Listener.OnRequest(c.Ctx, c, &msg)
598
599 startTime := time.Now()
600 resp, err := createMessage(c.Ctx, c.HTTPC, c.URL, c.APIKey, mr)
601 if resp != nil {
602 resp.StartTime = &startTime
603 endTime := time.Now()
604 resp.EndTime = &endTime
605 }
606
607 if err != nil {
608 return nil, err
609 }
610 c.messages = append(c.messages, msg, resp.ToMessage())
611 // Propagate usage to all ancestors (including us).
612 for x := c; x != nil; x = x.Parent {
613 x.usage.AddResponse(resp)
614 }
615 c.Listener.OnResponse(c.Ctx, c, resp)
616 return resp, err
617}
618
619type toolCallInfoKeyType string
620
621var toolCallInfoKey toolCallInfoKeyType
622
623type ToolCallInfo struct {
624 ToolUseID string
625 Convo *Convo
626}
627
628func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
629 v := ctx.Value(toolCallInfoKey)
630 i, _ := v.(ToolCallInfo)
631 return i
632}
633
634func (c *Convo) ToolResultCancelContents(resp *MessageResponse) ([]Content, error) {
635 if resp.StopReason != StopReasonToolUse {
636 return nil, nil
637 }
638 var toolResults []Content
639
640 for _, part := range resp.Content {
641 if part.Type != ContentTypeToolUse {
642 continue
643 }
644 c.incrementToolUse(part.ToolName)
645
646 content := Content{
647 Type: ContentTypeToolResult,
648 ToolUseID: part.ID,
649 }
650
651 content.ToolError = true
652 content.ToolResult = "user canceled this too_use"
653 toolResults = append(toolResults, content)
654 }
655 return toolResults, nil
656}
657
658func (c *Convo) CancelToolUse(toolUseID string, err error) error {
659 c.muToolUseCancel.Lock()
660 defer c.muToolUseCancel.Unlock()
661 cancel, ok := c.toolUseCancel[toolUseID]
662 if !ok {
663 return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
664 }
665 delete(c.toolUseCancel, toolUseID)
666 cancel(err)
667 return nil
668}
669
670func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
671 c.muToolUseCancel.Lock()
672 defer c.muToolUseCancel.Unlock()
673 ctx, cancel := context.WithCancelCause(ctx)
674 c.toolUseCancel[toolUseID] = cancel
675 return ctx, func() { c.CancelToolUse(toolUseID, nil) }
676}
677
678// ToolResultContents runs all tool uses requested by the response and returns their results.
679// Cancelling ctx will cancel any running tool calls.
680func (c *Convo) ToolResultContents(ctx context.Context, resp *MessageResponse) ([]Content, error) {
681 if resp.StopReason != StopReasonToolUse {
682 return nil, nil
683 }
684 // Extract all tool calls from the response, call the tools, and gather the results.
685 var wg sync.WaitGroup
686 toolResultC := make(chan Content, len(resp.Content))
687 for _, part := range resp.Content {
688 if part.Type != ContentTypeToolUse {
689 continue
690 }
691 c.incrementToolUse(part.ToolName)
692 wg.Add(1)
693 go func() {
694 defer wg.Done()
695
696 // Record start time
697 startTime := time.Now()
698
699 content := Content{
700 Type: ContentTypeToolResult,
701 ToolUseID: part.ID,
702 StartTime: &startTime,
703 }
704 sendErr := func(err error) {
705 // Record end time
706 endTime := time.Now()
707 content.EndTime = &endTime
708
709 content.ToolError = true
710 content.ToolResult = err.Error()
711 c.Listener.OnToolResult(ctx, c, part.ToolName, part.ToolInput, content, nil, err)
712 toolResultC <- content
713 }
714 sendRes := func(res string) {
715 // Record end time
716 endTime := time.Now()
717 content.EndTime = &endTime
718
719 content.ToolResult = res
720 c.Listener.OnToolResult(ctx, c, part.ToolName, part.ToolInput, content, &res, nil)
721 toolResultC <- content
722 }
723
724 tool, err := c.findTool(part.ToolName)
725 if err != nil {
726 sendErr(err)
727 return
728 }
729 // Create a new context for just this tool_use call, and register its
730 // cancel function so that it can be canceled individually.
731 toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
732 defer cancel()
733 // TODO: move this into newToolUseContext?
734 toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
735 toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
736 if errors.Is(err, ErrDoNotRespond) {
737 return
738 }
739 if toolUseCtx.Err() != nil {
740 sendErr(context.Cause(toolUseCtx))
741 return
742 }
743
744 if err != nil {
745 sendErr(err)
746 return
747 }
748 sendRes(toolResult)
749 }()
750 }
751 wg.Wait()
752 close(toolResultC)
753 var toolResults []Content
754 for toolResult := range toolResultC {
755 toolResults = append(toolResults, toolResult)
756 }
757 if ctx.Err() != nil {
758 return nil, ctx.Err()
759 }
760 return toolResults, nil
761}
762
763func (c *Convo) incrementToolUse(name string) {
764 c.mu.Lock()
765 defer c.mu.Unlock()
766
767 c.usage.ToolUses[name]++
768}
769
770// ContentsAttr returns contents as a slog.Attr.
771// It is meant for logging.
772func ContentsAttr(contents []Content) slog.Attr {
773 var contentAttrs []any // slog.Attr
774 for _, content := range contents {
775 var attrs []any // slog.Attr
776 switch content.Type {
777 case ContentTypeText:
778 attrs = append(attrs, slog.String("text", content.Text))
779 case ContentTypeToolUse:
780 attrs = append(attrs, slog.String("tool_name", content.ToolName))
781 attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
782 case ContentTypeToolResult:
783 attrs = append(attrs, slog.String("tool_result", content.ToolResult))
784 attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
785 case ContentTypeThinking:
786 attrs = append(attrs, slog.String("thinking", content.Text))
787 default:
788 attrs = append(attrs, slog.String("unknown_content_type", content.Type))
789 attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
790 }
791 contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
792 }
793 return slog.Group("contents", contentAttrs...)
794}
795
796// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
797// It panics if the schema is invalid.
798func MustSchema(schema string) json.RawMessage {
799 // TODO: validate schema, for now just make sure it's valid JSON
800 schema = strings.TrimSpace(schema)
801 bytes := []byte(schema)
802 if !json.Valid(bytes) {
803 panic("invalid JSON schema: " + schema)
804 }
805 return json.RawMessage(bytes)
806}
807
808// cents per million tokens
809// (not dollars because i'm twitchy about using floats for money)
810type centsPer1MTokens struct {
811 Input uint64
812 Output uint64
813 CacheRead uint64
814 CacheCreation uint64
815}
816
817// https://www.anthropic.com/pricing#anthropic-api
818var modelCost = map[string]centsPer1MTokens{
819 Claude37Sonnet: {
820 Input: 300, // $3
821 Output: 1500, // $15
822 CacheRead: 30, // $0.30
823 CacheCreation: 375, // $3.75
824 },
825 Claude35Haiku: {
826 Input: 80, // $0.80
827 Output: 400, // $4.00
828 CacheRead: 8, // $0.08
829 CacheCreation: 100, // $1.00
830 },
831 Claude35Sonnet: {
832 Input: 300, // $3
833 Output: 1500, // $15
834 CacheRead: 30, // $0.30
835 CacheCreation: 375, // $3.75
836 },
837}
838
839// TotalDollars returns the total cost to obtain this response, in dollars.
840func (mr *MessageResponse) TotalDollars() float64 {
841 cpm, ok := modelCost[mr.Model]
842 if !ok {
843 panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
844 }
845 use := mr.Usage
846 megaCents := use.InputTokens*cpm.Input +
847 use.OutputTokens*cpm.Output +
848 use.CacheReadInputTokens*cpm.CacheRead +
849 use.CacheCreationInputTokens*cpm.CacheCreation
850 cents := float64(megaCents) / 1_000_000.0
851 return cents / 100.0
852}
853
854func newUsage() *CumulativeUsage {
855 return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
856}
857
858func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
859 return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
860}
861
862// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
863type CumulativeUsage struct {
864 StartTime time.Time `json:"start_time"`
865 Responses uint64 `json:"messages"` // count of responses
866 InputTokens uint64 `json:"input_tokens"`
867 OutputTokens uint64 `json:"output_tokens"`
868 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
869 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
870 TotalCostUSD float64 `json:"total_cost_usd"`
871 ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
872}
873
874func (u *CumulativeUsage) Clone() CumulativeUsage {
875 v := *u
876 v.ToolUses = maps.Clone(u.ToolUses)
877 return v
878}
879
880func (c *Convo) CumulativeUsage() CumulativeUsage {
881 if c == nil {
882 return CumulativeUsage{}
883 }
884 c.mu.Lock()
885 defer c.mu.Unlock()
886 return c.usage.Clone()
887}
888
889func (u *CumulativeUsage) WallTime() time.Duration {
890 return time.Since(u.StartTime)
891}
892
893func (u *CumulativeUsage) DollarsPerHour() float64 {
894 hours := u.WallTime().Hours()
895 if hours == 0 {
896 return 0
897 }
898 return u.TotalCostUSD / hours
899}
900
901func (u *CumulativeUsage) AddResponse(resp *MessageResponse) {
902 usage := resp.Usage
903 u.Responses++
904 u.InputTokens += usage.InputTokens
905 u.OutputTokens += usage.OutputTokens
906 u.CacheReadInputTokens += usage.CacheReadInputTokens
907 u.CacheCreationInputTokens += usage.CacheCreationInputTokens
908 u.TotalCostUSD += resp.TotalDollars()
909}
910
Josh Bleecher Snyder35889972025-04-24 20:48:16 +0000911// TotalInputTokens returns the grand total cumulative input tokens in u.
912func (u *CumulativeUsage) TotalInputTokens() uint64 {
913 return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
914}
915
Earl Lee2e463fb2025-04-17 11:22:22 -0700916// Attr returns the cumulative usage as a slog.Attr with key "usage".
917func (u CumulativeUsage) Attr() slog.Attr {
918 elapsed := time.Since(u.StartTime)
919 return slog.Group("usage",
920 slog.Duration("wall_time", elapsed),
921 slog.Uint64("responses", u.Responses),
922 slog.Uint64("input_tokens", u.InputTokens),
923 slog.Uint64("output_tokens", u.OutputTokens),
924 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
925 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
926 slog.Float64("total_cost_usd", u.TotalCostUSD),
927 slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
928 slog.Any("tool_uses", maps.Clone(u.ToolUses)),
929 )
930}
931
932// A Budget represents the maximum amount of resources that may be spent on a conversation.
933// Note that the default (zero) budget is unlimited.
934type Budget struct {
935 MaxResponses uint64 // if > 0, max number of iterations (=responses)
936 MaxDollars float64 // if > 0, max dollars that may be spent
937 MaxWallTime time.Duration // if > 0, max wall time that may be spent
938}
939
940// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
941// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
942func (c *Convo) OverBudget() error {
943 for x := c; x != nil; x = x.Parent {
944 if err := x.overBudget(); err != nil {
945 return err
946 }
947 }
948 return nil
949}
950
951// ResetBudget sets the budget to the passed in budget and
952// adjusts it by what's been used so far.
953func (c *Convo) ResetBudget(budget Budget) {
954 c.Budget = budget
955 if c.Budget.MaxDollars > 0 {
956 c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
957 }
958 if c.Budget.MaxResponses > 0 {
959 c.Budget.MaxResponses += c.CumulativeUsage().Responses
960 }
961 if c.Budget.MaxWallTime > 0 {
962 c.Budget.MaxWallTime += c.usage.WallTime()
963 }
964}
965
966func (c *Convo) overBudget() error {
967 usage := c.CumulativeUsage()
968 // TODO: stop before we exceed the budget instead of after?
969 // Top priority is money, then time, then response count.
970 var err error
971 cont := "Continuing to chat will reset the budget."
972 if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
973 err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
974 }
975 if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
976 err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
977 }
978 if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
979 err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
980 }
981 return err
982}