blob: 0924b39df5feca8f804f02574f0d71c68d81118e [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
Josh Bleecher Snyder4d5e9972025-05-01 15:56:37 -07004 "cmp"
Earl Lee2e463fb2025-04-17 11:22:22 -07005 "context"
Sean McCullough9f4b8082025-04-30 17:34:07 +00006 "fmt"
Earl Lee2e463fb2025-04-17 11:22:22 -07007 "net/http"
8 "os"
Sean McCullough96b60dd2025-04-30 09:49:10 -07009 "slices"
Earl Lee2e463fb2025-04-17 11:22:22 -070010 "strings"
11 "testing"
12 "time"
13
14 "sketch.dev/ant"
15 "sketch.dev/httprr"
16)
17
18// TestAgentLoop tests that the Agent loop functionality works correctly.
19// It uses the httprr package to record HTTP interactions for replay in tests.
20// When failing, rebuild with "go test ./sketch/loop -run TestAgentLoop -httprecord .*agent_loop.*"
21// as necessary.
22func TestAgentLoop(t *testing.T) {
23 ctx := context.Background()
24
25 // Setup httprr recorder
26 rrPath := "testdata/agent_loop.httprr"
27 rr, err := httprr.Open(rrPath, http.DefaultTransport)
28 if err != nil && !os.IsNotExist(err) {
29 t.Fatal(err)
30 }
31
32 if rr.Recording() {
33 // Skip the test if API key is not available
34 if os.Getenv("ANTHROPIC_API_KEY") == "" {
35 t.Fatal("ANTHROPIC_API_KEY not set, required for HTTP recording")
36 }
37 }
38
39 // Create HTTP client
40 var client *http.Client
41 if rr != nil {
42 // Scrub API keys from requests for security
43 rr.ScrubReq(func(req *http.Request) error {
44 req.Header.Del("x-api-key")
45 req.Header.Del("anthropic-api-key")
46 return nil
47 })
48 client = rr.Client()
49 } else {
50 client = &http.Client{Transport: http.DefaultTransport}
51 }
52
53 // Create a new agent with the httprr client
54 origWD, err := os.Getwd()
55 if err != nil {
56 t.Fatal(err)
57 }
58 if err := os.Chdir("/"); err != nil {
59 t.Fatal(err)
60 }
61 budget := ant.Budget{MaxResponses: 100}
62 wd, err := os.Getwd()
63 if err != nil {
64 t.Fatal(err)
65 }
66
Josh Bleecher Snyder4d5e9972025-05-01 15:56:37 -070067 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_ANTHROPIC_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
Earl Lee2e463fb2025-04-17 11:22:22 -070068 cfg := AgentConfig{
69 Context: ctx,
Josh Bleecher Snyder4d5e9972025-05-01 15:56:37 -070070 APIKey: apiKey,
Earl Lee2e463fb2025-04-17 11:22:22 -070071 HTTPC: client,
72 Budget: budget,
73 GitUsername: "Test Agent",
74 GitEmail: "totallyhuman@sketch.dev",
75 SessionID: "test-session-id",
76 ClientGOOS: "linux",
77 ClientGOARCH: "amd64",
78 }
79 agent := NewAgent(cfg)
80 if err := os.Chdir(origWD); err != nil {
81 t.Fatal(err)
82 }
83 err = agent.Init(AgentInit{WorkingDir: wd, NoGit: true})
84 if err != nil {
85 t.Fatal(err)
86 }
87
88 // Setup a test message that will trigger a simple, predictable response
89 userMessage := "What tools are available to you? Please just list them briefly."
90
91 // Send the message to the agent
92 agent.UserMessage(ctx, userMessage)
93
94 // Process a single loop iteration to avoid long-running tests
Sean McCullough885a16a2025-04-30 02:49:25 +000095 agent.processTurn(ctx)
Earl Lee2e463fb2025-04-17 11:22:22 -070096
97 // Collect responses with a timeout
98 var responses []AgentMessage
Philip Zeyliger9373c072025-05-01 10:27:01 -070099 ctx2, cancel := context.WithDeadline(ctx, time.Now().Add(10*time.Second))
100 defer cancel()
Earl Lee2e463fb2025-04-17 11:22:22 -0700101 done := false
Philip Zeyligerb7c58752025-05-01 10:10:17 -0700102 it := agent.NewIterator(ctx2, 0)
Earl Lee2e463fb2025-04-17 11:22:22 -0700103
104 for !done {
Philip Zeyligerb7c58752025-05-01 10:10:17 -0700105 msg := it.Next()
106 t.Logf("Received message: Type=%s, EndOfTurn=%v, Content=%q", msg.Type, msg.EndOfTurn, msg.Content)
107 responses = append(responses, *msg)
108 if msg.EndOfTurn {
Earl Lee2e463fb2025-04-17 11:22:22 -0700109 done = true
Earl Lee2e463fb2025-04-17 11:22:22 -0700110 }
111 }
112
113 // Verify we got at least one response
114 if len(responses) == 0 {
115 t.Fatal("No responses received from agent")
116 }
117
118 // Log the received responses for debugging
119 t.Logf("Received %d responses", len(responses))
120
121 // Find the final agent response (with EndOfTurn=true)
122 var finalResponse *AgentMessage
123 for i := range responses {
124 if responses[i].Type == AgentMessageType && responses[i].EndOfTurn {
125 finalResponse = &responses[i]
126 break
127 }
128 }
129
130 // Verify we got a final agent response
131 if finalResponse == nil {
132 t.Fatal("No final agent response received")
133 }
134
135 // Check that the response contains tools information
136 if !strings.Contains(strings.ToLower(finalResponse.Content), "tool") {
137 t.Error("Expected response to mention tools")
138 }
139
140 // Count how many tool use messages we received
141 toolUseCount := 0
142 for _, msg := range responses {
143 if msg.Type == ToolUseMessageType {
144 toolUseCount++
145 }
146 }
147
148 t.Logf("Agent used %d tools in its response", toolUseCount)
149}
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000150
151func TestAgentTracksOutstandingCalls(t *testing.T) {
152 agent := &Agent{
153 outstandingLLMCalls: make(map[string]struct{}),
154 outstandingToolCalls: make(map[string]string),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700155 stateMachine: NewStateMachine(),
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000156 }
157
158 // Check initial state
159 if count := agent.OutstandingLLMCallCount(); count != 0 {
160 t.Errorf("Expected 0 outstanding LLM calls, got %d", count)
161 }
162
163 if tools := agent.OutstandingToolCalls(); len(tools) != 0 {
164 t.Errorf("Expected 0 outstanding tool calls, got %d", len(tools))
165 }
166
167 // Add some calls
168 agent.mu.Lock()
169 agent.outstandingLLMCalls["llm1"] = struct{}{}
170 agent.outstandingToolCalls["tool1"] = "bash"
171 agent.outstandingToolCalls["tool2"] = "think"
172 agent.mu.Unlock()
173
174 // Check tracking works
175 if count := agent.OutstandingLLMCallCount(); count != 1 {
176 t.Errorf("Expected 1 outstanding LLM call, got %d", count)
177 }
178
179 tools := agent.OutstandingToolCalls()
180 if len(tools) != 2 {
181 t.Errorf("Expected 2 outstanding tool calls, got %d", len(tools))
182 }
183
184 // Check removal
185 agent.mu.Lock()
186 delete(agent.outstandingLLMCalls, "llm1")
187 delete(agent.outstandingToolCalls, "tool1")
188 agent.mu.Unlock()
189
190 if count := agent.OutstandingLLMCallCount(); count != 0 {
191 t.Errorf("Expected 0 outstanding LLM calls after removal, got %d", count)
192 }
193
194 tools = agent.OutstandingToolCalls()
195 if len(tools) != 1 {
196 t.Errorf("Expected 1 outstanding tool call after removal, got %d", len(tools))
197 }
198
199 if tools[0] != "think" {
200 t.Errorf("Expected 'think' tool remaining, got %s", tools[0])
201 }
202}
Sean McCullough9f4b8082025-04-30 17:34:07 +0000203
204// TestAgentProcessTurnWithNilResponse tests the scenario where Agent.processTurn receives
205// a nil value for initialResp from processUserMessage.
206func TestAgentProcessTurnWithNilResponse(t *testing.T) {
207 // Create a mock conversation that will return nil and error
208 mockConvo := &MockConvoInterface{
209 sendMessageFunc: func(message ant.Message) (*ant.MessageResponse, error) {
210 return nil, fmt.Errorf("test error: simulating nil response")
211 },
212 }
213
214 // Create a minimal Agent instance for testing
215 agent := &Agent{
216 convo: mockConvo,
217 inbox: make(chan string, 10),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700218 subscribers: []chan *AgentMessage{},
Sean McCullough9f4b8082025-04-30 17:34:07 +0000219 outstandingLLMCalls: make(map[string]struct{}),
220 outstandingToolCalls: make(map[string]string),
221 }
222
223 // Create a test context
224 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
225 defer cancel()
226
227 // Push a test message to the inbox so that processUserMessage will try to process it
228 agent.inbox <- "Test message"
229
230 // Call processTurn - it should exit early without panic when initialResp is nil
231 agent.processTurn(ctx)
232
Philip Zeyliger9373c072025-05-01 10:27:01 -0700233 // Verify error message was added to history
234 agent.mu.Lock()
235 defer agent.mu.Unlock()
236
237 // There should be exactly one message
238 if len(agent.history) != 1 {
239 t.Errorf("Expected exactly one message, got %d", len(agent.history))
240 } else {
241 msg := agent.history[0]
Sean McCullough9f4b8082025-04-30 17:34:07 +0000242 if msg.Type != ErrorMessageType {
243 t.Errorf("Expected error message, got message type: %s", msg.Type)
244 }
245 if !strings.Contains(msg.Content, "simulating nil response") {
246 t.Errorf("Expected error message to contain 'simulating nil response', got: %s", msg.Content)
247 }
Sean McCullough9f4b8082025-04-30 17:34:07 +0000248 }
249}
250
251// MockConvoInterface implements the ConvoInterface for testing
252type MockConvoInterface struct {
253 sendMessageFunc func(message ant.Message) (*ant.MessageResponse, error)
254 sendUserTextMessageFunc func(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
255 toolResultContentsFunc func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
256 toolResultCancelContentsFunc func(resp *ant.MessageResponse) ([]ant.Content, error)
257 cancelToolUseFunc func(toolUseID string, cause error) error
258 cumulativeUsageFunc func() ant.CumulativeUsage
259 resetBudgetFunc func(ant.Budget)
260 overBudgetFunc func() error
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700261 getIDFunc func() string
262 subConvoWithHistoryFunc func() *ant.Convo
Sean McCullough9f4b8082025-04-30 17:34:07 +0000263}
264
265func (m *MockConvoInterface) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
266 if m.sendMessageFunc != nil {
267 return m.sendMessageFunc(message)
268 }
269 return nil, nil
270}
271
272func (m *MockConvoInterface) SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
273 if m.sendUserTextMessageFunc != nil {
274 return m.sendUserTextMessageFunc(s, otherContents...)
275 }
276 return nil, nil
277}
278
279func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
280 if m.toolResultContentsFunc != nil {
281 return m.toolResultContentsFunc(ctx, resp)
282 }
283 return nil, nil
284}
285
286func (m *MockConvoInterface) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
287 if m.toolResultCancelContentsFunc != nil {
288 return m.toolResultCancelContentsFunc(resp)
289 }
290 return nil, nil
291}
292
293func (m *MockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
294 if m.cancelToolUseFunc != nil {
295 return m.cancelToolUseFunc(toolUseID, cause)
296 }
297 return nil
298}
299
300func (m *MockConvoInterface) CumulativeUsage() ant.CumulativeUsage {
301 if m.cumulativeUsageFunc != nil {
302 return m.cumulativeUsageFunc()
303 }
304 return ant.CumulativeUsage{}
305}
306
307func (m *MockConvoInterface) ResetBudget(budget ant.Budget) {
308 if m.resetBudgetFunc != nil {
309 m.resetBudgetFunc(budget)
310 }
311}
312
313func (m *MockConvoInterface) OverBudget() error {
314 if m.overBudgetFunc != nil {
315 return m.overBudgetFunc()
316 }
317 return nil
318}
319
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700320func (m *MockConvoInterface) GetID() string {
321 if m.getIDFunc != nil {
322 return m.getIDFunc()
323 }
324 return "mock-convo-id"
325}
326
327func (m *MockConvoInterface) SubConvoWithHistory() *ant.Convo {
328 if m.subConvoWithHistoryFunc != nil {
329 return m.subConvoWithHistoryFunc()
330 }
331 return nil
332}
333
Sean McCullough9f4b8082025-04-30 17:34:07 +0000334// TestAgentProcessTurnWithNilResponseNilError tests the scenario where Agent.processTurn receives
335// a nil value for initialResp and nil error from processUserMessage.
336// This test verifies that the implementation properly handles this edge case.
337func TestAgentProcessTurnWithNilResponseNilError(t *testing.T) {
338 // Create a mock conversation that will return nil response and nil error
339 mockConvo := &MockConvoInterface{
340 sendMessageFunc: func(message ant.Message) (*ant.MessageResponse, error) {
341 return nil, nil // This is unusual but now handled gracefully
342 },
343 }
344
345 // Create a minimal Agent instance for testing
346 agent := &Agent{
347 convo: mockConvo,
348 inbox: make(chan string, 10),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700349 subscribers: []chan *AgentMessage{},
Sean McCullough9f4b8082025-04-30 17:34:07 +0000350 outstandingLLMCalls: make(map[string]struct{}),
351 outstandingToolCalls: make(map[string]string),
352 }
353
354 // Create a test context
355 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
356 defer cancel()
357
358 // Push a test message to the inbox so that processUserMessage will try to process it
359 agent.inbox <- "Test message"
360
361 // Call processTurn - it should handle nil initialResp with a descriptive error
362 err := agent.processTurn(ctx)
363
364 // Verify we get the expected error
365 if err == nil {
366 t.Error("Expected processTurn to return an error for nil initialResp, but got nil")
367 } else if !strings.Contains(err.Error(), "unexpected nil response") {
368 t.Errorf("Expected error about nil response, got: %v", err)
369 } else {
370 t.Logf("As expected, processTurn returned error: %v", err)
371 }
372
Philip Zeyliger9373c072025-05-01 10:27:01 -0700373 // Verify error message was added to history
374 agent.mu.Lock()
375 defer agent.mu.Unlock()
376
377 // There should be exactly one message
378 if len(agent.history) != 1 {
379 t.Errorf("Expected exactly one message, got %d", len(agent.history))
380 } else {
381 msg := agent.history[0]
Sean McCullough9f4b8082025-04-30 17:34:07 +0000382 if msg.Type != ErrorMessageType {
383 t.Errorf("Expected error message type, got: %s", msg.Type)
384 }
385 if !strings.Contains(msg.Content, "unexpected nil response") {
386 t.Errorf("Expected error about nil response, got: %s", msg.Content)
387 }
Sean McCullough9f4b8082025-04-30 17:34:07 +0000388 }
389}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700390
391func TestAgentStateMachine(t *testing.T) {
392 // Create a simplified test for the state machine functionality
393 agent := &Agent{
394 stateMachine: NewStateMachine(),
395 }
396
397 // Initially the state should be Ready
398 if state := agent.CurrentState(); state != StateReady {
399 t.Errorf("Expected initial state to be StateReady, got %s", state)
400 }
401
402 // Test manual transitions to verify state tracking
403 ctx := context.Background()
404
405 // Track transitions
406 var transitions []State
407 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
408 transitions = append(transitions, to)
409 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
410 })
411
412 // Perform a valid sequence of transitions (based on the state machine rules)
413 expectedStates := []State{
414 StateWaitingForUserInput,
415 StateSendingToLLM,
416 StateProcessingLLMResponse,
417 StateToolUseRequested,
418 StateCheckingForCancellation,
419 StateRunningTool,
420 StateCheckingGitCommits,
421 StateRunningAutoformatters,
422 StateCheckingBudget,
423 StateGatheringAdditionalMessages,
424 StateSendingToolResults,
425 StateProcessingLLMResponse,
426 StateEndOfTurn,
427 }
428
429 // Manually perform each transition
430 for _, state := range expectedStates {
431 err := agent.stateMachine.Transition(ctx, state, "Test transition to "+state.String())
432 if err != nil {
433 t.Errorf("Failed to transition to %s: %v", state, err)
434 }
435 }
436
437 // Check if we recorded the right number of transitions
438 if len(transitions) != len(expectedStates) {
439 t.Errorf("Expected %d state transitions, got %d", len(expectedStates), len(transitions))
440 }
441
442 // Check each transition matched what we expected
443 for i, expected := range expectedStates {
444 if i < len(transitions) {
445 if transitions[i] != expected {
446 t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
447 }
448 }
449 }
450
451 // Verify the current state is the last one we transitioned to
452 if state := agent.CurrentState(); state != expectedStates[len(expectedStates)-1] {
453 t.Errorf("Expected current state to be %s, got %s", expectedStates[len(expectedStates)-1], state)
454 }
455
456 // Test force transition
457 agent.stateMachine.ForceTransition(ctx, StateCancelled, "Testing force transition")
458
459 // Verify current state was updated
460 if state := agent.CurrentState(); state != StateCancelled {
461 t.Errorf("Expected forced state to be StateCancelled, got %s", state)
462 }
463}
464
465// mockConvoInterface is a mock implementation of ConvoInterface for testing
466type mockConvoInterface struct {
467 SendMessageFunc func(message ant.Message) (*ant.MessageResponse, error)
468 ToolResultContentsFunc func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
469}
470
471func (c *mockConvoInterface) GetID() string {
472 return "mockConvoInterface-id"
473}
474
475func (c *mockConvoInterface) SubConvoWithHistory() *ant.Convo {
476 return nil
477}
478
479func (m *mockConvoInterface) CumulativeUsage() ant.CumulativeUsage {
480 return ant.CumulativeUsage{}
481}
482
483func (m *mockConvoInterface) ResetBudget(ant.Budget) {}
484
485func (m *mockConvoInterface) OverBudget() error {
486 return nil
487}
488
489func (m *mockConvoInterface) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
490 if m.SendMessageFunc != nil {
491 return m.SendMessageFunc(message)
492 }
493 return &ant.MessageResponse{StopReason: ant.StopReasonEndTurn}, nil
494}
495
496func (m *mockConvoInterface) SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +0000497 return m.SendMessage(ant.UserStringMessage(s))
Sean McCullough96b60dd2025-04-30 09:49:10 -0700498}
499
500func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
501 if m.ToolResultContentsFunc != nil {
502 return m.ToolResultContentsFunc(ctx, resp)
503 }
504 return []ant.Content{}, nil
505}
506
507func (m *mockConvoInterface) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +0000508 return []ant.Content{ant.StringContent("Tool use cancelled")}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700509}
510
511func (m *mockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
512 return nil
513}
514
515func TestAgentProcessTurnStateTransitions(t *testing.T) {
516 // Create a mock ConvoInterface for testing
517 mockConvo := &mockConvoInterface{}
518
519 // Use the testing context
520 ctx := t.Context()
521
522 // Create an agent with the state machine
523 agent := &Agent{
524 convo: mockConvo,
525 config: AgentConfig{Context: ctx},
526 inbox: make(chan string, 10),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700527 ready: make(chan struct{}),
528 seenCommits: make(map[string]bool),
529 outstandingLLMCalls: make(map[string]struct{}),
530 outstandingToolCalls: make(map[string]string),
531 stateMachine: NewStateMachine(),
532 startOfTurn: time.Now(),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700533 subscribers: []chan *AgentMessage{},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700534 }
535
536 // Verify initial state
537 if state := agent.CurrentState(); state != StateReady {
538 t.Errorf("Expected initial state to be StateReady, got %s", state)
539 }
540
541 // Add a message to the inbox so we don't block in GatherMessages
542 agent.inbox <- "Test message"
543
544 // Setup the mock to simulate a model response with end of turn
545 mockConvo.SendMessageFunc = func(message ant.Message) (*ant.MessageResponse, error) {
546 return &ant.MessageResponse{
547 StopReason: ant.StopReasonEndTurn,
548 Content: []ant.Content{
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +0000549 ant.StringContent("This is a test response"),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700550 },
551 }, nil
552 }
553
554 // Track state transitions
555 var transitions []State
556 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
557 transitions = append(transitions, to)
558 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
559 })
560
561 // Process a turn, which should trigger state transitions
562 agent.processTurn(ctx)
563
564 // The minimum expected states for a simple end-of-turn response
565 minExpectedStates := []State{
566 StateWaitingForUserInput,
567 StateSendingToLLM,
568 StateProcessingLLMResponse,
569 StateEndOfTurn,
570 }
571
572 // Verify we have at least the minimum expected states
573 if len(transitions) < len(minExpectedStates) {
574 t.Errorf("Expected at least %d state transitions, got %d", len(minExpectedStates), len(transitions))
575 }
576
577 // Check that the transitions follow the expected sequence
578 for i, expected := range minExpectedStates {
579 if i < len(transitions) {
580 if transitions[i] != expected {
581 t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
582 }
583 }
584 }
585
586 // Verify the final state is EndOfTurn
587 if state := agent.CurrentState(); state != StateEndOfTurn {
588 t.Errorf("Expected final state to be StateEndOfTurn, got %s", state)
589 }
590}
591
592func TestAgentProcessTurnWithToolUse(t *testing.T) {
593 // Create a mock ConvoInterface for testing
594 mockConvo := &mockConvoInterface{}
595
596 // Setup a test context
597 ctx := context.Background()
598
599 // Create an agent with the state machine
600 agent := &Agent{
601 convo: mockConvo,
602 config: AgentConfig{Context: ctx},
603 inbox: make(chan string, 10),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700604 ready: make(chan struct{}),
605 seenCommits: make(map[string]bool),
606 outstandingLLMCalls: make(map[string]struct{}),
607 outstandingToolCalls: make(map[string]string),
608 stateMachine: NewStateMachine(),
609 startOfTurn: time.Now(),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700610 subscribers: []chan *AgentMessage{},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700611 }
612
613 // Add a message to the inbox so we don't block in GatherMessages
614 agent.inbox <- "Test message"
615
616 // First response requests a tool
617 firstResponseDone := false
618 mockConvo.SendMessageFunc = func(message ant.Message) (*ant.MessageResponse, error) {
619 if !firstResponseDone {
620 firstResponseDone = true
621 return &ant.MessageResponse{
622 StopReason: ant.StopReasonToolUse,
623 Content: []ant.Content{
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +0000624 ant.StringContent("I'll use a tool"),
625 {Type: ant.ContentTypeToolUse, ToolName: "test_tool", ToolInput: []byte("{}"), ID: "test_id"},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700626 },
627 }, nil
628 }
629 // Second response ends the turn
630 return &ant.MessageResponse{
631 StopReason: ant.StopReasonEndTurn,
632 Content: []ant.Content{
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +0000633 ant.StringContent("Finished using the tool"),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700634 },
635 }, nil
636 }
637
638 // Tool result content handler
639 mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
Josh Bleecher Snydera3dcd862025-04-30 19:47:16 +0000640 return []ant.Content{ant.StringContent("Tool executed successfully")}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700641 }
642
643 // Track state transitions
644 var transitions []State
645 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
646 transitions = append(transitions, to)
647 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
648 })
649
650 // Process a turn with tool use
651 agent.processTurn(ctx)
652
653 // Define expected states for a tool use flow
654 expectedToolStates := []State{
655 StateWaitingForUserInput,
656 StateSendingToLLM,
657 StateProcessingLLMResponse,
658 StateToolUseRequested,
659 StateCheckingForCancellation,
660 StateRunningTool,
661 }
662
663 // Verify that these states are present in order
664 for i, expectedState := range expectedToolStates {
665 if i >= len(transitions) {
666 t.Errorf("Missing expected transition to %s; only got %d transitions", expectedState, len(transitions))
667 continue
668 }
669 if transitions[i] != expectedState {
670 t.Errorf("Expected transition %d to be %s, got %s", i, expectedState, transitions[i])
671 }
672 }
673
674 // Also verify we eventually reached EndOfTurn
675 if !slices.Contains(transitions, StateEndOfTurn) {
676 t.Errorf("Expected to eventually reach StateEndOfTurn, but never did")
677 }
678}