blob: 69c17c25c721afac1028bbd51907616cc6f1f451 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package ant
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "math/rand/v2"
13 "net/http"
14 "slices"
15 "strings"
16 "sync"
17 "testing"
18 "time"
19
Philip Zeyliger99a9a022025-04-27 15:15:25 +000020 "github.com/oklog/ulid/v2"
Earl Lee2e463fb2025-04-17 11:22:22 -070021 "github.com/richardlehane/crock32"
22 "sketch.dev/skribe"
23)
24
25const (
26 DefaultModel = Claude37Sonnet
27 // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
28 // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
29 DefaultMaxTokens = 8192
30 DefaultURL = "https://api.anthropic.com/v1/messages"
31)
32
33const (
34 Claude35Sonnet = "claude-3-5-sonnet-20241022"
35 Claude35Haiku = "claude-3-5-haiku-20241022"
36 Claude37Sonnet = "claude-3-7-sonnet-20250219"
37)
38
39const (
40 MessageRoleUser = "user"
41 MessageRoleAssistant = "assistant"
42
43 ContentTypeText = "text"
44 ContentTypeThinking = "thinking"
45 ContentTypeRedactedThinking = "redacted_thinking"
46 ContentTypeToolUse = "tool_use"
47 ContentTypeToolResult = "tool_result"
48
49 StopReasonStopSequence = "stop_sequence"
50 StopReasonMaxTokens = "max_tokens"
51 StopReasonEndTurn = "end_turn"
52 StopReasonToolUse = "tool_use"
53)
54
55type Listener interface {
56 // TODO: Content is leaking an anthropic API; should we avoid it?
57 // TODO: Where should we include start/end time and usage?
Philip Zeyliger99a9a022025-04-27 15:15:25 +000058 OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content Content)
59 OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content Content, result *string, err error)
60 OnRequest(ctx context.Context, convo *Convo, requestID string, msg *Message)
61 OnResponse(ctx context.Context, convo *Convo, requestID string, msg *MessageResponse)
Earl Lee2e463fb2025-04-17 11:22:22 -070062}
63
64type NoopListener struct{}
65
Philip Zeyliger99a9a022025-04-27 15:15:25 +000066func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content Content) {
Earl Lee2e463fb2025-04-17 11:22:22 -070067}
Philip Zeyliger99a9a022025-04-27 15:15:25 +000068
69func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content Content, result *string, err error) {
70}
71
72func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *MessageResponse) {
73}
74func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *Message) {}
Earl Lee2e463fb2025-04-17 11:22:22 -070075
76type Content struct {
77 // TODO: image support?
78 // https://docs.anthropic.com/en/api/messages
79 ID string `json:"id,omitempty"`
80 Type string `json:"type,omitempty"`
81 Text string `json:"text,omitempty"`
82
83 // for thinking
84 Thinking string `json:"thinking,omitempty"`
85 Data string `json:"data,omitempty"` // for redacted_thinking
86 Signature string `json:"signature,omitempty"` // for thinking
87
88 // for tool_use
89 ToolName string `json:"name,omitempty"`
90 ToolInput json.RawMessage `json:"input,omitempty"`
91
92 // for tool_result
93 ToolUseID string `json:"tool_use_id,omitempty"`
94 ToolError bool `json:"is_error,omitempty"`
95 ToolResult string `json:"content,omitempty"`
96
97 // timing information for tool_result; not sent to Claude
98 StartTime *time.Time `json:"-"`
99 EndTime *time.Time `json:"-"`
100
101 CacheControl json.RawMessage `json:"cache_control,omitempty"`
102}
103
104func StringContent(s string) Content {
105 return Content{Type: ContentTypeText, Text: s}
106}
107
108// Message represents a message in the conversation.
109type Message struct {
110 Role string `json:"role"`
111 Content []Content `json:"content"`
112 ToolUse *ToolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
113}
114
115// ToolUse represents a tool use in the message content.
116type ToolUse struct {
117 ID string `json:"id"`
118 Name string `json:"name"`
119}
120
121// Tool represents a tool available to Claude.
122type Tool struct {
123 Name string `json:"name"`
124 // Type is used by the text editor tool; see
125 // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
126 Type string `json:"type,omitempty"`
127 Description string `json:"description,omitempty"`
128 InputSchema json.RawMessage `json:"input_schema,omitempty"`
129
130 // The Run function is automatically called when the tool is used.
131 // Run functions may be called concurrently with each other and themselves.
132 // The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
133 // The outputs from Run will be sent back to Claude.
134 // If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
135 // ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
136 Run func(ctx context.Context, input json.RawMessage) (string, error) `json:"-"`
137}
138
139var ErrDoNotRespond = errors.New("do not respond")
140
141// Usage represents the billing and rate-limit usage.
142type Usage struct {
143 InputTokens uint64 `json:"input_tokens"`
144 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
145 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
146 OutputTokens uint64 `json:"output_tokens"`
147 CostUSD float64 `json:"cost_usd"`
148}
149
150func (u *Usage) Add(other Usage) {
151 u.InputTokens += other.InputTokens
152 u.CacheCreationInputTokens += other.CacheCreationInputTokens
153 u.CacheReadInputTokens += other.CacheReadInputTokens
154 u.OutputTokens += other.OutputTokens
155 u.CostUSD += other.CostUSD
156}
157
158func (u *Usage) String() string {
159 return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
160}
161
162func (u *Usage) IsZero() bool {
163 return *u == Usage{}
164}
165
166func (u *Usage) Attr() slog.Attr {
167 return slog.Group("usage",
168 slog.Uint64("input_tokens", u.InputTokens),
169 slog.Uint64("output_tokens", u.OutputTokens),
170 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
171 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
172 )
173}
174
175type ErrorResponse struct {
176 Type string `json:"type"`
177 Message string `json:"message"`
178}
179
180// MessageResponse represents the response from the message API.
181type MessageResponse struct {
182 ID string `json:"id"`
183 Type string `json:"type"`
184 Role string `json:"role"`
185 Model string `json:"model"`
186 Content []Content `json:"content"`
187 StopReason string `json:"stop_reason"`
188 StopSequence *string `json:"stop_sequence,omitempty"`
189 Usage Usage `json:"usage"`
190 StartTime *time.Time `json:"start_time,omitempty"`
191 EndTime *time.Time `json:"end_time,omitempty"`
192}
193
194func (m *MessageResponse) ToMessage() Message {
195 return Message{
196 Role: m.Role,
197 Content: m.Content,
198 }
199}
200
201func (m *MessageResponse) StopSequenceString() string {
202 if m.StopSequence == nil {
203 return ""
204 }
205 return *m.StopSequence
206}
207
208const (
209 ToolChoiceTypeAuto = "auto" // default
210 ToolChoiceTypeAny = "any" // any tool, but must use one
211 ToolChoiceTypeNone = "none" // no tools allowed
212 ToolChoiceTypeTool = "tool" // must use the tool specified in the Name field
213)
214
215type ToolChoice struct {
216 Type string `json:"type"`
217 Name string `json:"name,omitempty"`
218}
219
220// https://docs.anthropic.com/en/api/messages#body-system
221type SystemContent struct {
222 Text string `json:"text,omitempty"`
223 Type string `json:"type,omitempty"`
224 CacheControl json.RawMessage `json:"cache_control,omitempty"`
225}
226
227// MessageRequest represents the request payload for creating a message.
228type MessageRequest struct {
229 Model string `json:"model"`
230 Messages []Message `json:"messages"`
231 ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
232 MaxTokens int `json:"max_tokens"`
233 Tools []*Tool `json:"tools,omitempty"`
234 Stream bool `json:"stream,omitempty"`
235 System []SystemContent `json:"system,omitempty"`
236 Temperature float64 `json:"temperature,omitempty"`
237 TopK int `json:"top_k,omitempty"`
238 TopP float64 `json:"top_p,omitempty"`
239 StopSequences []string `json:"stop_sequences,omitempty"`
240
241 TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
242}
243
244const dumpText = false // debugging toggle to see raw communications with Claude
245
246// createMessage sends a request to the Anthropic message API to create a message.
247func createMessage(ctx context.Context, httpc *http.Client, url, apiKey string, request *MessageRequest) (*MessageResponse, error) {
248 var payload []byte
249 var err error
250 if dumpText || testing.Testing() {
251 payload, err = json.MarshalIndent(request, "", " ")
252 } else {
253 payload, err = json.Marshal(request)
254 payload = append(payload, '\n')
255 }
256 if err != nil {
257 return nil, err
258 }
259
260 if false {
261 fmt.Printf("claude request payload:\n%s\n", payload)
262 }
263
264 backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
265 largerMaxTokens := false
266 var partialUsage Usage
267
268 // retry loop
269 for attempts := 0; ; attempts++ {
270 if dumpText {
271 fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
272 }
273 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
274 if err != nil {
275 return nil, err
276 }
277
278 req.Header.Set("Content-Type", "application/json")
279 req.Header.Set("X-API-Key", apiKey)
280 req.Header.Set("Anthropic-Version", "2023-06-01")
281
282 features := []string{}
283
284 if request.TokenEfficientToolUse {
285 features = append(features, "token-efficient-tool-use-2025-02-19")
286 }
287 if largerMaxTokens {
288 features = append(features, "output-128k-2025-02-19")
289 request.MaxTokens = 128 * 1024
290 }
291 if len(features) > 0 {
292 req.Header.Set("anthropic-beta", strings.Join(features, ","))
293 }
294
295 resp, err := httpc.Do(req)
296 if err != nil {
297 return nil, err
298 }
299 buf, _ := io.ReadAll(resp.Body)
300 resp.Body.Close()
301
302 switch {
303 case resp.StatusCode == http.StatusOK:
304 if dumpText {
305 fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
306 }
307 var response MessageResponse
308 err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
309 if err != nil {
310 return nil, err
311 }
312 if response.StopReason == StopReasonMaxTokens && !largerMaxTokens {
313 fmt.Printf("Retrying Anthropic API call with larger max tokens size.")
314 // Retry with more output tokens.
315 largerMaxTokens = true
316 response.Usage.CostUSD = response.TotalDollars()
317 partialUsage = response.Usage
318 continue
319 }
320
321 // Calculate and set the cost_usd field
322 if largerMaxTokens {
323 response.Usage.Add(partialUsage)
324 }
325 response.Usage.CostUSD = response.TotalDollars()
326
327 return &response, nil
328 case resp.StatusCode >= 500 && resp.StatusCode < 600:
329 // overloaded or unhappy, in one form or another
330 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
331 slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
332 time.Sleep(sleep)
333 case resp.StatusCode == 429:
334 // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
335 // and then add some additional time for backoff.
336 sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
337 slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
338 // case resp.StatusCode == 400:
339 // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
340 default:
341 return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
342 }
343 }
344}
345
346// A Convo is a managed conversation with Claude.
347// It automatically manages the state of the conversation,
348// including appending messages send/received,
349// calling tools and sending their results,
350// tracking usage, etc.
351//
352// Exported fields must not be altered concurrently with calling any method on Convo.
353// Typical usage is to configure a Convo once before using it.
354type Convo struct {
355 // ID is a unique ID for the conversation
356 ID string
357 // Ctx is the context for the entire conversation.
358 Ctx context.Context
359 // HTTPC is the HTTP client for the conversation.
360 HTTPC *http.Client
361 // URL is the remote messages URL to dial.
362 URL string
363 // APIKey is the API key for the conversation.
364 APIKey string
365 // Model is the model for the conversation.
366 Model string
367 // MaxTokens is the max tokens for each response in the conversation.
368 MaxTokens int
369 // Tools are the tools available during the conversation.
370 Tools []*Tool
371 // SystemPrompt is the system prompt for the conversation.
372 SystemPrompt string
373 // PromptCaching indicates whether to use Anthropic's prompt caching.
374 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
375 // for the documentation. At request send time, we set the cache_control field on the
376 // last message. We also cache the system prompt.
377 // Default: true.
378 PromptCaching bool
379 // ToolUseOnly indicates whether Claude may only use tools during this conversation.
380 // TODO: add more fine-grained control over tool use?
381 ToolUseOnly bool
382 // Parent is the parent conversation, if any.
383 // It is non-nil for "subagent" calls.
384 // It is set automatically when calling SubConvo,
385 // and usually should not be set manually.
386 Parent *Convo
387 // Budget is the budget for this conversation (and all sub-conversations).
388 // The Conversation DOES NOT automatically enforce the budget.
389 // It is up to the caller to call OverBudget() as appropriate.
390 Budget Budget
391
392 // messages tracks the messages so far in the conversation.
393 messages []Message
394
395 // Listener receives messages being sent.
396 Listener Listener
397
398 muToolUseCancel *sync.Mutex
399 toolUseCancel map[string]context.CancelCauseFunc
400
401 // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
402 mu *sync.Mutex
403 // usage tracks usage for this conversation and all sub-conversations.
404 usage *CumulativeUsage
405}
406
407// newConvoID generates a new 8-byte random id.
408// The uniqueness/collision requirements here are very low.
409// They are not global identifiers,
410// just enough to distinguish different convos in a single session.
411func newConvoID() string {
412 u1 := rand.Uint32()
413 s := crock32.Encode(uint64(u1))
414 if len(s) < 7 {
415 s += strings.Repeat("0", 7-len(s))
416 }
417 return s[:3] + "-" + s[3:]
418}
419
420// NewConvo creates a new conversation with Claude with sensible defaults.
421// ctx is the context for the entire conversation.
422func NewConvo(ctx context.Context, apiKey string) *Convo {
423 id := newConvoID()
424 return &Convo{
425 Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
426 HTTPC: http.DefaultClient,
427 URL: DefaultURL,
428 APIKey: apiKey,
429 Model: DefaultModel,
430 MaxTokens: DefaultMaxTokens,
431 PromptCaching: true,
432 usage: newUsage(),
433 Listener: &NoopListener{},
434 ID: id,
435 muToolUseCancel: &sync.Mutex{},
436 toolUseCancel: map[string]context.CancelCauseFunc{},
437 mu: &sync.Mutex{},
438 }
439}
440
441// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
442// (This propagates context for cancellation, HTTP client, API key, etc.)
443// The sub-conversation shares no messages with the parent conversation.
444// It does not inherit tools from the parent conversation.
445func (c *Convo) SubConvo() *Convo {
446 id := newConvoID()
447 return &Convo{
448 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
449 HTTPC: c.HTTPC,
Josh Bleecher Snyder6a50b182025-04-24 11:07:41 -0700450 URL: c.URL,
Earl Lee2e463fb2025-04-17 11:22:22 -0700451 APIKey: c.APIKey,
452 Model: c.Model,
453 MaxTokens: c.MaxTokens,
454 PromptCaching: c.PromptCaching,
455 Parent: c,
456 // For convenience, sub-convo usage shares tool uses map with parent,
457 // all other fields separate, propagated in AddResponse
458 usage: newUsageWithSharedToolUses(c.usage),
459 mu: c.mu,
460 Listener: c.Listener,
461 ID: id,
462 // Do not copy Budget. Each budget is independent,
463 // and OverBudget checks whether any ancestor is over budget.
464 }
465}
466
467// Depth reports how many "sub-conversations" deep this conversation is.
468// That it, it walks up parents until it finds a root.
469func (c *Convo) Depth() int {
470 x := c
471 var depth int
472 for x.Parent != nil {
473 x = x.Parent
474 depth++
475 }
476 return depth
477}
478
479// SendUserTextMessage sends a text message to Claude in this conversation.
480// otherContents contains additional contents to send with the message, usually tool results.
481func (c *Convo) SendUserTextMessage(s string, otherContents ...Content) (*MessageResponse, error) {
482 contents := slices.Clone(otherContents)
483 if s != "" {
484 contents = append(contents, Content{Type: ContentTypeText, Text: s})
485 }
486 msg := Message{
487 Role: MessageRoleUser,
488 Content: contents,
489 }
490 return c.SendMessage(msg)
491}
492
493func (c *Convo) messageRequest(msg Message) *MessageRequest {
494 system := []SystemContent{}
495 if c.SystemPrompt != "" {
496 var d SystemContent
497 d = SystemContent{Type: "text", Text: c.SystemPrompt}
498 if c.PromptCaching {
499 d.CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
500 }
501 system = []SystemContent{d}
502 }
503
504 // Claude is happy to return an empty response in response to our Done() call,
505 // and, if so, you'll see something like:
506 // API request failed with status 400 Bad Request
507 // {"type":"error","error": {"type":"invalid_request_error",
508 // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
509 // So, we filter out those empty messages.
510 var nonEmptyMessages []Message
511 for _, m := range c.messages {
512 if len(m.Content) > 0 {
513 nonEmptyMessages = append(nonEmptyMessages, m)
514 }
515 }
516
517 mr := &MessageRequest{
518 Model: c.Model,
519 Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
520 System: system,
521 Tools: c.Tools,
522 MaxTokens: c.MaxTokens,
523 }
524 if c.ToolUseOnly {
525 mr.ToolChoice = &ToolChoice{Type: ToolChoiceTypeAny}
526 }
527 return mr
528}
529
530func (c *Convo) findTool(name string) (*Tool, error) {
531 for _, tool := range c.Tools {
532 if tool.Name == name {
533 return tool, nil
534 }
535 }
536 return nil, fmt.Errorf("tool %q not found", name)
537}
538
539// insertMissingToolResults adds error results for tool uses that were requested
540// but not included in the message, which can happen in error paths like "out of budget."
541// We only insert these if there were no tool responses at all, since an incorrect
542// number of tool results would be a programmer error. Mutates inputs.
543func (c *Convo) insertMissingToolResults(mr *MessageRequest, msg *Message) {
544 if len(mr.Messages) < 2 {
545 return
546 }
547 prev := mr.Messages[len(mr.Messages)-2]
548 var toolUsePrev int
549 for _, c := range prev.Content {
550 if c.Type == ContentTypeToolUse {
551 toolUsePrev++
552 }
553 }
554 if toolUsePrev == 0 {
555 return
556 }
557 var toolUseCurrent int
558 for _, c := range msg.Content {
559 if c.Type == ContentTypeToolResult {
560 toolUseCurrent++
561 }
562 }
563 if toolUseCurrent != 0 {
564 return
565 }
566 var prefix []Content
567 for _, part := range prev.Content {
568 if part.Type != ContentTypeToolUse {
569 continue
570 }
571 content := Content{
572 Type: ContentTypeToolResult,
573 ToolUseID: part.ID,
574 ToolError: true,
575 ToolResult: "not executed; retry possible",
576 }
577 prefix = append(prefix, content)
578 msg.Content = append(prefix, msg.Content...)
579 mr.Messages[len(mr.Messages)-1].Content = msg.Content
580 }
581 slog.DebugContext(c.Ctx, "inserted missing tool results")
582}
583
584// SendMessage sends a message to Claude.
585// The conversation records (internally) all messages succesfully sent and received.
586func (c *Convo) SendMessage(msg Message) (*MessageResponse, error) {
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000587 id := ulid.Make().String()
Earl Lee2e463fb2025-04-17 11:22:22 -0700588 mr := c.messageRequest(msg)
589 var lastMessage *Message
590 if c.PromptCaching {
591 lastMessage = &mr.Messages[len(mr.Messages)-1]
592 if len(lastMessage.Content) > 0 {
593 lastMessage.Content[len(lastMessage.Content)-1].CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
594 }
595 }
596 defer func() {
597 if lastMessage == nil {
598 return
599 }
600 if len(lastMessage.Content) > 0 {
601 lastMessage.Content[len(lastMessage.Content)-1].CacheControl = []byte{}
602 }
603 }()
604 c.insertMissingToolResults(mr, &msg)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000605 c.Listener.OnRequest(c.Ctx, c, id, &msg)
Earl Lee2e463fb2025-04-17 11:22:22 -0700606
607 startTime := time.Now()
608 resp, err := createMessage(c.Ctx, c.HTTPC, c.URL, c.APIKey, mr)
609 if resp != nil {
610 resp.StartTime = &startTime
611 endTime := time.Now()
612 resp.EndTime = &endTime
613 }
614
615 if err != nil {
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000616 c.Listener.OnResponse(c.Ctx, c, id, nil)
Earl Lee2e463fb2025-04-17 11:22:22 -0700617 return nil, err
618 }
619 c.messages = append(c.messages, msg, resp.ToMessage())
620 // Propagate usage to all ancestors (including us).
621 for x := c; x != nil; x = x.Parent {
622 x.usage.AddResponse(resp)
623 }
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000624 c.Listener.OnResponse(c.Ctx, c, id, resp)
Earl Lee2e463fb2025-04-17 11:22:22 -0700625 return resp, err
626}
627
628type toolCallInfoKeyType string
629
630var toolCallInfoKey toolCallInfoKeyType
631
632type ToolCallInfo struct {
633 ToolUseID string
634 Convo *Convo
635}
636
637func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
638 v := ctx.Value(toolCallInfoKey)
639 i, _ := v.(ToolCallInfo)
640 return i
641}
642
643func (c *Convo) ToolResultCancelContents(resp *MessageResponse) ([]Content, error) {
644 if resp.StopReason != StopReasonToolUse {
645 return nil, nil
646 }
647 var toolResults []Content
648
649 for _, part := range resp.Content {
650 if part.Type != ContentTypeToolUse {
651 continue
652 }
653 c.incrementToolUse(part.ToolName)
654
655 content := Content{
656 Type: ContentTypeToolResult,
657 ToolUseID: part.ID,
658 }
659
660 content.ToolError = true
661 content.ToolResult = "user canceled this too_use"
662 toolResults = append(toolResults, content)
663 }
664 return toolResults, nil
665}
666
667func (c *Convo) CancelToolUse(toolUseID string, err error) error {
668 c.muToolUseCancel.Lock()
669 defer c.muToolUseCancel.Unlock()
670 cancel, ok := c.toolUseCancel[toolUseID]
671 if !ok {
672 return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
673 }
674 delete(c.toolUseCancel, toolUseID)
675 cancel(err)
676 return nil
677}
678
679func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
680 c.muToolUseCancel.Lock()
681 defer c.muToolUseCancel.Unlock()
682 ctx, cancel := context.WithCancelCause(ctx)
683 c.toolUseCancel[toolUseID] = cancel
684 return ctx, func() { c.CancelToolUse(toolUseID, nil) }
685}
686
687// ToolResultContents runs all tool uses requested by the response and returns their results.
688// Cancelling ctx will cancel any running tool calls.
689func (c *Convo) ToolResultContents(ctx context.Context, resp *MessageResponse) ([]Content, error) {
690 if resp.StopReason != StopReasonToolUse {
691 return nil, nil
692 }
693 // Extract all tool calls from the response, call the tools, and gather the results.
694 var wg sync.WaitGroup
695 toolResultC := make(chan Content, len(resp.Content))
696 for _, part := range resp.Content {
697 if part.Type != ContentTypeToolUse {
698 continue
699 }
700 c.incrementToolUse(part.ToolName)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000701 startTime := time.Now()
702
703 c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, Content{
704 Type: ContentTypeToolUse,
705 ToolUseID: part.ID,
706 StartTime: &startTime,
707 })
708
Earl Lee2e463fb2025-04-17 11:22:22 -0700709 wg.Add(1)
710 go func() {
711 defer wg.Done()
712
Earl Lee2e463fb2025-04-17 11:22:22 -0700713 content := Content{
714 Type: ContentTypeToolResult,
715 ToolUseID: part.ID,
716 StartTime: &startTime,
717 }
718 sendErr := func(err error) {
719 // Record end time
720 endTime := time.Now()
721 content.EndTime = &endTime
722
723 content.ToolError = true
724 content.ToolResult = err.Error()
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000725 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
Earl Lee2e463fb2025-04-17 11:22:22 -0700726 toolResultC <- content
727 }
728 sendRes := func(res string) {
729 // Record end time
730 endTime := time.Now()
731 content.EndTime = &endTime
732
733 content.ToolResult = res
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000734 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &res, nil)
Earl Lee2e463fb2025-04-17 11:22:22 -0700735 toolResultC <- content
736 }
737
738 tool, err := c.findTool(part.ToolName)
739 if err != nil {
740 sendErr(err)
741 return
742 }
743 // Create a new context for just this tool_use call, and register its
744 // cancel function so that it can be canceled individually.
745 toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
746 defer cancel()
747 // TODO: move this into newToolUseContext?
748 toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
749 toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
750 if errors.Is(err, ErrDoNotRespond) {
751 return
752 }
753 if toolUseCtx.Err() != nil {
754 sendErr(context.Cause(toolUseCtx))
755 return
756 }
757
758 if err != nil {
759 sendErr(err)
760 return
761 }
762 sendRes(toolResult)
763 }()
764 }
765 wg.Wait()
766 close(toolResultC)
767 var toolResults []Content
768 for toolResult := range toolResultC {
769 toolResults = append(toolResults, toolResult)
770 }
771 if ctx.Err() != nil {
772 return nil, ctx.Err()
773 }
774 return toolResults, nil
775}
776
777func (c *Convo) incrementToolUse(name string) {
778 c.mu.Lock()
779 defer c.mu.Unlock()
780
781 c.usage.ToolUses[name]++
782}
783
784// ContentsAttr returns contents as a slog.Attr.
785// It is meant for logging.
786func ContentsAttr(contents []Content) slog.Attr {
787 var contentAttrs []any // slog.Attr
788 for _, content := range contents {
789 var attrs []any // slog.Attr
790 switch content.Type {
791 case ContentTypeText:
792 attrs = append(attrs, slog.String("text", content.Text))
793 case ContentTypeToolUse:
794 attrs = append(attrs, slog.String("tool_name", content.ToolName))
795 attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
796 case ContentTypeToolResult:
797 attrs = append(attrs, slog.String("tool_result", content.ToolResult))
798 attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
799 case ContentTypeThinking:
800 attrs = append(attrs, slog.String("thinking", content.Text))
801 default:
802 attrs = append(attrs, slog.String("unknown_content_type", content.Type))
803 attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
804 }
805 contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
806 }
807 return slog.Group("contents", contentAttrs...)
808}
809
810// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
811// It panics if the schema is invalid.
812func MustSchema(schema string) json.RawMessage {
813 // TODO: validate schema, for now just make sure it's valid JSON
814 schema = strings.TrimSpace(schema)
815 bytes := []byte(schema)
816 if !json.Valid(bytes) {
817 panic("invalid JSON schema: " + schema)
818 }
819 return json.RawMessage(bytes)
820}
821
822// cents per million tokens
823// (not dollars because i'm twitchy about using floats for money)
824type centsPer1MTokens struct {
825 Input uint64
826 Output uint64
827 CacheRead uint64
828 CacheCreation uint64
829}
830
831// https://www.anthropic.com/pricing#anthropic-api
832var modelCost = map[string]centsPer1MTokens{
833 Claude37Sonnet: {
834 Input: 300, // $3
835 Output: 1500, // $15
836 CacheRead: 30, // $0.30
837 CacheCreation: 375, // $3.75
838 },
839 Claude35Haiku: {
840 Input: 80, // $0.80
841 Output: 400, // $4.00
842 CacheRead: 8, // $0.08
843 CacheCreation: 100, // $1.00
844 },
845 Claude35Sonnet: {
846 Input: 300, // $3
847 Output: 1500, // $15
848 CacheRead: 30, // $0.30
849 CacheCreation: 375, // $3.75
850 },
851}
852
853// TotalDollars returns the total cost to obtain this response, in dollars.
854func (mr *MessageResponse) TotalDollars() float64 {
855 cpm, ok := modelCost[mr.Model]
856 if !ok {
857 panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
858 }
859 use := mr.Usage
860 megaCents := use.InputTokens*cpm.Input +
861 use.OutputTokens*cpm.Output +
862 use.CacheReadInputTokens*cpm.CacheRead +
863 use.CacheCreationInputTokens*cpm.CacheCreation
864 cents := float64(megaCents) / 1_000_000.0
865 return cents / 100.0
866}
867
868func newUsage() *CumulativeUsage {
869 return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
870}
871
872func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
873 return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
874}
875
876// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
877type CumulativeUsage struct {
878 StartTime time.Time `json:"start_time"`
879 Responses uint64 `json:"messages"` // count of responses
880 InputTokens uint64 `json:"input_tokens"`
881 OutputTokens uint64 `json:"output_tokens"`
882 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
883 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
884 TotalCostUSD float64 `json:"total_cost_usd"`
885 ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
886}
887
888func (u *CumulativeUsage) Clone() CumulativeUsage {
889 v := *u
890 v.ToolUses = maps.Clone(u.ToolUses)
891 return v
892}
893
894func (c *Convo) CumulativeUsage() CumulativeUsage {
895 if c == nil {
896 return CumulativeUsage{}
897 }
898 c.mu.Lock()
899 defer c.mu.Unlock()
900 return c.usage.Clone()
901}
902
903func (u *CumulativeUsage) WallTime() time.Duration {
904 return time.Since(u.StartTime)
905}
906
907func (u *CumulativeUsage) DollarsPerHour() float64 {
908 hours := u.WallTime().Hours()
909 if hours == 0 {
910 return 0
911 }
912 return u.TotalCostUSD / hours
913}
914
915func (u *CumulativeUsage) AddResponse(resp *MessageResponse) {
916 usage := resp.Usage
917 u.Responses++
918 u.InputTokens += usage.InputTokens
919 u.OutputTokens += usage.OutputTokens
920 u.CacheReadInputTokens += usage.CacheReadInputTokens
921 u.CacheCreationInputTokens += usage.CacheCreationInputTokens
922 u.TotalCostUSD += resp.TotalDollars()
923}
924
Josh Bleecher Snyder35889972025-04-24 20:48:16 +0000925// TotalInputTokens returns the grand total cumulative input tokens in u.
926func (u *CumulativeUsage) TotalInputTokens() uint64 {
927 return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
928}
929
Earl Lee2e463fb2025-04-17 11:22:22 -0700930// Attr returns the cumulative usage as a slog.Attr with key "usage".
931func (u CumulativeUsage) Attr() slog.Attr {
932 elapsed := time.Since(u.StartTime)
933 return slog.Group("usage",
934 slog.Duration("wall_time", elapsed),
935 slog.Uint64("responses", u.Responses),
936 slog.Uint64("input_tokens", u.InputTokens),
937 slog.Uint64("output_tokens", u.OutputTokens),
938 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
939 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
940 slog.Float64("total_cost_usd", u.TotalCostUSD),
941 slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
942 slog.Any("tool_uses", maps.Clone(u.ToolUses)),
943 )
944}
945
946// A Budget represents the maximum amount of resources that may be spent on a conversation.
947// Note that the default (zero) budget is unlimited.
948type Budget struct {
949 MaxResponses uint64 // if > 0, max number of iterations (=responses)
950 MaxDollars float64 // if > 0, max dollars that may be spent
951 MaxWallTime time.Duration // if > 0, max wall time that may be spent
952}
953
954// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
955// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
956func (c *Convo) OverBudget() error {
957 for x := c; x != nil; x = x.Parent {
958 if err := x.overBudget(); err != nil {
959 return err
960 }
961 }
962 return nil
963}
964
965// ResetBudget sets the budget to the passed in budget and
966// adjusts it by what's been used so far.
967func (c *Convo) ResetBudget(budget Budget) {
968 c.Budget = budget
969 if c.Budget.MaxDollars > 0 {
970 c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
971 }
972 if c.Budget.MaxResponses > 0 {
973 c.Budget.MaxResponses += c.CumulativeUsage().Responses
974 }
975 if c.Budget.MaxWallTime > 0 {
976 c.Budget.MaxWallTime += c.usage.WallTime()
977 }
978}
979
980func (c *Convo) overBudget() error {
981 usage := c.CumulativeUsage()
982 // TODO: stop before we exceed the budget instead of after?
983 // Top priority is money, then time, then response count.
984 var err error
985 cont := "Continuing to chat will reset the budget."
986 if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
987 err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
988 }
989 if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
990 err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
991 }
992 if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
993 err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
994 }
995 return err
996}