blob: bde0d2041da406538ec0a333ad529e817bab2d50 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
Sean McCullough9f4b8082025-04-30 17:34:07 +00005 "fmt"
Earl Lee2e463fb2025-04-17 11:22:22 -07006 "net/http"
7 "os"
8 "strings"
9 "testing"
10 "time"
11
12 "sketch.dev/ant"
13 "sketch.dev/httprr"
14)
15
16// TestAgentLoop tests that the Agent loop functionality works correctly.
17// It uses the httprr package to record HTTP interactions for replay in tests.
18// When failing, rebuild with "go test ./sketch/loop -run TestAgentLoop -httprecord .*agent_loop.*"
19// as necessary.
20func TestAgentLoop(t *testing.T) {
21 ctx := context.Background()
22
23 // Setup httprr recorder
24 rrPath := "testdata/agent_loop.httprr"
25 rr, err := httprr.Open(rrPath, http.DefaultTransport)
26 if err != nil && !os.IsNotExist(err) {
27 t.Fatal(err)
28 }
29
30 if rr.Recording() {
31 // Skip the test if API key is not available
32 if os.Getenv("ANTHROPIC_API_KEY") == "" {
33 t.Fatal("ANTHROPIC_API_KEY not set, required for HTTP recording")
34 }
35 }
36
37 // Create HTTP client
38 var client *http.Client
39 if rr != nil {
40 // Scrub API keys from requests for security
41 rr.ScrubReq(func(req *http.Request) error {
42 req.Header.Del("x-api-key")
43 req.Header.Del("anthropic-api-key")
44 return nil
45 })
46 client = rr.Client()
47 } else {
48 client = &http.Client{Transport: http.DefaultTransport}
49 }
50
51 // Create a new agent with the httprr client
52 origWD, err := os.Getwd()
53 if err != nil {
54 t.Fatal(err)
55 }
56 if err := os.Chdir("/"); err != nil {
57 t.Fatal(err)
58 }
59 budget := ant.Budget{MaxResponses: 100}
60 wd, err := os.Getwd()
61 if err != nil {
62 t.Fatal(err)
63 }
64
65 cfg := AgentConfig{
66 Context: ctx,
67 APIKey: os.Getenv("ANTHROPIC_API_KEY"),
68 HTTPC: client,
69 Budget: budget,
70 GitUsername: "Test Agent",
71 GitEmail: "totallyhuman@sketch.dev",
72 SessionID: "test-session-id",
73 ClientGOOS: "linux",
74 ClientGOARCH: "amd64",
75 }
76 agent := NewAgent(cfg)
77 if err := os.Chdir(origWD); err != nil {
78 t.Fatal(err)
79 }
80 err = agent.Init(AgentInit{WorkingDir: wd, NoGit: true})
81 if err != nil {
82 t.Fatal(err)
83 }
84
85 // Setup a test message that will trigger a simple, predictable response
86 userMessage := "What tools are available to you? Please just list them briefly."
87
88 // Send the message to the agent
89 agent.UserMessage(ctx, userMessage)
90
91 // Process a single loop iteration to avoid long-running tests
Sean McCullough885a16a2025-04-30 02:49:25 +000092 agent.processTurn(ctx)
Earl Lee2e463fb2025-04-17 11:22:22 -070093
94 // Collect responses with a timeout
95 var responses []AgentMessage
96 timeout := time.After(10 * time.Second)
97 done := false
98
99 for !done {
100 select {
101 case <-timeout:
102 t.Log("Timeout reached while waiting for agent responses")
103 done = true
104 default:
105 select {
106 case msg := <-agent.outbox:
107 t.Logf("Received message: Type=%s, EndOfTurn=%v, Content=%q", msg.Type, msg.EndOfTurn, msg.Content)
108 responses = append(responses, msg)
109 if msg.EndOfTurn {
110 done = true
111 }
112 default:
113 // No more messages available right now
114 time.Sleep(100 * time.Millisecond)
115 }
116 }
117 }
118
119 // Verify we got at least one response
120 if len(responses) == 0 {
121 t.Fatal("No responses received from agent")
122 }
123
124 // Log the received responses for debugging
125 t.Logf("Received %d responses", len(responses))
126
127 // Find the final agent response (with EndOfTurn=true)
128 var finalResponse *AgentMessage
129 for i := range responses {
130 if responses[i].Type == AgentMessageType && responses[i].EndOfTurn {
131 finalResponse = &responses[i]
132 break
133 }
134 }
135
136 // Verify we got a final agent response
137 if finalResponse == nil {
138 t.Fatal("No final agent response received")
139 }
140
141 // Check that the response contains tools information
142 if !strings.Contains(strings.ToLower(finalResponse.Content), "tool") {
143 t.Error("Expected response to mention tools")
144 }
145
146 // Count how many tool use messages we received
147 toolUseCount := 0
148 for _, msg := range responses {
149 if msg.Type == ToolUseMessageType {
150 toolUseCount++
151 }
152 }
153
154 t.Logf("Agent used %d tools in its response", toolUseCount)
155}
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000156
157func TestAgentTracksOutstandingCalls(t *testing.T) {
158 agent := &Agent{
159 outstandingLLMCalls: make(map[string]struct{}),
160 outstandingToolCalls: make(map[string]string),
161 }
162
163 // Check initial state
164 if count := agent.OutstandingLLMCallCount(); count != 0 {
165 t.Errorf("Expected 0 outstanding LLM calls, got %d", count)
166 }
167
168 if tools := agent.OutstandingToolCalls(); len(tools) != 0 {
169 t.Errorf("Expected 0 outstanding tool calls, got %d", len(tools))
170 }
171
172 // Add some calls
173 agent.mu.Lock()
174 agent.outstandingLLMCalls["llm1"] = struct{}{}
175 agent.outstandingToolCalls["tool1"] = "bash"
176 agent.outstandingToolCalls["tool2"] = "think"
177 agent.mu.Unlock()
178
179 // Check tracking works
180 if count := agent.OutstandingLLMCallCount(); count != 1 {
181 t.Errorf("Expected 1 outstanding LLM call, got %d", count)
182 }
183
184 tools := agent.OutstandingToolCalls()
185 if len(tools) != 2 {
186 t.Errorf("Expected 2 outstanding tool calls, got %d", len(tools))
187 }
188
189 // Check removal
190 agent.mu.Lock()
191 delete(agent.outstandingLLMCalls, "llm1")
192 delete(agent.outstandingToolCalls, "tool1")
193 agent.mu.Unlock()
194
195 if count := agent.OutstandingLLMCallCount(); count != 0 {
196 t.Errorf("Expected 0 outstanding LLM calls after removal, got %d", count)
197 }
198
199 tools = agent.OutstandingToolCalls()
200 if len(tools) != 1 {
201 t.Errorf("Expected 1 outstanding tool call after removal, got %d", len(tools))
202 }
203
204 if tools[0] != "think" {
205 t.Errorf("Expected 'think' tool remaining, got %s", tools[0])
206 }
207}
Sean McCullough9f4b8082025-04-30 17:34:07 +0000208
209// TestAgentProcessTurnWithNilResponse tests the scenario where Agent.processTurn receives
210// a nil value for initialResp from processUserMessage.
211func TestAgentProcessTurnWithNilResponse(t *testing.T) {
212 // Create a mock conversation that will return nil and error
213 mockConvo := &MockConvoInterface{
214 sendMessageFunc: func(message ant.Message) (*ant.MessageResponse, error) {
215 return nil, fmt.Errorf("test error: simulating nil response")
216 },
217 }
218
219 // Create a minimal Agent instance for testing
220 agent := &Agent{
221 convo: mockConvo,
222 inbox: make(chan string, 10),
223 outbox: make(chan AgentMessage, 10),
224 outstandingLLMCalls: make(map[string]struct{}),
225 outstandingToolCalls: make(map[string]string),
226 }
227
228 // Create a test context
229 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
230 defer cancel()
231
232 // Push a test message to the inbox so that processUserMessage will try to process it
233 agent.inbox <- "Test message"
234
235 // Call processTurn - it should exit early without panic when initialResp is nil
236 agent.processTurn(ctx)
237
238 // Verify the error message was added to outbox
239 select {
240 case msg := <-agent.outbox:
241 if msg.Type != ErrorMessageType {
242 t.Errorf("Expected error message, got message type: %s", msg.Type)
243 }
244 if !strings.Contains(msg.Content, "simulating nil response") {
245 t.Errorf("Expected error message to contain 'simulating nil response', got: %s", msg.Content)
246 }
247 case <-time.After(time.Second):
248 t.Error("Timed out waiting for error message in outbox")
249 }
250
251 // No more messages should be in the outbox since processTurn should exit early
252 select {
253 case msg := <-agent.outbox:
254 t.Errorf("Expected no more messages in outbox, but got: %+v", msg)
255 case <-time.After(100 * time.Millisecond):
256 // This is the expected outcome - no more messages
257 }
258}
259
260// MockConvoInterface implements the ConvoInterface for testing
261type MockConvoInterface struct {
262 sendMessageFunc func(message ant.Message) (*ant.MessageResponse, error)
263 sendUserTextMessageFunc func(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
264 toolResultContentsFunc func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
265 toolResultCancelContentsFunc func(resp *ant.MessageResponse) ([]ant.Content, error)
266 cancelToolUseFunc func(toolUseID string, cause error) error
267 cumulativeUsageFunc func() ant.CumulativeUsage
268 resetBudgetFunc func(ant.Budget)
269 overBudgetFunc func() error
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700270 getIDFunc func() string
271 subConvoWithHistoryFunc func() *ant.Convo
Sean McCullough9f4b8082025-04-30 17:34:07 +0000272}
273
274func (m *MockConvoInterface) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
275 if m.sendMessageFunc != nil {
276 return m.sendMessageFunc(message)
277 }
278 return nil, nil
279}
280
281func (m *MockConvoInterface) SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
282 if m.sendUserTextMessageFunc != nil {
283 return m.sendUserTextMessageFunc(s, otherContents...)
284 }
285 return nil, nil
286}
287
288func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
289 if m.toolResultContentsFunc != nil {
290 return m.toolResultContentsFunc(ctx, resp)
291 }
292 return nil, nil
293}
294
295func (m *MockConvoInterface) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
296 if m.toolResultCancelContentsFunc != nil {
297 return m.toolResultCancelContentsFunc(resp)
298 }
299 return nil, nil
300}
301
302func (m *MockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
303 if m.cancelToolUseFunc != nil {
304 return m.cancelToolUseFunc(toolUseID, cause)
305 }
306 return nil
307}
308
309func (m *MockConvoInterface) CumulativeUsage() ant.CumulativeUsage {
310 if m.cumulativeUsageFunc != nil {
311 return m.cumulativeUsageFunc()
312 }
313 return ant.CumulativeUsage{}
314}
315
316func (m *MockConvoInterface) ResetBudget(budget ant.Budget) {
317 if m.resetBudgetFunc != nil {
318 m.resetBudgetFunc(budget)
319 }
320}
321
322func (m *MockConvoInterface) OverBudget() error {
323 if m.overBudgetFunc != nil {
324 return m.overBudgetFunc()
325 }
326 return nil
327}
328
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700329func (m *MockConvoInterface) GetID() string {
330 if m.getIDFunc != nil {
331 return m.getIDFunc()
332 }
333 return "mock-convo-id"
334}
335
336func (m *MockConvoInterface) SubConvoWithHistory() *ant.Convo {
337 if m.subConvoWithHistoryFunc != nil {
338 return m.subConvoWithHistoryFunc()
339 }
340 return nil
341}
342
Sean McCullough9f4b8082025-04-30 17:34:07 +0000343// TestAgentProcessTurnWithNilResponseNilError tests the scenario where Agent.processTurn receives
344// a nil value for initialResp and nil error from processUserMessage.
345// This test verifies that the implementation properly handles this edge case.
346func TestAgentProcessTurnWithNilResponseNilError(t *testing.T) {
347 // Create a mock conversation that will return nil response and nil error
348 mockConvo := &MockConvoInterface{
349 sendMessageFunc: func(message ant.Message) (*ant.MessageResponse, error) {
350 return nil, nil // This is unusual but now handled gracefully
351 },
352 }
353
354 // Create a minimal Agent instance for testing
355 agent := &Agent{
356 convo: mockConvo,
357 inbox: make(chan string, 10),
358 outbox: make(chan AgentMessage, 10),
359 outstandingLLMCalls: make(map[string]struct{}),
360 outstandingToolCalls: make(map[string]string),
361 }
362
363 // Create a test context
364 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
365 defer cancel()
366
367 // Push a test message to the inbox so that processUserMessage will try to process it
368 agent.inbox <- "Test message"
369
370 // Call processTurn - it should handle nil initialResp with a descriptive error
371 err := agent.processTurn(ctx)
372
373 // Verify we get the expected error
374 if err == nil {
375 t.Error("Expected processTurn to return an error for nil initialResp, but got nil")
376 } else if !strings.Contains(err.Error(), "unexpected nil response") {
377 t.Errorf("Expected error about nil response, got: %v", err)
378 } else {
379 t.Logf("As expected, processTurn returned error: %v", err)
380 }
381
382 // Verify an error message was sent to the outbox
383 select {
384 case msg := <-agent.outbox:
385 if msg.Type != ErrorMessageType {
386 t.Errorf("Expected error message type, got: %s", msg.Type)
387 }
388 if !strings.Contains(msg.Content, "unexpected nil response") {
389 t.Errorf("Expected error about nil response, got: %s", msg.Content)
390 }
391 case <-time.After(time.Second):
392 t.Error("Timed out waiting for error message in outbox")
393 }
394}