blob: 12c334f5edbcfcef3e4fb6aa52ab98af17d3cf48 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package conversation
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "log/slog"
9 "maps"
10 "math/rand/v2"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/oklog/ulid/v2"
17 "github.com/richardlehane/crock32"
18 "sketch.dev/llm"
19 "sketch.dev/skribe"
20)
21
22type Listener interface {
23 // TODO: Content is leaking an anthropic API; should we avoid it?
24 // TODO: Where should we include start/end time and usage?
25 OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content)
26 OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error)
27 OnRequest(ctx context.Context, convo *Convo, requestID string, msg *llm.Message)
28 OnResponse(ctx context.Context, convo *Convo, requestID string, msg *llm.Response)
29}
30
31type NoopListener struct{}
32
33func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content) {
34}
35
36func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
37}
38
39func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *llm.Response) {
40}
41func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {}
42
43var ErrDoNotRespond = errors.New("do not respond")
44
45// A Convo is a managed conversation with Claude.
46// It automatically manages the state of the conversation,
47// including appending messages send/received,
48// calling tools and sending their results,
49// tracking usage, etc.
50//
51// Exported fields must not be altered concurrently with calling any method on Convo.
52// Typical usage is to configure a Convo once before using it.
53type Convo struct {
54 // ID is a unique ID for the conversation
55 ID string
56 // Ctx is the context for the entire conversation.
57 Ctx context.Context
58 // Service is the LLM service to use.
59 Service llm.Service
60 // Tools are the tools available during the conversation.
61 Tools []*llm.Tool
62 // SystemPrompt is the system prompt for the conversation.
63 SystemPrompt string
64 // PromptCaching indicates whether to use Anthropic's prompt caching.
65 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
66 // for the documentation. At request send time, we set the cache_control field on the
67 // last message. We also cache the system prompt.
68 // Default: true.
69 PromptCaching bool
70 // ToolUseOnly indicates whether Claude may only use tools during this conversation.
71 // TODO: add more fine-grained control over tool use?
72 ToolUseOnly bool
73 // Parent is the parent conversation, if any.
74 // It is non-nil for "subagent" calls.
75 // It is set automatically when calling SubConvo,
76 // and usually should not be set manually.
77 Parent *Convo
78 // Budget is the budget for this conversation (and all sub-conversations).
79 // The Conversation DOES NOT automatically enforce the budget.
80 // It is up to the caller to call OverBudget() as appropriate.
81 Budget Budget
Josh Bleecher Snyder4d544932025-05-07 13:33:53 +000082 // Hidden indicates that the output of this conversation should be hidden in the UI.
83 // This is useful for subconversations that can generate noisy, uninteresting output.
84 Hidden bool
Josh Bleecher Snyder31785ae2025-05-06 01:50:58 +000085 // ExtraData is extra data to make available to all tool calls.
86 ExtraData map[string]any
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070087
88 // messages tracks the messages so far in the conversation.
89 messages []llm.Message
90
91 // Listener receives messages being sent.
92 Listener Listener
93
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +000094 toolUseCancelMu sync.Mutex
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070095 toolUseCancel map[string]context.CancelCauseFunc
96
97 // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
98 mu *sync.Mutex
99 // usage tracks usage for this conversation and all sub-conversations.
100 usage *CumulativeUsage
101}
102
103// newConvoID generates a new 8-byte random id.
104// The uniqueness/collision requirements here are very low.
105// They are not global identifiers,
106// just enough to distinguish different convos in a single session.
107func newConvoID() string {
108 u1 := rand.Uint32()
109 s := crock32.Encode(uint64(u1))
110 if len(s) < 7 {
111 s += strings.Repeat("0", 7-len(s))
112 }
113 return s[:3] + "-" + s[3:]
114}
115
116// New creates a new conversation with Claude with sensible defaults.
117// ctx is the context for the entire conversation.
118func New(ctx context.Context, srv llm.Service) *Convo {
119 id := newConvoID()
120 return &Convo{
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000121 Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
122 Service: srv,
123 PromptCaching: true,
124 usage: newUsage(),
125 Listener: &NoopListener{},
126 ID: id,
127 toolUseCancel: map[string]context.CancelCauseFunc{},
128 mu: &sync.Mutex{},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700129 }
130}
131
132// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
133// (This propagates context for cancellation, HTTP client, API key, etc.)
134// The sub-conversation shares no messages with the parent conversation.
135// It does not inherit tools from the parent conversation.
136func (c *Convo) SubConvo() *Convo {
137 id := newConvoID()
138 return &Convo{
139 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
140 Service: c.Service,
141 PromptCaching: c.PromptCaching,
142 Parent: c,
143 // For convenience, sub-convo usage shares tool uses map with parent,
144 // all other fields separate, propagated in AddResponse
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000145 usage: newUsageWithSharedToolUses(c.usage),
146 mu: c.mu,
147 Listener: c.Listener,
148 ID: id,
149 toolUseCancel: map[string]context.CancelCauseFunc{},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700150 // Do not copy Budget. Each budget is independent,
151 // and OverBudget checks whether any ancestor is over budget.
152 }
153}
154
155func (c *Convo) SubConvoWithHistory() *Convo {
156 id := newConvoID()
157 return &Convo{
158 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
159 Service: c.Service,
160 PromptCaching: c.PromptCaching,
161 Parent: c,
162 // For convenience, sub-convo usage shares tool uses map with parent,
163 // all other fields separate, propagated in AddResponse
164 usage: newUsageWithSharedToolUses(c.usage),
165 mu: c.mu,
166 Listener: c.Listener,
167 ID: id,
168 // Do not copy Budget. Each budget is independent,
169 // and OverBudget checks whether any ancestor is over budget.
170 messages: slices.Clone(c.messages),
171 }
172}
173
174// Depth reports how many "sub-conversations" deep this conversation is.
175// That it, it walks up parents until it finds a root.
176func (c *Convo) Depth() int {
177 x := c
178 var depth int
179 for x.Parent != nil {
180 x = x.Parent
181 depth++
182 }
183 return depth
184}
185
186// SendUserTextMessage sends a text message to the LLM in this conversation.
187// otherContents contains additional contents to send with the message, usually tool results.
188func (c *Convo) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
189 contents := slices.Clone(otherContents)
190 if s != "" {
191 contents = append(contents, llm.Content{Type: llm.ContentTypeText, Text: s})
192 }
193 msg := llm.Message{
194 Role: llm.MessageRoleUser,
195 Content: contents,
196 }
197 return c.SendMessage(msg)
198}
199
200func (c *Convo) messageRequest(msg llm.Message) *llm.Request {
201 system := []llm.SystemContent{}
202 if c.SystemPrompt != "" {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000203 d := llm.SystemContent{Type: "text", Text: c.SystemPrompt}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700204 if c.PromptCaching {
205 d.Cache = true
206 }
207 system = []llm.SystemContent{d}
208 }
209
210 // Claude is happy to return an empty response in response to our Done() call,
211 // and, if so, you'll see something like:
212 // API request failed with status 400 Bad Request
213 // {"type":"error","error": {"type":"invalid_request_error",
214 // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
215 // So, we filter out those empty messages.
216 var nonEmptyMessages []llm.Message
217 for _, m := range c.messages {
218 if len(m.Content) > 0 {
219 nonEmptyMessages = append(nonEmptyMessages, m)
220 }
221 }
222
223 mr := &llm.Request{
224 Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
225 System: system,
226 Tools: c.Tools,
227 }
228 if c.ToolUseOnly {
229 mr.ToolChoice = &llm.ToolChoice{Type: llm.ToolChoiceTypeAny}
230 }
231 return mr
232}
233
234func (c *Convo) findTool(name string) (*llm.Tool, error) {
235 for _, tool := range c.Tools {
236 if tool.Name == name {
237 return tool, nil
238 }
239 }
240 return nil, fmt.Errorf("tool %q not found", name)
241}
242
243// insertMissingToolResults adds error results for tool uses that were requested
244// but not included in the message, which can happen in error paths like "out of budget."
245// We only insert these if there were no tool responses at all, since an incorrect
246// number of tool results would be a programmer error. Mutates inputs.
247func (c *Convo) insertMissingToolResults(mr *llm.Request, msg *llm.Message) {
248 if len(mr.Messages) < 2 {
249 return
250 }
251 prev := mr.Messages[len(mr.Messages)-2]
252 var toolUsePrev int
253 for _, c := range prev.Content {
254 if c.Type == llm.ContentTypeToolUse {
255 toolUsePrev++
256 }
257 }
258 if toolUsePrev == 0 {
259 return
260 }
261 var toolUseCurrent int
262 for _, c := range msg.Content {
263 if c.Type == llm.ContentTypeToolResult {
264 toolUseCurrent++
265 }
266 }
267 if toolUseCurrent != 0 {
268 return
269 }
270 var prefix []llm.Content
271 for _, part := range prev.Content {
272 if part.Type != llm.ContentTypeToolUse {
273 continue
274 }
275 content := llm.Content{
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700276 Type: llm.ContentTypeToolResult,
277 ToolUseID: part.ID,
278 ToolError: true,
279 ToolResult: []llm.Content{{
280 Type: llm.ContentTypeText,
281 Text: "not executed; retry possible",
282 }},
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700283 }
284 prefix = append(prefix, content)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700285 }
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +0000286 msg.Content = append(prefix, msg.Content...)
287 mr.Messages[len(mr.Messages)-1].Content = msg.Content
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700288 slog.DebugContext(c.Ctx, "inserted missing tool results")
289}
290
291// SendMessage sends a message to Claude.
292// The conversation records (internally) all messages succesfully sent and received.
293func (c *Convo) SendMessage(msg llm.Message) (*llm.Response, error) {
294 id := ulid.Make().String()
295 mr := c.messageRequest(msg)
296 var lastMessage *llm.Message
297 if c.PromptCaching {
298 lastMessage = &mr.Messages[len(mr.Messages)-1]
299 if len(lastMessage.Content) > 0 {
300 lastMessage.Content[len(lastMessage.Content)-1].Cache = true
301 }
302 }
303 defer func() {
304 if lastMessage == nil {
305 return
306 }
307 if len(lastMessage.Content) > 0 {
308 lastMessage.Content[len(lastMessage.Content)-1].Cache = false
309 }
310 }()
311 c.insertMissingToolResults(mr, &msg)
312 c.Listener.OnRequest(c.Ctx, c, id, &msg)
313
314 startTime := time.Now()
315 resp, err := c.Service.Do(c.Ctx, mr)
316 if resp != nil {
317 resp.StartTime = &startTime
318 endTime := time.Now()
319 resp.EndTime = &endTime
320 }
321
322 if err != nil {
323 c.Listener.OnResponse(c.Ctx, c, id, nil)
324 return nil, err
325 }
326 c.messages = append(c.messages, msg, resp.ToMessage())
327 // Propagate usage to all ancestors (including us).
328 for x := c; x != nil; x = x.Parent {
329 x.usage.Add(resp.Usage)
330 }
331 c.Listener.OnResponse(c.Ctx, c, id, resp)
332 return resp, err
333}
334
335type toolCallInfoKeyType string
336
337var toolCallInfoKey toolCallInfoKeyType
338
339type ToolCallInfo struct {
340 ToolUseID string
341 Convo *Convo
342}
343
344func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
345 v := ctx.Value(toolCallInfoKey)
346 i, _ := v.(ToolCallInfo)
347 return i
348}
349
350func (c *Convo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
351 if resp.StopReason != llm.StopReasonToolUse {
352 return nil, nil
353 }
354 var toolResults []llm.Content
355
356 for _, part := range resp.Content {
357 if part.Type != llm.ContentTypeToolUse {
358 continue
359 }
360 c.incrementToolUse(part.ToolName)
361
362 content := llm.Content{
363 Type: llm.ContentTypeToolResult,
364 ToolUseID: part.ID,
365 }
366
367 content.ToolError = true
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700368 content.ToolResult = []llm.Content{{
369 Type: llm.ContentTypeText,
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +0000370 Text: "user canceled this tool_use",
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700371 }}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700372 toolResults = append(toolResults, content)
373 }
374 return toolResults, nil
375}
376
377// GetID returns the conversation ID
378func (c *Convo) GetID() string {
379 return c.ID
380}
381
382func (c *Convo) CancelToolUse(toolUseID string, err error) error {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000383 c.toolUseCancelMu.Lock()
384 defer c.toolUseCancelMu.Unlock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700385 cancel, ok := c.toolUseCancel[toolUseID]
386 if !ok {
387 return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
388 }
389 delete(c.toolUseCancel, toolUseID)
390 cancel(err)
391 return nil
392}
393
394func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000395 c.toolUseCancelMu.Lock()
396 defer c.toolUseCancelMu.Unlock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700397 ctx, cancel := context.WithCancelCause(ctx)
398 c.toolUseCancel[toolUseID] = cancel
399 return ctx, func() { c.CancelToolUse(toolUseID, nil) }
400}
401
402// ToolResultContents runs all tool uses requested by the response and returns their results.
403// Cancelling ctx will cancel any running tool calls.
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000404// The boolean return value indicates whether any of the executed tools should end the turn.
405func (c *Convo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700406 if resp.StopReason != llm.StopReasonToolUse {
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000407 return nil, false, nil
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700408 }
409 // Extract all tool calls from the response, call the tools, and gather the results.
410 var wg sync.WaitGroup
411 toolResultC := make(chan llm.Content, len(resp.Content))
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000412
413 endsTurn := false
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700414 for _, part := range resp.Content {
415 if part.Type != llm.ContentTypeToolUse {
416 continue
417 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000418 tool, err := c.findTool(part.ToolName)
419 if err == nil && tool.EndsTurn {
420 endsTurn = true
421 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700422 c.incrementToolUse(part.ToolName)
423 startTime := time.Now()
424
425 c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, llm.Content{
426 Type: llm.ContentTypeToolUse,
427 ToolUseID: part.ID,
428 ToolUseStartTime: &startTime,
429 })
430
431 wg.Add(1)
432 go func() {
433 defer wg.Done()
434
435 content := llm.Content{
436 Type: llm.ContentTypeToolResult,
437 ToolUseID: part.ID,
438 ToolUseStartTime: &startTime,
439 }
440 sendErr := func(err error) {
441 // Record end time
442 endTime := time.Now()
443 content.ToolUseEndTime = &endTime
444
445 content.ToolError = true
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700446 content.ToolResult = []llm.Content{{
447 Type: llm.ContentTypeText,
448 Text: err.Error(),
449 }}
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700450 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
451 toolResultC <- content
452 }
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700453 sendRes := func(toolResult []llm.Content) {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700454 // Record end time
455 endTime := time.Now()
456 content.ToolUseEndTime = &endTime
457
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700458 content.ToolResult = toolResult
459 var firstText string
460 if len(toolResult) > 0 {
461 firstText = toolResult[0].Text
462 }
463 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &firstText, nil)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700464 toolResultC <- content
465 }
466
467 tool, err := c.findTool(part.ToolName)
468 if err != nil {
469 sendErr(err)
470 return
471 }
472 // Create a new context for just this tool_use call, and register its
473 // cancel function so that it can be canceled individually.
474 toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
475 defer cancel()
476 // TODO: move this into newToolUseContext?
477 toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
478 toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
479 if errors.Is(err, ErrDoNotRespond) {
480 return
481 }
482 if toolUseCtx.Err() != nil {
483 sendErr(context.Cause(toolUseCtx))
484 return
485 }
486
487 if err != nil {
488 sendErr(err)
489 return
490 }
491 sendRes(toolResult)
492 }()
493 }
494 wg.Wait()
495 close(toolResultC)
496 var toolResults []llm.Content
497 for toolResult := range toolResultC {
498 toolResults = append(toolResults, toolResult)
499 }
500 if ctx.Err() != nil {
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000501 return nil, false, ctx.Err()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700502 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000503 return toolResults, endsTurn, nil
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700504}
505
506func (c *Convo) incrementToolUse(name string) {
507 c.mu.Lock()
508 defer c.mu.Unlock()
509
510 c.usage.ToolUses[name]++
511}
512
513// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
514type CumulativeUsage struct {
515 StartTime time.Time `json:"start_time"`
516 Responses uint64 `json:"messages"` // count of responses
517 InputTokens uint64 `json:"input_tokens"`
518 OutputTokens uint64 `json:"output_tokens"`
519 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
520 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
521 TotalCostUSD float64 `json:"total_cost_usd"`
522 ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
523}
524
525func newUsage() *CumulativeUsage {
526 return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
527}
528
529func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
530 return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
531}
532
533func (u *CumulativeUsage) Clone() CumulativeUsage {
534 v := *u
535 v.ToolUses = maps.Clone(u.ToolUses)
536 return v
537}
538
539func (c *Convo) CumulativeUsage() CumulativeUsage {
540 if c == nil {
541 return CumulativeUsage{}
542 }
543 c.mu.Lock()
544 defer c.mu.Unlock()
545 return c.usage.Clone()
546}
547
548func (u *CumulativeUsage) WallTime() time.Duration {
549 return time.Since(u.StartTime)
550}
551
552func (u *CumulativeUsage) DollarsPerHour() float64 {
553 hours := u.WallTime().Hours()
554 // Prevent division by very small numbers that could cause issues
555 if hours < 1e-6 {
556 return 0
557 }
558 return u.TotalCostUSD / hours
559}
560
561func (u *CumulativeUsage) Add(usage llm.Usage) {
562 u.Responses++
563 u.InputTokens += usage.InputTokens
564 u.OutputTokens += usage.OutputTokens
565 u.CacheReadInputTokens += usage.CacheReadInputTokens
566 u.CacheCreationInputTokens += usage.CacheCreationInputTokens
567 u.TotalCostUSD += usage.CostUSD
568}
569
570// TotalInputTokens returns the grand total cumulative input tokens in u.
571func (u *CumulativeUsage) TotalInputTokens() uint64 {
572 return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
573}
574
575// Attr returns the cumulative usage as a slog.Attr with key "usage".
576func (u CumulativeUsage) Attr() slog.Attr {
577 elapsed := time.Since(u.StartTime)
578 return slog.Group("usage",
579 slog.Duration("wall_time", elapsed),
580 slog.Uint64("responses", u.Responses),
581 slog.Uint64("input_tokens", u.InputTokens),
582 slog.Uint64("output_tokens", u.OutputTokens),
583 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
584 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
585 slog.Float64("total_cost_usd", u.TotalCostUSD),
586 slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
587 slog.Any("tool_uses", maps.Clone(u.ToolUses)),
588 )
589}
590
591// A Budget represents the maximum amount of resources that may be spent on a conversation.
592// Note that the default (zero) budget is unlimited.
593type Budget struct {
594 MaxResponses uint64 // if > 0, max number of iterations (=responses)
595 MaxDollars float64 // if > 0, max dollars that may be spent
596 MaxWallTime time.Duration // if > 0, max wall time that may be spent
597}
598
599// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
600// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
601func (c *Convo) OverBudget() error {
602 for x := c; x != nil; x = x.Parent {
603 if err := x.overBudget(); err != nil {
604 return err
605 }
606 }
607 return nil
608}
609
610// ResetBudget sets the budget to the passed in budget and
611// adjusts it by what's been used so far.
612func (c *Convo) ResetBudget(budget Budget) {
613 c.Budget = budget
614 if c.Budget.MaxDollars > 0 {
615 c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
616 }
617 if c.Budget.MaxResponses > 0 {
618 c.Budget.MaxResponses += c.CumulativeUsage().Responses
619 }
620 if c.Budget.MaxWallTime > 0 {
621 c.Budget.MaxWallTime += c.usage.WallTime()
622 }
623}
624
625func (c *Convo) overBudget() error {
626 usage := c.CumulativeUsage()
627 // TODO: stop before we exceed the budget instead of after?
628 // Top priority is money, then time, then response count.
629 var err error
630 cont := "Continuing to chat will reset the budget."
631 if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
632 err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
633 }
634 if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
635 err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
636 }
637 if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
638 err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
639 }
640 return err
641}