blob: bf9d6a95f57520c68d0603bd51fb0c9a698a8b71 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
Josh Bleecher Snyder4d5e9972025-05-01 15:56:37 -07004 "cmp"
Earl Lee2e463fb2025-04-17 11:22:22 -07005 "context"
Sean McCullough9f4b8082025-04-30 17:34:07 +00006 "fmt"
Earl Lee2e463fb2025-04-17 11:22:22 -07007 "net/http"
8 "os"
Sean McCullough96b60dd2025-04-30 09:49:10 -07009 "slices"
Earl Lee2e463fb2025-04-17 11:22:22 -070010 "strings"
11 "testing"
12 "time"
13
Earl Lee2e463fb2025-04-17 11:22:22 -070014 "sketch.dev/httprr"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070015 "sketch.dev/llm"
16 "sketch.dev/llm/ant"
17 "sketch.dev/llm/conversation"
Earl Lee2e463fb2025-04-17 11:22:22 -070018)
19
20// TestAgentLoop tests that the Agent loop functionality works correctly.
21// It uses the httprr package to record HTTP interactions for replay in tests.
22// When failing, rebuild with "go test ./sketch/loop -run TestAgentLoop -httprecord .*agent_loop.*"
23// as necessary.
24func TestAgentLoop(t *testing.T) {
25 ctx := context.Background()
26
27 // Setup httprr recorder
28 rrPath := "testdata/agent_loop.httprr"
29 rr, err := httprr.Open(rrPath, http.DefaultTransport)
30 if err != nil && !os.IsNotExist(err) {
31 t.Fatal(err)
32 }
33
34 if rr.Recording() {
35 // Skip the test if API key is not available
36 if os.Getenv("ANTHROPIC_API_KEY") == "" {
37 t.Fatal("ANTHROPIC_API_KEY not set, required for HTTP recording")
38 }
39 }
40
41 // Create HTTP client
42 var client *http.Client
43 if rr != nil {
44 // Scrub API keys from requests for security
45 rr.ScrubReq(func(req *http.Request) error {
46 req.Header.Del("x-api-key")
47 req.Header.Del("anthropic-api-key")
48 return nil
49 })
50 client = rr.Client()
51 } else {
52 client = &http.Client{Transport: http.DefaultTransport}
53 }
54
55 // Create a new agent with the httprr client
56 origWD, err := os.Getwd()
57 if err != nil {
58 t.Fatal(err)
59 }
60 if err := os.Chdir("/"); err != nil {
61 t.Fatal(err)
62 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070063 budget := conversation.Budget{MaxResponses: 100}
Earl Lee2e463fb2025-04-17 11:22:22 -070064 wd, err := os.Getwd()
65 if err != nil {
66 t.Fatal(err)
67 }
68
David Crawshaw3659d872025-05-05 17:52:23 -070069 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
Earl Lee2e463fb2025-04-17 11:22:22 -070070 cfg := AgentConfig{
Philip Zeyligerbc8c8dc2025-05-21 13:19:13 -070071 Context: ctx,
72 WorkingDir: wd,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070073 Service: &ant.Service{
74 APIKey: apiKey,
75 HTTPC: client,
76 },
Earl Lee2e463fb2025-04-17 11:22:22 -070077 Budget: budget,
78 GitUsername: "Test Agent",
79 GitEmail: "totallyhuman@sketch.dev",
80 SessionID: "test-session-id",
81 ClientGOOS: "linux",
82 ClientGOARCH: "amd64",
83 }
84 agent := NewAgent(cfg)
85 if err := os.Chdir(origWD); err != nil {
86 t.Fatal(err)
87 }
Philip Zeyligerbc8c8dc2025-05-21 13:19:13 -070088 err = agent.Init(AgentInit{NoGit: true})
Earl Lee2e463fb2025-04-17 11:22:22 -070089 if err != nil {
90 t.Fatal(err)
91 }
92
93 // Setup a test message that will trigger a simple, predictable response
Josh Bleecher Snyderd2f54c22025-05-07 18:38:07 -070094 userMessage := "What tools are available to you? Please just list them briefly. (Do not call the title tool.)"
Earl Lee2e463fb2025-04-17 11:22:22 -070095
96 // Send the message to the agent
97 agent.UserMessage(ctx, userMessage)
98
99 // Process a single loop iteration to avoid long-running tests
Sean McCullough885a16a2025-04-30 02:49:25 +0000100 agent.processTurn(ctx)
Earl Lee2e463fb2025-04-17 11:22:22 -0700101
102 // Collect responses with a timeout
103 var responses []AgentMessage
Philip Zeyliger9373c072025-05-01 10:27:01 -0700104 ctx2, cancel := context.WithDeadline(ctx, time.Now().Add(10*time.Second))
105 defer cancel()
Earl Lee2e463fb2025-04-17 11:22:22 -0700106 done := false
Philip Zeyligerb7c58752025-05-01 10:10:17 -0700107 it := agent.NewIterator(ctx2, 0)
Earl Lee2e463fb2025-04-17 11:22:22 -0700108
109 for !done {
Philip Zeyligerb7c58752025-05-01 10:10:17 -0700110 msg := it.Next()
111 t.Logf("Received message: Type=%s, EndOfTurn=%v, Content=%q", msg.Type, msg.EndOfTurn, msg.Content)
112 responses = append(responses, *msg)
113 if msg.EndOfTurn {
Earl Lee2e463fb2025-04-17 11:22:22 -0700114 done = true
Earl Lee2e463fb2025-04-17 11:22:22 -0700115 }
116 }
117
118 // Verify we got at least one response
119 if len(responses) == 0 {
120 t.Fatal("No responses received from agent")
121 }
122
123 // Log the received responses for debugging
124 t.Logf("Received %d responses", len(responses))
125
126 // Find the final agent response (with EndOfTurn=true)
127 var finalResponse *AgentMessage
128 for i := range responses {
129 if responses[i].Type == AgentMessageType && responses[i].EndOfTurn {
130 finalResponse = &responses[i]
131 break
132 }
133 }
134
135 // Verify we got a final agent response
136 if finalResponse == nil {
137 t.Fatal("No final agent response received")
138 }
139
140 // Check that the response contains tools information
141 if !strings.Contains(strings.ToLower(finalResponse.Content), "tool") {
142 t.Error("Expected response to mention tools")
143 }
144
145 // Count how many tool use messages we received
146 toolUseCount := 0
147 for _, msg := range responses {
148 if msg.Type == ToolUseMessageType {
149 toolUseCount++
150 }
151 }
152
153 t.Logf("Agent used %d tools in its response", toolUseCount)
154}
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000155
156func TestAgentTracksOutstandingCalls(t *testing.T) {
157 agent := &Agent{
158 outstandingLLMCalls: make(map[string]struct{}),
159 outstandingToolCalls: make(map[string]string),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700160 stateMachine: NewStateMachine(),
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000161 }
162
163 // Check initial state
164 if count := agent.OutstandingLLMCallCount(); count != 0 {
165 t.Errorf("Expected 0 outstanding LLM calls, got %d", count)
166 }
167
168 if tools := agent.OutstandingToolCalls(); len(tools) != 0 {
169 t.Errorf("Expected 0 outstanding tool calls, got %d", len(tools))
170 }
171
172 // Add some calls
173 agent.mu.Lock()
174 agent.outstandingLLMCalls["llm1"] = struct{}{}
175 agent.outstandingToolCalls["tool1"] = "bash"
176 agent.outstandingToolCalls["tool2"] = "think"
177 agent.mu.Unlock()
178
179 // Check tracking works
180 if count := agent.OutstandingLLMCallCount(); count != 1 {
181 t.Errorf("Expected 1 outstanding LLM call, got %d", count)
182 }
183
184 tools := agent.OutstandingToolCalls()
185 if len(tools) != 2 {
186 t.Errorf("Expected 2 outstanding tool calls, got %d", len(tools))
187 }
188
189 // Check removal
190 agent.mu.Lock()
191 delete(agent.outstandingLLMCalls, "llm1")
192 delete(agent.outstandingToolCalls, "tool1")
193 agent.mu.Unlock()
194
195 if count := agent.OutstandingLLMCallCount(); count != 0 {
196 t.Errorf("Expected 0 outstanding LLM calls after removal, got %d", count)
197 }
198
199 tools = agent.OutstandingToolCalls()
200 if len(tools) != 1 {
201 t.Errorf("Expected 1 outstanding tool call after removal, got %d", len(tools))
202 }
203
204 if tools[0] != "think" {
205 t.Errorf("Expected 'think' tool remaining, got %s", tools[0])
206 }
207}
Sean McCullough9f4b8082025-04-30 17:34:07 +0000208
209// TestAgentProcessTurnWithNilResponse tests the scenario where Agent.processTurn receives
210// a nil value for initialResp from processUserMessage.
211func TestAgentProcessTurnWithNilResponse(t *testing.T) {
212 // Create a mock conversation that will return nil and error
213 mockConvo := &MockConvoInterface{
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700214 sendMessageFunc: func(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000215 return nil, fmt.Errorf("test error: simulating nil response")
216 },
217 }
218
219 // Create a minimal Agent instance for testing
220 agent := &Agent{
221 convo: mockConvo,
222 inbox: make(chan string, 10),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700223 subscribers: []chan *AgentMessage{},
Sean McCullough9f4b8082025-04-30 17:34:07 +0000224 outstandingLLMCalls: make(map[string]struct{}),
225 outstandingToolCalls: make(map[string]string),
226 }
227
228 // Create a test context
229 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
230 defer cancel()
231
232 // Push a test message to the inbox so that processUserMessage will try to process it
233 agent.inbox <- "Test message"
234
235 // Call processTurn - it should exit early without panic when initialResp is nil
236 agent.processTurn(ctx)
237
Philip Zeyliger9373c072025-05-01 10:27:01 -0700238 // Verify error message was added to history
239 agent.mu.Lock()
240 defer agent.mu.Unlock()
241
242 // There should be exactly one message
243 if len(agent.history) != 1 {
244 t.Errorf("Expected exactly one message, got %d", len(agent.history))
245 } else {
246 msg := agent.history[0]
Sean McCullough9f4b8082025-04-30 17:34:07 +0000247 if msg.Type != ErrorMessageType {
248 t.Errorf("Expected error message, got message type: %s", msg.Type)
249 }
250 if !strings.Contains(msg.Content, "simulating nil response") {
251 t.Errorf("Expected error message to contain 'simulating nil response', got: %s", msg.Content)
252 }
Sean McCullough9f4b8082025-04-30 17:34:07 +0000253 }
254}
255
256// MockConvoInterface implements the ConvoInterface for testing
257type MockConvoInterface struct {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700258 sendMessageFunc func(message llm.Message) (*llm.Response, error)
259 sendUserTextMessageFunc func(s string, otherContents ...llm.Content) (*llm.Response, error)
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000260 toolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700261 toolResultCancelContentsFunc func(resp *llm.Response) ([]llm.Content, error)
Sean McCullough9f4b8082025-04-30 17:34:07 +0000262 cancelToolUseFunc func(toolUseID string, cause error) error
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700263 cumulativeUsageFunc func() conversation.CumulativeUsage
264 resetBudgetFunc func(conversation.Budget)
Sean McCullough9f4b8082025-04-30 17:34:07 +0000265 overBudgetFunc func() error
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700266 getIDFunc func() string
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700267 subConvoWithHistoryFunc func() *conversation.Convo
Sean McCullough9f4b8082025-04-30 17:34:07 +0000268}
269
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700270func (m *MockConvoInterface) SendMessage(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000271 if m.sendMessageFunc != nil {
272 return m.sendMessageFunc(message)
273 }
274 return nil, nil
275}
276
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700277func (m *MockConvoInterface) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000278 if m.sendUserTextMessageFunc != nil {
279 return m.sendUserTextMessageFunc(s, otherContents...)
280 }
281 return nil, nil
282}
283
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000284func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000285 if m.toolResultContentsFunc != nil {
286 return m.toolResultContentsFunc(ctx, resp)
287 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000288 return nil, false, nil
Sean McCullough9f4b8082025-04-30 17:34:07 +0000289}
290
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700291func (m *MockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000292 if m.toolResultCancelContentsFunc != nil {
293 return m.toolResultCancelContentsFunc(resp)
294 }
295 return nil, nil
296}
297
298func (m *MockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
299 if m.cancelToolUseFunc != nil {
300 return m.cancelToolUseFunc(toolUseID, cause)
301 }
302 return nil
303}
304
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700305func (m *MockConvoInterface) CumulativeUsage() conversation.CumulativeUsage {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000306 if m.cumulativeUsageFunc != nil {
307 return m.cumulativeUsageFunc()
308 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700309 return conversation.CumulativeUsage{}
Sean McCullough9f4b8082025-04-30 17:34:07 +0000310}
311
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700312func (m *MockConvoInterface) ResetBudget(budget conversation.Budget) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000313 if m.resetBudgetFunc != nil {
314 m.resetBudgetFunc(budget)
315 }
316}
317
318func (m *MockConvoInterface) OverBudget() error {
319 if m.overBudgetFunc != nil {
320 return m.overBudgetFunc()
321 }
322 return nil
323}
324
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700325func (m *MockConvoInterface) GetID() string {
326 if m.getIDFunc != nil {
327 return m.getIDFunc()
328 }
329 return "mock-convo-id"
330}
331
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700332func (m *MockConvoInterface) SubConvoWithHistory() *conversation.Convo {
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700333 if m.subConvoWithHistoryFunc != nil {
334 return m.subConvoWithHistoryFunc()
335 }
336 return nil
337}
338
Sean McCullough9f4b8082025-04-30 17:34:07 +0000339// TestAgentProcessTurnWithNilResponseNilError tests the scenario where Agent.processTurn receives
340// a nil value for initialResp and nil error from processUserMessage.
341// This test verifies that the implementation properly handles this edge case.
342func TestAgentProcessTurnWithNilResponseNilError(t *testing.T) {
343 // Create a mock conversation that will return nil response and nil error
344 mockConvo := &MockConvoInterface{
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700345 sendMessageFunc: func(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000346 return nil, nil // This is unusual but now handled gracefully
347 },
348 }
349
350 // Create a minimal Agent instance for testing
351 agent := &Agent{
352 convo: mockConvo,
353 inbox: make(chan string, 10),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700354 subscribers: []chan *AgentMessage{},
Sean McCullough9f4b8082025-04-30 17:34:07 +0000355 outstandingLLMCalls: make(map[string]struct{}),
356 outstandingToolCalls: make(map[string]string),
357 }
358
359 // Create a test context
360 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
361 defer cancel()
362
363 // Push a test message to the inbox so that processUserMessage will try to process it
364 agent.inbox <- "Test message"
365
366 // Call processTurn - it should handle nil initialResp with a descriptive error
367 err := agent.processTurn(ctx)
368
369 // Verify we get the expected error
370 if err == nil {
371 t.Error("Expected processTurn to return an error for nil initialResp, but got nil")
372 } else if !strings.Contains(err.Error(), "unexpected nil response") {
373 t.Errorf("Expected error about nil response, got: %v", err)
374 } else {
375 t.Logf("As expected, processTurn returned error: %v", err)
376 }
377
Philip Zeyliger9373c072025-05-01 10:27:01 -0700378 // Verify error message was added to history
379 agent.mu.Lock()
380 defer agent.mu.Unlock()
381
382 // There should be exactly one message
383 if len(agent.history) != 1 {
384 t.Errorf("Expected exactly one message, got %d", len(agent.history))
385 } else {
386 msg := agent.history[0]
Sean McCullough9f4b8082025-04-30 17:34:07 +0000387 if msg.Type != ErrorMessageType {
388 t.Errorf("Expected error message type, got: %s", msg.Type)
389 }
390 if !strings.Contains(msg.Content, "unexpected nil response") {
391 t.Errorf("Expected error about nil response, got: %s", msg.Content)
392 }
Sean McCullough9f4b8082025-04-30 17:34:07 +0000393 }
394}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700395
396func TestAgentStateMachine(t *testing.T) {
397 // Create a simplified test for the state machine functionality
398 agent := &Agent{
399 stateMachine: NewStateMachine(),
400 }
401
402 // Initially the state should be Ready
403 if state := agent.CurrentState(); state != StateReady {
404 t.Errorf("Expected initial state to be StateReady, got %s", state)
405 }
406
407 // Test manual transitions to verify state tracking
408 ctx := context.Background()
409
410 // Track transitions
411 var transitions []State
412 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
413 transitions = append(transitions, to)
414 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
415 })
416
417 // Perform a valid sequence of transitions (based on the state machine rules)
418 expectedStates := []State{
419 StateWaitingForUserInput,
420 StateSendingToLLM,
421 StateProcessingLLMResponse,
422 StateToolUseRequested,
423 StateCheckingForCancellation,
424 StateRunningTool,
425 StateCheckingGitCommits,
426 StateRunningAutoformatters,
427 StateCheckingBudget,
428 StateGatheringAdditionalMessages,
429 StateSendingToolResults,
430 StateProcessingLLMResponse,
431 StateEndOfTurn,
432 }
433
434 // Manually perform each transition
435 for _, state := range expectedStates {
436 err := agent.stateMachine.Transition(ctx, state, "Test transition to "+state.String())
437 if err != nil {
438 t.Errorf("Failed to transition to %s: %v", state, err)
439 }
440 }
441
442 // Check if we recorded the right number of transitions
443 if len(transitions) != len(expectedStates) {
444 t.Errorf("Expected %d state transitions, got %d", len(expectedStates), len(transitions))
445 }
446
447 // Check each transition matched what we expected
448 for i, expected := range expectedStates {
449 if i < len(transitions) {
450 if transitions[i] != expected {
451 t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
452 }
453 }
454 }
455
456 // Verify the current state is the last one we transitioned to
457 if state := agent.CurrentState(); state != expectedStates[len(expectedStates)-1] {
458 t.Errorf("Expected current state to be %s, got %s", expectedStates[len(expectedStates)-1], state)
459 }
460
461 // Test force transition
462 agent.stateMachine.ForceTransition(ctx, StateCancelled, "Testing force transition")
463
464 // Verify current state was updated
465 if state := agent.CurrentState(); state != StateCancelled {
466 t.Errorf("Expected forced state to be StateCancelled, got %s", state)
467 }
468}
469
470// mockConvoInterface is a mock implementation of ConvoInterface for testing
471type mockConvoInterface struct {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700472 SendMessageFunc func(message llm.Message) (*llm.Response, error)
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000473 ToolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error)
Sean McCullough96b60dd2025-04-30 09:49:10 -0700474}
475
476func (c *mockConvoInterface) GetID() string {
477 return "mockConvoInterface-id"
478}
479
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700480func (c *mockConvoInterface) SubConvoWithHistory() *conversation.Convo {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700481 return nil
482}
483
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700484func (m *mockConvoInterface) CumulativeUsage() conversation.CumulativeUsage {
485 return conversation.CumulativeUsage{}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700486}
487
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700488func (m *mockConvoInterface) ResetBudget(conversation.Budget) {}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700489
490func (m *mockConvoInterface) OverBudget() error {
491 return nil
492}
493
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700494func (m *mockConvoInterface) SendMessage(message llm.Message) (*llm.Response, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700495 if m.SendMessageFunc != nil {
496 return m.SendMessageFunc(message)
497 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700498 return &llm.Response{StopReason: llm.StopReasonEndTurn}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700499}
500
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700501func (m *mockConvoInterface) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
502 return m.SendMessage(llm.UserStringMessage(s))
Sean McCullough96b60dd2025-04-30 09:49:10 -0700503}
504
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000505func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700506 if m.ToolResultContentsFunc != nil {
507 return m.ToolResultContentsFunc(ctx, resp)
508 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000509 return []llm.Content{}, false, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700510}
511
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700512func (m *mockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
513 return []llm.Content{llm.StringContent("Tool use cancelled")}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700514}
515
516func (m *mockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
517 return nil
518}
519
520func TestAgentProcessTurnStateTransitions(t *testing.T) {
521 // Create a mock ConvoInterface for testing
522 mockConvo := &mockConvoInterface{}
523
524 // Use the testing context
525 ctx := t.Context()
526
527 // Create an agent with the state machine
528 agent := &Agent{
529 convo: mockConvo,
530 config: AgentConfig{Context: ctx},
531 inbox: make(chan string, 10),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700532 ready: make(chan struct{}),
533 seenCommits: make(map[string]bool),
534 outstandingLLMCalls: make(map[string]struct{}),
535 outstandingToolCalls: make(map[string]string),
536 stateMachine: NewStateMachine(),
537 startOfTurn: time.Now(),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700538 subscribers: []chan *AgentMessage{},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700539 }
540
541 // Verify initial state
542 if state := agent.CurrentState(); state != StateReady {
543 t.Errorf("Expected initial state to be StateReady, got %s", state)
544 }
545
546 // Add a message to the inbox so we don't block in GatherMessages
547 agent.inbox <- "Test message"
548
549 // Setup the mock to simulate a model response with end of turn
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700550 mockConvo.SendMessageFunc = func(message llm.Message) (*llm.Response, error) {
551 return &llm.Response{
552 StopReason: llm.StopReasonEndTurn,
553 Content: []llm.Content{
554 llm.StringContent("This is a test response"),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700555 },
556 }, nil
557 }
558
559 // Track state transitions
560 var transitions []State
561 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
562 transitions = append(transitions, to)
563 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
564 })
565
566 // Process a turn, which should trigger state transitions
567 agent.processTurn(ctx)
568
569 // The minimum expected states for a simple end-of-turn response
570 minExpectedStates := []State{
571 StateWaitingForUserInput,
572 StateSendingToLLM,
573 StateProcessingLLMResponse,
574 StateEndOfTurn,
575 }
576
577 // Verify we have at least the minimum expected states
578 if len(transitions) < len(minExpectedStates) {
579 t.Errorf("Expected at least %d state transitions, got %d", len(minExpectedStates), len(transitions))
580 }
581
582 // Check that the transitions follow the expected sequence
583 for i, expected := range minExpectedStates {
584 if i < len(transitions) {
585 if transitions[i] != expected {
586 t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
587 }
588 }
589 }
590
591 // Verify the final state is EndOfTurn
592 if state := agent.CurrentState(); state != StateEndOfTurn {
593 t.Errorf("Expected final state to be StateEndOfTurn, got %s", state)
594 }
595}
596
597func TestAgentProcessTurnWithToolUse(t *testing.T) {
598 // Create a mock ConvoInterface for testing
599 mockConvo := &mockConvoInterface{}
600
601 // Setup a test context
602 ctx := context.Background()
603
604 // Create an agent with the state machine
605 agent := &Agent{
606 convo: mockConvo,
607 config: AgentConfig{Context: ctx},
608 inbox: make(chan string, 10),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700609 ready: make(chan struct{}),
610 seenCommits: make(map[string]bool),
611 outstandingLLMCalls: make(map[string]struct{}),
612 outstandingToolCalls: make(map[string]string),
613 stateMachine: NewStateMachine(),
614 startOfTurn: time.Now(),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700615 subscribers: []chan *AgentMessage{},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700616 }
617
618 // Add a message to the inbox so we don't block in GatherMessages
619 agent.inbox <- "Test message"
620
621 // First response requests a tool
622 firstResponseDone := false
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700623 mockConvo.SendMessageFunc = func(message llm.Message) (*llm.Response, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700624 if !firstResponseDone {
625 firstResponseDone = true
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700626 return &llm.Response{
627 StopReason: llm.StopReasonToolUse,
628 Content: []llm.Content{
629 llm.StringContent("I'll use a tool"),
630 {Type: llm.ContentTypeToolUse, ToolName: "test_tool", ToolInput: []byte("{}"), ID: "test_id"},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700631 },
632 }, nil
633 }
634 // Second response ends the turn
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700635 return &llm.Response{
636 StopReason: llm.StopReasonEndTurn,
637 Content: []llm.Content{
638 llm.StringContent("Finished using the tool"),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700639 },
640 }, nil
641 }
642
643 // Tool result content handler
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000644 mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
645 return []llm.Content{llm.StringContent("Tool executed successfully")}, false, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700646 }
647
648 // Track state transitions
649 var transitions []State
650 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
651 transitions = append(transitions, to)
652 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
653 })
654
655 // Process a turn with tool use
656 agent.processTurn(ctx)
657
658 // Define expected states for a tool use flow
659 expectedToolStates := []State{
660 StateWaitingForUserInput,
661 StateSendingToLLM,
662 StateProcessingLLMResponse,
663 StateToolUseRequested,
664 StateCheckingForCancellation,
665 StateRunningTool,
666 }
667
668 // Verify that these states are present in order
669 for i, expectedState := range expectedToolStates {
670 if i >= len(transitions) {
671 t.Errorf("Missing expected transition to %s; only got %d transitions", expectedState, len(transitions))
672 continue
673 }
674 if transitions[i] != expectedState {
675 t.Errorf("Expected transition %d to be %s, got %s", i, expectedState, transitions[i])
676 }
677 }
678
679 // Also verify we eventually reached EndOfTurn
680 if !slices.Contains(transitions, StateEndOfTurn) {
681 t.Errorf("Expected to eventually reach StateEndOfTurn, but never did")
682 }
683}
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700684
685func TestContentToString(t *testing.T) {
686 tests := []struct {
687 name string
688 contents []llm.Content
689 want string
690 }{
691 {
692 name: "empty",
693 contents: []llm.Content{},
694 want: "",
695 },
696 {
697 name: "single text content",
698 contents: []llm.Content{
699 {Type: llm.ContentTypeText, Text: "hello world"},
700 },
701 want: "hello world",
702 },
703 {
704 name: "multiple text content",
705 contents: []llm.Content{
706 {Type: llm.ContentTypeText, Text: "hello "},
707 {Type: llm.ContentTypeText, Text: "world"},
708 },
709 want: "hello world",
710 },
711 {
712 name: "mixed content types",
713 contents: []llm.Content{
714 {Type: llm.ContentTypeText, Text: "hello "},
715 {Type: llm.ContentTypeText, MediaType: "image/png", Data: "base64data"},
716 {Type: llm.ContentTypeText, Text: "world"},
717 },
718 want: "hello world",
719 },
720 {
721 name: "non-text content only",
722 contents: []llm.Content{
723 {Type: llm.ContentTypeToolUse, ToolName: "example"},
724 },
725 want: "",
726 },
727 {
728 name: "nested tool result",
729 contents: []llm.Content{
730 {Type: llm.ContentTypeText, Text: "outer "},
731 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
732 {Type: llm.ContentTypeText, Text: "inner"},
733 }},
734 },
735 want: "outer inner",
736 },
737 {
738 name: "deeply nested tool result",
739 contents: []llm.Content{
740 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
741 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
742 {Type: llm.ContentTypeText, Text: "deeply nested"},
743 }},
744 }},
745 },
746 want: "deeply nested",
747 },
748 }
749
750 for _, tt := range tests {
751 t.Run(tt.name, func(t *testing.T) {
752 if got := contentToString(tt.contents); got != tt.want {
753 t.Errorf("contentToString() = %v, want %v", got, tt.want)
754 }
755 })
756 }
757}
758
759func TestPushToOutbox(t *testing.T) {
760 // Create a new agent
761 a := &Agent{
762 outstandingLLMCalls: make(map[string]struct{}),
763 outstandingToolCalls: make(map[string]string),
764 stateMachine: NewStateMachine(),
765 subscribers: make([]chan *AgentMessage, 0),
766 }
767
768 // Create a channel to receive messages
769 messageCh := make(chan *AgentMessage, 1)
770
771 // Add the channel to the subscribers list
772 a.mu.Lock()
773 a.subscribers = append(a.subscribers, messageCh)
774 a.mu.Unlock()
775
776 // We need to set the text that would be produced by our modified contentToString function
777 resultText := "test resultnested result" // Directly set the expected output
778
779 // In a real-world scenario, this would be coming from a toolResult that contained nested content
780
781 m := AgentMessage{
782 Type: ToolUseMessageType,
783 ToolResult: resultText,
784 }
785
786 // Push the message to the outbox
787 a.pushToOutbox(context.Background(), m)
788
789 // Receive the message from the subscriber
790 received := <-messageCh
791
792 // Check that the Content field contains the concatenated text from ToolResult
793 expected := "test resultnested result"
794 if received.Content != expected {
795 t.Errorf("Expected Content to be %q, got %q", expected, received.Content)
796 }
797}