blob: 7860a07172c29cdfc44e0ea4bd5e06f61e9fb3b8 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package conversation
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "log/slog"
9 "maps"
10 "math/rand/v2"
11 "slices"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/oklog/ulid/v2"
17 "github.com/richardlehane/crock32"
18 "sketch.dev/llm"
19 "sketch.dev/skribe"
20)
21
22type Listener interface {
23 // TODO: Content is leaking an anthropic API; should we avoid it?
24 // TODO: Where should we include start/end time and usage?
25 OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content)
26 OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error)
27 OnRequest(ctx context.Context, convo *Convo, requestID string, msg *llm.Message)
28 OnResponse(ctx context.Context, convo *Convo, requestID string, msg *llm.Response)
29}
30
31type NoopListener struct{}
32
33func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content) {
34}
35
36func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
37}
38
39func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *llm.Response) {
40}
41func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {}
42
43var ErrDoNotRespond = errors.New("do not respond")
44
45// A Convo is a managed conversation with Claude.
46// It automatically manages the state of the conversation,
47// including appending messages send/received,
48// calling tools and sending their results,
49// tracking usage, etc.
50//
51// Exported fields must not be altered concurrently with calling any method on Convo.
52// Typical usage is to configure a Convo once before using it.
53type Convo struct {
54 // ID is a unique ID for the conversation
55 ID string
56 // Ctx is the context for the entire conversation.
57 Ctx context.Context
58 // Service is the LLM service to use.
59 Service llm.Service
60 // Tools are the tools available during the conversation.
61 Tools []*llm.Tool
62 // SystemPrompt is the system prompt for the conversation.
63 SystemPrompt string
64 // PromptCaching indicates whether to use Anthropic's prompt caching.
65 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
66 // for the documentation. At request send time, we set the cache_control field on the
67 // last message. We also cache the system prompt.
68 // Default: true.
69 PromptCaching bool
70 // ToolUseOnly indicates whether Claude may only use tools during this conversation.
71 // TODO: add more fine-grained control over tool use?
72 ToolUseOnly bool
73 // Parent is the parent conversation, if any.
74 // It is non-nil for "subagent" calls.
75 // It is set automatically when calling SubConvo,
76 // and usually should not be set manually.
77 Parent *Convo
78 // Budget is the budget for this conversation (and all sub-conversations).
79 // The Conversation DOES NOT automatically enforce the budget.
80 // It is up to the caller to call OverBudget() as appropriate.
81 Budget Budget
Josh Bleecher Snyder4d544932025-05-07 13:33:53 +000082 // Hidden indicates that the output of this conversation should be hidden in the UI.
83 // This is useful for subconversations that can generate noisy, uninteresting output.
84 Hidden bool
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070085
86 // messages tracks the messages so far in the conversation.
87 messages []llm.Message
88
89 // Listener receives messages being sent.
90 Listener Listener
91
92 muToolUseCancel *sync.Mutex
93 toolUseCancel map[string]context.CancelCauseFunc
94
95 // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
96 mu *sync.Mutex
97 // usage tracks usage for this conversation and all sub-conversations.
98 usage *CumulativeUsage
99}
100
101// newConvoID generates a new 8-byte random id.
102// The uniqueness/collision requirements here are very low.
103// They are not global identifiers,
104// just enough to distinguish different convos in a single session.
105func newConvoID() string {
106 u1 := rand.Uint32()
107 s := crock32.Encode(uint64(u1))
108 if len(s) < 7 {
109 s += strings.Repeat("0", 7-len(s))
110 }
111 return s[:3] + "-" + s[3:]
112}
113
114// New creates a new conversation with Claude with sensible defaults.
115// ctx is the context for the entire conversation.
116func New(ctx context.Context, srv llm.Service) *Convo {
117 id := newConvoID()
118 return &Convo{
119 Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
120 Service: srv,
121 PromptCaching: true,
122 usage: newUsage(),
123 Listener: &NoopListener{},
124 ID: id,
125 muToolUseCancel: &sync.Mutex{},
126 toolUseCancel: map[string]context.CancelCauseFunc{},
127 mu: &sync.Mutex{},
128 }
129}
130
131// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
132// (This propagates context for cancellation, HTTP client, API key, etc.)
133// The sub-conversation shares no messages with the parent conversation.
134// It does not inherit tools from the parent conversation.
135func (c *Convo) SubConvo() *Convo {
136 id := newConvoID()
137 return &Convo{
138 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
139 Service: c.Service,
140 PromptCaching: c.PromptCaching,
141 Parent: c,
142 // For convenience, sub-convo usage shares tool uses map with parent,
143 // all other fields separate, propagated in AddResponse
144 usage: newUsageWithSharedToolUses(c.usage),
145 mu: c.mu,
146 Listener: c.Listener,
147 ID: id,
148 // Do not copy Budget. Each budget is independent,
149 // and OverBudget checks whether any ancestor is over budget.
150 }
151}
152
153func (c *Convo) SubConvoWithHistory() *Convo {
154 id := newConvoID()
155 return &Convo{
156 Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
157 Service: c.Service,
158 PromptCaching: c.PromptCaching,
159 Parent: c,
160 // For convenience, sub-convo usage shares tool uses map with parent,
161 // all other fields separate, propagated in AddResponse
162 usage: newUsageWithSharedToolUses(c.usage),
163 mu: c.mu,
164 Listener: c.Listener,
165 ID: id,
166 // Do not copy Budget. Each budget is independent,
167 // and OverBudget checks whether any ancestor is over budget.
168 messages: slices.Clone(c.messages),
169 }
170}
171
172// Depth reports how many "sub-conversations" deep this conversation is.
173// That it, it walks up parents until it finds a root.
174func (c *Convo) Depth() int {
175 x := c
176 var depth int
177 for x.Parent != nil {
178 x = x.Parent
179 depth++
180 }
181 return depth
182}
183
184// SendUserTextMessage sends a text message to the LLM in this conversation.
185// otherContents contains additional contents to send with the message, usually tool results.
186func (c *Convo) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
187 contents := slices.Clone(otherContents)
188 if s != "" {
189 contents = append(contents, llm.Content{Type: llm.ContentTypeText, Text: s})
190 }
191 msg := llm.Message{
192 Role: llm.MessageRoleUser,
193 Content: contents,
194 }
195 return c.SendMessage(msg)
196}
197
198func (c *Convo) messageRequest(msg llm.Message) *llm.Request {
199 system := []llm.SystemContent{}
200 if c.SystemPrompt != "" {
201 var d llm.SystemContent
202 d = llm.SystemContent{Type: "text", Text: c.SystemPrompt}
203 if c.PromptCaching {
204 d.Cache = true
205 }
206 system = []llm.SystemContent{d}
207 }
208
209 // Claude is happy to return an empty response in response to our Done() call,
210 // and, if so, you'll see something like:
211 // API request failed with status 400 Bad Request
212 // {"type":"error","error": {"type":"invalid_request_error",
213 // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
214 // So, we filter out those empty messages.
215 var nonEmptyMessages []llm.Message
216 for _, m := range c.messages {
217 if len(m.Content) > 0 {
218 nonEmptyMessages = append(nonEmptyMessages, m)
219 }
220 }
221
222 mr := &llm.Request{
223 Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
224 System: system,
225 Tools: c.Tools,
226 }
227 if c.ToolUseOnly {
228 mr.ToolChoice = &llm.ToolChoice{Type: llm.ToolChoiceTypeAny}
229 }
230 return mr
231}
232
233func (c *Convo) findTool(name string) (*llm.Tool, error) {
234 for _, tool := range c.Tools {
235 if tool.Name == name {
236 return tool, nil
237 }
238 }
239 return nil, fmt.Errorf("tool %q not found", name)
240}
241
242// insertMissingToolResults adds error results for tool uses that were requested
243// but not included in the message, which can happen in error paths like "out of budget."
244// We only insert these if there were no tool responses at all, since an incorrect
245// number of tool results would be a programmer error. Mutates inputs.
246func (c *Convo) insertMissingToolResults(mr *llm.Request, msg *llm.Message) {
247 if len(mr.Messages) < 2 {
248 return
249 }
250 prev := mr.Messages[len(mr.Messages)-2]
251 var toolUsePrev int
252 for _, c := range prev.Content {
253 if c.Type == llm.ContentTypeToolUse {
254 toolUsePrev++
255 }
256 }
257 if toolUsePrev == 0 {
258 return
259 }
260 var toolUseCurrent int
261 for _, c := range msg.Content {
262 if c.Type == llm.ContentTypeToolResult {
263 toolUseCurrent++
264 }
265 }
266 if toolUseCurrent != 0 {
267 return
268 }
269 var prefix []llm.Content
270 for _, part := range prev.Content {
271 if part.Type != llm.ContentTypeToolUse {
272 continue
273 }
274 content := llm.Content{
275 Type: llm.ContentTypeToolResult,
276 ToolUseID: part.ID,
277 ToolError: true,
278 ToolResult: "not executed; retry possible",
279 }
280 prefix = append(prefix, content)
281 msg.Content = append(prefix, msg.Content...)
282 mr.Messages[len(mr.Messages)-1].Content = msg.Content
283 }
284 slog.DebugContext(c.Ctx, "inserted missing tool results")
285}
286
287// SendMessage sends a message to Claude.
288// The conversation records (internally) all messages succesfully sent and received.
289func (c *Convo) SendMessage(msg llm.Message) (*llm.Response, error) {
290 id := ulid.Make().String()
291 mr := c.messageRequest(msg)
292 var lastMessage *llm.Message
293 if c.PromptCaching {
294 lastMessage = &mr.Messages[len(mr.Messages)-1]
295 if len(lastMessage.Content) > 0 {
296 lastMessage.Content[len(lastMessage.Content)-1].Cache = true
297 }
298 }
299 defer func() {
300 if lastMessage == nil {
301 return
302 }
303 if len(lastMessage.Content) > 0 {
304 lastMessage.Content[len(lastMessage.Content)-1].Cache = false
305 }
306 }()
307 c.insertMissingToolResults(mr, &msg)
308 c.Listener.OnRequest(c.Ctx, c, id, &msg)
309
310 startTime := time.Now()
311 resp, err := c.Service.Do(c.Ctx, mr)
312 if resp != nil {
313 resp.StartTime = &startTime
314 endTime := time.Now()
315 resp.EndTime = &endTime
316 }
317
318 if err != nil {
319 c.Listener.OnResponse(c.Ctx, c, id, nil)
320 return nil, err
321 }
322 c.messages = append(c.messages, msg, resp.ToMessage())
323 // Propagate usage to all ancestors (including us).
324 for x := c; x != nil; x = x.Parent {
325 x.usage.Add(resp.Usage)
326 }
327 c.Listener.OnResponse(c.Ctx, c, id, resp)
328 return resp, err
329}
330
331type toolCallInfoKeyType string
332
333var toolCallInfoKey toolCallInfoKeyType
334
335type ToolCallInfo struct {
336 ToolUseID string
337 Convo *Convo
338}
339
340func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
341 v := ctx.Value(toolCallInfoKey)
342 i, _ := v.(ToolCallInfo)
343 return i
344}
345
346func (c *Convo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
347 if resp.StopReason != llm.StopReasonToolUse {
348 return nil, nil
349 }
350 var toolResults []llm.Content
351
352 for _, part := range resp.Content {
353 if part.Type != llm.ContentTypeToolUse {
354 continue
355 }
356 c.incrementToolUse(part.ToolName)
357
358 content := llm.Content{
359 Type: llm.ContentTypeToolResult,
360 ToolUseID: part.ID,
361 }
362
363 content.ToolError = true
364 content.ToolResult = "user canceled this too_use"
365 toolResults = append(toolResults, content)
366 }
367 return toolResults, nil
368}
369
370// GetID returns the conversation ID
371func (c *Convo) GetID() string {
372 return c.ID
373}
374
375func (c *Convo) CancelToolUse(toolUseID string, err error) error {
376 c.muToolUseCancel.Lock()
377 defer c.muToolUseCancel.Unlock()
378 cancel, ok := c.toolUseCancel[toolUseID]
379 if !ok {
380 return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
381 }
382 delete(c.toolUseCancel, toolUseID)
383 cancel(err)
384 return nil
385}
386
387func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
388 c.muToolUseCancel.Lock()
389 defer c.muToolUseCancel.Unlock()
390 ctx, cancel := context.WithCancelCause(ctx)
391 c.toolUseCancel[toolUseID] = cancel
392 return ctx, func() { c.CancelToolUse(toolUseID, nil) }
393}
394
395// ToolResultContents runs all tool uses requested by the response and returns their results.
396// Cancelling ctx will cancel any running tool calls.
397func (c *Convo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
398 if resp.StopReason != llm.StopReasonToolUse {
399 return nil, nil
400 }
401 // Extract all tool calls from the response, call the tools, and gather the results.
402 var wg sync.WaitGroup
403 toolResultC := make(chan llm.Content, len(resp.Content))
404 for _, part := range resp.Content {
405 if part.Type != llm.ContentTypeToolUse {
406 continue
407 }
408 c.incrementToolUse(part.ToolName)
409 startTime := time.Now()
410
411 c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, llm.Content{
412 Type: llm.ContentTypeToolUse,
413 ToolUseID: part.ID,
414 ToolUseStartTime: &startTime,
415 })
416
417 wg.Add(1)
418 go func() {
419 defer wg.Done()
420
421 content := llm.Content{
422 Type: llm.ContentTypeToolResult,
423 ToolUseID: part.ID,
424 ToolUseStartTime: &startTime,
425 }
426 sendErr := func(err error) {
427 // Record end time
428 endTime := time.Now()
429 content.ToolUseEndTime = &endTime
430
431 content.ToolError = true
432 content.ToolResult = err.Error()
433 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
434 toolResultC <- content
435 }
436 sendRes := func(res string) {
437 // Record end time
438 endTime := time.Now()
439 content.ToolUseEndTime = &endTime
440
441 content.ToolResult = res
442 c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &res, nil)
443 toolResultC <- content
444 }
445
446 tool, err := c.findTool(part.ToolName)
447 if err != nil {
448 sendErr(err)
449 return
450 }
451 // Create a new context for just this tool_use call, and register its
452 // cancel function so that it can be canceled individually.
453 toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
454 defer cancel()
455 // TODO: move this into newToolUseContext?
456 toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
457 toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
458 if errors.Is(err, ErrDoNotRespond) {
459 return
460 }
461 if toolUseCtx.Err() != nil {
462 sendErr(context.Cause(toolUseCtx))
463 return
464 }
465
466 if err != nil {
467 sendErr(err)
468 return
469 }
470 sendRes(toolResult)
471 }()
472 }
473 wg.Wait()
474 close(toolResultC)
475 var toolResults []llm.Content
476 for toolResult := range toolResultC {
477 toolResults = append(toolResults, toolResult)
478 }
479 if ctx.Err() != nil {
480 return nil, ctx.Err()
481 }
482 return toolResults, nil
483}
484
485func (c *Convo) incrementToolUse(name string) {
486 c.mu.Lock()
487 defer c.mu.Unlock()
488
489 c.usage.ToolUses[name]++
490}
491
492// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
493type CumulativeUsage struct {
494 StartTime time.Time `json:"start_time"`
495 Responses uint64 `json:"messages"` // count of responses
496 InputTokens uint64 `json:"input_tokens"`
497 OutputTokens uint64 `json:"output_tokens"`
498 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
499 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
500 TotalCostUSD float64 `json:"total_cost_usd"`
501 ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
502}
503
504func newUsage() *CumulativeUsage {
505 return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
506}
507
508func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
509 return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
510}
511
512func (u *CumulativeUsage) Clone() CumulativeUsage {
513 v := *u
514 v.ToolUses = maps.Clone(u.ToolUses)
515 return v
516}
517
518func (c *Convo) CumulativeUsage() CumulativeUsage {
519 if c == nil {
520 return CumulativeUsage{}
521 }
522 c.mu.Lock()
523 defer c.mu.Unlock()
524 return c.usage.Clone()
525}
526
527func (u *CumulativeUsage) WallTime() time.Duration {
528 return time.Since(u.StartTime)
529}
530
531func (u *CumulativeUsage) DollarsPerHour() float64 {
532 hours := u.WallTime().Hours()
533 // Prevent division by very small numbers that could cause issues
534 if hours < 1e-6 {
535 return 0
536 }
537 return u.TotalCostUSD / hours
538}
539
540func (u *CumulativeUsage) Add(usage llm.Usage) {
541 u.Responses++
542 u.InputTokens += usage.InputTokens
543 u.OutputTokens += usage.OutputTokens
544 u.CacheReadInputTokens += usage.CacheReadInputTokens
545 u.CacheCreationInputTokens += usage.CacheCreationInputTokens
546 u.TotalCostUSD += usage.CostUSD
547}
548
549// TotalInputTokens returns the grand total cumulative input tokens in u.
550func (u *CumulativeUsage) TotalInputTokens() uint64 {
551 return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
552}
553
554// Attr returns the cumulative usage as a slog.Attr with key "usage".
555func (u CumulativeUsage) Attr() slog.Attr {
556 elapsed := time.Since(u.StartTime)
557 return slog.Group("usage",
558 slog.Duration("wall_time", elapsed),
559 slog.Uint64("responses", u.Responses),
560 slog.Uint64("input_tokens", u.InputTokens),
561 slog.Uint64("output_tokens", u.OutputTokens),
562 slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
563 slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
564 slog.Float64("total_cost_usd", u.TotalCostUSD),
565 slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
566 slog.Any("tool_uses", maps.Clone(u.ToolUses)),
567 )
568}
569
570// A Budget represents the maximum amount of resources that may be spent on a conversation.
571// Note that the default (zero) budget is unlimited.
572type Budget struct {
573 MaxResponses uint64 // if > 0, max number of iterations (=responses)
574 MaxDollars float64 // if > 0, max dollars that may be spent
575 MaxWallTime time.Duration // if > 0, max wall time that may be spent
576}
577
578// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
579// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
580func (c *Convo) OverBudget() error {
581 for x := c; x != nil; x = x.Parent {
582 if err := x.overBudget(); err != nil {
583 return err
584 }
585 }
586 return nil
587}
588
589// ResetBudget sets the budget to the passed in budget and
590// adjusts it by what's been used so far.
591func (c *Convo) ResetBudget(budget Budget) {
592 c.Budget = budget
593 if c.Budget.MaxDollars > 0 {
594 c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
595 }
596 if c.Budget.MaxResponses > 0 {
597 c.Budget.MaxResponses += c.CumulativeUsage().Responses
598 }
599 if c.Budget.MaxWallTime > 0 {
600 c.Budget.MaxWallTime += c.usage.WallTime()
601 }
602}
603
604func (c *Convo) overBudget() error {
605 usage := c.CumulativeUsage()
606 // TODO: stop before we exceed the budget instead of after?
607 // Top priority is money, then time, then response count.
608 var err error
609 cont := "Continuing to chat will reset the budget."
610 if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
611 err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
612 }
613 if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
614 err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
615 }
616 if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
617 err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
618 }
619 return err
620}