Implement tracking of outstanding LLM and Tool calls
This commit implements a listener pattern between ant.convo and the Agent for tracking outstanding calls.
* Added fields to the Agent struct to track outstanding LLM calls and Tool calls
* Implemented the listener methods to properly track and update these fields
* Added methods to retrieve the counts and names
* Updated the State struct in loophttp.go to expose this information
* Added a unit test to verify the tracking functionality
* Created UI components with lightbulb and wrench icons to display call status
* Added numerical indicators that always show when there are active calls
Co-Authored-By: sketch <hello@sketch.dev>
diff --git a/loop/agent.go b/loop/agent.go
index 2f4efe9..fec5dd7 100644
--- a/loop/agent.go
+++ b/loop/agent.go
@@ -73,6 +73,12 @@
// OS returns the operating system of the client.
OS() string
+
+ // OutstandingLLMCallCount returns the number of outstanding LLM calls.
+ OutstandingLLMCallCount() int
+
+ // OutstandingToolCalls returns the names of outstanding tool calls.
+ OutstandingToolCalls() []string
OutsideOS() string
OutsideHostname() string
OutsideWorkingDir() string
@@ -277,6 +283,12 @@
// Track git commits we've already seen (by hash)
seenCommits map[string]bool
+
+ // Track outstanding LLM call IDs
+ outstandingLLMCalls map[string]struct{}
+
+ // Track outstanding tool calls by ID with their names
+ outstandingToolCalls map[string]string
}
func (a *Agent) URL() string { return a.url }
@@ -289,6 +301,25 @@
return a.title
}
+// OutstandingLLMCallCount returns the number of outstanding LLM calls.
+func (a *Agent) OutstandingLLMCallCount() int {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return len(a.outstandingLLMCalls)
+}
+
+// OutstandingToolCalls returns the names of outstanding tool calls.
+func (a *Agent) OutstandingToolCalls() []string {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ tools := make([]string, 0, len(a.outstandingToolCalls))
+ for _, toolName := range a.outstandingToolCalls {
+ tools = append(tools, toolName)
+ }
+ return tools
+}
+
// OS returns the operating system of the client.
func (a *Agent) OS() string {
return a.config.ClientGOOS
@@ -326,8 +357,21 @@
a.listeners = a.listeners[:0]
}
+// OnToolCall implements ant.Listener and tracks the start of a tool call.
+func (a *Agent) OnToolCall(ctx context.Context, convo *ant.Convo, id string, toolName string, toolInput json.RawMessage, content ant.Content) {
+ // Track the tool call
+ a.mu.Lock()
+ a.outstandingToolCalls[id] = toolName
+ a.mu.Unlock()
+}
+
// OnToolResult implements ant.Listener.
-func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
+func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolID string, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
+ // Remove the tool call from outstanding calls
+ a.mu.Lock()
+ delete(a.outstandingToolCalls, toolID)
+ a.mu.Unlock()
+
m := AgentMessage{
Type: ToolUseMessageType,
Content: content.Text,
@@ -354,8 +398,10 @@
}
// OnRequest implements ant.Listener.
-func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, msg *ant.Message) {
- // No-op.
+func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, id string, msg *ant.Message) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.outstandingLLMCalls[id] = struct{}{}
// We already get tool results from the above. We send user messages to the outbox in the agent loop.
}
@@ -363,7 +409,12 @@
// that need to be displayed (as well as tool calls that we send along when
// they're done). (It would be reasonable to also mention tool calls when they're
// started, but we don't do that yet.)
-func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, resp *ant.MessageResponse) {
+func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, id string, resp *ant.MessageResponse) {
+ // Remove the LLM call from outstanding calls
+ a.mu.Lock()
+ delete(a.outstandingLLMCalls, id)
+ a.mu.Unlock()
+
endOfTurn := false
if resp.StopReason != ant.StopReasonToolUse {
endOfTurn = true
@@ -451,16 +502,18 @@
// It is not usable until Init() is called.
func NewAgent(config AgentConfig) *Agent {
agent := &Agent{
- config: config,
- ready: make(chan struct{}),
- inbox: make(chan string, 100),
- outbox: make(chan AgentMessage, 100),
- startedAt: time.Now(),
- originalBudget: config.Budget,
- seenCommits: make(map[string]bool),
- outsideHostname: config.OutsideHostname,
- outsideOS: config.OutsideOS,
- outsideWorkingDir: config.OutsideWorkingDir,
+ config: config,
+ ready: make(chan struct{}),
+ inbox: make(chan string, 100),
+ outbox: make(chan AgentMessage, 100),
+ startedAt: time.Now(),
+ originalBudget: config.Budget,
+ seenCommits: make(map[string]bool),
+ outsideHostname: config.OutsideHostname,
+ outsideOS: config.OutsideOS,
+ outsideWorkingDir: config.OutsideWorkingDir,
+ outstandingLLMCalls: make(map[string]struct{}),
+ outstandingToolCalls: make(map[string]string),
}
return agent
}
diff --git a/loop/agent_test.go b/loop/agent_test.go
index b9f9994..5bde1b1 100644
--- a/loop/agent_test.go
+++ b/loop/agent_test.go
@@ -152,3 +152,55 @@
t.Logf("Agent used %d tools in its response", toolUseCount)
}
+
+func TestAgentTracksOutstandingCalls(t *testing.T) {
+ agent := &Agent{
+ outstandingLLMCalls: make(map[string]struct{}),
+ outstandingToolCalls: make(map[string]string),
+ }
+
+ // Check initial state
+ if count := agent.OutstandingLLMCallCount(); count != 0 {
+ t.Errorf("Expected 0 outstanding LLM calls, got %d", count)
+ }
+
+ if tools := agent.OutstandingToolCalls(); len(tools) != 0 {
+ t.Errorf("Expected 0 outstanding tool calls, got %d", len(tools))
+ }
+
+ // Add some calls
+ agent.mu.Lock()
+ agent.outstandingLLMCalls["llm1"] = struct{}{}
+ agent.outstandingToolCalls["tool1"] = "bash"
+ agent.outstandingToolCalls["tool2"] = "think"
+ agent.mu.Unlock()
+
+ // Check tracking works
+ if count := agent.OutstandingLLMCallCount(); count != 1 {
+ t.Errorf("Expected 1 outstanding LLM call, got %d", count)
+ }
+
+ tools := agent.OutstandingToolCalls()
+ if len(tools) != 2 {
+ t.Errorf("Expected 2 outstanding tool calls, got %d", len(tools))
+ }
+
+ // Check removal
+ agent.mu.Lock()
+ delete(agent.outstandingLLMCalls, "llm1")
+ delete(agent.outstandingToolCalls, "tool1")
+ agent.mu.Unlock()
+
+ if count := agent.OutstandingLLMCallCount(); count != 0 {
+ t.Errorf("Expected 0 outstanding LLM calls after removal, got %d", count)
+ }
+
+ tools = agent.OutstandingToolCalls()
+ if len(tools) != 1 {
+ t.Errorf("Expected 1 outstanding tool call after removal, got %d", len(tools))
+ }
+
+ if tools[0] != "think" {
+ t.Errorf("Expected 'think' tool remaining, got %s", tools[0])
+ }
+}
diff --git a/loop/server/loophttp.go b/loop/server/loophttp.go
index e1b76ad..3ef540b 100644
--- a/loop/server/loophttp.go
+++ b/loop/server/loophttp.go
@@ -50,14 +50,16 @@
}
type State struct {
- MessageCount int `json:"message_count"`
- TotalUsage *ant.CumulativeUsage `json:"total_usage,omitempty"`
- InitialCommit string `json:"initial_commit"`
- Title string `json:"title"`
- Hostname string `json:"hostname"` // deprecated
- WorkingDir string `json:"working_dir"` // deprecated
- OS string `json:"os"` // deprecated
- GitOrigin string `json:"git_origin,omitempty"`
+ MessageCount int `json:"message_count"`
+ TotalUsage *ant.CumulativeUsage `json:"total_usage,omitempty"`
+ InitialCommit string `json:"initial_commit"`
+ Title string `json:"title"`
+ Hostname string `json:"hostname"` // deprecated
+ WorkingDir string `json:"working_dir"` // deprecated
+ OS string `json:"os"` // deprecated
+ GitOrigin string `json:"git_origin,omitempty"`
+ OutstandingLLMCalls int `json:"outstanding_llm_calls"`
+ OutstandingToolCalls []string `json:"outstanding_tool_calls"`
OutsideHostname string `json:"outside_hostname,omitempty"`
InsideHostname string `json:"inside_hostname,omitempty"`
@@ -349,20 +351,22 @@
w.Header().Set("Content-Type", "application/json")
state := State{
- MessageCount: serverMessageCount,
- TotalUsage: &totalUsage,
- Hostname: s.hostname,
- WorkingDir: getWorkingDir(),
- InitialCommit: agent.InitialCommit(),
- Title: agent.Title(),
- OS: agent.OS(),
- OutsideHostname: agent.OutsideHostname(),
- InsideHostname: s.hostname,
- OutsideOS: agent.OutsideOS(),
- InsideOS: agent.OS(),
- OutsideWorkingDir: agent.OutsideWorkingDir(),
- InsideWorkingDir: getWorkingDir(),
- GitOrigin: agent.GitOrigin(),
+ MessageCount: serverMessageCount,
+ TotalUsage: &totalUsage,
+ Hostname: s.hostname,
+ WorkingDir: getWorkingDir(),
+ InitialCommit: agent.InitialCommit(),
+ Title: agent.Title(),
+ OS: agent.OS(),
+ OutsideHostname: agent.OutsideHostname(),
+ InsideHostname: s.hostname,
+ OutsideOS: agent.OutsideOS(),
+ InsideOS: agent.OS(),
+ OutsideWorkingDir: agent.OutsideWorkingDir(),
+ InsideWorkingDir: getWorkingDir(),
+ GitOrigin: agent.GitOrigin(),
+ OutstandingLLMCalls: agent.OutstandingLLMCallCount(),
+ OutstandingToolCalls: agent.OutstandingToolCalls(),
}
// Create a JSON encoder with indentation for pretty-printing