blob: 5e20137efb598f3e096304202d02bd3e636293ec [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package conversation
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "log/slog"
9 "maps"
10 "math/rand/v2"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/oklog/ulid/v2"
17 "github.com/richardlehane/crock32"
18 "sketch.dev/llm"
19 "sketch.dev/skribe"
20)
21
22type Listener interface {
23 // TODO: Content is leaking an anthropic API; should we avoid it?
24 // TODO: Where should we include start/end time and usage?
25 OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content)
26 OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error)
27 OnRequest(ctx context.Context, convo *Convo, requestID string, msg *llm.Message)
28 OnResponse(ctx context.Context, convo *Convo, requestID string, msg *llm.Response)
29}
30
31type NoopListener struct{}
32
33func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content) {
34}
35
36func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
37}
38
39func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *llm.Response) {
40}
41func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {}
42
43var ErrDoNotRespond = errors.New("do not respond")
44
45// A Convo is a managed conversation with Claude.
46// It automatically manages the state of the conversation,
47// including appending messages send/received,
48// calling tools and sending their results,
49// tracking usage, etc.
50//
51// Exported fields must not be altered concurrently with calling any method on Convo.
52// Typical usage is to configure a Convo once before using it.
53type Convo struct {
54 // ID is a unique ID for the conversation
55 ID string
56 // Ctx is the context for the entire conversation.
57 Ctx context.Context
58 // Service is the LLM service to use.
59 Service llm.Service
60 // Tools are the tools available during the conversation.
61 Tools []*llm.Tool
62 // SystemPrompt is the system prompt for the conversation.
63 SystemPrompt string
64 // PromptCaching indicates whether to use Anthropic's prompt caching.
65 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
66 // for the documentation. At request send time, we set the cache_control field on the
67 // last message. We also cache the system prompt.
68 // Default: true.
69 PromptCaching bool
70 // ToolUseOnly indicates whether Claude may only use tools during this conversation.
71 // TODO: add more fine-grained control over tool use?
72 ToolUseOnly bool
73 // Parent is the parent conversation, if any.
74 // It is non-nil for "subagent" calls.
75 // It is set automatically when calling SubConvo,
76 // and usually should not be set manually.
77 Parent *Convo
78 // Budget is the budget for this conversation (and all sub-conversations).
79 // The Conversation DOES NOT automatically enforce the budget.
80 // It is up to the caller to call OverBudget() as appropriate.
81 Budget Budget
Josh Bleecher Snyder4d544932025-05-07 13:33:53 +000082 // Hidden indicates that the output of this conversation should be hidden in the UI.
83 // This is useful for subconversations that can generate noisy, uninteresting output.
84 Hidden bool
Josh Bleecher Snyder31785ae2025-05-06 01:50:58 +000085 // ExtraData is extra data to make available to all tool calls.
86 ExtraData map[string]any
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070087
88 // messages tracks the messages so far in the conversation.
89 messages []llm.Message
90
91 // Listener receives messages being sent.
92 Listener Listener
93
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +000094 toolUseCancelMu sync.Mutex
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070095 toolUseCancel map[string]context.CancelCauseFunc
96
97 // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
98 mu *sync.Mutex
99 // usage tracks usage for this conversation and all sub-conversations.
100 usage *CumulativeUsage
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700101 // lastUsage tracks the usage from the most recent API call
102 lastUsage llm.Usage
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700103}
104
105// newConvoID generates a new 8-byte random id.
106// The uniqueness/collision requirements here are very low.
107// They are not global identifiers,
108// just enough to distinguish different convos in a single session.
109func newConvoID() string {
110 u1 := rand.Uint32()
111 s := crock32.Encode(uint64(u1))
112 if len(s) < 7 {
113 s += strings.Repeat("0", 7-len(s))
114 }
115 return s[:3] + "-" + s[3:]
116}
117
118// New creates a new conversation with Claude with sensible defaults.
119// ctx is the context for the entire conversation.
philip.zeyliger882e7ea2025-06-20 14:31:16 +0000120func New(ctx context.Context, srv llm.Service, usage *CumulativeUsage) *Convo {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700121 id := newConvoID()
philip.zeyliger882e7ea2025-06-20 14:31:16 +0000122 if usage == nil {
123 usage = newUsage()
124 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700125 return &Convo{
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000126 Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
127 Service: srv,
128 PromptCaching: true,
philip.zeyliger882e7ea2025-06-20 14:31:16 +0000129 usage: usage,
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000130 Listener: &NoopListener{},
131 ID: id,
132 toolUseCancel: map[string]context.CancelCauseFunc{},
133 mu: &sync.Mutex{},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700134 }
135}
136
137// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
138// (This propagates context for cancellation, HTTP client, API key, etc.)
139// The sub-conversation shares no messages with the parent conversation.
140// It does not inherit tools from the parent conversation.
141func (c *Convo) SubConvo() *Convo {
142 id := newConvoID()
143 return &Convo{
144 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
145 Service: c.Service,
146 PromptCaching: c.PromptCaching,
147 Parent: c,
148 // For convenience, sub-convo usage shares tool uses map with parent,
149 // all other fields separate, propagated in AddResponse
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000150 usage: newUsageWithSharedToolUses(c.usage),
151 mu: c.mu,
152 Listener: c.Listener,
153 ID: id,
154 toolUseCancel: map[string]context.CancelCauseFunc{},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700155 // Do not copy Budget. Each budget is independent,
156 // and OverBudget checks whether any ancestor is over budget.
157 }
158}
159
160func (c *Convo) SubConvoWithHistory() *Convo {
161 id := newConvoID()
162 return &Convo{
163 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
164 Service: c.Service,
165 PromptCaching: c.PromptCaching,
166 Parent: c,
167 // For convenience, sub-convo usage shares tool uses map with parent,
168 // all other fields separate, propagated in AddResponse
169 usage: newUsageWithSharedToolUses(c.usage),
170 mu: c.mu,
171 Listener: c.Listener,
172 ID: id,
173 // Do not copy Budget. Each budget is independent,
174 // and OverBudget checks whether any ancestor is over budget.
175 messages: slices.Clone(c.messages),
176 }
177}
178
179// Depth reports how many "sub-conversations" deep this conversation is.
180// That it, it walks up parents until it finds a root.
181func (c *Convo) Depth() int {
182 x := c
183 var depth int
184 for x.Parent != nil {
185 x = x.Parent
186 depth++
187 }
188 return depth
189}
190
191// SendUserTextMessage sends a text message to the LLM in this conversation.
192// otherContents contains additional contents to send with the message, usually tool results.
193func (c *Convo) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
194 contents := slices.Clone(otherContents)
195 if s != "" {
196 contents = append(contents, llm.Content{Type: llm.ContentTypeText, Text: s})
197 }
198 msg := llm.Message{
199 Role: llm.MessageRoleUser,
200 Content: contents,
201 }
202 return c.SendMessage(msg)
203}
204
205func (c *Convo) messageRequest(msg llm.Message) *llm.Request {
206 system := []llm.SystemContent{}
207 if c.SystemPrompt != "" {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000208 d := llm.SystemContent{Type: "text", Text: c.SystemPrompt}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700209 if c.PromptCaching {
210 d.Cache = true
211 }
212 system = []llm.SystemContent{d}
213 }
214
215 // Claude is happy to return an empty response in response to our Done() call,
216 // and, if so, you'll see something like:
217 // API request failed with status 400 Bad Request
218 // {"type":"error","error": {"type":"invalid_request_error",
219 // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
220 // So, we filter out those empty messages.
221 var nonEmptyMessages []llm.Message
222 for _, m := range c.messages {
223 if len(m.Content) > 0 {
224 nonEmptyMessages = append(nonEmptyMessages, m)
225 }
226 }
227
228 mr := &llm.Request{
229 Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
230 System: system,
231 Tools: c.Tools,
232 }
233 if c.ToolUseOnly {
234 mr.ToolChoice = &llm.ToolChoice{Type: llm.ToolChoiceTypeAny}
235 }
236 return mr
237}
238
239func (c *Convo) findTool(name string) (*llm.Tool, error) {
240 for _, tool := range c.Tools {
241 if tool.Name == name {
242 return tool, nil
243 }
244 }
245 return nil, fmt.Errorf("tool %q not found", name)
246}
247
248// insertMissingToolResults adds error results for tool uses that were requested
249// but not included in the message, which can happen in error paths like "out of budget."
250// We only insert these if there were no tool responses at all, since an incorrect
251// number of tool results would be a programmer error. Mutates inputs.
252func (c *Convo) insertMissingToolResults(mr *llm.Request, msg *llm.Message) {
253 if len(mr.Messages) < 2 {
254 return
255 }
256 prev := mr.Messages[len(mr.Messages)-2]
257 var toolUsePrev int
258 for _, c := range prev.Content {
259 if c.Type == llm.ContentTypeToolUse {
260 toolUsePrev++
261 }
262 }
263 if toolUsePrev == 0 {
264 return
265 }
266 var toolUseCurrent int
267 for _, c := range msg.Content {
268 if c.Type == llm.ContentTypeToolResult {
269 toolUseCurrent++
270 }
271 }
272 if toolUseCurrent != 0 {
273 return
274 }
275 var prefix []llm.Content
276 for _, part := range prev.Content {
277 if part.Type != llm.ContentTypeToolUse {
278 continue
279 }
280 content := llm.Content{
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700281 Type: llm.ContentTypeToolResult,
282 ToolUseID: part.ID,
283 ToolError: true,
284 ToolResult: []llm.Content{{
285 Type: llm.ContentTypeText,
286 Text: "not executed; retry possible",
287 }},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700288 }
289 prefix = append(prefix, content)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700290 }
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +0000291 msg.Content = append(prefix, msg.Content...)
292 mr.Messages[len(mr.Messages)-1].Content = msg.Content
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700293 slog.DebugContext(c.Ctx, "inserted missing tool results")
294}
295
296// SendMessage sends a message to Claude.
297// The conversation records (internally) all messages succesfully sent and received.
298func (c *Convo) SendMessage(msg llm.Message) (*llm.Response, error) {
299 id := ulid.Make().String()
300 mr := c.messageRequest(msg)
301 var lastMessage *llm.Message
302 if c.PromptCaching {
303 lastMessage = &mr.Messages[len(mr.Messages)-1]
304 if len(lastMessage.Content) > 0 {
305 lastMessage.Content[len(lastMessage.Content)-1].Cache = true
306 }
307 }
308 defer func() {
309 if lastMessage == nil {
310 return
311 }
312 if len(lastMessage.Content) > 0 {
313 lastMessage.Content[len(lastMessage.Content)-1].Cache = false
314 }
315 }()
316 c.insertMissingToolResults(mr, &msg)
317 c.Listener.OnRequest(c.Ctx, c, id, &msg)
318
319 startTime := time.Now()
320 resp, err := c.Service.Do(c.Ctx, mr)
321 if resp != nil {
322 resp.StartTime = &startTime
323 endTime := time.Now()
324 resp.EndTime = &endTime
325 }
326
327 if err != nil {
328 c.Listener.OnResponse(c.Ctx, c, id, nil)
329 return nil, err
330 }
331 c.messages = append(c.messages, msg, resp.ToMessage())
332 // Propagate usage to all ancestors (including us).
333 for x := c; x != nil; x = x.Parent {
334 x.usage.Add(resp.Usage)
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700335 // Store the most recent usage (only on the current conversation, not ancestors)
336 if x == c {
337 x.lastUsage = resp.Usage
338 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700339 }
340 c.Listener.OnResponse(c.Ctx, c, id, resp)
341 return resp, err
342}
343
344type toolCallInfoKeyType string
345
346var toolCallInfoKey toolCallInfoKeyType
347
348type ToolCallInfo struct {
349 ToolUseID string
350 Convo *Convo
351}
352
353func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
354 v := ctx.Value(toolCallInfoKey)
355 i, _ := v.(ToolCallInfo)
356 return i
357}
358
359func (c *Convo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
360 if resp.StopReason != llm.StopReasonToolUse {
361 return nil, nil
362 }
363 var toolResults []llm.Content
364
365 for _, part := range resp.Content {
366 if part.Type != llm.ContentTypeToolUse {
367 continue
368 }
369 c.incrementToolUse(part.ToolName)
370
371 content := llm.Content{
372 Type: llm.ContentTypeToolResult,
373 ToolUseID: part.ID,
374 }
375
376 content.ToolError = true
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700377 content.ToolResult = []llm.Content{{
378 Type: llm.ContentTypeText,
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +0000379 Text: "user canceled this tool_use",
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700380 }}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700381 toolResults = append(toolResults, content)
382 }
383 return toolResults, nil
384}
385
386// GetID returns the conversation ID
387func (c *Convo) GetID() string {
388 return c.ID
389}
390
391func (c *Convo) CancelToolUse(toolUseID string, err error) error {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000392 c.toolUseCancelMu.Lock()
393 defer c.toolUseCancelMu.Unlock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700394 cancel, ok := c.toolUseCancel[toolUseID]
395 if !ok {
396 return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
397 }
398 delete(c.toolUseCancel, toolUseID)
399 cancel(err)
400 return nil
401}
402
403func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000404 c.toolUseCancelMu.Lock()
405 defer c.toolUseCancelMu.Unlock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700406 ctx, cancel := context.WithCancelCause(ctx)
407 c.toolUseCancel[toolUseID] = cancel
408 return ctx, func() { c.CancelToolUse(toolUseID, nil) }
409}
410
411// ToolResultContents runs all tool uses requested by the response and returns their results.
412// Cancelling ctx will cancel any running tool calls.
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000413// The boolean return value indicates whether any of the executed tools should end the turn.
414func (c *Convo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700415 if resp.StopReason != llm.StopReasonToolUse {
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000416 return nil, false, nil
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700417 }
418 // Extract all tool calls from the response, call the tools, and gather the results.
419 var wg sync.WaitGroup
420 toolResultC := make(chan llm.Content, len(resp.Content))
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000421
422 endsTurn := false
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700423 for _, part := range resp.Content {
424 if part.Type != llm.ContentTypeToolUse {
425 continue
426 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000427 tool, err := c.findTool(part.ToolName)
428 if err == nil && tool.EndsTurn {
429 endsTurn = true
430 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700431 c.incrementToolUse(part.ToolName)
432 startTime := time.Now()
433
434 c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, llm.Content{
435 Type: llm.ContentTypeToolUse,
436 ToolUseID: part.ID,
437 ToolUseStartTime: &startTime,
438 })
439
440 wg.Add(1)
441 go func() {
442 defer wg.Done()
443
444 content := llm.Content{
445 Type: llm.ContentTypeToolResult,
446 ToolUseID: part.ID,
447 ToolUseStartTime: &startTime,
448 }
449 sendErr := func(err error) {
450 // Record end time
451 endTime := time.Now()
452 content.ToolUseEndTime = &endTime
453
454 content.ToolError = true
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700455 content.ToolResult = []llm.Content{{
456 Type: llm.ContentTypeText,
457 Text: err.Error(),
458 }}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700459 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
460 toolResultC <- content
461 }
Josh Bleecher Snyder3dd3e412025-07-22 20:32:03 -0700462 sendRes := func(toolOut llm.ToolOut) {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700463 // Record end time
464 endTime := time.Now()
465 content.ToolUseEndTime = &endTime
466
Josh Bleecher Snyder3dd3e412025-07-22 20:32:03 -0700467 content.ToolResult = toolOut.LLMContent
468 content.Display = toolOut.Display
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700469 var firstText string
Josh Bleecher Snyder3dd3e412025-07-22 20:32:03 -0700470 if len(toolOut.LLMContent) > 0 {
471 firstText = toolOut.LLMContent[0].Text
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700472 }
473 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &firstText, nil)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700474 toolResultC <- content
475 }
476
477 tool, err := c.findTool(part.ToolName)
478 if err != nil {
479 sendErr(err)
480 return
481 }
482 // Create a new context for just this tool_use call, and register its
483 // cancel function so that it can be canceled individually.
484 toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
485 defer cancel()
486 // TODO: move this into newToolUseContext?
487 toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
Josh Bleecher Snyder43b60b92025-07-21 14:57:10 -0700488 toolOut := tool.Run(toolUseCtx, part.ToolInput)
489 if errors.Is(toolOut.Error, ErrDoNotRespond) {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700490 return
491 }
492 if toolUseCtx.Err() != nil {
493 sendErr(context.Cause(toolUseCtx))
494 return
495 }
496
Josh Bleecher Snyder43b60b92025-07-21 14:57:10 -0700497 if toolOut.Error != nil {
498 sendErr(toolOut.Error)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700499 return
500 }
Josh Bleecher Snyder3dd3e412025-07-22 20:32:03 -0700501 sendRes(toolOut)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700502 }()
503 }
504 wg.Wait()
505 close(toolResultC)
506 var toolResults []llm.Content
507 for toolResult := range toolResultC {
508 toolResults = append(toolResults, toolResult)
509 }
510 if ctx.Err() != nil {
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000511 return nil, false, ctx.Err()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700512 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000513 return toolResults, endsTurn, nil
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700514}
515
516func (c *Convo) incrementToolUse(name string) {
517 c.mu.Lock()
518 defer c.mu.Unlock()
519
520 c.usage.ToolUses[name]++
521}
522
523// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
524type CumulativeUsage struct {
525 StartTime time.Time `json:"start_time"`
526 Responses uint64 `json:"messages"` // count of responses
527 InputTokens uint64 `json:"input_tokens"`
528 OutputTokens uint64 `json:"output_tokens"`
529 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
530 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
531 TotalCostUSD float64 `json:"total_cost_usd"`
532 ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
533}
534
535func newUsage() *CumulativeUsage {
536 return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
537}
538
539func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
540 return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
541}
542
543func (u *CumulativeUsage) Clone() CumulativeUsage {
544 v := *u
545 v.ToolUses = maps.Clone(u.ToolUses)
546 return v
547}
548
549func (c *Convo) CumulativeUsage() CumulativeUsage {
550 if c == nil {
551 return CumulativeUsage{}
552 }
553 c.mu.Lock()
554 defer c.mu.Unlock()
555 return c.usage.Clone()
556}
557
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700558// LastUsage returns the usage from the most recent API call
559func (c *Convo) LastUsage() llm.Usage {
560 if c == nil {
561 return llm.Usage{}
562 }
563 c.mu.Lock()
564 defer c.mu.Unlock()
565 return c.lastUsage
566}
567
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700568func (u *CumulativeUsage) WallTime() time.Duration {
569 return time.Since(u.StartTime)
570}
571
572func (u *CumulativeUsage) DollarsPerHour() float64 {
573 hours := u.WallTime().Hours()
574 // Prevent division by very small numbers that could cause issues
575 if hours < 1e-6 {
576 return 0
577 }
578 return u.TotalCostUSD / hours
579}
580
581func (u *CumulativeUsage) Add(usage llm.Usage) {
582 u.Responses++
583 u.InputTokens += usage.InputTokens
584 u.OutputTokens += usage.OutputTokens
585 u.CacheReadInputTokens += usage.CacheReadInputTokens
586 u.CacheCreationInputTokens += usage.CacheCreationInputTokens
587 u.TotalCostUSD += usage.CostUSD
588}
589
590// TotalInputTokens returns the grand total cumulative input tokens in u.
591func (u *CumulativeUsage) TotalInputTokens() uint64 {
592 return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
593}
594
595// Attr returns the cumulative usage as a slog.Attr with key "usage".
596func (u CumulativeUsage) Attr() slog.Attr {
597 elapsed := time.Since(u.StartTime)
598 return slog.Group("usage",
599 slog.Duration("wall_time", elapsed),
600 slog.Uint64("responses", u.Responses),
601 slog.Uint64("input_tokens", u.InputTokens),
602 slog.Uint64("output_tokens", u.OutputTokens),
603 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
604 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
605 slog.Float64("total_cost_usd", u.TotalCostUSD),
606 slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
607 slog.Any("tool_uses", maps.Clone(u.ToolUses)),
608 )
609}
610
611// A Budget represents the maximum amount of resources that may be spent on a conversation.
612// Note that the default (zero) budget is unlimited.
613type Budget struct {
Philip Zeyligere6c294d2025-06-04 16:55:21 +0000614 MaxDollars float64 // if > 0, max dollars that may be spent
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700615}
616
617// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
618// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
619func (c *Convo) OverBudget() error {
620 for x := c; x != nil; x = x.Parent {
621 if err := x.overBudget(); err != nil {
622 return err
623 }
624 }
625 return nil
626}
627
628// ResetBudget sets the budget to the passed in budget and
629// adjusts it by what's been used so far.
630func (c *Convo) ResetBudget(budget Budget) {
631 c.Budget = budget
632 if c.Budget.MaxDollars > 0 {
633 c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
634 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700635}
636
637func (c *Convo) overBudget() error {
638 usage := c.CumulativeUsage()
639 // TODO: stop before we exceed the budget instead of after?
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700640 var err error
641 cont := "Continuing to chat will reset the budget."
642 if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
643 err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
644 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700645 return err
646}
Philip Zeyliger43a0bfc2025-07-14 14:54:27 -0700647
648// DebugJSON returns the conversation history as JSON for debugging purposes.
649func (c *Convo) DebugJSON() ([]byte, error) {
650 return json.MarshalIndent(c.messages, "", " ")
651}