blob: 8759883655394d91f58e99c8019856906d16634c [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package ant
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "math/rand/v2"
13 "net/http"
14 "slices"
15 "strings"
16 "sync"
17 "testing"
18 "time"
19
Philip Zeyliger99a9a022025-04-27 15:15:25 +000020 "github.com/oklog/ulid/v2"
Earl Lee2e463fb2025-04-17 11:22:22 -070021 "github.com/richardlehane/crock32"
22 "sketch.dev/skribe"
23)
24
25const (
26 DefaultModel = Claude37Sonnet
27 // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
28 // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
29 DefaultMaxTokens = 8192
30 DefaultURL = "https://api.anthropic.com/v1/messages"
31)
32
33const (
34 Claude35Sonnet = "claude-3-5-sonnet-20241022"
35 Claude35Haiku = "claude-3-5-haiku-20241022"
36 Claude37Sonnet = "claude-3-7-sonnet-20250219"
37)
38
39const (
40 MessageRoleUser = "user"
41 MessageRoleAssistant = "assistant"
42
43 ContentTypeText = "text"
44 ContentTypeThinking = "thinking"
45 ContentTypeRedactedThinking = "redacted_thinking"
46 ContentTypeToolUse = "tool_use"
47 ContentTypeToolResult = "tool_result"
48
49 StopReasonStopSequence = "stop_sequence"
50 StopReasonMaxTokens = "max_tokens"
51 StopReasonEndTurn = "end_turn"
52 StopReasonToolUse = "tool_use"
53)
54
55type Listener interface {
56 // TODO: Content is leaking an anthropic API; should we avoid it?
57 // TODO: Where should we include start/end time and usage?
Philip Zeyliger99a9a022025-04-27 15:15:25 +000058 OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content Content)
59 OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content Content, result *string, err error)
60 OnRequest(ctx context.Context, convo *Convo, requestID string, msg *Message)
61 OnResponse(ctx context.Context, convo *Convo, requestID string, msg *MessageResponse)
Earl Lee2e463fb2025-04-17 11:22:22 -070062}
63
64type NoopListener struct{}
65
Philip Zeyliger99a9a022025-04-27 15:15:25 +000066func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content Content) {
Earl Lee2e463fb2025-04-17 11:22:22 -070067}
Philip Zeyliger99a9a022025-04-27 15:15:25 +000068
69func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content Content, result *string, err error) {
70}
71
72func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *MessageResponse) {
73}
74func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *Message) {}
Earl Lee2e463fb2025-04-17 11:22:22 -070075
76type Content struct {
77 // TODO: image support?
78 // https://docs.anthropic.com/en/api/messages
79 ID string `json:"id,omitempty"`
80 Type string `json:"type,omitempty"`
81 Text string `json:"text,omitempty"`
82
83 // for thinking
84 Thinking string `json:"thinking,omitempty"`
85 Data string `json:"data,omitempty"` // for redacted_thinking
86 Signature string `json:"signature,omitempty"` // for thinking
87
88 // for tool_use
89 ToolName string `json:"name,omitempty"`
90 ToolInput json.RawMessage `json:"input,omitempty"`
91
92 // for tool_result
93 ToolUseID string `json:"tool_use_id,omitempty"`
94 ToolError bool `json:"is_error,omitempty"`
95 ToolResult string `json:"content,omitempty"`
96
97 // timing information for tool_result; not sent to Claude
98 StartTime *time.Time `json:"-"`
99 EndTime *time.Time `json:"-"`
100
101 CacheControl json.RawMessage `json:"cache_control,omitempty"`
102}
103
104func StringContent(s string) Content {
105 return Content{Type: ContentTypeText, Text: s}
106}
107
108// Message represents a message in the conversation.
109type Message struct {
110 Role string `json:"role"`
111 Content []Content `json:"content"`
112 ToolUse *ToolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
113}
114
115// ToolUse represents a tool use in the message content.
116type ToolUse struct {
117 ID string `json:"id"`
118 Name string `json:"name"`
119}
120
121// Tool represents a tool available to Claude.
122type Tool struct {
123 Name string `json:"name"`
124 // Type is used by the text editor tool; see
125 // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
126 Type string `json:"type,omitempty"`
127 Description string `json:"description,omitempty"`
128 InputSchema json.RawMessage `json:"input_schema,omitempty"`
129
130 // The Run function is automatically called when the tool is used.
131 // Run functions may be called concurrently with each other and themselves.
132 // The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
133 // The outputs from Run will be sent back to Claude.
134 // If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
135 // ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
136 Run func(ctx context.Context, input json.RawMessage) (string, error) `json:"-"`
137}
138
139var ErrDoNotRespond = errors.New("do not respond")
140
141// Usage represents the billing and rate-limit usage.
142type Usage struct {
143 InputTokens uint64 `json:"input_tokens"`
144 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
145 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
146 OutputTokens uint64 `json:"output_tokens"`
147 CostUSD float64 `json:"cost_usd"`
148}
149
150func (u *Usage) Add(other Usage) {
151 u.InputTokens += other.InputTokens
152 u.CacheCreationInputTokens += other.CacheCreationInputTokens
153 u.CacheReadInputTokens += other.CacheReadInputTokens
154 u.OutputTokens += other.OutputTokens
155 u.CostUSD += other.CostUSD
156}
157
158func (u *Usage) String() string {
159 return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
160}
161
162func (u *Usage) IsZero() bool {
163 return *u == Usage{}
164}
165
166func (u *Usage) Attr() slog.Attr {
167 return slog.Group("usage",
168 slog.Uint64("input_tokens", u.InputTokens),
169 slog.Uint64("output_tokens", u.OutputTokens),
170 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
171 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
172 )
173}
174
175type ErrorResponse struct {
176 Type string `json:"type"`
177 Message string `json:"message"`
178}
179
180// MessageResponse represents the response from the message API.
181type MessageResponse struct {
182 ID string `json:"id"`
183 Type string `json:"type"`
184 Role string `json:"role"`
185 Model string `json:"model"`
186 Content []Content `json:"content"`
187 StopReason string `json:"stop_reason"`
188 StopSequence *string `json:"stop_sequence,omitempty"`
189 Usage Usage `json:"usage"`
190 StartTime *time.Time `json:"start_time,omitempty"`
191 EndTime *time.Time `json:"end_time,omitempty"`
192}
193
194func (m *MessageResponse) ToMessage() Message {
195 return Message{
196 Role: m.Role,
197 Content: m.Content,
198 }
199}
200
201func (m *MessageResponse) StopSequenceString() string {
202 if m.StopSequence == nil {
203 return ""
204 }
205 return *m.StopSequence
206}
207
208const (
209 ToolChoiceTypeAuto = "auto" // default
210 ToolChoiceTypeAny = "any" // any tool, but must use one
211 ToolChoiceTypeNone = "none" // no tools allowed
212 ToolChoiceTypeTool = "tool" // must use the tool specified in the Name field
213)
214
215type ToolChoice struct {
216 Type string `json:"type"`
217 Name string `json:"name,omitempty"`
218}
219
220// https://docs.anthropic.com/en/api/messages#body-system
221type SystemContent struct {
222 Text string `json:"text,omitempty"`
223 Type string `json:"type,omitempty"`
224 CacheControl json.RawMessage `json:"cache_control,omitempty"`
225}
226
227// MessageRequest represents the request payload for creating a message.
228type MessageRequest struct {
229 Model string `json:"model"`
230 Messages []Message `json:"messages"`
231 ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
232 MaxTokens int `json:"max_tokens"`
233 Tools []*Tool `json:"tools,omitempty"`
234 Stream bool `json:"stream,omitempty"`
235 System []SystemContent `json:"system,omitempty"`
236 Temperature float64 `json:"temperature,omitempty"`
237 TopK int `json:"top_k,omitempty"`
238 TopP float64 `json:"top_p,omitempty"`
239 StopSequences []string `json:"stop_sequences,omitempty"`
240
241 TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
242}
243
244const dumpText = false // debugging toggle to see raw communications with Claude
245
246// createMessage sends a request to the Anthropic message API to create a message.
247func createMessage(ctx context.Context, httpc *http.Client, url, apiKey string, request *MessageRequest) (*MessageResponse, error) {
248 var payload []byte
249 var err error
250 if dumpText || testing.Testing() {
251 payload, err = json.MarshalIndent(request, "", " ")
252 } else {
253 payload, err = json.Marshal(request)
254 payload = append(payload, '\n')
255 }
256 if err != nil {
257 return nil, err
258 }
259
260 if false {
261 fmt.Printf("claude request payload:\n%s\n", payload)
262 }
263
264 backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
265 largerMaxTokens := false
266 var partialUsage Usage
267
268 // retry loop
269 for attempts := 0; ; attempts++ {
270 if dumpText {
271 fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
272 }
273 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
274 if err != nil {
275 return nil, err
276 }
277
278 req.Header.Set("Content-Type", "application/json")
279 req.Header.Set("X-API-Key", apiKey)
280 req.Header.Set("Anthropic-Version", "2023-06-01")
281
282 features := []string{}
283
284 if request.TokenEfficientToolUse {
285 features = append(features, "token-efficient-tool-use-2025-02-19")
286 }
287 if largerMaxTokens {
288 features = append(features, "output-128k-2025-02-19")
289 request.MaxTokens = 128 * 1024
290 }
291 if len(features) > 0 {
292 req.Header.Set("anthropic-beta", strings.Join(features, ","))
293 }
294
295 resp, err := httpc.Do(req)
296 if err != nil {
297 return nil, err
298 }
299 buf, _ := io.ReadAll(resp.Body)
300 resp.Body.Close()
301
302 switch {
303 case resp.StatusCode == http.StatusOK:
304 if dumpText {
305 fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
306 }
307 var response MessageResponse
308 err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
309 if err != nil {
310 return nil, err
311 }
312 if response.StopReason == StopReasonMaxTokens && !largerMaxTokens {
313 fmt.Printf("Retrying Anthropic API call with larger max tokens size.")
314 // Retry with more output tokens.
315 largerMaxTokens = true
316 response.Usage.CostUSD = response.TotalDollars()
317 partialUsage = response.Usage
318 continue
319 }
320
321 // Calculate and set the cost_usd field
322 if largerMaxTokens {
323 response.Usage.Add(partialUsage)
324 }
325 response.Usage.CostUSD = response.TotalDollars()
326
327 return &response, nil
328 case resp.StatusCode >= 500 && resp.StatusCode < 600:
329 // overloaded or unhappy, in one form or another
330 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
331 slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
332 time.Sleep(sleep)
333 case resp.StatusCode == 429:
334 // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
335 // and then add some additional time for backoff.
336 sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
337 slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
338 // case resp.StatusCode == 400:
339 // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
340 default:
341 return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
342 }
343 }
344}
345
346// A Convo is a managed conversation with Claude.
347// It automatically manages the state of the conversation,
348// including appending messages send/received,
349// calling tools and sending their results,
350// tracking usage, etc.
351//
352// Exported fields must not be altered concurrently with calling any method on Convo.
353// Typical usage is to configure a Convo once before using it.
354type Convo struct {
355 // ID is a unique ID for the conversation
356 ID string
357 // Ctx is the context for the entire conversation.
358 Ctx context.Context
359 // HTTPC is the HTTP client for the conversation.
360 HTTPC *http.Client
361 // URL is the remote messages URL to dial.
362 URL string
363 // APIKey is the API key for the conversation.
364 APIKey string
365 // Model is the model for the conversation.
366 Model string
367 // MaxTokens is the max tokens for each response in the conversation.
368 MaxTokens int
369 // Tools are the tools available during the conversation.
370 Tools []*Tool
371 // SystemPrompt is the system prompt for the conversation.
372 SystemPrompt string
373 // PromptCaching indicates whether to use Anthropic's prompt caching.
374 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
375 // for the documentation. At request send time, we set the cache_control field on the
376 // last message. We also cache the system prompt.
377 // Default: true.
378 PromptCaching bool
379 // ToolUseOnly indicates whether Claude may only use tools during this conversation.
380 // TODO: add more fine-grained control over tool use?
381 ToolUseOnly bool
382 // Parent is the parent conversation, if any.
383 // It is non-nil for "subagent" calls.
384 // It is set automatically when calling SubConvo,
385 // and usually should not be set manually.
386 Parent *Convo
387 // Budget is the budget for this conversation (and all sub-conversations).
388 // The Conversation DOES NOT automatically enforce the budget.
389 // It is up to the caller to call OverBudget() as appropriate.
390 Budget Budget
391
392 // messages tracks the messages so far in the conversation.
393 messages []Message
394
395 // Listener receives messages being sent.
396 Listener Listener
397
398 muToolUseCancel *sync.Mutex
399 toolUseCancel map[string]context.CancelCauseFunc
400
401 // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
402 mu *sync.Mutex
403 // usage tracks usage for this conversation and all sub-conversations.
404 usage *CumulativeUsage
405}
406
407// newConvoID generates a new 8-byte random id.
408// The uniqueness/collision requirements here are very low.
409// They are not global identifiers,
410// just enough to distinguish different convos in a single session.
411func newConvoID() string {
412 u1 := rand.Uint32()
413 s := crock32.Encode(uint64(u1))
414 if len(s) < 7 {
415 s += strings.Repeat("0", 7-len(s))
416 }
417 return s[:3] + "-" + s[3:]
418}
419
420// NewConvo creates a new conversation with Claude with sensible defaults.
421// ctx is the context for the entire conversation.
422func NewConvo(ctx context.Context, apiKey string) *Convo {
423 id := newConvoID()
424 return &Convo{
425 Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
426 HTTPC: http.DefaultClient,
427 URL: DefaultURL,
428 APIKey: apiKey,
429 Model: DefaultModel,
430 MaxTokens: DefaultMaxTokens,
431 PromptCaching: true,
432 usage: newUsage(),
433 Listener: &NoopListener{},
434 ID: id,
435 muToolUseCancel: &sync.Mutex{},
436 toolUseCancel: map[string]context.CancelCauseFunc{},
437 mu: &sync.Mutex{},
438 }
439}
440
441// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
442// (This propagates context for cancellation, HTTP client, API key, etc.)
443// The sub-conversation shares no messages with the parent conversation.
444// It does not inherit tools from the parent conversation.
445func (c *Convo) SubConvo() *Convo {
446 id := newConvoID()
447 return &Convo{
448 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
449 HTTPC: c.HTTPC,
Josh Bleecher Snyder6a50b182025-04-24 11:07:41 -0700450 URL: c.URL,
Earl Lee2e463fb2025-04-17 11:22:22 -0700451 APIKey: c.APIKey,
452 Model: c.Model,
453 MaxTokens: c.MaxTokens,
454 PromptCaching: c.PromptCaching,
455 Parent: c,
456 // For convenience, sub-convo usage shares tool uses map with parent,
457 // all other fields separate, propagated in AddResponse
458 usage: newUsageWithSharedToolUses(c.usage),
459 mu: c.mu,
460 Listener: c.Listener,
461 ID: id,
462 // Do not copy Budget. Each budget is independent,
463 // and OverBudget checks whether any ancestor is over budget.
464 }
465}
466
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700467func (c *Convo) SubConvoWithHistory() *Convo {
468 id := newConvoID()
469 return &Convo{
470 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
471 HTTPC: c.HTTPC,
472 URL: c.URL,
473 APIKey: c.APIKey,
474 Model: c.Model,
475 MaxTokens: c.MaxTokens,
476 PromptCaching: c.PromptCaching,
477 Parent: c,
478 // For convenience, sub-convo usage shares tool uses map with parent,
479 // all other fields separate, propagated in AddResponse
480 usage: newUsageWithSharedToolUses(c.usage),
481 mu: c.mu,
482 Listener: c.Listener,
483 ID: id,
484 // Do not copy Budget. Each budget is independent,
485 // and OverBudget checks whether any ancestor is over budget.
486 messages: slices.Clone(c.messages),
487 }
488}
489
Earl Lee2e463fb2025-04-17 11:22:22 -0700490// Depth reports how many "sub-conversations" deep this conversation is.
491// That it, it walks up parents until it finds a root.
492func (c *Convo) Depth() int {
493 x := c
494 var depth int
495 for x.Parent != nil {
496 x = x.Parent
497 depth++
498 }
499 return depth
500}
501
502// SendUserTextMessage sends a text message to Claude in this conversation.
503// otherContents contains additional contents to send with the message, usually tool results.
504func (c *Convo) SendUserTextMessage(s string, otherContents ...Content) (*MessageResponse, error) {
505 contents := slices.Clone(otherContents)
506 if s != "" {
507 contents = append(contents, Content{Type: ContentTypeText, Text: s})
508 }
509 msg := Message{
510 Role: MessageRoleUser,
511 Content: contents,
512 }
513 return c.SendMessage(msg)
514}
515
516func (c *Convo) messageRequest(msg Message) *MessageRequest {
517 system := []SystemContent{}
518 if c.SystemPrompt != "" {
519 var d SystemContent
520 d = SystemContent{Type: "text", Text: c.SystemPrompt}
521 if c.PromptCaching {
522 d.CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
523 }
524 system = []SystemContent{d}
525 }
526
527 // Claude is happy to return an empty response in response to our Done() call,
528 // and, if so, you'll see something like:
529 // API request failed with status 400 Bad Request
530 // {"type":"error","error": {"type":"invalid_request_error",
531 // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
532 // So, we filter out those empty messages.
533 var nonEmptyMessages []Message
534 for _, m := range c.messages {
535 if len(m.Content) > 0 {
536 nonEmptyMessages = append(nonEmptyMessages, m)
537 }
538 }
539
540 mr := &MessageRequest{
541 Model: c.Model,
542 Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
543 System: system,
544 Tools: c.Tools,
545 MaxTokens: c.MaxTokens,
546 }
547 if c.ToolUseOnly {
548 mr.ToolChoice = &ToolChoice{Type: ToolChoiceTypeAny}
549 }
550 return mr
551}
552
553func (c *Convo) findTool(name string) (*Tool, error) {
554 for _, tool := range c.Tools {
555 if tool.Name == name {
556 return tool, nil
557 }
558 }
559 return nil, fmt.Errorf("tool %q not found", name)
560}
561
562// insertMissingToolResults adds error results for tool uses that were requested
563// but not included in the message, which can happen in error paths like "out of budget."
564// We only insert these if there were no tool responses at all, since an incorrect
565// number of tool results would be a programmer error. Mutates inputs.
566func (c *Convo) insertMissingToolResults(mr *MessageRequest, msg *Message) {
567 if len(mr.Messages) < 2 {
568 return
569 }
570 prev := mr.Messages[len(mr.Messages)-2]
571 var toolUsePrev int
572 for _, c := range prev.Content {
573 if c.Type == ContentTypeToolUse {
574 toolUsePrev++
575 }
576 }
577 if toolUsePrev == 0 {
578 return
579 }
580 var toolUseCurrent int
581 for _, c := range msg.Content {
582 if c.Type == ContentTypeToolResult {
583 toolUseCurrent++
584 }
585 }
586 if toolUseCurrent != 0 {
587 return
588 }
589 var prefix []Content
590 for _, part := range prev.Content {
591 if part.Type != ContentTypeToolUse {
592 continue
593 }
594 content := Content{
595 Type: ContentTypeToolResult,
596 ToolUseID: part.ID,
597 ToolError: true,
598 ToolResult: "not executed; retry possible",
599 }
600 prefix = append(prefix, content)
601 msg.Content = append(prefix, msg.Content...)
602 mr.Messages[len(mr.Messages)-1].Content = msg.Content
603 }
604 slog.DebugContext(c.Ctx, "inserted missing tool results")
605}
606
607// SendMessage sends a message to Claude.
608// The conversation records (internally) all messages succesfully sent and received.
609func (c *Convo) SendMessage(msg Message) (*MessageResponse, error) {
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000610 id := ulid.Make().String()
Earl Lee2e463fb2025-04-17 11:22:22 -0700611 mr := c.messageRequest(msg)
612 var lastMessage *Message
613 if c.PromptCaching {
614 lastMessage = &mr.Messages[len(mr.Messages)-1]
615 if len(lastMessage.Content) > 0 {
616 lastMessage.Content[len(lastMessage.Content)-1].CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
617 }
618 }
619 defer func() {
620 if lastMessage == nil {
621 return
622 }
623 if len(lastMessage.Content) > 0 {
624 lastMessage.Content[len(lastMessage.Content)-1].CacheControl = []byte{}
625 }
626 }()
627 c.insertMissingToolResults(mr, &msg)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000628 c.Listener.OnRequest(c.Ctx, c, id, &msg)
Earl Lee2e463fb2025-04-17 11:22:22 -0700629
630 startTime := time.Now()
631 resp, err := createMessage(c.Ctx, c.HTTPC, c.URL, c.APIKey, mr)
632 if resp != nil {
633 resp.StartTime = &startTime
634 endTime := time.Now()
635 resp.EndTime = &endTime
636 }
637
638 if err != nil {
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000639 c.Listener.OnResponse(c.Ctx, c, id, nil)
Earl Lee2e463fb2025-04-17 11:22:22 -0700640 return nil, err
641 }
642 c.messages = append(c.messages, msg, resp.ToMessage())
643 // Propagate usage to all ancestors (including us).
644 for x := c; x != nil; x = x.Parent {
645 x.usage.AddResponse(resp)
646 }
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000647 c.Listener.OnResponse(c.Ctx, c, id, resp)
Earl Lee2e463fb2025-04-17 11:22:22 -0700648 return resp, err
649}
650
651type toolCallInfoKeyType string
652
653var toolCallInfoKey toolCallInfoKeyType
654
655type ToolCallInfo struct {
656 ToolUseID string
657 Convo *Convo
658}
659
660func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
661 v := ctx.Value(toolCallInfoKey)
662 i, _ := v.(ToolCallInfo)
663 return i
664}
665
666func (c *Convo) ToolResultCancelContents(resp *MessageResponse) ([]Content, error) {
667 if resp.StopReason != StopReasonToolUse {
668 return nil, nil
669 }
670 var toolResults []Content
671
672 for _, part := range resp.Content {
673 if part.Type != ContentTypeToolUse {
674 continue
675 }
676 c.incrementToolUse(part.ToolName)
677
678 content := Content{
679 Type: ContentTypeToolResult,
680 ToolUseID: part.ID,
681 }
682
683 content.ToolError = true
684 content.ToolResult = "user canceled this too_use"
685 toolResults = append(toolResults, content)
686 }
687 return toolResults, nil
688}
689
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700690// GetID returns the conversation ID
691func (c *Convo) GetID() string {
692 return c.ID
693}
694
Earl Lee2e463fb2025-04-17 11:22:22 -0700695func (c *Convo) CancelToolUse(toolUseID string, err error) error {
696 c.muToolUseCancel.Lock()
697 defer c.muToolUseCancel.Unlock()
698 cancel, ok := c.toolUseCancel[toolUseID]
699 if !ok {
700 return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
701 }
702 delete(c.toolUseCancel, toolUseID)
703 cancel(err)
704 return nil
705}
706
707func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
708 c.muToolUseCancel.Lock()
709 defer c.muToolUseCancel.Unlock()
710 ctx, cancel := context.WithCancelCause(ctx)
711 c.toolUseCancel[toolUseID] = cancel
712 return ctx, func() { c.CancelToolUse(toolUseID, nil) }
713}
714
715// ToolResultContents runs all tool uses requested by the response and returns their results.
716// Cancelling ctx will cancel any running tool calls.
717func (c *Convo) ToolResultContents(ctx context.Context, resp *MessageResponse) ([]Content, error) {
718 if resp.StopReason != StopReasonToolUse {
719 return nil, nil
720 }
721 // Extract all tool calls from the response, call the tools, and gather the results.
722 var wg sync.WaitGroup
723 toolResultC := make(chan Content, len(resp.Content))
724 for _, part := range resp.Content {
725 if part.Type != ContentTypeToolUse {
726 continue
727 }
728 c.incrementToolUse(part.ToolName)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000729 startTime := time.Now()
730
731 c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, Content{
732 Type: ContentTypeToolUse,
733 ToolUseID: part.ID,
734 StartTime: &startTime,
735 })
736
Earl Lee2e463fb2025-04-17 11:22:22 -0700737 wg.Add(1)
738 go func() {
739 defer wg.Done()
740
Earl Lee2e463fb2025-04-17 11:22:22 -0700741 content := Content{
742 Type: ContentTypeToolResult,
743 ToolUseID: part.ID,
744 StartTime: &startTime,
745 }
746 sendErr := func(err error) {
747 // Record end time
748 endTime := time.Now()
749 content.EndTime = &endTime
750
751 content.ToolError = true
752 content.ToolResult = err.Error()
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000753 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
Earl Lee2e463fb2025-04-17 11:22:22 -0700754 toolResultC <- content
755 }
756 sendRes := func(res string) {
757 // Record end time
758 endTime := time.Now()
759 content.EndTime = &endTime
760
761 content.ToolResult = res
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000762 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &res, nil)
Earl Lee2e463fb2025-04-17 11:22:22 -0700763 toolResultC <- content
764 }
765
766 tool, err := c.findTool(part.ToolName)
767 if err != nil {
768 sendErr(err)
769 return
770 }
771 // Create a new context for just this tool_use call, and register its
772 // cancel function so that it can be canceled individually.
773 toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
774 defer cancel()
775 // TODO: move this into newToolUseContext?
776 toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
777 toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
778 if errors.Is(err, ErrDoNotRespond) {
779 return
780 }
781 if toolUseCtx.Err() != nil {
782 sendErr(context.Cause(toolUseCtx))
783 return
784 }
785
786 if err != nil {
787 sendErr(err)
788 return
789 }
790 sendRes(toolResult)
791 }()
792 }
793 wg.Wait()
794 close(toolResultC)
795 var toolResults []Content
796 for toolResult := range toolResultC {
797 toolResults = append(toolResults, toolResult)
798 }
799 if ctx.Err() != nil {
800 return nil, ctx.Err()
801 }
802 return toolResults, nil
803}
804
805func (c *Convo) incrementToolUse(name string) {
806 c.mu.Lock()
807 defer c.mu.Unlock()
808
809 c.usage.ToolUses[name]++
810}
811
812// ContentsAttr returns contents as a slog.Attr.
813// It is meant for logging.
814func ContentsAttr(contents []Content) slog.Attr {
815 var contentAttrs []any // slog.Attr
816 for _, content := range contents {
817 var attrs []any // slog.Attr
818 switch content.Type {
819 case ContentTypeText:
820 attrs = append(attrs, slog.String("text", content.Text))
821 case ContentTypeToolUse:
822 attrs = append(attrs, slog.String("tool_name", content.ToolName))
823 attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
824 case ContentTypeToolResult:
825 attrs = append(attrs, slog.String("tool_result", content.ToolResult))
826 attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
827 case ContentTypeThinking:
828 attrs = append(attrs, slog.String("thinking", content.Text))
829 default:
830 attrs = append(attrs, slog.String("unknown_content_type", content.Type))
831 attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
832 }
833 contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
834 }
835 return slog.Group("contents", contentAttrs...)
836}
837
838// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
839// It panics if the schema is invalid.
840func MustSchema(schema string) json.RawMessage {
841 // TODO: validate schema, for now just make sure it's valid JSON
842 schema = strings.TrimSpace(schema)
843 bytes := []byte(schema)
844 if !json.Valid(bytes) {
845 panic("invalid JSON schema: " + schema)
846 }
847 return json.RawMessage(bytes)
848}
849
850// cents per million tokens
851// (not dollars because i'm twitchy about using floats for money)
852type centsPer1MTokens struct {
853 Input uint64
854 Output uint64
855 CacheRead uint64
856 CacheCreation uint64
857}
858
859// https://www.anthropic.com/pricing#anthropic-api
860var modelCost = map[string]centsPer1MTokens{
861 Claude37Sonnet: {
862 Input: 300, // $3
863 Output: 1500, // $15
864 CacheRead: 30, // $0.30
865 CacheCreation: 375, // $3.75
866 },
867 Claude35Haiku: {
868 Input: 80, // $0.80
869 Output: 400, // $4.00
870 CacheRead: 8, // $0.08
871 CacheCreation: 100, // $1.00
872 },
873 Claude35Sonnet: {
874 Input: 300, // $3
875 Output: 1500, // $15
876 CacheRead: 30, // $0.30
877 CacheCreation: 375, // $3.75
878 },
879}
880
881// TotalDollars returns the total cost to obtain this response, in dollars.
882func (mr *MessageResponse) TotalDollars() float64 {
883 cpm, ok := modelCost[mr.Model]
884 if !ok {
885 panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
886 }
887 use := mr.Usage
888 megaCents := use.InputTokens*cpm.Input +
889 use.OutputTokens*cpm.Output +
890 use.CacheReadInputTokens*cpm.CacheRead +
891 use.CacheCreationInputTokens*cpm.CacheCreation
892 cents := float64(megaCents) / 1_000_000.0
893 return cents / 100.0
894}
895
896func newUsage() *CumulativeUsage {
897 return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
898}
899
900func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
901 return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
902}
903
904// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
905type CumulativeUsage struct {
906 StartTime time.Time `json:"start_time"`
907 Responses uint64 `json:"messages"` // count of responses
908 InputTokens uint64 `json:"input_tokens"`
909 OutputTokens uint64 `json:"output_tokens"`
910 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
911 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
912 TotalCostUSD float64 `json:"total_cost_usd"`
913 ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
914}
915
916func (u *CumulativeUsage) Clone() CumulativeUsage {
917 v := *u
918 v.ToolUses = maps.Clone(u.ToolUses)
919 return v
920}
921
922func (c *Convo) CumulativeUsage() CumulativeUsage {
923 if c == nil {
924 return CumulativeUsage{}
925 }
926 c.mu.Lock()
927 defer c.mu.Unlock()
928 return c.usage.Clone()
929}
930
931func (u *CumulativeUsage) WallTime() time.Duration {
932 return time.Since(u.StartTime)
933}
934
935func (u *CumulativeUsage) DollarsPerHour() float64 {
936 hours := u.WallTime().Hours()
937 if hours == 0 {
938 return 0
939 }
940 return u.TotalCostUSD / hours
941}
942
943func (u *CumulativeUsage) AddResponse(resp *MessageResponse) {
944 usage := resp.Usage
945 u.Responses++
946 u.InputTokens += usage.InputTokens
947 u.OutputTokens += usage.OutputTokens
948 u.CacheReadInputTokens += usage.CacheReadInputTokens
949 u.CacheCreationInputTokens += usage.CacheCreationInputTokens
950 u.TotalCostUSD += resp.TotalDollars()
951}
952
Josh Bleecher Snyder35889972025-04-24 20:48:16 +0000953// TotalInputTokens returns the grand total cumulative input tokens in u.
954func (u *CumulativeUsage) TotalInputTokens() uint64 {
955 return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
956}
957
Earl Lee2e463fb2025-04-17 11:22:22 -0700958// Attr returns the cumulative usage as a slog.Attr with key "usage".
959func (u CumulativeUsage) Attr() slog.Attr {
960 elapsed := time.Since(u.StartTime)
961 return slog.Group("usage",
962 slog.Duration("wall_time", elapsed),
963 slog.Uint64("responses", u.Responses),
964 slog.Uint64("input_tokens", u.InputTokens),
965 slog.Uint64("output_tokens", u.OutputTokens),
966 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
967 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
968 slog.Float64("total_cost_usd", u.TotalCostUSD),
969 slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
970 slog.Any("tool_uses", maps.Clone(u.ToolUses)),
971 )
972}
973
974// A Budget represents the maximum amount of resources that may be spent on a conversation.
975// Note that the default (zero) budget is unlimited.
976type Budget struct {
977 MaxResponses uint64 // if > 0, max number of iterations (=responses)
978 MaxDollars float64 // if > 0, max dollars that may be spent
979 MaxWallTime time.Duration // if > 0, max wall time that may be spent
980}
981
982// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
983// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
984func (c *Convo) OverBudget() error {
985 for x := c; x != nil; x = x.Parent {
986 if err := x.overBudget(); err != nil {
987 return err
988 }
989 }
990 return nil
991}
992
993// ResetBudget sets the budget to the passed in budget and
994// adjusts it by what's been used so far.
995func (c *Convo) ResetBudget(budget Budget) {
996 c.Budget = budget
997 if c.Budget.MaxDollars > 0 {
998 c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
999 }
1000 if c.Budget.MaxResponses > 0 {
1001 c.Budget.MaxResponses += c.CumulativeUsage().Responses
1002 }
1003 if c.Budget.MaxWallTime > 0 {
1004 c.Budget.MaxWallTime += c.usage.WallTime()
1005 }
1006}
1007
1008func (c *Convo) overBudget() error {
1009 usage := c.CumulativeUsage()
1010 // TODO: stop before we exceed the budget instead of after?
1011 // Top priority is money, then time, then response count.
1012 var err error
1013 cont := "Continuing to chat will reset the budget."
1014 if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
1015 err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
1016 }
1017 if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
1018 err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
1019 }
1020 if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
1021 err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
1022 }
1023 return err
1024}