blob: ce44352dfefce201980c0b4bfedcbf9a27dce481 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
Josh Bleecher Snyder4d5e9972025-05-01 15:56:37 -07004 "cmp"
Earl Lee2e463fb2025-04-17 11:22:22 -07005 "context"
Sean McCullough9f4b8082025-04-30 17:34:07 +00006 "fmt"
Earl Lee2e463fb2025-04-17 11:22:22 -07007 "net/http"
8 "os"
Sean McCullough96b60dd2025-04-30 09:49:10 -07009 "slices"
Earl Lee2e463fb2025-04-17 11:22:22 -070010 "strings"
11 "testing"
12 "time"
13
Earl Lee2e463fb2025-04-17 11:22:22 -070014 "sketch.dev/httprr"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070015 "sketch.dev/llm"
16 "sketch.dev/llm/ant"
17 "sketch.dev/llm/conversation"
Earl Lee2e463fb2025-04-17 11:22:22 -070018)
19
20// TestAgentLoop tests that the Agent loop functionality works correctly.
21// It uses the httprr package to record HTTP interactions for replay in tests.
22// When failing, rebuild with "go test ./sketch/loop -run TestAgentLoop -httprecord .*agent_loop.*"
23// as necessary.
24func TestAgentLoop(t *testing.T) {
25 ctx := context.Background()
26
27 // Setup httprr recorder
28 rrPath := "testdata/agent_loop.httprr"
29 rr, err := httprr.Open(rrPath, http.DefaultTransport)
30 if err != nil && !os.IsNotExist(err) {
31 t.Fatal(err)
32 }
33
34 if rr.Recording() {
35 // Skip the test if API key is not available
36 if os.Getenv("ANTHROPIC_API_KEY") == "" {
37 t.Fatal("ANTHROPIC_API_KEY not set, required for HTTP recording")
38 }
39 }
40
41 // Create HTTP client
42 var client *http.Client
43 if rr != nil {
44 // Scrub API keys from requests for security
45 rr.ScrubReq(func(req *http.Request) error {
46 req.Header.Del("x-api-key")
47 req.Header.Del("anthropic-api-key")
48 return nil
49 })
50 client = rr.Client()
51 } else {
52 client = &http.Client{Transport: http.DefaultTransport}
53 }
54
55 // Create a new agent with the httprr client
56 origWD, err := os.Getwd()
57 if err != nil {
58 t.Fatal(err)
59 }
60 if err := os.Chdir("/"); err != nil {
61 t.Fatal(err)
62 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070063 budget := conversation.Budget{MaxResponses: 100}
Earl Lee2e463fb2025-04-17 11:22:22 -070064 wd, err := os.Getwd()
65 if err != nil {
66 t.Fatal(err)
67 }
68
David Crawshaw3659d872025-05-05 17:52:23 -070069 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
Earl Lee2e463fb2025-04-17 11:22:22 -070070 cfg := AgentConfig{
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070071 Context: ctx,
72 Service: &ant.Service{
73 APIKey: apiKey,
74 HTTPC: client,
75 },
Earl Lee2e463fb2025-04-17 11:22:22 -070076 Budget: budget,
77 GitUsername: "Test Agent",
78 GitEmail: "totallyhuman@sketch.dev",
79 SessionID: "test-session-id",
80 ClientGOOS: "linux",
81 ClientGOARCH: "amd64",
82 }
83 agent := NewAgent(cfg)
84 if err := os.Chdir(origWD); err != nil {
85 t.Fatal(err)
86 }
87 err = agent.Init(AgentInit{WorkingDir: wd, NoGit: true})
88 if err != nil {
89 t.Fatal(err)
90 }
91
92 // Setup a test message that will trigger a simple, predictable response
Josh Bleecher Snyderd2f54c22025-05-07 18:38:07 -070093 userMessage := "What tools are available to you? Please just list them briefly. (Do not call the title tool.)"
Earl Lee2e463fb2025-04-17 11:22:22 -070094
95 // Send the message to the agent
96 agent.UserMessage(ctx, userMessage)
97
98 // Process a single loop iteration to avoid long-running tests
Sean McCullough885a16a2025-04-30 02:49:25 +000099 agent.processTurn(ctx)
Earl Lee2e463fb2025-04-17 11:22:22 -0700100
101 // Collect responses with a timeout
102 var responses []AgentMessage
Philip Zeyliger9373c072025-05-01 10:27:01 -0700103 ctx2, cancel := context.WithDeadline(ctx, time.Now().Add(10*time.Second))
104 defer cancel()
Earl Lee2e463fb2025-04-17 11:22:22 -0700105 done := false
Philip Zeyligerb7c58752025-05-01 10:10:17 -0700106 it := agent.NewIterator(ctx2, 0)
Earl Lee2e463fb2025-04-17 11:22:22 -0700107
108 for !done {
Philip Zeyligerb7c58752025-05-01 10:10:17 -0700109 msg := it.Next()
110 t.Logf("Received message: Type=%s, EndOfTurn=%v, Content=%q", msg.Type, msg.EndOfTurn, msg.Content)
111 responses = append(responses, *msg)
112 if msg.EndOfTurn {
Earl Lee2e463fb2025-04-17 11:22:22 -0700113 done = true
Earl Lee2e463fb2025-04-17 11:22:22 -0700114 }
115 }
116
117 // Verify we got at least one response
118 if len(responses) == 0 {
119 t.Fatal("No responses received from agent")
120 }
121
122 // Log the received responses for debugging
123 t.Logf("Received %d responses", len(responses))
124
125 // Find the final agent response (with EndOfTurn=true)
126 var finalResponse *AgentMessage
127 for i := range responses {
128 if responses[i].Type == AgentMessageType && responses[i].EndOfTurn {
129 finalResponse = &responses[i]
130 break
131 }
132 }
133
134 // Verify we got a final agent response
135 if finalResponse == nil {
136 t.Fatal("No final agent response received")
137 }
138
139 // Check that the response contains tools information
140 if !strings.Contains(strings.ToLower(finalResponse.Content), "tool") {
141 t.Error("Expected response to mention tools")
142 }
143
144 // Count how many tool use messages we received
145 toolUseCount := 0
146 for _, msg := range responses {
147 if msg.Type == ToolUseMessageType {
148 toolUseCount++
149 }
150 }
151
152 t.Logf("Agent used %d tools in its response", toolUseCount)
153}
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000154
155func TestAgentTracksOutstandingCalls(t *testing.T) {
156 agent := &Agent{
157 outstandingLLMCalls: make(map[string]struct{}),
158 outstandingToolCalls: make(map[string]string),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700159 stateMachine: NewStateMachine(),
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000160 }
161
162 // Check initial state
163 if count := agent.OutstandingLLMCallCount(); count != 0 {
164 t.Errorf("Expected 0 outstanding LLM calls, got %d", count)
165 }
166
167 if tools := agent.OutstandingToolCalls(); len(tools) != 0 {
168 t.Errorf("Expected 0 outstanding tool calls, got %d", len(tools))
169 }
170
171 // Add some calls
172 agent.mu.Lock()
173 agent.outstandingLLMCalls["llm1"] = struct{}{}
174 agent.outstandingToolCalls["tool1"] = "bash"
175 agent.outstandingToolCalls["tool2"] = "think"
176 agent.mu.Unlock()
177
178 // Check tracking works
179 if count := agent.OutstandingLLMCallCount(); count != 1 {
180 t.Errorf("Expected 1 outstanding LLM call, got %d", count)
181 }
182
183 tools := agent.OutstandingToolCalls()
184 if len(tools) != 2 {
185 t.Errorf("Expected 2 outstanding tool calls, got %d", len(tools))
186 }
187
188 // Check removal
189 agent.mu.Lock()
190 delete(agent.outstandingLLMCalls, "llm1")
191 delete(agent.outstandingToolCalls, "tool1")
192 agent.mu.Unlock()
193
194 if count := agent.OutstandingLLMCallCount(); count != 0 {
195 t.Errorf("Expected 0 outstanding LLM calls after removal, got %d", count)
196 }
197
198 tools = agent.OutstandingToolCalls()
199 if len(tools) != 1 {
200 t.Errorf("Expected 1 outstanding tool call after removal, got %d", len(tools))
201 }
202
203 if tools[0] != "think" {
204 t.Errorf("Expected 'think' tool remaining, got %s", tools[0])
205 }
206}
Sean McCullough9f4b8082025-04-30 17:34:07 +0000207
208// TestAgentProcessTurnWithNilResponse tests the scenario where Agent.processTurn receives
209// a nil value for initialResp from processUserMessage.
210func TestAgentProcessTurnWithNilResponse(t *testing.T) {
211 // Create a mock conversation that will return nil and error
212 mockConvo := &MockConvoInterface{
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700213 sendMessageFunc: func(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000214 return nil, fmt.Errorf("test error: simulating nil response")
215 },
216 }
217
218 // Create a minimal Agent instance for testing
219 agent := &Agent{
220 convo: mockConvo,
221 inbox: make(chan string, 10),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700222 subscribers: []chan *AgentMessage{},
Sean McCullough9f4b8082025-04-30 17:34:07 +0000223 outstandingLLMCalls: make(map[string]struct{}),
224 outstandingToolCalls: make(map[string]string),
225 }
226
227 // Create a test context
228 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
229 defer cancel()
230
231 // Push a test message to the inbox so that processUserMessage will try to process it
232 agent.inbox <- "Test message"
233
234 // Call processTurn - it should exit early without panic when initialResp is nil
235 agent.processTurn(ctx)
236
Philip Zeyliger9373c072025-05-01 10:27:01 -0700237 // Verify error message was added to history
238 agent.mu.Lock()
239 defer agent.mu.Unlock()
240
241 // There should be exactly one message
242 if len(agent.history) != 1 {
243 t.Errorf("Expected exactly one message, got %d", len(agent.history))
244 } else {
245 msg := agent.history[0]
Sean McCullough9f4b8082025-04-30 17:34:07 +0000246 if msg.Type != ErrorMessageType {
247 t.Errorf("Expected error message, got message type: %s", msg.Type)
248 }
249 if !strings.Contains(msg.Content, "simulating nil response") {
250 t.Errorf("Expected error message to contain 'simulating nil response', got: %s", msg.Content)
251 }
Sean McCullough9f4b8082025-04-30 17:34:07 +0000252 }
253}
254
255// MockConvoInterface implements the ConvoInterface for testing
256type MockConvoInterface struct {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700257 sendMessageFunc func(message llm.Message) (*llm.Response, error)
258 sendUserTextMessageFunc func(s string, otherContents ...llm.Content) (*llm.Response, error)
259 toolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, error)
260 toolResultCancelContentsFunc func(resp *llm.Response) ([]llm.Content, error)
Sean McCullough9f4b8082025-04-30 17:34:07 +0000261 cancelToolUseFunc func(toolUseID string, cause error) error
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700262 cumulativeUsageFunc func() conversation.CumulativeUsage
263 resetBudgetFunc func(conversation.Budget)
Sean McCullough9f4b8082025-04-30 17:34:07 +0000264 overBudgetFunc func() error
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700265 getIDFunc func() string
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700266 subConvoWithHistoryFunc func() *conversation.Convo
Sean McCullough9f4b8082025-04-30 17:34:07 +0000267}
268
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700269func (m *MockConvoInterface) SendMessage(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000270 if m.sendMessageFunc != nil {
271 return m.sendMessageFunc(message)
272 }
273 return nil, nil
274}
275
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700276func (m *MockConvoInterface) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000277 if m.sendUserTextMessageFunc != nil {
278 return m.sendUserTextMessageFunc(s, otherContents...)
279 }
280 return nil, nil
281}
282
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700283func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000284 if m.toolResultContentsFunc != nil {
285 return m.toolResultContentsFunc(ctx, resp)
286 }
287 return nil, nil
288}
289
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700290func (m *MockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000291 if m.toolResultCancelContentsFunc != nil {
292 return m.toolResultCancelContentsFunc(resp)
293 }
294 return nil, nil
295}
296
297func (m *MockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
298 if m.cancelToolUseFunc != nil {
299 return m.cancelToolUseFunc(toolUseID, cause)
300 }
301 return nil
302}
303
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700304func (m *MockConvoInterface) CumulativeUsage() conversation.CumulativeUsage {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000305 if m.cumulativeUsageFunc != nil {
306 return m.cumulativeUsageFunc()
307 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700308 return conversation.CumulativeUsage{}
Sean McCullough9f4b8082025-04-30 17:34:07 +0000309}
310
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700311func (m *MockConvoInterface) ResetBudget(budget conversation.Budget) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000312 if m.resetBudgetFunc != nil {
313 m.resetBudgetFunc(budget)
314 }
315}
316
317func (m *MockConvoInterface) OverBudget() error {
318 if m.overBudgetFunc != nil {
319 return m.overBudgetFunc()
320 }
321 return nil
322}
323
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700324func (m *MockConvoInterface) GetID() string {
325 if m.getIDFunc != nil {
326 return m.getIDFunc()
327 }
328 return "mock-convo-id"
329}
330
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700331func (m *MockConvoInterface) SubConvoWithHistory() *conversation.Convo {
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700332 if m.subConvoWithHistoryFunc != nil {
333 return m.subConvoWithHistoryFunc()
334 }
335 return nil
336}
337
Sean McCullough9f4b8082025-04-30 17:34:07 +0000338// TestAgentProcessTurnWithNilResponseNilError tests the scenario where Agent.processTurn receives
339// a nil value for initialResp and nil error from processUserMessage.
340// This test verifies that the implementation properly handles this edge case.
341func TestAgentProcessTurnWithNilResponseNilError(t *testing.T) {
342 // Create a mock conversation that will return nil response and nil error
343 mockConvo := &MockConvoInterface{
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700344 sendMessageFunc: func(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000345 return nil, nil // This is unusual but now handled gracefully
346 },
347 }
348
349 // Create a minimal Agent instance for testing
350 agent := &Agent{
351 convo: mockConvo,
352 inbox: make(chan string, 10),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700353 subscribers: []chan *AgentMessage{},
Sean McCullough9f4b8082025-04-30 17:34:07 +0000354 outstandingLLMCalls: make(map[string]struct{}),
355 outstandingToolCalls: make(map[string]string),
356 }
357
358 // Create a test context
359 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
360 defer cancel()
361
362 // Push a test message to the inbox so that processUserMessage will try to process it
363 agent.inbox <- "Test message"
364
365 // Call processTurn - it should handle nil initialResp with a descriptive error
366 err := agent.processTurn(ctx)
367
368 // Verify we get the expected error
369 if err == nil {
370 t.Error("Expected processTurn to return an error for nil initialResp, but got nil")
371 } else if !strings.Contains(err.Error(), "unexpected nil response") {
372 t.Errorf("Expected error about nil response, got: %v", err)
373 } else {
374 t.Logf("As expected, processTurn returned error: %v", err)
375 }
376
Philip Zeyliger9373c072025-05-01 10:27:01 -0700377 // Verify error message was added to history
378 agent.mu.Lock()
379 defer agent.mu.Unlock()
380
381 // There should be exactly one message
382 if len(agent.history) != 1 {
383 t.Errorf("Expected exactly one message, got %d", len(agent.history))
384 } else {
385 msg := agent.history[0]
Sean McCullough9f4b8082025-04-30 17:34:07 +0000386 if msg.Type != ErrorMessageType {
387 t.Errorf("Expected error message type, got: %s", msg.Type)
388 }
389 if !strings.Contains(msg.Content, "unexpected nil response") {
390 t.Errorf("Expected error about nil response, got: %s", msg.Content)
391 }
Sean McCullough9f4b8082025-04-30 17:34:07 +0000392 }
393}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700394
395func TestAgentStateMachine(t *testing.T) {
396 // Create a simplified test for the state machine functionality
397 agent := &Agent{
398 stateMachine: NewStateMachine(),
399 }
400
401 // Initially the state should be Ready
402 if state := agent.CurrentState(); state != StateReady {
403 t.Errorf("Expected initial state to be StateReady, got %s", state)
404 }
405
406 // Test manual transitions to verify state tracking
407 ctx := context.Background()
408
409 // Track transitions
410 var transitions []State
411 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
412 transitions = append(transitions, to)
413 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
414 })
415
416 // Perform a valid sequence of transitions (based on the state machine rules)
417 expectedStates := []State{
418 StateWaitingForUserInput,
419 StateSendingToLLM,
420 StateProcessingLLMResponse,
421 StateToolUseRequested,
422 StateCheckingForCancellation,
423 StateRunningTool,
424 StateCheckingGitCommits,
425 StateRunningAutoformatters,
426 StateCheckingBudget,
427 StateGatheringAdditionalMessages,
428 StateSendingToolResults,
429 StateProcessingLLMResponse,
430 StateEndOfTurn,
431 }
432
433 // Manually perform each transition
434 for _, state := range expectedStates {
435 err := agent.stateMachine.Transition(ctx, state, "Test transition to "+state.String())
436 if err != nil {
437 t.Errorf("Failed to transition to %s: %v", state, err)
438 }
439 }
440
441 // Check if we recorded the right number of transitions
442 if len(transitions) != len(expectedStates) {
443 t.Errorf("Expected %d state transitions, got %d", len(expectedStates), len(transitions))
444 }
445
446 // Check each transition matched what we expected
447 for i, expected := range expectedStates {
448 if i < len(transitions) {
449 if transitions[i] != expected {
450 t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
451 }
452 }
453 }
454
455 // Verify the current state is the last one we transitioned to
456 if state := agent.CurrentState(); state != expectedStates[len(expectedStates)-1] {
457 t.Errorf("Expected current state to be %s, got %s", expectedStates[len(expectedStates)-1], state)
458 }
459
460 // Test force transition
461 agent.stateMachine.ForceTransition(ctx, StateCancelled, "Testing force transition")
462
463 // Verify current state was updated
464 if state := agent.CurrentState(); state != StateCancelled {
465 t.Errorf("Expected forced state to be StateCancelled, got %s", state)
466 }
467}
468
469// mockConvoInterface is a mock implementation of ConvoInterface for testing
470type mockConvoInterface struct {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700471 SendMessageFunc func(message llm.Message) (*llm.Response, error)
472 ToolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, error)
Sean McCullough96b60dd2025-04-30 09:49:10 -0700473}
474
475func (c *mockConvoInterface) GetID() string {
476 return "mockConvoInterface-id"
477}
478
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700479func (c *mockConvoInterface) SubConvoWithHistory() *conversation.Convo {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700480 return nil
481}
482
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700483func (m *mockConvoInterface) CumulativeUsage() conversation.CumulativeUsage {
484 return conversation.CumulativeUsage{}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700485}
486
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700487func (m *mockConvoInterface) ResetBudget(conversation.Budget) {}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700488
489func (m *mockConvoInterface) OverBudget() error {
490 return nil
491}
492
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700493func (m *mockConvoInterface) SendMessage(message llm.Message) (*llm.Response, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700494 if m.SendMessageFunc != nil {
495 return m.SendMessageFunc(message)
496 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700497 return &llm.Response{StopReason: llm.StopReasonEndTurn}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700498}
499
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700500func (m *mockConvoInterface) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
501 return m.SendMessage(llm.UserStringMessage(s))
Sean McCullough96b60dd2025-04-30 09:49:10 -0700502}
503
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700504func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700505 if m.ToolResultContentsFunc != nil {
506 return m.ToolResultContentsFunc(ctx, resp)
507 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700508 return []llm.Content{}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700509}
510
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700511func (m *mockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
512 return []llm.Content{llm.StringContent("Tool use cancelled")}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700513}
514
515func (m *mockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
516 return nil
517}
518
519func TestAgentProcessTurnStateTransitions(t *testing.T) {
520 // Create a mock ConvoInterface for testing
521 mockConvo := &mockConvoInterface{}
522
523 // Use the testing context
524 ctx := t.Context()
525
526 // Create an agent with the state machine
527 agent := &Agent{
528 convo: mockConvo,
529 config: AgentConfig{Context: ctx},
530 inbox: make(chan string, 10),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700531 ready: make(chan struct{}),
532 seenCommits: make(map[string]bool),
533 outstandingLLMCalls: make(map[string]struct{}),
534 outstandingToolCalls: make(map[string]string),
535 stateMachine: NewStateMachine(),
536 startOfTurn: time.Now(),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700537 subscribers: []chan *AgentMessage{},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700538 }
539
540 // Verify initial state
541 if state := agent.CurrentState(); state != StateReady {
542 t.Errorf("Expected initial state to be StateReady, got %s", state)
543 }
544
545 // Add a message to the inbox so we don't block in GatherMessages
546 agent.inbox <- "Test message"
547
548 // Setup the mock to simulate a model response with end of turn
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700549 mockConvo.SendMessageFunc = func(message llm.Message) (*llm.Response, error) {
550 return &llm.Response{
551 StopReason: llm.StopReasonEndTurn,
552 Content: []llm.Content{
553 llm.StringContent("This is a test response"),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700554 },
555 }, nil
556 }
557
558 // Track state transitions
559 var transitions []State
560 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
561 transitions = append(transitions, to)
562 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
563 })
564
565 // Process a turn, which should trigger state transitions
566 agent.processTurn(ctx)
567
568 // The minimum expected states for a simple end-of-turn response
569 minExpectedStates := []State{
570 StateWaitingForUserInput,
571 StateSendingToLLM,
572 StateProcessingLLMResponse,
573 StateEndOfTurn,
574 }
575
576 // Verify we have at least the minimum expected states
577 if len(transitions) < len(minExpectedStates) {
578 t.Errorf("Expected at least %d state transitions, got %d", len(minExpectedStates), len(transitions))
579 }
580
581 // Check that the transitions follow the expected sequence
582 for i, expected := range minExpectedStates {
583 if i < len(transitions) {
584 if transitions[i] != expected {
585 t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
586 }
587 }
588 }
589
590 // Verify the final state is EndOfTurn
591 if state := agent.CurrentState(); state != StateEndOfTurn {
592 t.Errorf("Expected final state to be StateEndOfTurn, got %s", state)
593 }
594}
595
596func TestAgentProcessTurnWithToolUse(t *testing.T) {
597 // Create a mock ConvoInterface for testing
598 mockConvo := &mockConvoInterface{}
599
600 // Setup a test context
601 ctx := context.Background()
602
603 // Create an agent with the state machine
604 agent := &Agent{
605 convo: mockConvo,
606 config: AgentConfig{Context: ctx},
607 inbox: make(chan string, 10),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700608 ready: make(chan struct{}),
609 seenCommits: make(map[string]bool),
610 outstandingLLMCalls: make(map[string]struct{}),
611 outstandingToolCalls: make(map[string]string),
612 stateMachine: NewStateMachine(),
613 startOfTurn: time.Now(),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700614 subscribers: []chan *AgentMessage{},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700615 }
616
617 // Add a message to the inbox so we don't block in GatherMessages
618 agent.inbox <- "Test message"
619
620 // First response requests a tool
621 firstResponseDone := false
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700622 mockConvo.SendMessageFunc = func(message llm.Message) (*llm.Response, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700623 if !firstResponseDone {
624 firstResponseDone = true
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700625 return &llm.Response{
626 StopReason: llm.StopReasonToolUse,
627 Content: []llm.Content{
628 llm.StringContent("I'll use a tool"),
629 {Type: llm.ContentTypeToolUse, ToolName: "test_tool", ToolInput: []byte("{}"), ID: "test_id"},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700630 },
631 }, nil
632 }
633 // Second response ends the turn
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700634 return &llm.Response{
635 StopReason: llm.StopReasonEndTurn,
636 Content: []llm.Content{
637 llm.StringContent("Finished using the tool"),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700638 },
639 }, nil
640 }
641
642 // Tool result content handler
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700643 mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
644 return []llm.Content{llm.StringContent("Tool executed successfully")}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700645 }
646
647 // Track state transitions
648 var transitions []State
649 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
650 transitions = append(transitions, to)
651 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
652 })
653
654 // Process a turn with tool use
655 agent.processTurn(ctx)
656
657 // Define expected states for a tool use flow
658 expectedToolStates := []State{
659 StateWaitingForUserInput,
660 StateSendingToLLM,
661 StateProcessingLLMResponse,
662 StateToolUseRequested,
663 StateCheckingForCancellation,
664 StateRunningTool,
665 }
666
667 // Verify that these states are present in order
668 for i, expectedState := range expectedToolStates {
669 if i >= len(transitions) {
670 t.Errorf("Missing expected transition to %s; only got %d transitions", expectedState, len(transitions))
671 continue
672 }
673 if transitions[i] != expectedState {
674 t.Errorf("Expected transition %d to be %s, got %s", i, expectedState, transitions[i])
675 }
676 }
677
678 // Also verify we eventually reached EndOfTurn
679 if !slices.Contains(transitions, StateEndOfTurn) {
680 t.Errorf("Expected to eventually reach StateEndOfTurn, but never did")
681 }
682}
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700683
684func TestContentToString(t *testing.T) {
685 tests := []struct {
686 name string
687 contents []llm.Content
688 want string
689 }{
690 {
691 name: "empty",
692 contents: []llm.Content{},
693 want: "",
694 },
695 {
696 name: "single text content",
697 contents: []llm.Content{
698 {Type: llm.ContentTypeText, Text: "hello world"},
699 },
700 want: "hello world",
701 },
702 {
703 name: "multiple text content",
704 contents: []llm.Content{
705 {Type: llm.ContentTypeText, Text: "hello "},
706 {Type: llm.ContentTypeText, Text: "world"},
707 },
708 want: "hello world",
709 },
710 {
711 name: "mixed content types",
712 contents: []llm.Content{
713 {Type: llm.ContentTypeText, Text: "hello "},
714 {Type: llm.ContentTypeText, MediaType: "image/png", Data: "base64data"},
715 {Type: llm.ContentTypeText, Text: "world"},
716 },
717 want: "hello world",
718 },
719 {
720 name: "non-text content only",
721 contents: []llm.Content{
722 {Type: llm.ContentTypeToolUse, ToolName: "example"},
723 },
724 want: "",
725 },
726 {
727 name: "nested tool result",
728 contents: []llm.Content{
729 {Type: llm.ContentTypeText, Text: "outer "},
730 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
731 {Type: llm.ContentTypeText, Text: "inner"},
732 }},
733 },
734 want: "outer inner",
735 },
736 {
737 name: "deeply nested tool result",
738 contents: []llm.Content{
739 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
740 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
741 {Type: llm.ContentTypeText, Text: "deeply nested"},
742 }},
743 }},
744 },
745 want: "deeply nested",
746 },
747 }
748
749 for _, tt := range tests {
750 t.Run(tt.name, func(t *testing.T) {
751 if got := contentToString(tt.contents); got != tt.want {
752 t.Errorf("contentToString() = %v, want %v", got, tt.want)
753 }
754 })
755 }
756}
757
758func TestPushToOutbox(t *testing.T) {
759 // Create a new agent
760 a := &Agent{
761 outstandingLLMCalls: make(map[string]struct{}),
762 outstandingToolCalls: make(map[string]string),
763 stateMachine: NewStateMachine(),
764 subscribers: make([]chan *AgentMessage, 0),
765 }
766
767 // Create a channel to receive messages
768 messageCh := make(chan *AgentMessage, 1)
769
770 // Add the channel to the subscribers list
771 a.mu.Lock()
772 a.subscribers = append(a.subscribers, messageCh)
773 a.mu.Unlock()
774
775 // We need to set the text that would be produced by our modified contentToString function
776 resultText := "test resultnested result" // Directly set the expected output
777
778 // In a real-world scenario, this would be coming from a toolResult that contained nested content
779
780 m := AgentMessage{
781 Type: ToolUseMessageType,
782 ToolResult: resultText,
783 }
784
785 // Push the message to the outbox
786 a.pushToOutbox(context.Background(), m)
787
788 // Receive the message from the subscriber
789 received := <-messageCh
790
791 // Check that the Content field contains the concatenated text from ToolResult
792 expected := "test resultnested result"
793 if received.Content != expected {
794 t.Errorf("Expected Content to be %q, got %q", expected, received.Content)
795 }
796}