blob: 30affbe5e3bffd3179cf986440e8000dc4ab45d5 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
Josh Bleecher Snyder4d5e9972025-05-01 15:56:37 -07004 "cmp"
Earl Lee2e463fb2025-04-17 11:22:22 -07005 "context"
Sean McCullough9f4b8082025-04-30 17:34:07 +00006 "fmt"
Earl Lee2e463fb2025-04-17 11:22:22 -07007 "net/http"
8 "os"
Sean McCullough96b60dd2025-04-30 09:49:10 -07009 "slices"
Earl Lee2e463fb2025-04-17 11:22:22 -070010 "strings"
11 "testing"
12 "time"
13
Earl Lee2e463fb2025-04-17 11:22:22 -070014 "sketch.dev/httprr"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070015 "sketch.dev/llm"
16 "sketch.dev/llm/ant"
17 "sketch.dev/llm/conversation"
Earl Lee2e463fb2025-04-17 11:22:22 -070018)
19
20// TestAgentLoop tests that the Agent loop functionality works correctly.
21// It uses the httprr package to record HTTP interactions for replay in tests.
22// When failing, rebuild with "go test ./sketch/loop -run TestAgentLoop -httprecord .*agent_loop.*"
23// as necessary.
24func TestAgentLoop(t *testing.T) {
25 ctx := context.Background()
26
27 // Setup httprr recorder
28 rrPath := "testdata/agent_loop.httprr"
29 rr, err := httprr.Open(rrPath, http.DefaultTransport)
30 if err != nil && !os.IsNotExist(err) {
31 t.Fatal(err)
32 }
33
34 if rr.Recording() {
35 // Skip the test if API key is not available
36 if os.Getenv("ANTHROPIC_API_KEY") == "" {
37 t.Fatal("ANTHROPIC_API_KEY not set, required for HTTP recording")
38 }
39 }
40
41 // Create HTTP client
42 var client *http.Client
43 if rr != nil {
44 // Scrub API keys from requests for security
45 rr.ScrubReq(func(req *http.Request) error {
46 req.Header.Del("x-api-key")
47 req.Header.Del("anthropic-api-key")
48 return nil
49 })
50 client = rr.Client()
51 } else {
52 client = &http.Client{Transport: http.DefaultTransport}
53 }
54
55 // Create a new agent with the httprr client
56 origWD, err := os.Getwd()
57 if err != nil {
58 t.Fatal(err)
59 }
60 if err := os.Chdir("/"); err != nil {
61 t.Fatal(err)
62 }
Philip Zeyligere6c294d2025-06-04 16:55:21 +000063 budget := conversation.Budget{MaxDollars: 10.0}
Earl Lee2e463fb2025-04-17 11:22:22 -070064 wd, err := os.Getwd()
65 if err != nil {
66 t.Fatal(err)
67 }
68
David Crawshaw3659d872025-05-05 17:52:23 -070069 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
Earl Lee2e463fb2025-04-17 11:22:22 -070070 cfg := AgentConfig{
Philip Zeyligerbc8c8dc2025-05-21 13:19:13 -070071 Context: ctx,
72 WorkingDir: wd,
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070073 Service: &ant.Service{
74 APIKey: apiKey,
75 HTTPC: client,
76 },
Earl Lee2e463fb2025-04-17 11:22:22 -070077 Budget: budget,
78 GitUsername: "Test Agent",
79 GitEmail: "totallyhuman@sketch.dev",
80 SessionID: "test-session-id",
81 ClientGOOS: "linux",
82 ClientGOARCH: "amd64",
83 }
84 agent := NewAgent(cfg)
85 if err := os.Chdir(origWD); err != nil {
86 t.Fatal(err)
87 }
Philip Zeyligerbc8c8dc2025-05-21 13:19:13 -070088 err = agent.Init(AgentInit{NoGit: true})
Earl Lee2e463fb2025-04-17 11:22:22 -070089 if err != nil {
90 t.Fatal(err)
91 }
92
93 // Setup a test message that will trigger a simple, predictable response
Josh Bleecher Snyderd2f54c22025-05-07 18:38:07 -070094 userMessage := "What tools are available to you? Please just list them briefly. (Do not call the title tool.)"
Earl Lee2e463fb2025-04-17 11:22:22 -070095
96 // Send the message to the agent
97 agent.UserMessage(ctx, userMessage)
98
99 // Process a single loop iteration to avoid long-running tests
Sean McCullough885a16a2025-04-30 02:49:25 +0000100 agent.processTurn(ctx)
Earl Lee2e463fb2025-04-17 11:22:22 -0700101
102 // Collect responses with a timeout
103 var responses []AgentMessage
Philip Zeyliger9373c072025-05-01 10:27:01 -0700104 ctx2, cancel := context.WithDeadline(ctx, time.Now().Add(10*time.Second))
105 defer cancel()
Earl Lee2e463fb2025-04-17 11:22:22 -0700106 done := false
Philip Zeyligerb7c58752025-05-01 10:10:17 -0700107 it := agent.NewIterator(ctx2, 0)
Earl Lee2e463fb2025-04-17 11:22:22 -0700108
109 for !done {
Philip Zeyligerb7c58752025-05-01 10:10:17 -0700110 msg := it.Next()
111 t.Logf("Received message: Type=%s, EndOfTurn=%v, Content=%q", msg.Type, msg.EndOfTurn, msg.Content)
112 responses = append(responses, *msg)
113 if msg.EndOfTurn {
Earl Lee2e463fb2025-04-17 11:22:22 -0700114 done = true
Earl Lee2e463fb2025-04-17 11:22:22 -0700115 }
116 }
117
118 // Verify we got at least one response
119 if len(responses) == 0 {
120 t.Fatal("No responses received from agent")
121 }
122
123 // Log the received responses for debugging
124 t.Logf("Received %d responses", len(responses))
125
126 // Find the final agent response (with EndOfTurn=true)
127 var finalResponse *AgentMessage
128 for i := range responses {
129 if responses[i].Type == AgentMessageType && responses[i].EndOfTurn {
130 finalResponse = &responses[i]
131 break
132 }
133 }
134
135 // Verify we got a final agent response
136 if finalResponse == nil {
137 t.Fatal("No final agent response received")
138 }
139
140 // Check that the response contains tools information
141 if !strings.Contains(strings.ToLower(finalResponse.Content), "tool") {
142 t.Error("Expected response to mention tools")
143 }
144
145 // Count how many tool use messages we received
146 toolUseCount := 0
147 for _, msg := range responses {
148 if msg.Type == ToolUseMessageType {
149 toolUseCount++
150 }
151 }
152
153 t.Logf("Agent used %d tools in its response", toolUseCount)
154}
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000155
156func TestAgentTracksOutstandingCalls(t *testing.T) {
157 agent := &Agent{
158 outstandingLLMCalls: make(map[string]struct{}),
159 outstandingToolCalls: make(map[string]string),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700160 stateMachine: NewStateMachine(),
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000161 }
162
163 // Check initial state
164 if count := agent.OutstandingLLMCallCount(); count != 0 {
165 t.Errorf("Expected 0 outstanding LLM calls, got %d", count)
166 }
167
168 if tools := agent.OutstandingToolCalls(); len(tools) != 0 {
169 t.Errorf("Expected 0 outstanding tool calls, got %d", len(tools))
170 }
171
172 // Add some calls
173 agent.mu.Lock()
174 agent.outstandingLLMCalls["llm1"] = struct{}{}
175 agent.outstandingToolCalls["tool1"] = "bash"
176 agent.outstandingToolCalls["tool2"] = "think"
177 agent.mu.Unlock()
178
179 // Check tracking works
180 if count := agent.OutstandingLLMCallCount(); count != 1 {
181 t.Errorf("Expected 1 outstanding LLM call, got %d", count)
182 }
183
184 tools := agent.OutstandingToolCalls()
185 if len(tools) != 2 {
186 t.Errorf("Expected 2 outstanding tool calls, got %d", len(tools))
187 }
188
189 // Check removal
190 agent.mu.Lock()
191 delete(agent.outstandingLLMCalls, "llm1")
192 delete(agent.outstandingToolCalls, "tool1")
193 agent.mu.Unlock()
194
195 if count := agent.OutstandingLLMCallCount(); count != 0 {
196 t.Errorf("Expected 0 outstanding LLM calls after removal, got %d", count)
197 }
198
199 tools = agent.OutstandingToolCalls()
200 if len(tools) != 1 {
201 t.Errorf("Expected 1 outstanding tool call after removal, got %d", len(tools))
202 }
203
204 if tools[0] != "think" {
205 t.Errorf("Expected 'think' tool remaining, got %s", tools[0])
206 }
207}
Sean McCullough9f4b8082025-04-30 17:34:07 +0000208
209// TestAgentProcessTurnWithNilResponse tests the scenario where Agent.processTurn receives
210// a nil value for initialResp from processUserMessage.
211func TestAgentProcessTurnWithNilResponse(t *testing.T) {
212 // Create a mock conversation that will return nil and error
213 mockConvo := &MockConvoInterface{
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700214 sendMessageFunc: func(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000215 return nil, fmt.Errorf("test error: simulating nil response")
216 },
217 }
218
219 // Create a minimal Agent instance for testing
220 agent := &Agent{
221 convo: mockConvo,
222 inbox: make(chan string, 10),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700223 subscribers: []chan *AgentMessage{},
Sean McCullough9f4b8082025-04-30 17:34:07 +0000224 outstandingLLMCalls: make(map[string]struct{}),
225 outstandingToolCalls: make(map[string]string),
226 }
227
228 // Create a test context
229 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
230 defer cancel()
231
232 // Push a test message to the inbox so that processUserMessage will try to process it
233 agent.inbox <- "Test message"
234
235 // Call processTurn - it should exit early without panic when initialResp is nil
236 agent.processTurn(ctx)
237
Philip Zeyliger9373c072025-05-01 10:27:01 -0700238 // Verify error message was added to history
239 agent.mu.Lock()
240 defer agent.mu.Unlock()
241
242 // There should be exactly one message
243 if len(agent.history) != 1 {
244 t.Errorf("Expected exactly one message, got %d", len(agent.history))
245 } else {
246 msg := agent.history[0]
Sean McCullough9f4b8082025-04-30 17:34:07 +0000247 if msg.Type != ErrorMessageType {
248 t.Errorf("Expected error message, got message type: %s", msg.Type)
249 }
250 if !strings.Contains(msg.Content, "simulating nil response") {
251 t.Errorf("Expected error message to contain 'simulating nil response', got: %s", msg.Content)
252 }
Sean McCullough9f4b8082025-04-30 17:34:07 +0000253 }
254}
255
256// MockConvoInterface implements the ConvoInterface for testing
257type MockConvoInterface struct {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700258 sendMessageFunc func(message llm.Message) (*llm.Response, error)
259 sendUserTextMessageFunc func(s string, otherContents ...llm.Content) (*llm.Response, error)
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000260 toolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700261 toolResultCancelContentsFunc func(resp *llm.Response) ([]llm.Content, error)
Sean McCullough9f4b8082025-04-30 17:34:07 +0000262 cancelToolUseFunc func(toolUseID string, cause error) error
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700263 cumulativeUsageFunc func() conversation.CumulativeUsage
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700264 lastUsageFunc func() llm.Usage
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700265 resetBudgetFunc func(conversation.Budget)
Sean McCullough9f4b8082025-04-30 17:34:07 +0000266 overBudgetFunc func() error
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700267 getIDFunc func() string
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700268 subConvoWithHistoryFunc func() *conversation.Convo
Sean McCullough9f4b8082025-04-30 17:34:07 +0000269}
270
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700271func (m *MockConvoInterface) SendMessage(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000272 if m.sendMessageFunc != nil {
273 return m.sendMessageFunc(message)
274 }
275 return nil, nil
276}
277
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700278func (m *MockConvoInterface) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000279 if m.sendUserTextMessageFunc != nil {
280 return m.sendUserTextMessageFunc(s, otherContents...)
281 }
282 return nil, nil
283}
284
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000285func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000286 if m.toolResultContentsFunc != nil {
287 return m.toolResultContentsFunc(ctx, resp)
288 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000289 return nil, false, nil
Sean McCullough9f4b8082025-04-30 17:34:07 +0000290}
291
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700292func (m *MockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000293 if m.toolResultCancelContentsFunc != nil {
294 return m.toolResultCancelContentsFunc(resp)
295 }
296 return nil, nil
297}
298
299func (m *MockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
300 if m.cancelToolUseFunc != nil {
301 return m.cancelToolUseFunc(toolUseID, cause)
302 }
303 return nil
304}
305
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700306func (m *MockConvoInterface) CumulativeUsage() conversation.CumulativeUsage {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000307 if m.cumulativeUsageFunc != nil {
308 return m.cumulativeUsageFunc()
309 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700310 return conversation.CumulativeUsage{}
Sean McCullough9f4b8082025-04-30 17:34:07 +0000311}
312
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700313func (m *MockConvoInterface) LastUsage() llm.Usage {
314 if m.lastUsageFunc != nil {
315 return m.lastUsageFunc()
316 }
317 return llm.Usage{}
318}
319
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700320func (m *MockConvoInterface) ResetBudget(budget conversation.Budget) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000321 if m.resetBudgetFunc != nil {
322 m.resetBudgetFunc(budget)
323 }
324}
325
326func (m *MockConvoInterface) OverBudget() error {
327 if m.overBudgetFunc != nil {
328 return m.overBudgetFunc()
329 }
330 return nil
331}
332
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700333func (m *MockConvoInterface) GetID() string {
334 if m.getIDFunc != nil {
335 return m.getIDFunc()
336 }
337 return "mock-convo-id"
338}
339
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700340func (m *MockConvoInterface) SubConvoWithHistory() *conversation.Convo {
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700341 if m.subConvoWithHistoryFunc != nil {
342 return m.subConvoWithHistoryFunc()
343 }
344 return nil
345}
346
Sean McCullough9f4b8082025-04-30 17:34:07 +0000347// TestAgentProcessTurnWithNilResponseNilError tests the scenario where Agent.processTurn receives
348// a nil value for initialResp and nil error from processUserMessage.
349// This test verifies that the implementation properly handles this edge case.
350func TestAgentProcessTurnWithNilResponseNilError(t *testing.T) {
351 // Create a mock conversation that will return nil response and nil error
352 mockConvo := &MockConvoInterface{
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700353 sendMessageFunc: func(message llm.Message) (*llm.Response, error) {
Sean McCullough9f4b8082025-04-30 17:34:07 +0000354 return nil, nil // This is unusual but now handled gracefully
355 },
356 }
357
358 // Create a minimal Agent instance for testing
359 agent := &Agent{
360 convo: mockConvo,
361 inbox: make(chan string, 10),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700362 subscribers: []chan *AgentMessage{},
Sean McCullough9f4b8082025-04-30 17:34:07 +0000363 outstandingLLMCalls: make(map[string]struct{}),
364 outstandingToolCalls: make(map[string]string),
365 }
366
367 // Create a test context
368 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
369 defer cancel()
370
371 // Push a test message to the inbox so that processUserMessage will try to process it
372 agent.inbox <- "Test message"
373
374 // Call processTurn - it should handle nil initialResp with a descriptive error
375 err := agent.processTurn(ctx)
376
377 // Verify we get the expected error
378 if err == nil {
379 t.Error("Expected processTurn to return an error for nil initialResp, but got nil")
380 } else if !strings.Contains(err.Error(), "unexpected nil response") {
381 t.Errorf("Expected error about nil response, got: %v", err)
382 } else {
383 t.Logf("As expected, processTurn returned error: %v", err)
384 }
385
Philip Zeyliger9373c072025-05-01 10:27:01 -0700386 // Verify error message was added to history
387 agent.mu.Lock()
388 defer agent.mu.Unlock()
389
390 // There should be exactly one message
391 if len(agent.history) != 1 {
392 t.Errorf("Expected exactly one message, got %d", len(agent.history))
393 } else {
394 msg := agent.history[0]
Sean McCullough9f4b8082025-04-30 17:34:07 +0000395 if msg.Type != ErrorMessageType {
396 t.Errorf("Expected error message type, got: %s", msg.Type)
397 }
398 if !strings.Contains(msg.Content, "unexpected nil response") {
399 t.Errorf("Expected error about nil response, got: %s", msg.Content)
400 }
Sean McCullough9f4b8082025-04-30 17:34:07 +0000401 }
402}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700403
404func TestAgentStateMachine(t *testing.T) {
405 // Create a simplified test for the state machine functionality
406 agent := &Agent{
407 stateMachine: NewStateMachine(),
408 }
409
410 // Initially the state should be Ready
411 if state := agent.CurrentState(); state != StateReady {
412 t.Errorf("Expected initial state to be StateReady, got %s", state)
413 }
414
415 // Test manual transitions to verify state tracking
416 ctx := context.Background()
417
418 // Track transitions
419 var transitions []State
420 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
421 transitions = append(transitions, to)
422 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
423 })
424
425 // Perform a valid sequence of transitions (based on the state machine rules)
426 expectedStates := []State{
427 StateWaitingForUserInput,
428 StateSendingToLLM,
429 StateProcessingLLMResponse,
430 StateToolUseRequested,
431 StateCheckingForCancellation,
432 StateRunningTool,
433 StateCheckingGitCommits,
434 StateRunningAutoformatters,
435 StateCheckingBudget,
436 StateGatheringAdditionalMessages,
437 StateSendingToolResults,
438 StateProcessingLLMResponse,
439 StateEndOfTurn,
440 }
441
442 // Manually perform each transition
443 for _, state := range expectedStates {
444 err := agent.stateMachine.Transition(ctx, state, "Test transition to "+state.String())
445 if err != nil {
446 t.Errorf("Failed to transition to %s: %v", state, err)
447 }
448 }
449
450 // Check if we recorded the right number of transitions
451 if len(transitions) != len(expectedStates) {
452 t.Errorf("Expected %d state transitions, got %d", len(expectedStates), len(transitions))
453 }
454
455 // Check each transition matched what we expected
456 for i, expected := range expectedStates {
457 if i < len(transitions) {
458 if transitions[i] != expected {
459 t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
460 }
461 }
462 }
463
464 // Verify the current state is the last one we transitioned to
465 if state := agent.CurrentState(); state != expectedStates[len(expectedStates)-1] {
466 t.Errorf("Expected current state to be %s, got %s", expectedStates[len(expectedStates)-1], state)
467 }
468
469 // Test force transition
470 agent.stateMachine.ForceTransition(ctx, StateCancelled, "Testing force transition")
471
472 // Verify current state was updated
473 if state := agent.CurrentState(); state != StateCancelled {
474 t.Errorf("Expected forced state to be StateCancelled, got %s", state)
475 }
476}
477
478// mockConvoInterface is a mock implementation of ConvoInterface for testing
479type mockConvoInterface struct {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700480 SendMessageFunc func(message llm.Message) (*llm.Response, error)
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000481 ToolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error)
Sean McCullough96b60dd2025-04-30 09:49:10 -0700482}
483
484func (c *mockConvoInterface) GetID() string {
485 return "mockConvoInterface-id"
486}
487
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700488func (c *mockConvoInterface) SubConvoWithHistory() *conversation.Convo {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700489 return nil
490}
491
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700492func (m *mockConvoInterface) CumulativeUsage() conversation.CumulativeUsage {
493 return conversation.CumulativeUsage{}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700494}
495
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700496func (m *mockConvoInterface) LastUsage() llm.Usage {
497 return llm.Usage{}
498}
499
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700500func (m *mockConvoInterface) ResetBudget(conversation.Budget) {}
Sean McCullough96b60dd2025-04-30 09:49:10 -0700501
502func (m *mockConvoInterface) OverBudget() error {
503 return nil
504}
505
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700506func (m *mockConvoInterface) SendMessage(message llm.Message) (*llm.Response, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700507 if m.SendMessageFunc != nil {
508 return m.SendMessageFunc(message)
509 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700510 return &llm.Response{StopReason: llm.StopReasonEndTurn}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700511}
512
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700513func (m *mockConvoInterface) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
514 return m.SendMessage(llm.UserStringMessage(s))
Sean McCullough96b60dd2025-04-30 09:49:10 -0700515}
516
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000517func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700518 if m.ToolResultContentsFunc != nil {
519 return m.ToolResultContentsFunc(ctx, resp)
520 }
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000521 return []llm.Content{}, false, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700522}
523
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700524func (m *mockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
525 return []llm.Content{llm.StringContent("Tool use cancelled")}, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700526}
527
528func (m *mockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
529 return nil
530}
531
532func TestAgentProcessTurnStateTransitions(t *testing.T) {
533 // Create a mock ConvoInterface for testing
534 mockConvo := &mockConvoInterface{}
535
536 // Use the testing context
537 ctx := t.Context()
538
539 // Create an agent with the state machine
540 agent := &Agent{
Philip Zeyligerf2872992025-05-22 10:35:28 -0700541 convo: mockConvo,
542 config: AgentConfig{Context: ctx},
543 inbox: make(chan string, 10),
544 ready: make(chan struct{}),
545
Sean McCullough96b60dd2025-04-30 09:49:10 -0700546 outstandingLLMCalls: make(map[string]struct{}),
547 outstandingToolCalls: make(map[string]string),
548 stateMachine: NewStateMachine(),
549 startOfTurn: time.Now(),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700550 subscribers: []chan *AgentMessage{},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700551 }
552
553 // Verify initial state
554 if state := agent.CurrentState(); state != StateReady {
555 t.Errorf("Expected initial state to be StateReady, got %s", state)
556 }
557
558 // Add a message to the inbox so we don't block in GatherMessages
559 agent.inbox <- "Test message"
560
561 // Setup the mock to simulate a model response with end of turn
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700562 mockConvo.SendMessageFunc = func(message llm.Message) (*llm.Response, error) {
563 return &llm.Response{
564 StopReason: llm.StopReasonEndTurn,
565 Content: []llm.Content{
566 llm.StringContent("This is a test response"),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700567 },
568 }, nil
569 }
570
571 // Track state transitions
572 var transitions []State
573 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
574 transitions = append(transitions, to)
575 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
576 })
577
578 // Process a turn, which should trigger state transitions
579 agent.processTurn(ctx)
580
581 // The minimum expected states for a simple end-of-turn response
582 minExpectedStates := []State{
583 StateWaitingForUserInput,
584 StateSendingToLLM,
585 StateProcessingLLMResponse,
586 StateEndOfTurn,
587 }
588
589 // Verify we have at least the minimum expected states
590 if len(transitions) < len(minExpectedStates) {
591 t.Errorf("Expected at least %d state transitions, got %d", len(minExpectedStates), len(transitions))
592 }
593
594 // Check that the transitions follow the expected sequence
595 for i, expected := range minExpectedStates {
596 if i < len(transitions) {
597 if transitions[i] != expected {
598 t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
599 }
600 }
601 }
602
603 // Verify the final state is EndOfTurn
604 if state := agent.CurrentState(); state != StateEndOfTurn {
605 t.Errorf("Expected final state to be StateEndOfTurn, got %s", state)
606 }
607}
608
609func TestAgentProcessTurnWithToolUse(t *testing.T) {
610 // Create a mock ConvoInterface for testing
611 mockConvo := &mockConvoInterface{}
612
613 // Setup a test context
614 ctx := context.Background()
615
616 // Create an agent with the state machine
617 agent := &Agent{
Philip Zeyligerf2872992025-05-22 10:35:28 -0700618 convo: mockConvo,
619 config: AgentConfig{Context: ctx},
620 inbox: make(chan string, 10),
621 ready: make(chan struct{}),
622
Sean McCullough96b60dd2025-04-30 09:49:10 -0700623 outstandingLLMCalls: make(map[string]struct{}),
624 outstandingToolCalls: make(map[string]string),
625 stateMachine: NewStateMachine(),
626 startOfTurn: time.Now(),
Philip Zeyliger9373c072025-05-01 10:27:01 -0700627 subscribers: []chan *AgentMessage{},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700628 }
629
630 // Add a message to the inbox so we don't block in GatherMessages
631 agent.inbox <- "Test message"
632
633 // First response requests a tool
634 firstResponseDone := false
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700635 mockConvo.SendMessageFunc = func(message llm.Message) (*llm.Response, error) {
Sean McCullough96b60dd2025-04-30 09:49:10 -0700636 if !firstResponseDone {
637 firstResponseDone = true
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700638 return &llm.Response{
639 StopReason: llm.StopReasonToolUse,
640 Content: []llm.Content{
641 llm.StringContent("I'll use a tool"),
642 {Type: llm.ContentTypeToolUse, ToolName: "test_tool", ToolInput: []byte("{}"), ID: "test_id"},
Sean McCullough96b60dd2025-04-30 09:49:10 -0700643 },
644 }, nil
645 }
646 // Second response ends the turn
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700647 return &llm.Response{
648 StopReason: llm.StopReasonEndTurn,
649 Content: []llm.Content{
650 llm.StringContent("Finished using the tool"),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700651 },
652 }, nil
653 }
654
655 // Tool result content handler
Josh Bleecher Snyder64f2aa82025-05-14 18:31:05 +0000656 mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
657 return []llm.Content{llm.StringContent("Tool executed successfully")}, false, nil
Sean McCullough96b60dd2025-04-30 09:49:10 -0700658 }
659
660 // Track state transitions
661 var transitions []State
662 agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
663 transitions = append(transitions, to)
664 t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
665 })
666
667 // Process a turn with tool use
668 agent.processTurn(ctx)
669
670 // Define expected states for a tool use flow
671 expectedToolStates := []State{
672 StateWaitingForUserInput,
673 StateSendingToLLM,
674 StateProcessingLLMResponse,
675 StateToolUseRequested,
676 StateCheckingForCancellation,
677 StateRunningTool,
678 }
679
680 // Verify that these states are present in order
681 for i, expectedState := range expectedToolStates {
682 if i >= len(transitions) {
683 t.Errorf("Missing expected transition to %s; only got %d transitions", expectedState, len(transitions))
684 continue
685 }
686 if transitions[i] != expectedState {
687 t.Errorf("Expected transition %d to be %s, got %s", i, expectedState, transitions[i])
688 }
689 }
690
691 // Also verify we eventually reached EndOfTurn
692 if !slices.Contains(transitions, StateEndOfTurn) {
693 t.Errorf("Expected to eventually reach StateEndOfTurn, but never did")
694 }
695}
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700696
697func TestContentToString(t *testing.T) {
698 tests := []struct {
699 name string
700 contents []llm.Content
701 want string
702 }{
703 {
704 name: "empty",
705 contents: []llm.Content{},
706 want: "",
707 },
708 {
709 name: "single text content",
710 contents: []llm.Content{
711 {Type: llm.ContentTypeText, Text: "hello world"},
712 },
713 want: "hello world",
714 },
715 {
716 name: "multiple text content",
717 contents: []llm.Content{
718 {Type: llm.ContentTypeText, Text: "hello "},
719 {Type: llm.ContentTypeText, Text: "world"},
720 },
721 want: "hello world",
722 },
723 {
724 name: "mixed content types",
725 contents: []llm.Content{
726 {Type: llm.ContentTypeText, Text: "hello "},
727 {Type: llm.ContentTypeText, MediaType: "image/png", Data: "base64data"},
728 {Type: llm.ContentTypeText, Text: "world"},
729 },
730 want: "hello world",
731 },
732 {
733 name: "non-text content only",
734 contents: []llm.Content{
735 {Type: llm.ContentTypeToolUse, ToolName: "example"},
736 },
737 want: "",
738 },
739 {
740 name: "nested tool result",
741 contents: []llm.Content{
742 {Type: llm.ContentTypeText, Text: "outer "},
743 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
744 {Type: llm.ContentTypeText, Text: "inner"},
745 }},
746 },
747 want: "outer inner",
748 },
749 {
750 name: "deeply nested tool result",
751 contents: []llm.Content{
752 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
753 {Type: llm.ContentTypeToolResult, ToolResult: []llm.Content{
754 {Type: llm.ContentTypeText, Text: "deeply nested"},
755 }},
756 }},
757 },
758 want: "deeply nested",
759 },
760 }
761
762 for _, tt := range tests {
763 t.Run(tt.name, func(t *testing.T) {
764 if got := contentToString(tt.contents); got != tt.want {
765 t.Errorf("contentToString() = %v, want %v", got, tt.want)
766 }
767 })
768 }
769}
770
771func TestPushToOutbox(t *testing.T) {
772 // Create a new agent
773 a := &Agent{
774 outstandingLLMCalls: make(map[string]struct{}),
775 outstandingToolCalls: make(map[string]string),
776 stateMachine: NewStateMachine(),
777 subscribers: make([]chan *AgentMessage, 0),
778 }
779
780 // Create a channel to receive messages
781 messageCh := make(chan *AgentMessage, 1)
782
783 // Add the channel to the subscribers list
784 a.mu.Lock()
785 a.subscribers = append(a.subscribers, messageCh)
786 a.mu.Unlock()
787
788 // We need to set the text that would be produced by our modified contentToString function
789 resultText := "test resultnested result" // Directly set the expected output
790
791 // In a real-world scenario, this would be coming from a toolResult that contained nested content
792
793 m := AgentMessage{
794 Type: ToolUseMessageType,
795 ToolResult: resultText,
796 }
797
798 // Push the message to the outbox
799 a.pushToOutbox(context.Background(), m)
800
801 // Receive the message from the subscriber
802 received := <-messageCh
803
804 // Check that the Content field contains the concatenated text from ToolResult
805 expected := "test resultnested result"
806 if received.Content != expected {
807 t.Errorf("Expected Content to be %q, got %q", expected, received.Content)
808 }
809}
Philip Zeyliger59e1c162025-06-02 12:54:34 +0000810
811func TestBranchNamingIncrement(t *testing.T) {
812 testCases := []struct {
813 name string
814 originalBranch string
815 expectedBranches []string
816 }{
817 {
818 name: "base branch without number",
819 originalBranch: "sketch/test-branch",
820 expectedBranches: []string{
821 "sketch/test-branch", // retries = 0
822 "sketch/test-branch1", // retries = 1
823 "sketch/test-branch2", // retries = 2
824 "sketch/test-branch3", // retries = 3
825 },
826 },
827 {
828 name: "branch already has number",
829 originalBranch: "sketch/test-branch1",
830 expectedBranches: []string{
831 "sketch/test-branch1", // retries = 0
832 "sketch/test-branch2", // retries = 1
833 "sketch/test-branch3", // retries = 2
834 "sketch/test-branch4", // retries = 3
835 },
836 },
837 {
838 name: "branch with larger number",
839 originalBranch: "sketch/test-branch42",
840 expectedBranches: []string{
841 "sketch/test-branch42", // retries = 0
842 "sketch/test-branch43", // retries = 1
843 "sketch/test-branch44", // retries = 2
844 "sketch/test-branch45", // retries = 3
845 },
846 },
847 }
848
849 for _, tc := range testCases {
850 t.Run(tc.name, func(t *testing.T) {
851 // Parse the original branch name to extract base name and starting number
852 baseBranch, startNum := parseBranchNameAndNumber(tc.originalBranch)
853
854 // Simulate the retry logic
855 for retries := range len(tc.expectedBranches) {
856 var branch string
857 if retries > 0 {
858 // This is the same logic used in the actual code
859 branch = fmt.Sprintf("%s%d", baseBranch, startNum+retries)
860 } else {
861 branch = tc.originalBranch
862 }
863
864 if branch != tc.expectedBranches[retries] {
865 t.Errorf("Retry %d: expected %s, got %s", retries, tc.expectedBranches[retries], branch)
866 }
867 }
868 })
869 }
870}
871
872func TestParseBranchNameAndNumber(t *testing.T) {
873 testCases := []struct {
874 branchName string
875 expectedBase string
876 expectedNumber int
877 }{
878 {"sketch/test-branch", "sketch/test-branch", 0},
879 {"sketch/test-branch1", "sketch/test-branch", 1},
880 {"sketch/test-branch42", "sketch/test-branch", 42},
881 {"sketch/test-branch-foo", "sketch/test-branch-foo", 0},
882 {"sketch/test-branch-foo123", "sketch/test-branch-foo", 123},
883 {"main", "main", 0},
884 {"main2", "main", 2},
885 {"feature/abc123def", "feature/abc123def", 0}, // number in middle, not at end
886 {"feature/abc123def456", "feature/abc123def", 456}, // number at end
887 }
888
889 for _, tc := range testCases {
890 t.Run(tc.branchName, func(t *testing.T) {
891 base, num := parseBranchNameAndNumber(tc.branchName)
892 if base != tc.expectedBase {
893 t.Errorf("Base: expected %s, got %s", tc.expectedBase, base)
894 }
895 if num != tc.expectedNumber {
896 t.Errorf("Number: expected %d, got %d", tc.expectedNumber, num)
897 }
898 })
899 }
900}