blob: 95e4fbaa02dd439aefb54587b6d606191f6a4db3 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package conversation
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "log/slog"
9 "maps"
10 "math/rand/v2"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/oklog/ulid/v2"
17 "github.com/richardlehane/crock32"
18 "sketch.dev/llm"
19 "sketch.dev/skribe"
20)
21
22type Listener interface {
23 // TODO: Content is leaking an anthropic API; should we avoid it?
24 // TODO: Where should we include start/end time and usage?
25 OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content)
26 OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error)
27 OnRequest(ctx context.Context, convo *Convo, requestID string, msg *llm.Message)
28 OnResponse(ctx context.Context, convo *Convo, requestID string, msg *llm.Response)
29}
30
31type NoopListener struct{}
32
33func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content) {
34}
35
36func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
37}
38
39func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *llm.Response) {
40}
41func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {}
42
43var ErrDoNotRespond = errors.New("do not respond")
44
45// A Convo is a managed conversation with Claude.
46// It automatically manages the state of the conversation,
47// including appending messages send/received,
48// calling tools and sending their results,
49// tracking usage, etc.
50//
51// Exported fields must not be altered concurrently with calling any method on Convo.
52// Typical usage is to configure a Convo once before using it.
53type Convo struct {
54 // ID is a unique ID for the conversation
55 ID string
56 // Ctx is the context for the entire conversation.
57 Ctx context.Context
58 // Service is the LLM service to use.
59 Service llm.Service
60 // Tools are the tools available during the conversation.
61 Tools []*llm.Tool
62 // SystemPrompt is the system prompt for the conversation.
63 SystemPrompt string
64 // PromptCaching indicates whether to use Anthropic's prompt caching.
65 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
66 // for the documentation. At request send time, we set the cache_control field on the
67 // last message. We also cache the system prompt.
68 // Default: true.
69 PromptCaching bool
70 // ToolUseOnly indicates whether Claude may only use tools during this conversation.
71 // TODO: add more fine-grained control over tool use?
72 ToolUseOnly bool
73 // Parent is the parent conversation, if any.
74 // It is non-nil for "subagent" calls.
75 // It is set automatically when calling SubConvo,
76 // and usually should not be set manually.
77 Parent *Convo
78 // Budget is the budget for this conversation (and all sub-conversations).
79 // The Conversation DOES NOT automatically enforce the budget.
80 // It is up to the caller to call OverBudget() as appropriate.
81 Budget Budget
Josh Bleecher Snyder4d544932025-05-07 13:33:53 +000082 // Hidden indicates that the output of this conversation should be hidden in the UI.
83 // This is useful for subconversations that can generate noisy, uninteresting output.
84 Hidden bool
Josh Bleecher Snyder31785ae2025-05-06 01:50:58 +000085 // ExtraData is extra data to make available to all tool calls.
86 ExtraData map[string]any
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070087
88 // messages tracks the messages so far in the conversation.
89 messages []llm.Message
90
91 // Listener receives messages being sent.
92 Listener Listener
93
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +000094 toolUseCancelMu sync.Mutex
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070095 toolUseCancel map[string]context.CancelCauseFunc
96
97 // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
98 mu *sync.Mutex
99 // usage tracks usage for this conversation and all sub-conversations.
100 usage *CumulativeUsage
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700101 // lastUsage tracks the usage from the most recent API call
102 lastUsage llm.Usage
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700103}
104
105// newConvoID generates a new 8-byte random id.
106// The uniqueness/collision requirements here are very low.
107// They are not global identifiers,
108// just enough to distinguish different convos in a single session.
109func newConvoID() string {
110 u1 := rand.Uint32()
111 s := crock32.Encode(uint64(u1))
112 if len(s) < 7 {
113 s += strings.Repeat("0", 7-len(s))
114 }
115 return s[:3] + "-" + s[3:]
116}
117
118// New creates a new conversation with Claude with sensible defaults.
119// ctx is the context for the entire conversation.
120func New(ctx context.Context, srv llm.Service) *Convo {
121 id := newConvoID()
122 return &Convo{
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000123 Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
124 Service: srv,
125 PromptCaching: true,
126 usage: newUsage(),
127 Listener: &NoopListener{},
128 ID: id,
129 toolUseCancel: map[string]context.CancelCauseFunc{},
130 mu: &sync.Mutex{},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700131 }
132}
133
134// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
135// (This propagates context for cancellation, HTTP client, API key, etc.)
136// The sub-conversation shares no messages with the parent conversation.
137// It does not inherit tools from the parent conversation.
138func (c *Convo) SubConvo() *Convo {
139 id := newConvoID()
140 return &Convo{
141 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
142 Service: c.Service,
143 PromptCaching: c.PromptCaching,
144 Parent: c,
145 // For convenience, sub-convo usage shares tool uses map with parent,
146 // all other fields separate, propagated in AddResponse
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000147 usage: newUsageWithSharedToolUses(c.usage),
148 mu: c.mu,
149 Listener: c.Listener,
150 ID: id,
151 toolUseCancel: map[string]context.CancelCauseFunc{},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700152 // Do not copy Budget. Each budget is independent,
153 // and OverBudget checks whether any ancestor is over budget.
154 }
155}
156
157func (c *Convo) SubConvoWithHistory() *Convo {
158 id := newConvoID()
159 return &Convo{
160 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
161 Service: c.Service,
162 PromptCaching: c.PromptCaching,
163 Parent: c,
164 // For convenience, sub-convo usage shares tool uses map with parent,
165 // all other fields separate, propagated in AddResponse
166 usage: newUsageWithSharedToolUses(c.usage),
167 mu: c.mu,
168 Listener: c.Listener,
169 ID: id,
170 // Do not copy Budget. Each budget is independent,
171 // and OverBudget checks whether any ancestor is over budget.
172 messages: slices.Clone(c.messages),
173 }
174}
175
176// Depth reports how many "sub-conversations" deep this conversation is.
177// That it, it walks up parents until it finds a root.
178func (c *Convo) Depth() int {
179 x := c
180 var depth int
181 for x.Parent != nil {
182 x = x.Parent
183 depth++
184 }
185 return depth
186}
187
188// SendUserTextMessage sends a text message to the LLM in this conversation.
189// otherContents contains additional contents to send with the message, usually tool results.
190func (c *Convo) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
191 contents := slices.Clone(otherContents)
192 if s != "" {
193 contents = append(contents, llm.Content{Type: llm.ContentTypeText, Text: s})
194 }
195 msg := llm.Message{
196 Role: llm.MessageRoleUser,
197 Content: contents,
198 }
199 return c.SendMessage(msg)
200}
201
202func (c *Convo) messageRequest(msg llm.Message) *llm.Request {
203 system := []llm.SystemContent{}
204 if c.SystemPrompt != "" {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000205 d := llm.SystemContent{Type: "text", Text: c.SystemPrompt}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700206 if c.PromptCaching {
207 d.Cache = true
208 }
209 system = []llm.SystemContent{d}
210 }
211
212 // Claude is happy to return an empty response in response to our Done() call,
213 // and, if so, you'll see something like:
214 // API request failed with status 400 Bad Request
215 // {"type":"error","error": {"type":"invalid_request_error",
216 // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
217 // So, we filter out those empty messages.
218 var nonEmptyMessages []llm.Message
219 for _, m := range c.messages {
220 if len(m.Content) > 0 {
221 nonEmptyMessages = append(nonEmptyMessages, m)
222 }
223 }
224
225 mr := &llm.Request{
226 Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
227 System: system,
228 Tools: c.Tools,
229 }
230 if c.ToolUseOnly {
231 mr.ToolChoice = &llm.ToolChoice{Type: llm.ToolChoiceTypeAny}
232 }
233 return mr
234}
235
236func (c *Convo) findTool(name string) (*llm.Tool, error) {
237 for _, tool := range c.Tools {
238 if tool.Name == name {
239 return tool, nil
240 }
241 }
242 return nil, fmt.Errorf("tool %q not found", name)
243}
244
245// insertMissingToolResults adds error results for tool uses that were requested
246// but not included in the message, which can happen in error paths like "out of budget."
247// We only insert these if there were no tool responses at all, since an incorrect
248// number of tool results would be a programmer error. Mutates inputs.
249func (c *Convo) insertMissingToolResults(mr *llm.Request, msg *llm.Message) {
250 if len(mr.Messages) < 2 {
251 return
252 }
253 prev := mr.Messages[len(mr.Messages)-2]
254 var toolUsePrev int
255 for _, c := range prev.Content {
256 if c.Type == llm.ContentTypeToolUse {
257 toolUsePrev++
258 }
259 }
260 if toolUsePrev == 0 {
261 return
262 }
263 var toolUseCurrent int
264 for _, c := range msg.Content {
265 if c.Type == llm.ContentTypeToolResult {
266 toolUseCurrent++
267 }
268 }
269 if toolUseCurrent != 0 {
270 return
271 }
272 var prefix []llm.Content
273 for _, part := range prev.Content {
274 if part.Type != llm.ContentTypeToolUse {
275 continue
276 }
277 content := llm.Content{
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700278 Type: llm.ContentTypeToolResult,
279 ToolUseID: part.ID,
280 ToolError: true,
281 ToolResult: []llm.Content{{
282 Type: llm.ContentTypeText,
283 Text: "not executed; retry possible",
284 }},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700285 }
286 prefix = append(prefix, content)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700287 }
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +0000288 msg.Content = append(prefix, msg.Content...)
289 mr.Messages[len(mr.Messages)-1].Content = msg.Content
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700290 slog.DebugContext(c.Ctx, "inserted missing tool results")
291}
292
293// SendMessage sends a message to Claude.
294// The conversation records (internally) all messages succesfully sent and received.
295func (c *Convo) SendMessage(msg llm.Message) (*llm.Response, error) {
296 id := ulid.Make().String()
297 mr := c.messageRequest(msg)
298 var lastMessage *llm.Message
299 if c.PromptCaching {
300 lastMessage = &mr.Messages[len(mr.Messages)-1]
301 if len(lastMessage.Content) > 0 {
302 lastMessage.Content[len(lastMessage.Content)-1].Cache = true
303 }
304 }
305 defer func() {
306 if lastMessage == nil {
307 return
308 }
309 if len(lastMessage.Content) > 0 {
310 lastMessage.Content[len(lastMessage.Content)-1].Cache = false
311 }
312 }()
313 c.insertMissingToolResults(mr, &msg)
314 c.Listener.OnRequest(c.Ctx, c, id, &msg)
315
316 startTime := time.Now()
317 resp, err := c.Service.Do(c.Ctx, mr)
318 if resp != nil {
319 resp.StartTime = &startTime
320 endTime := time.Now()
321 resp.EndTime = &endTime
322 }
323
324 if err != nil {
325 c.Listener.OnResponse(c.Ctx, c, id, nil)
326 return nil, err
327 }
328 c.messages = append(c.messages, msg, resp.ToMessage())
329 // Propagate usage to all ancestors (including us).
330 for x := c; x != nil; x = x.Parent {
331 x.usage.Add(resp.Usage)
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700332 // Store the most recent usage (only on the current conversation, not ancestors)
333 if x == c {
334 x.lastUsage = resp.Usage
335 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700336 }
337 c.Listener.OnResponse(c.Ctx, c, id, resp)
338 return resp, err
339}
340
341type toolCallInfoKeyType string
342
343var toolCallInfoKey toolCallInfoKeyType
344
345type ToolCallInfo struct {
346 ToolUseID string
347 Convo *Convo
348}
349
350func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
351 v := ctx.Value(toolCallInfoKey)
352 i, _ := v.(ToolCallInfo)
353 return i
354}
355
356func (c *Convo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
357 if resp.StopReason != llm.StopReasonToolUse {
358 return nil, nil
359 }
360 var toolResults []llm.Content
361
362 for _, part := range resp.Content {
363 if part.Type != llm.ContentTypeToolUse {
364 continue
365 }
366 c.incrementToolUse(part.ToolName)
367
368 content := llm.Content{
369 Type: llm.ContentTypeToolResult,
370 ToolUseID: part.ID,
371 }
372
373 content.ToolError = true
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700374 content.ToolResult = []llm.Content{{
375 Type: llm.ContentTypeText,
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +0000376 Text: "user canceled this tool_use",
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700377 }}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700378 toolResults = append(toolResults, content)
379 }
380 return toolResults, nil
381}
382
383// GetID returns the conversation ID
384func (c *Convo) GetID() string {
385 return c.ID
386}
387
388func (c *Convo) CancelToolUse(toolUseID string, err error) error {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000389 c.toolUseCancelMu.Lock()
390 defer c.toolUseCancelMu.Unlock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700391 cancel, ok := c.toolUseCancel[toolUseID]
392 if !ok {
393 return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
394 }
395 delete(c.toolUseCancel, toolUseID)
396 cancel(err)
397 return nil
398}
399
400func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000401 c.toolUseCancelMu.Lock()
402 defer c.toolUseCancelMu.Unlock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700403 ctx, cancel := context.WithCancelCause(ctx)
404 c.toolUseCancel[toolUseID] = cancel
405 return ctx, func() { c.CancelToolUse(toolUseID, nil) }
406}
407
408// ToolResultContents runs all tool uses requested by the response and returns their results.
409// Cancelling ctx will cancel any running tool calls.
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000410// The boolean return value indicates whether any of the executed tools should end the turn.
411func (c *Convo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700412 if resp.StopReason != llm.StopReasonToolUse {
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000413 return nil, false, nil
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700414 }
415 // Extract all tool calls from the response, call the tools, and gather the results.
416 var wg sync.WaitGroup
417 toolResultC := make(chan llm.Content, len(resp.Content))
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000418
419 endsTurn := false
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700420 for _, part := range resp.Content {
421 if part.Type != llm.ContentTypeToolUse {
422 continue
423 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000424 tool, err := c.findTool(part.ToolName)
425 if err == nil && tool.EndsTurn {
426 endsTurn = true
427 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700428 c.incrementToolUse(part.ToolName)
429 startTime := time.Now()
430
431 c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, llm.Content{
432 Type: llm.ContentTypeToolUse,
433 ToolUseID: part.ID,
434 ToolUseStartTime: &startTime,
435 })
436
437 wg.Add(1)
438 go func() {
439 defer wg.Done()
440
441 content := llm.Content{
442 Type: llm.ContentTypeToolResult,
443 ToolUseID: part.ID,
444 ToolUseStartTime: &startTime,
445 }
446 sendErr := func(err error) {
447 // Record end time
448 endTime := time.Now()
449 content.ToolUseEndTime = &endTime
450
451 content.ToolError = true
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700452 content.ToolResult = []llm.Content{{
453 Type: llm.ContentTypeText,
454 Text: err.Error(),
455 }}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700456 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
457 toolResultC <- content
458 }
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700459 sendRes := func(toolResult []llm.Content) {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700460 // Record end time
461 endTime := time.Now()
462 content.ToolUseEndTime = &endTime
463
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700464 content.ToolResult = toolResult
465 var firstText string
466 if len(toolResult) > 0 {
467 firstText = toolResult[0].Text
468 }
469 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &firstText, nil)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700470 toolResultC <- content
471 }
472
473 tool, err := c.findTool(part.ToolName)
474 if err != nil {
475 sendErr(err)
476 return
477 }
478 // Create a new context for just this tool_use call, and register its
479 // cancel function so that it can be canceled individually.
480 toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
481 defer cancel()
482 // TODO: move this into newToolUseContext?
483 toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
484 toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
485 if errors.Is(err, ErrDoNotRespond) {
486 return
487 }
488 if toolUseCtx.Err() != nil {
489 sendErr(context.Cause(toolUseCtx))
490 return
491 }
492
493 if err != nil {
494 sendErr(err)
495 return
496 }
497 sendRes(toolResult)
498 }()
499 }
500 wg.Wait()
501 close(toolResultC)
502 var toolResults []llm.Content
503 for toolResult := range toolResultC {
504 toolResults = append(toolResults, toolResult)
505 }
506 if ctx.Err() != nil {
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000507 return nil, false, ctx.Err()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700508 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000509 return toolResults, endsTurn, nil
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700510}
511
512func (c *Convo) incrementToolUse(name string) {
513 c.mu.Lock()
514 defer c.mu.Unlock()
515
516 c.usage.ToolUses[name]++
517}
518
519// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
520type CumulativeUsage struct {
521 StartTime time.Time `json:"start_time"`
522 Responses uint64 `json:"messages"` // count of responses
523 InputTokens uint64 `json:"input_tokens"`
524 OutputTokens uint64 `json:"output_tokens"`
525 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
526 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
527 TotalCostUSD float64 `json:"total_cost_usd"`
528 ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
529}
530
531func newUsage() *CumulativeUsage {
532 return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
533}
534
535func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
536 return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
537}
538
539func (u *CumulativeUsage) Clone() CumulativeUsage {
540 v := *u
541 v.ToolUses = maps.Clone(u.ToolUses)
542 return v
543}
544
545func (c *Convo) CumulativeUsage() CumulativeUsage {
546 if c == nil {
547 return CumulativeUsage{}
548 }
549 c.mu.Lock()
550 defer c.mu.Unlock()
551 return c.usage.Clone()
552}
553
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700554// LastUsage returns the usage from the most recent API call
555func (c *Convo) LastUsage() llm.Usage {
556 if c == nil {
557 return llm.Usage{}
558 }
559 c.mu.Lock()
560 defer c.mu.Unlock()
561 return c.lastUsage
562}
563
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700564func (u *CumulativeUsage) WallTime() time.Duration {
565 return time.Since(u.StartTime)
566}
567
568func (u *CumulativeUsage) DollarsPerHour() float64 {
569 hours := u.WallTime().Hours()
570 // Prevent division by very small numbers that could cause issues
571 if hours < 1e-6 {
572 return 0
573 }
574 return u.TotalCostUSD / hours
575}
576
577func (u *CumulativeUsage) Add(usage llm.Usage) {
578 u.Responses++
579 u.InputTokens += usage.InputTokens
580 u.OutputTokens += usage.OutputTokens
581 u.CacheReadInputTokens += usage.CacheReadInputTokens
582 u.CacheCreationInputTokens += usage.CacheCreationInputTokens
583 u.TotalCostUSD += usage.CostUSD
584}
585
586// TotalInputTokens returns the grand total cumulative input tokens in u.
587func (u *CumulativeUsage) TotalInputTokens() uint64 {
588 return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
589}
590
591// Attr returns the cumulative usage as a slog.Attr with key "usage".
592func (u CumulativeUsage) Attr() slog.Attr {
593 elapsed := time.Since(u.StartTime)
594 return slog.Group("usage",
595 slog.Duration("wall_time", elapsed),
596 slog.Uint64("responses", u.Responses),
597 slog.Uint64("input_tokens", u.InputTokens),
598 slog.Uint64("output_tokens", u.OutputTokens),
599 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
600 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
601 slog.Float64("total_cost_usd", u.TotalCostUSD),
602 slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
603 slog.Any("tool_uses", maps.Clone(u.ToolUses)),
604 )
605}
606
607// A Budget represents the maximum amount of resources that may be spent on a conversation.
608// Note that the default (zero) budget is unlimited.
609type Budget struct {
Philip Zeyligere6c294d2025-06-04 16:55:21 +0000610 MaxDollars float64 // if > 0, max dollars that may be spent
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700611}
612
613// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
614// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
615func (c *Convo) OverBudget() error {
616 for x := c; x != nil; x = x.Parent {
617 if err := x.overBudget(); err != nil {
618 return err
619 }
620 }
621 return nil
622}
623
624// ResetBudget sets the budget to the passed in budget and
625// adjusts it by what's been used so far.
626func (c *Convo) ResetBudget(budget Budget) {
627 c.Budget = budget
628 if c.Budget.MaxDollars > 0 {
629 c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
630 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700631}
632
633func (c *Convo) overBudget() error {
634 usage := c.CumulativeUsage()
635 // TODO: stop before we exceed the budget instead of after?
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700636 var err error
637 cont := "Continuing to chat will reset the budget."
638 if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
639 err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
640 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700641 return err
642}