Add StateMachine tracking to Agent control flow
This commit integrates the existing StateMachine type with the Agent
to provide state tracking and validation throughout the conversation
control flow. This allows for better monitoring and debugging of the
Agent's behavior during execution.
This commit adds tests to verify the correct behavior of the state
machine integration with the Agent type:
1. A test for basic state transitions
2. A test for a complete processTurn flow with a simple response
3. A test for a processTurn flow with tool use
Co-Authored-By: sketch <hello@sketch.dev>
diff --git a/loop/agent_test.go b/loop/agent_test.go
index bde0d20..c62fd21 100644
--- a/loop/agent_test.go
+++ b/loop/agent_test.go
@@ -5,6 +5,7 @@
"fmt"
"net/http"
"os"
+ "slices"
"strings"
"testing"
"time"
@@ -158,6 +159,7 @@
agent := &Agent{
outstandingLLMCalls: make(map[string]struct{}),
outstandingToolCalls: make(map[string]string),
+ stateMachine: NewStateMachine(),
}
// Check initial state
@@ -392,3 +394,292 @@
t.Error("Timed out waiting for error message in outbox")
}
}
+
+func TestAgentStateMachine(t *testing.T) {
+ // Create a simplified test for the state machine functionality
+ agent := &Agent{
+ stateMachine: NewStateMachine(),
+ }
+
+ // Initially the state should be Ready
+ if state := agent.CurrentState(); state != StateReady {
+ t.Errorf("Expected initial state to be StateReady, got %s", state)
+ }
+
+ // Test manual transitions to verify state tracking
+ ctx := context.Background()
+
+ // Track transitions
+ var transitions []State
+ agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
+ transitions = append(transitions, to)
+ t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
+ })
+
+ // Perform a valid sequence of transitions (based on the state machine rules)
+ expectedStates := []State{
+ StateWaitingForUserInput,
+ StateSendingToLLM,
+ StateProcessingLLMResponse,
+ StateToolUseRequested,
+ StateCheckingForCancellation,
+ StateRunningTool,
+ StateCheckingGitCommits,
+ StateRunningAutoformatters,
+ StateCheckingBudget,
+ StateGatheringAdditionalMessages,
+ StateSendingToolResults,
+ StateProcessingLLMResponse,
+ StateEndOfTurn,
+ }
+
+ // Manually perform each transition
+ for _, state := range expectedStates {
+ err := agent.stateMachine.Transition(ctx, state, "Test transition to "+state.String())
+ if err != nil {
+ t.Errorf("Failed to transition to %s: %v", state, err)
+ }
+ }
+
+ // Check if we recorded the right number of transitions
+ if len(transitions) != len(expectedStates) {
+ t.Errorf("Expected %d state transitions, got %d", len(expectedStates), len(transitions))
+ }
+
+ // Check each transition matched what we expected
+ for i, expected := range expectedStates {
+ if i < len(transitions) {
+ if transitions[i] != expected {
+ t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
+ }
+ }
+ }
+
+ // Verify the current state is the last one we transitioned to
+ if state := agent.CurrentState(); state != expectedStates[len(expectedStates)-1] {
+ t.Errorf("Expected current state to be %s, got %s", expectedStates[len(expectedStates)-1], state)
+ }
+
+ // Test force transition
+ agent.stateMachine.ForceTransition(ctx, StateCancelled, "Testing force transition")
+
+ // Verify current state was updated
+ if state := agent.CurrentState(); state != StateCancelled {
+ t.Errorf("Expected forced state to be StateCancelled, got %s", state)
+ }
+}
+
+// mockConvoInterface is a mock implementation of ConvoInterface for testing
+type mockConvoInterface struct {
+ SendMessageFunc func(message ant.Message) (*ant.MessageResponse, error)
+ ToolResultContentsFunc func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
+}
+
+func (c *mockConvoInterface) GetID() string {
+ return "mockConvoInterface-id"
+}
+
+func (c *mockConvoInterface) SubConvoWithHistory() *ant.Convo {
+ return nil
+}
+
+func (m *mockConvoInterface) CumulativeUsage() ant.CumulativeUsage {
+ return ant.CumulativeUsage{}
+}
+
+func (m *mockConvoInterface) ResetBudget(ant.Budget) {}
+
+func (m *mockConvoInterface) OverBudget() error {
+ return nil
+}
+
+func (m *mockConvoInterface) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
+ if m.SendMessageFunc != nil {
+ return m.SendMessageFunc(message)
+ }
+ return &ant.MessageResponse{StopReason: ant.StopReasonEndTurn}, nil
+}
+
+func (m *mockConvoInterface) SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
+ return m.SendMessage(ant.Message{Role: "user", Content: []ant.Content{{Type: "text", Text: s}}})
+}
+
+func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
+ if m.ToolResultContentsFunc != nil {
+ return m.ToolResultContentsFunc(ctx, resp)
+ }
+ return []ant.Content{}, nil
+}
+
+func (m *mockConvoInterface) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
+ return []ant.Content{{Type: "text", Text: "Tool use cancelled"}}, nil
+}
+
+func (m *mockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
+ return nil
+}
+
+func TestAgentProcessTurnStateTransitions(t *testing.T) {
+ // Create a mock ConvoInterface for testing
+ mockConvo := &mockConvoInterface{}
+
+ // Use the testing context
+ ctx := t.Context()
+
+ // Create an agent with the state machine
+ agent := &Agent{
+ convo: mockConvo,
+ config: AgentConfig{Context: ctx},
+ inbox: make(chan string, 10),
+ outbox: make(chan AgentMessage, 10),
+ ready: make(chan struct{}),
+ seenCommits: make(map[string]bool),
+ outstandingLLMCalls: make(map[string]struct{}),
+ outstandingToolCalls: make(map[string]string),
+ stateMachine: NewStateMachine(),
+ startOfTurn: time.Now(),
+ }
+
+ // Verify initial state
+ if state := agent.CurrentState(); state != StateReady {
+ t.Errorf("Expected initial state to be StateReady, got %s", state)
+ }
+
+ // Add a message to the inbox so we don't block in GatherMessages
+ agent.inbox <- "Test message"
+
+ // Setup the mock to simulate a model response with end of turn
+ mockConvo.SendMessageFunc = func(message ant.Message) (*ant.MessageResponse, error) {
+ return &ant.MessageResponse{
+ StopReason: ant.StopReasonEndTurn,
+ Content: []ant.Content{
+ {Type: "text", Text: "This is a test response"},
+ },
+ }, nil
+ }
+
+ // Track state transitions
+ var transitions []State
+ agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
+ transitions = append(transitions, to)
+ t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
+ })
+
+ // Process a turn, which should trigger state transitions
+ agent.processTurn(ctx)
+
+ // The minimum expected states for a simple end-of-turn response
+ minExpectedStates := []State{
+ StateWaitingForUserInput,
+ StateSendingToLLM,
+ StateProcessingLLMResponse,
+ StateEndOfTurn,
+ }
+
+ // Verify we have at least the minimum expected states
+ if len(transitions) < len(minExpectedStates) {
+ t.Errorf("Expected at least %d state transitions, got %d", len(minExpectedStates), len(transitions))
+ }
+
+ // Check that the transitions follow the expected sequence
+ for i, expected := range minExpectedStates {
+ if i < len(transitions) {
+ if transitions[i] != expected {
+ t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
+ }
+ }
+ }
+
+ // Verify the final state is EndOfTurn
+ if state := agent.CurrentState(); state != StateEndOfTurn {
+ t.Errorf("Expected final state to be StateEndOfTurn, got %s", state)
+ }
+}
+
+func TestAgentProcessTurnWithToolUse(t *testing.T) {
+ // Create a mock ConvoInterface for testing
+ mockConvo := &mockConvoInterface{}
+
+ // Setup a test context
+ ctx := context.Background()
+
+ // Create an agent with the state machine
+ agent := &Agent{
+ convo: mockConvo,
+ config: AgentConfig{Context: ctx},
+ inbox: make(chan string, 10),
+ outbox: make(chan AgentMessage, 10),
+ ready: make(chan struct{}),
+ seenCommits: make(map[string]bool),
+ outstandingLLMCalls: make(map[string]struct{}),
+ outstandingToolCalls: make(map[string]string),
+ stateMachine: NewStateMachine(),
+ startOfTurn: time.Now(),
+ }
+
+ // Add a message to the inbox so we don't block in GatherMessages
+ agent.inbox <- "Test message"
+
+ // First response requests a tool
+ firstResponseDone := false
+ mockConvo.SendMessageFunc = func(message ant.Message) (*ant.MessageResponse, error) {
+ if !firstResponseDone {
+ firstResponseDone = true
+ return &ant.MessageResponse{
+ StopReason: ant.StopReasonToolUse,
+ Content: []ant.Content{
+ {Type: "text", Text: "I'll use a tool"},
+ {Type: "tool_use", ToolName: "test_tool", ToolInput: []byte("{}"), ID: "test_id"},
+ },
+ }, nil
+ }
+ // Second response ends the turn
+ return &ant.MessageResponse{
+ StopReason: ant.StopReasonEndTurn,
+ Content: []ant.Content{
+ {Type: "text", Text: "Finished using the tool"},
+ },
+ }, nil
+ }
+
+ // Tool result content handler
+ mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
+ return []ant.Content{{Type: "text", Text: "Tool executed successfully"}}, nil
+ }
+
+ // Track state transitions
+ var transitions []State
+ agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
+ transitions = append(transitions, to)
+ t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
+ })
+
+ // Process a turn with tool use
+ agent.processTurn(ctx)
+
+ // Define expected states for a tool use flow
+ expectedToolStates := []State{
+ StateWaitingForUserInput,
+ StateSendingToLLM,
+ StateProcessingLLMResponse,
+ StateToolUseRequested,
+ StateCheckingForCancellation,
+ StateRunningTool,
+ }
+
+ // Verify that these states are present in order
+ for i, expectedState := range expectedToolStates {
+ if i >= len(transitions) {
+ t.Errorf("Missing expected transition to %s; only got %d transitions", expectedState, len(transitions))
+ continue
+ }
+ if transitions[i] != expectedState {
+ t.Errorf("Expected transition %d to be %s, got %s", i, expectedState, transitions[i])
+ }
+ }
+
+ // Also verify we eventually reached EndOfTurn
+ if !slices.Contains(transitions, StateEndOfTurn) {
+ t.Errorf("Expected to eventually reach StateEndOfTurn, but never did")
+ }
+}