Implement tracking of outstanding LLM and Tool calls
This commit implements a listener pattern between ant.convo and the Agent for tracking outstanding calls.
* Added fields to the Agent struct to track outstanding LLM calls and Tool calls
* Implemented the listener methods to properly track and update these fields
* Added methods to retrieve the counts and names
* Updated the State struct in loophttp.go to expose this information
* Added a unit test to verify the tracking functionality
* Created UI components with lightbulb and wrench icons to display call status
* Added numerical indicators that always show when there are active calls
Co-Authored-By: sketch <hello@sketch.dev>
diff --git a/ant/ant.go b/ant/ant.go
index 463d5fb..69c17c2 100644
--- a/ant/ant.go
+++ b/ant/ant.go
@@ -17,6 +17,7 @@
"testing"
"time"
+ "github.com/oklog/ulid/v2"
"github.com/richardlehane/crock32"
"sketch.dev/skribe"
)
@@ -54,17 +55,23 @@
type Listener interface {
// TODO: Content is leaking an anthropic API; should we avoid it?
// TODO: Where should we include start/end time and usage?
- OnToolResult(ctx context.Context, convo *Convo, toolName string, toolInput json.RawMessage, content Content, result *string, err error)
- OnResponse(ctx context.Context, convo *Convo, msg *MessageResponse)
- OnRequest(ctx context.Context, convo *Convo, msg *Message)
+ OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content Content)
+ OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content Content, result *string, err error)
+ OnRequest(ctx context.Context, convo *Convo, requestID string, msg *Message)
+ OnResponse(ctx context.Context, convo *Convo, requestID string, msg *MessageResponse)
}
type NoopListener struct{}
-func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, toolName string, toolInput json.RawMessage, content Content, result *string, err error) {
+func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content Content) {
}
-func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, msg *MessageResponse) {}
-func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, msg *Message) {}
+
+func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content Content, result *string, err error) {
+}
+
+func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *MessageResponse) {
+}
+func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *Message) {}
type Content struct {
// TODO: image support?
@@ -577,6 +584,7 @@
// SendMessage sends a message to Claude.
// The conversation records (internally) all messages succesfully sent and received.
func (c *Convo) SendMessage(msg Message) (*MessageResponse, error) {
+ id := ulid.Make().String()
mr := c.messageRequest(msg)
var lastMessage *Message
if c.PromptCaching {
@@ -594,7 +602,7 @@
}
}()
c.insertMissingToolResults(mr, &msg)
- c.Listener.OnRequest(c.Ctx, c, &msg)
+ c.Listener.OnRequest(c.Ctx, c, id, &msg)
startTime := time.Now()
resp, err := createMessage(c.Ctx, c.HTTPC, c.URL, c.APIKey, mr)
@@ -605,6 +613,7 @@
}
if err != nil {
+ c.Listener.OnResponse(c.Ctx, c, id, nil)
return nil, err
}
c.messages = append(c.messages, msg, resp.ToMessage())
@@ -612,7 +621,7 @@
for x := c; x != nil; x = x.Parent {
x.usage.AddResponse(resp)
}
- c.Listener.OnResponse(c.Ctx, c, resp)
+ c.Listener.OnResponse(c.Ctx, c, id, resp)
return resp, err
}
@@ -689,13 +698,18 @@
continue
}
c.incrementToolUse(part.ToolName)
+ startTime := time.Now()
+
+ c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, Content{
+ Type: ContentTypeToolUse,
+ ToolUseID: part.ID,
+ StartTime: &startTime,
+ })
+
wg.Add(1)
go func() {
defer wg.Done()
- // Record start time
- startTime := time.Now()
-
content := Content{
Type: ContentTypeToolResult,
ToolUseID: part.ID,
@@ -708,7 +722,7 @@
content.ToolError = true
content.ToolResult = err.Error()
- c.Listener.OnToolResult(ctx, c, part.ToolName, part.ToolInput, content, nil, err)
+ c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
toolResultC <- content
}
sendRes := func(res string) {
@@ -717,7 +731,7 @@
content.EndTime = &endTime
content.ToolResult = res
- c.Listener.OnToolResult(ctx, c, part.ToolName, part.ToolInput, content, &res, nil)
+ c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &res, nil)
toolResultC <- content
}