blob: 0829286aa8e5c10b5a0aa267c57b250ff8068669 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package ant
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "maps"
12 "math/rand/v2"
13 "net/http"
14 "slices"
15 "strings"
16 "sync"
17 "testing"
18 "time"
19
Philip Zeyliger99a9a022025-04-27 15:15:25 +000020 "github.com/oklog/ulid/v2"
Earl Lee2e463fb2025-04-17 11:22:22 -070021 "github.com/richardlehane/crock32"
22 "sketch.dev/skribe"
23)
24
25const (
26 DefaultModel = Claude37Sonnet
27 // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
28 // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
29 DefaultMaxTokens = 8192
30 DefaultURL = "https://api.anthropic.com/v1/messages"
31)
32
33const (
34 Claude35Sonnet = "claude-3-5-sonnet-20241022"
35 Claude35Haiku = "claude-3-5-haiku-20241022"
36 Claude37Sonnet = "claude-3-7-sonnet-20250219"
37)
38
39const (
40 MessageRoleUser = "user"
41 MessageRoleAssistant = "assistant"
42
43 ContentTypeText = "text"
44 ContentTypeThinking = "thinking"
45 ContentTypeRedactedThinking = "redacted_thinking"
46 ContentTypeToolUse = "tool_use"
47 ContentTypeToolResult = "tool_result"
48
49 StopReasonStopSequence = "stop_sequence"
50 StopReasonMaxTokens = "max_tokens"
51 StopReasonEndTurn = "end_turn"
52 StopReasonToolUse = "tool_use"
53)
54
55type Listener interface {
56 // TODO: Content is leaking an anthropic API; should we avoid it?
57 // TODO: Where should we include start/end time and usage?
Philip Zeyliger99a9a022025-04-27 15:15:25 +000058 OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content Content)
59 OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content Content, result *string, err error)
60 OnRequest(ctx context.Context, convo *Convo, requestID string, msg *Message)
61 OnResponse(ctx context.Context, convo *Convo, requestID string, msg *MessageResponse)
Earl Lee2e463fb2025-04-17 11:22:22 -070062}
63
64type NoopListener struct{}
65
Philip Zeyliger99a9a022025-04-27 15:15:25 +000066func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content Content) {
Earl Lee2e463fb2025-04-17 11:22:22 -070067}
Philip Zeyliger99a9a022025-04-27 15:15:25 +000068
69func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content Content, result *string, err error) {
70}
71
72func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *MessageResponse) {
73}
74func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *Message) {}
Earl Lee2e463fb2025-04-17 11:22:22 -070075
76type Content struct {
77 // TODO: image support?
78 // https://docs.anthropic.com/en/api/messages
79 ID string `json:"id,omitempty"`
80 Type string `json:"type,omitempty"`
81 Text string `json:"text,omitempty"`
82
83 // for thinking
84 Thinking string `json:"thinking,omitempty"`
85 Data string `json:"data,omitempty"` // for redacted_thinking
86 Signature string `json:"signature,omitempty"` // for thinking
87
88 // for tool_use
89 ToolName string `json:"name,omitempty"`
90 ToolInput json.RawMessage `json:"input,omitempty"`
91
92 // for tool_result
93 ToolUseID string `json:"tool_use_id,omitempty"`
94 ToolError bool `json:"is_error,omitempty"`
95 ToolResult string `json:"content,omitempty"`
96
97 // timing information for tool_result; not sent to Claude
98 StartTime *time.Time `json:"-"`
99 EndTime *time.Time `json:"-"`
100
101 CacheControl json.RawMessage `json:"cache_control,omitempty"`
102}
103
104func StringContent(s string) Content {
105 return Content{Type: ContentTypeText, Text: s}
106}
107
108// Message represents a message in the conversation.
109type Message struct {
110 Role string `json:"role"`
111 Content []Content `json:"content"`
112 ToolUse *ToolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
113}
114
115// ToolUse represents a tool use in the message content.
116type ToolUse struct {
117 ID string `json:"id"`
118 Name string `json:"name"`
119}
120
121// Tool represents a tool available to Claude.
122type Tool struct {
123 Name string `json:"name"`
124 // Type is used by the text editor tool; see
125 // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
126 Type string `json:"type,omitempty"`
127 Description string `json:"description,omitempty"`
128 InputSchema json.RawMessage `json:"input_schema,omitempty"`
129
130 // The Run function is automatically called when the tool is used.
131 // Run functions may be called concurrently with each other and themselves.
132 // The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
133 // The outputs from Run will be sent back to Claude.
134 // If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
135 // ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
136 Run func(ctx context.Context, input json.RawMessage) (string, error) `json:"-"`
137}
138
139var ErrDoNotRespond = errors.New("do not respond")
140
141// Usage represents the billing and rate-limit usage.
142type Usage struct {
143 InputTokens uint64 `json:"input_tokens"`
144 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
145 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
146 OutputTokens uint64 `json:"output_tokens"`
147 CostUSD float64 `json:"cost_usd"`
148}
149
150func (u *Usage) Add(other Usage) {
151 u.InputTokens += other.InputTokens
152 u.CacheCreationInputTokens += other.CacheCreationInputTokens
153 u.CacheReadInputTokens += other.CacheReadInputTokens
154 u.OutputTokens += other.OutputTokens
155 u.CostUSD += other.CostUSD
156}
157
158func (u *Usage) String() string {
159 return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
160}
161
162func (u *Usage) IsZero() bool {
163 return *u == Usage{}
164}
165
166func (u *Usage) Attr() slog.Attr {
167 return slog.Group("usage",
168 slog.Uint64("input_tokens", u.InputTokens),
169 slog.Uint64("output_tokens", u.OutputTokens),
170 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
171 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
172 )
173}
174
175type ErrorResponse struct {
176 Type string `json:"type"`
177 Message string `json:"message"`
178}
179
180// MessageResponse represents the response from the message API.
181type MessageResponse struct {
182 ID string `json:"id"`
183 Type string `json:"type"`
184 Role string `json:"role"`
185 Model string `json:"model"`
186 Content []Content `json:"content"`
187 StopReason string `json:"stop_reason"`
188 StopSequence *string `json:"stop_sequence,omitempty"`
189 Usage Usage `json:"usage"`
190 StartTime *time.Time `json:"start_time,omitempty"`
191 EndTime *time.Time `json:"end_time,omitempty"`
192}
193
194func (m *MessageResponse) ToMessage() Message {
195 return Message{
196 Role: m.Role,
197 Content: m.Content,
198 }
199}
200
201func (m *MessageResponse) StopSequenceString() string {
202 if m.StopSequence == nil {
203 return ""
204 }
205 return *m.StopSequence
206}
207
208const (
209 ToolChoiceTypeAuto = "auto" // default
210 ToolChoiceTypeAny = "any" // any tool, but must use one
211 ToolChoiceTypeNone = "none" // no tools allowed
212 ToolChoiceTypeTool = "tool" // must use the tool specified in the Name field
213)
214
215type ToolChoice struct {
216 Type string `json:"type"`
217 Name string `json:"name,omitempty"`
218}
219
220// https://docs.anthropic.com/en/api/messages#body-system
221type SystemContent struct {
222 Text string `json:"text,omitempty"`
223 Type string `json:"type,omitempty"`
224 CacheControl json.RawMessage `json:"cache_control,omitempty"`
225}
226
227// MessageRequest represents the request payload for creating a message.
228type MessageRequest struct {
229 Model string `json:"model"`
230 Messages []Message `json:"messages"`
231 ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
232 MaxTokens int `json:"max_tokens"`
233 Tools []*Tool `json:"tools,omitempty"`
234 Stream bool `json:"stream,omitempty"`
235 System []SystemContent `json:"system,omitempty"`
236 Temperature float64 `json:"temperature,omitempty"`
237 TopK int `json:"top_k,omitempty"`
238 TopP float64 `json:"top_p,omitempty"`
239 StopSequences []string `json:"stop_sequences,omitempty"`
240
241 TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
242}
243
244const dumpText = false // debugging toggle to see raw communications with Claude
245
246// createMessage sends a request to the Anthropic message API to create a message.
247func createMessage(ctx context.Context, httpc *http.Client, url, apiKey string, request *MessageRequest) (*MessageResponse, error) {
248 var payload []byte
249 var err error
250 if dumpText || testing.Testing() {
251 payload, err = json.MarshalIndent(request, "", " ")
252 } else {
253 payload, err = json.Marshal(request)
254 payload = append(payload, '\n')
255 }
256 if err != nil {
257 return nil, err
258 }
259
260 if false {
261 fmt.Printf("claude request payload:\n%s\n", payload)
262 }
263
264 backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
265 largerMaxTokens := false
266 var partialUsage Usage
267
268 // retry loop
269 for attempts := 0; ; attempts++ {
270 if dumpText {
271 fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
272 }
273 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
274 if err != nil {
275 return nil, err
276 }
277
278 req.Header.Set("Content-Type", "application/json")
279 req.Header.Set("X-API-Key", apiKey)
280 req.Header.Set("Anthropic-Version", "2023-06-01")
281
282 features := []string{}
283
284 if request.TokenEfficientToolUse {
285 features = append(features, "token-efficient-tool-use-2025-02-19")
286 }
287 if largerMaxTokens {
288 features = append(features, "output-128k-2025-02-19")
289 request.MaxTokens = 128 * 1024
290 }
291 if len(features) > 0 {
292 req.Header.Set("anthropic-beta", strings.Join(features, ","))
293 }
294
295 resp, err := httpc.Do(req)
296 if err != nil {
297 return nil, err
298 }
299 buf, _ := io.ReadAll(resp.Body)
300 resp.Body.Close()
301
302 switch {
303 case resp.StatusCode == http.StatusOK:
304 if dumpText {
305 fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
306 }
307 var response MessageResponse
308 err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
309 if err != nil {
310 return nil, err
311 }
312 if response.StopReason == StopReasonMaxTokens && !largerMaxTokens {
313 fmt.Printf("Retrying Anthropic API call with larger max tokens size.")
314 // Retry with more output tokens.
315 largerMaxTokens = true
316 response.Usage.CostUSD = response.TotalDollars()
317 partialUsage = response.Usage
318 continue
319 }
320
321 // Calculate and set the cost_usd field
322 if largerMaxTokens {
323 response.Usage.Add(partialUsage)
324 }
325 response.Usage.CostUSD = response.TotalDollars()
326
327 return &response, nil
328 case resp.StatusCode >= 500 && resp.StatusCode < 600:
329 // overloaded or unhappy, in one form or another
330 sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
331 slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
332 time.Sleep(sleep)
333 case resp.StatusCode == 429:
334 // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
335 // and then add some additional time for backoff.
336 sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
337 slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
Josh Bleecher Snyder21cbc452025-05-02 10:54:54 -0700338 time.Sleep(sleep)
Earl Lee2e463fb2025-04-17 11:22:22 -0700339 // case resp.StatusCode == 400:
340 // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
341 default:
342 return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
343 }
344 }
345}
346
347// A Convo is a managed conversation with Claude.
348// It automatically manages the state of the conversation,
349// including appending messages send/received,
350// calling tools and sending their results,
351// tracking usage, etc.
352//
353// Exported fields must not be altered concurrently with calling any method on Convo.
354// Typical usage is to configure a Convo once before using it.
355type Convo struct {
356 // ID is a unique ID for the conversation
357 ID string
358 // Ctx is the context for the entire conversation.
359 Ctx context.Context
360 // HTTPC is the HTTP client for the conversation.
361 HTTPC *http.Client
362 // URL is the remote messages URL to dial.
363 URL string
364 // APIKey is the API key for the conversation.
365 APIKey string
366 // Model is the model for the conversation.
367 Model string
368 // MaxTokens is the max tokens for each response in the conversation.
369 MaxTokens int
370 // Tools are the tools available during the conversation.
371 Tools []*Tool
372 // SystemPrompt is the system prompt for the conversation.
373 SystemPrompt string
374 // PromptCaching indicates whether to use Anthropic's prompt caching.
375 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
376 // for the documentation. At request send time, we set the cache_control field on the
377 // last message. We also cache the system prompt.
378 // Default: true.
379 PromptCaching bool
380 // ToolUseOnly indicates whether Claude may only use tools during this conversation.
381 // TODO: add more fine-grained control over tool use?
382 ToolUseOnly bool
383 // Parent is the parent conversation, if any.
384 // It is non-nil for "subagent" calls.
385 // It is set automatically when calling SubConvo,
386 // and usually should not be set manually.
387 Parent *Convo
388 // Budget is the budget for this conversation (and all sub-conversations).
389 // The Conversation DOES NOT automatically enforce the budget.
390 // It is up to the caller to call OverBudget() as appropriate.
391 Budget Budget
392
393 // messages tracks the messages so far in the conversation.
394 messages []Message
395
396 // Listener receives messages being sent.
397 Listener Listener
398
399 muToolUseCancel *sync.Mutex
400 toolUseCancel map[string]context.CancelCauseFunc
401
402 // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
403 mu *sync.Mutex
404 // usage tracks usage for this conversation and all sub-conversations.
405 usage *CumulativeUsage
406}
407
408// newConvoID generates a new 8-byte random id.
409// The uniqueness/collision requirements here are very low.
410// They are not global identifiers,
411// just enough to distinguish different convos in a single session.
412func newConvoID() string {
413 u1 := rand.Uint32()
414 s := crock32.Encode(uint64(u1))
415 if len(s) < 7 {
416 s += strings.Repeat("0", 7-len(s))
417 }
418 return s[:3] + "-" + s[3:]
419}
420
421// NewConvo creates a new conversation with Claude with sensible defaults.
422// ctx is the context for the entire conversation.
423func NewConvo(ctx context.Context, apiKey string) *Convo {
424 id := newConvoID()
425 return &Convo{
426 Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
427 HTTPC: http.DefaultClient,
428 URL: DefaultURL,
429 APIKey: apiKey,
430 Model: DefaultModel,
431 MaxTokens: DefaultMaxTokens,
432 PromptCaching: true,
433 usage: newUsage(),
434 Listener: &NoopListener{},
435 ID: id,
436 muToolUseCancel: &sync.Mutex{},
437 toolUseCancel: map[string]context.CancelCauseFunc{},
438 mu: &sync.Mutex{},
439 }
440}
441
442// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
443// (This propagates context for cancellation, HTTP client, API key, etc.)
444// The sub-conversation shares no messages with the parent conversation.
445// It does not inherit tools from the parent conversation.
446func (c *Convo) SubConvo() *Convo {
447 id := newConvoID()
448 return &Convo{
449 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
450 HTTPC: c.HTTPC,
Josh Bleecher Snyder6a50b182025-04-24 11:07:41 -0700451 URL: c.URL,
Earl Lee2e463fb2025-04-17 11:22:22 -0700452 APIKey: c.APIKey,
453 Model: c.Model,
454 MaxTokens: c.MaxTokens,
455 PromptCaching: c.PromptCaching,
456 Parent: c,
457 // For convenience, sub-convo usage shares tool uses map with parent,
458 // all other fields separate, propagated in AddResponse
459 usage: newUsageWithSharedToolUses(c.usage),
460 mu: c.mu,
461 Listener: c.Listener,
462 ID: id,
463 // Do not copy Budget. Each budget is independent,
464 // and OverBudget checks whether any ancestor is over budget.
465 }
466}
467
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700468func (c *Convo) SubConvoWithHistory() *Convo {
469 id := newConvoID()
470 return &Convo{
471 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
472 HTTPC: c.HTTPC,
473 URL: c.URL,
474 APIKey: c.APIKey,
475 Model: c.Model,
476 MaxTokens: c.MaxTokens,
477 PromptCaching: c.PromptCaching,
478 Parent: c,
479 // For convenience, sub-convo usage shares tool uses map with parent,
480 // all other fields separate, propagated in AddResponse
481 usage: newUsageWithSharedToolUses(c.usage),
482 mu: c.mu,
483 Listener: c.Listener,
484 ID: id,
485 // Do not copy Budget. Each budget is independent,
486 // and OverBudget checks whether any ancestor is over budget.
487 messages: slices.Clone(c.messages),
488 }
489}
490
Earl Lee2e463fb2025-04-17 11:22:22 -0700491// Depth reports how many "sub-conversations" deep this conversation is.
492// That it, it walks up parents until it finds a root.
493func (c *Convo) Depth() int {
494 x := c
495 var depth int
496 for x.Parent != nil {
497 x = x.Parent
498 depth++
499 }
500 return depth
501}
502
503// SendUserTextMessage sends a text message to Claude in this conversation.
504// otherContents contains additional contents to send with the message, usually tool results.
505func (c *Convo) SendUserTextMessage(s string, otherContents ...Content) (*MessageResponse, error) {
506 contents := slices.Clone(otherContents)
507 if s != "" {
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +0000508 contents = append(contents, StringContent(s))
Earl Lee2e463fb2025-04-17 11:22:22 -0700509 }
510 msg := Message{
511 Role: MessageRoleUser,
512 Content: contents,
513 }
514 return c.SendMessage(msg)
515}
516
517func (c *Convo) messageRequest(msg Message) *MessageRequest {
518 system := []SystemContent{}
519 if c.SystemPrompt != "" {
520 var d SystemContent
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +0000521 d = SystemContent{Type: ContentTypeText, Text: c.SystemPrompt}
Earl Lee2e463fb2025-04-17 11:22:22 -0700522 if c.PromptCaching {
523 d.CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
524 }
525 system = []SystemContent{d}
526 }
527
528 // Claude is happy to return an empty response in response to our Done() call,
529 // and, if so, you'll see something like:
530 // API request failed with status 400 Bad Request
531 // {"type":"error","error": {"type":"invalid_request_error",
532 // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
533 // So, we filter out those empty messages.
534 var nonEmptyMessages []Message
535 for _, m := range c.messages {
536 if len(m.Content) > 0 {
537 nonEmptyMessages = append(nonEmptyMessages, m)
538 }
539 }
540
541 mr := &MessageRequest{
542 Model: c.Model,
543 Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
544 System: system,
545 Tools: c.Tools,
546 MaxTokens: c.MaxTokens,
547 }
548 if c.ToolUseOnly {
549 mr.ToolChoice = &ToolChoice{Type: ToolChoiceTypeAny}
550 }
551 return mr
552}
553
554func (c *Convo) findTool(name string) (*Tool, error) {
555 for _, tool := range c.Tools {
556 if tool.Name == name {
557 return tool, nil
558 }
559 }
560 return nil, fmt.Errorf("tool %q not found", name)
561}
562
563// insertMissingToolResults adds error results for tool uses that were requested
564// but not included in the message, which can happen in error paths like "out of budget."
565// We only insert these if there were no tool responses at all, since an incorrect
566// number of tool results would be a programmer error. Mutates inputs.
567func (c *Convo) insertMissingToolResults(mr *MessageRequest, msg *Message) {
568 if len(mr.Messages) < 2 {
569 return
570 }
571 prev := mr.Messages[len(mr.Messages)-2]
572 var toolUsePrev int
573 for _, c := range prev.Content {
574 if c.Type == ContentTypeToolUse {
575 toolUsePrev++
576 }
577 }
578 if toolUsePrev == 0 {
579 return
580 }
581 var toolUseCurrent int
582 for _, c := range msg.Content {
583 if c.Type == ContentTypeToolResult {
584 toolUseCurrent++
585 }
586 }
587 if toolUseCurrent != 0 {
588 return
589 }
590 var prefix []Content
591 for _, part := range prev.Content {
592 if part.Type != ContentTypeToolUse {
593 continue
594 }
595 content := Content{
596 Type: ContentTypeToolResult,
597 ToolUseID: part.ID,
598 ToolError: true,
599 ToolResult: "not executed; retry possible",
600 }
601 prefix = append(prefix, content)
602 msg.Content = append(prefix, msg.Content...)
603 mr.Messages[len(mr.Messages)-1].Content = msg.Content
604 }
605 slog.DebugContext(c.Ctx, "inserted missing tool results")
606}
607
608// SendMessage sends a message to Claude.
609// The conversation records (internally) all messages succesfully sent and received.
610func (c *Convo) SendMessage(msg Message) (*MessageResponse, error) {
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000611 id := ulid.Make().String()
Earl Lee2e463fb2025-04-17 11:22:22 -0700612 mr := c.messageRequest(msg)
613 var lastMessage *Message
614 if c.PromptCaching {
615 lastMessage = &mr.Messages[len(mr.Messages)-1]
616 if len(lastMessage.Content) > 0 {
617 lastMessage.Content[len(lastMessage.Content)-1].CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
618 }
619 }
620 defer func() {
621 if lastMessage == nil {
622 return
623 }
624 if len(lastMessage.Content) > 0 {
625 lastMessage.Content[len(lastMessage.Content)-1].CacheControl = []byte{}
626 }
627 }()
628 c.insertMissingToolResults(mr, &msg)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000629 c.Listener.OnRequest(c.Ctx, c, id, &msg)
Earl Lee2e463fb2025-04-17 11:22:22 -0700630
631 startTime := time.Now()
632 resp, err := createMessage(c.Ctx, c.HTTPC, c.URL, c.APIKey, mr)
633 if resp != nil {
634 resp.StartTime = &startTime
635 endTime := time.Now()
636 resp.EndTime = &endTime
637 }
638
639 if err != nil {
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000640 c.Listener.OnResponse(c.Ctx, c, id, nil)
Earl Lee2e463fb2025-04-17 11:22:22 -0700641 return nil, err
642 }
643 c.messages = append(c.messages, msg, resp.ToMessage())
644 // Propagate usage to all ancestors (including us).
645 for x := c; x != nil; x = x.Parent {
646 x.usage.AddResponse(resp)
647 }
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000648 c.Listener.OnResponse(c.Ctx, c, id, resp)
Earl Lee2e463fb2025-04-17 11:22:22 -0700649 return resp, err
650}
651
652type toolCallInfoKeyType string
653
654var toolCallInfoKey toolCallInfoKeyType
655
656type ToolCallInfo struct {
657 ToolUseID string
658 Convo *Convo
659}
660
661func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
662 v := ctx.Value(toolCallInfoKey)
663 i, _ := v.(ToolCallInfo)
664 return i
665}
666
667func (c *Convo) ToolResultCancelContents(resp *MessageResponse) ([]Content, error) {
668 if resp.StopReason != StopReasonToolUse {
669 return nil, nil
670 }
671 var toolResults []Content
672
673 for _, part := range resp.Content {
674 if part.Type != ContentTypeToolUse {
675 continue
676 }
677 c.incrementToolUse(part.ToolName)
678
679 content := Content{
680 Type: ContentTypeToolResult,
681 ToolUseID: part.ID,
682 }
683
684 content.ToolError = true
685 content.ToolResult = "user canceled this too_use"
686 toolResults = append(toolResults, content)
687 }
688 return toolResults, nil
689}
690
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700691// GetID returns the conversation ID
692func (c *Convo) GetID() string {
693 return c.ID
694}
695
Earl Lee2e463fb2025-04-17 11:22:22 -0700696func (c *Convo) CancelToolUse(toolUseID string, err error) error {
697 c.muToolUseCancel.Lock()
698 defer c.muToolUseCancel.Unlock()
699 cancel, ok := c.toolUseCancel[toolUseID]
700 if !ok {
701 return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
702 }
703 delete(c.toolUseCancel, toolUseID)
704 cancel(err)
705 return nil
706}
707
708func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
709 c.muToolUseCancel.Lock()
710 defer c.muToolUseCancel.Unlock()
711 ctx, cancel := context.WithCancelCause(ctx)
712 c.toolUseCancel[toolUseID] = cancel
713 return ctx, func() { c.CancelToolUse(toolUseID, nil) }
714}
715
716// ToolResultContents runs all tool uses requested by the response and returns their results.
717// Cancelling ctx will cancel any running tool calls.
718func (c *Convo) ToolResultContents(ctx context.Context, resp *MessageResponse) ([]Content, error) {
719 if resp.StopReason != StopReasonToolUse {
720 return nil, nil
721 }
722 // Extract all tool calls from the response, call the tools, and gather the results.
723 var wg sync.WaitGroup
724 toolResultC := make(chan Content, len(resp.Content))
725 for _, part := range resp.Content {
726 if part.Type != ContentTypeToolUse {
727 continue
728 }
729 c.incrementToolUse(part.ToolName)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000730 startTime := time.Now()
731
732 c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, Content{
733 Type: ContentTypeToolUse,
734 ToolUseID: part.ID,
735 StartTime: &startTime,
736 })
737
Earl Lee2e463fb2025-04-17 11:22:22 -0700738 wg.Add(1)
739 go func() {
740 defer wg.Done()
741
Earl Lee2e463fb2025-04-17 11:22:22 -0700742 content := Content{
743 Type: ContentTypeToolResult,
744 ToolUseID: part.ID,
745 StartTime: &startTime,
746 }
747 sendErr := func(err error) {
748 // Record end time
749 endTime := time.Now()
750 content.EndTime = &endTime
751
752 content.ToolError = true
753 content.ToolResult = err.Error()
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000754 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
Earl Lee2e463fb2025-04-17 11:22:22 -0700755 toolResultC <- content
756 }
757 sendRes := func(res string) {
758 // Record end time
759 endTime := time.Now()
760 content.EndTime = &endTime
761
762 content.ToolResult = res
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000763 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &res, nil)
Earl Lee2e463fb2025-04-17 11:22:22 -0700764 toolResultC <- content
765 }
766
767 tool, err := c.findTool(part.ToolName)
768 if err != nil {
769 sendErr(err)
770 return
771 }
772 // Create a new context for just this tool_use call, and register its
773 // cancel function so that it can be canceled individually.
774 toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
775 defer cancel()
776 // TODO: move this into newToolUseContext?
777 toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
778 toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
779 if errors.Is(err, ErrDoNotRespond) {
780 return
781 }
782 if toolUseCtx.Err() != nil {
783 sendErr(context.Cause(toolUseCtx))
784 return
785 }
786
787 if err != nil {
788 sendErr(err)
789 return
790 }
791 sendRes(toolResult)
792 }()
793 }
794 wg.Wait()
795 close(toolResultC)
796 var toolResults []Content
797 for toolResult := range toolResultC {
798 toolResults = append(toolResults, toolResult)
799 }
800 if ctx.Err() != nil {
801 return nil, ctx.Err()
802 }
803 return toolResults, nil
804}
805
806func (c *Convo) incrementToolUse(name string) {
807 c.mu.Lock()
808 defer c.mu.Unlock()
809
810 c.usage.ToolUses[name]++
811}
812
813// ContentsAttr returns contents as a slog.Attr.
814// It is meant for logging.
815func ContentsAttr(contents []Content) slog.Attr {
816 var contentAttrs []any // slog.Attr
817 for _, content := range contents {
818 var attrs []any // slog.Attr
819 switch content.Type {
820 case ContentTypeText:
821 attrs = append(attrs, slog.String("text", content.Text))
822 case ContentTypeToolUse:
823 attrs = append(attrs, slog.String("tool_name", content.ToolName))
824 attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
825 case ContentTypeToolResult:
826 attrs = append(attrs, slog.String("tool_result", content.ToolResult))
827 attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
828 case ContentTypeThinking:
829 attrs = append(attrs, slog.String("thinking", content.Text))
830 default:
831 attrs = append(attrs, slog.String("unknown_content_type", content.Type))
832 attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
833 }
834 contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
835 }
836 return slog.Group("contents", contentAttrs...)
837}
838
839// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
840// It panics if the schema is invalid.
841func MustSchema(schema string) json.RawMessage {
842 // TODO: validate schema, for now just make sure it's valid JSON
843 schema = strings.TrimSpace(schema)
844 bytes := []byte(schema)
845 if !json.Valid(bytes) {
846 panic("invalid JSON schema: " + schema)
847 }
848 return json.RawMessage(bytes)
849}
850
851// cents per million tokens
852// (not dollars because i'm twitchy about using floats for money)
853type centsPer1MTokens struct {
854 Input uint64
855 Output uint64
856 CacheRead uint64
857 CacheCreation uint64
858}
859
860// https://www.anthropic.com/pricing#anthropic-api
861var modelCost = map[string]centsPer1MTokens{
862 Claude37Sonnet: {
863 Input: 300, // $3
864 Output: 1500, // $15
865 CacheRead: 30, // $0.30
866 CacheCreation: 375, // $3.75
867 },
868 Claude35Haiku: {
869 Input: 80, // $0.80
870 Output: 400, // $4.00
871 CacheRead: 8, // $0.08
872 CacheCreation: 100, // $1.00
873 },
874 Claude35Sonnet: {
875 Input: 300, // $3
876 Output: 1500, // $15
877 CacheRead: 30, // $0.30
878 CacheCreation: 375, // $3.75
879 },
880}
881
882// TotalDollars returns the total cost to obtain this response, in dollars.
883func (mr *MessageResponse) TotalDollars() float64 {
884 cpm, ok := modelCost[mr.Model]
885 if !ok {
886 panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
887 }
888 use := mr.Usage
889 megaCents := use.InputTokens*cpm.Input +
890 use.OutputTokens*cpm.Output +
891 use.CacheReadInputTokens*cpm.CacheRead +
892 use.CacheCreationInputTokens*cpm.CacheCreation
893 cents := float64(megaCents) / 1_000_000.0
894 return cents / 100.0
895}
896
897func newUsage() *CumulativeUsage {
898 return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
899}
900
901func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
902 return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
903}
904
905// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
906type CumulativeUsage struct {
907 StartTime time.Time `json:"start_time"`
908 Responses uint64 `json:"messages"` // count of responses
909 InputTokens uint64 `json:"input_tokens"`
910 OutputTokens uint64 `json:"output_tokens"`
911 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
912 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
913 TotalCostUSD float64 `json:"total_cost_usd"`
914 ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
915}
916
917func (u *CumulativeUsage) Clone() CumulativeUsage {
918 v := *u
919 v.ToolUses = maps.Clone(u.ToolUses)
920 return v
921}
922
923func (c *Convo) CumulativeUsage() CumulativeUsage {
924 if c == nil {
925 return CumulativeUsage{}
926 }
927 c.mu.Lock()
928 defer c.mu.Unlock()
929 return c.usage.Clone()
930}
931
932func (u *CumulativeUsage) WallTime() time.Duration {
933 return time.Since(u.StartTime)
934}
935
936func (u *CumulativeUsage) DollarsPerHour() float64 {
937 hours := u.WallTime().Hours()
938 if hours == 0 {
939 return 0
940 }
941 return u.TotalCostUSD / hours
942}
943
944func (u *CumulativeUsage) AddResponse(resp *MessageResponse) {
945 usage := resp.Usage
946 u.Responses++
947 u.InputTokens += usage.InputTokens
948 u.OutputTokens += usage.OutputTokens
949 u.CacheReadInputTokens += usage.CacheReadInputTokens
950 u.CacheCreationInputTokens += usage.CacheCreationInputTokens
951 u.TotalCostUSD += resp.TotalDollars()
952}
953
Josh Bleecher Snyder35889972025-04-24 20:48:16 +0000954// TotalInputTokens returns the grand total cumulative input tokens in u.
955func (u *CumulativeUsage) TotalInputTokens() uint64 {
956 return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
957}
958
Earl Lee2e463fb2025-04-17 11:22:22 -0700959// Attr returns the cumulative usage as a slog.Attr with key "usage".
960func (u CumulativeUsage) Attr() slog.Attr {
961 elapsed := time.Since(u.StartTime)
962 return slog.Group("usage",
963 slog.Duration("wall_time", elapsed),
964 slog.Uint64("responses", u.Responses),
965 slog.Uint64("input_tokens", u.InputTokens),
966 slog.Uint64("output_tokens", u.OutputTokens),
967 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
968 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
969 slog.Float64("total_cost_usd", u.TotalCostUSD),
970 slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
971 slog.Any("tool_uses", maps.Clone(u.ToolUses)),
972 )
973}
974
975// A Budget represents the maximum amount of resources that may be spent on a conversation.
976// Note that the default (zero) budget is unlimited.
977type Budget struct {
978 MaxResponses uint64 // if > 0, max number of iterations (=responses)
979 MaxDollars float64 // if > 0, max dollars that may be spent
980 MaxWallTime time.Duration // if > 0, max wall time that may be spent
981}
982
983// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
984// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
985func (c *Convo) OverBudget() error {
986 for x := c; x != nil; x = x.Parent {
987 if err := x.overBudget(); err != nil {
988 return err
989 }
990 }
991 return nil
992}
993
994// ResetBudget sets the budget to the passed in budget and
995// adjusts it by what's been used so far.
996func (c *Convo) ResetBudget(budget Budget) {
997 c.Budget = budget
998 if c.Budget.MaxDollars > 0 {
999 c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
1000 }
1001 if c.Budget.MaxResponses > 0 {
1002 c.Budget.MaxResponses += c.CumulativeUsage().Responses
1003 }
1004 if c.Budget.MaxWallTime > 0 {
1005 c.Budget.MaxWallTime += c.usage.WallTime()
1006 }
1007}
1008
1009func (c *Convo) overBudget() error {
1010 usage := c.CumulativeUsage()
1011 // TODO: stop before we exceed the budget instead of after?
1012 // Top priority is money, then time, then response count.
1013 var err error
1014 cont := "Continuing to chat will reset the budget."
1015 if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
1016 err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
1017 }
1018 if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
1019 err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
1020 }
1021 if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
1022 err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
1023 }
1024 return err
1025}
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +00001026
1027// UserStringMessage creates a user message with a single text content item.
1028func UserStringMessage(text string) Message {
1029 return Message{
1030 Role: MessageRoleUser,
1031 Content: []Content{StringContent(text)},
1032 }
1033}