Add StateMachine tracking to Agent control flow
This commit integrates the existing StateMachine type with the Agent
to provide state tracking and validation throughout the conversation
control flow. This allows for better monitoring and debugging of the
Agent's behavior during execution.
This commit adds tests to verify the correct behavior of the state
machine integration with the Agent type:
1. A test for basic state transitions
2. A test for a complete processTurn flow with a simple response
3. A test for a processTurn flow with tool use
Co-Authored-By: sketch <hello@sketch.dev>
diff --git a/loop/agent.go b/loop/agent.go
index b698d22..b12a5f5 100644
--- a/loop/agent.go
+++ b/loop/agent.go
@@ -280,6 +280,8 @@
title string
branchName string
codereview *claudetool.CodeReviewer
+ // State machine to track agent state
+ stateMachine *StateMachine
// Outside information
outsideHostname string
outsideOS string
@@ -391,6 +393,11 @@
return a.gitOrigin
}
+// CurrentState returns the current state of the agent's state machine.
+func (a *Agent) CurrentState() State {
+ return a.stateMachine.CurrentState()
+}
+
func (a *Agent) IsInContainer() bool {
return a.config.InDocker
}
@@ -577,6 +584,7 @@
outsideWorkingDir: config.OutsideWorkingDir,
outstandingLLMCalls: make(map[string]struct{}),
outstandingToolCalls: make(map[string]string),
+ stateMachine: NewStateMachine(),
}
return agent
}
@@ -827,6 +835,9 @@
a.cancelTurnMu.Lock()
defer a.cancelTurnMu.Unlock()
if a.cancelTurn != nil {
+ // Force state transition to cancelled state
+ ctx := a.config.Context
+ a.stateMachine.ForceTransition(ctx, StateCancelled, "User cancelled turn: "+cause.Error())
a.cancelTurn(cause)
}
}
@@ -906,15 +917,21 @@
// Reset the start of turn time
a.startOfTurn = time.Now()
+ // Transition to waiting for user input state
+ a.stateMachine.Transition(ctx, StateWaitingForUserInput, "Starting turn")
+
// Process initial user message
initialResp, err := a.processUserMessage(ctx)
if err != nil {
+ a.stateMachine.Transition(ctx, StateError, "Error processing user message: "+err.Error())
return err
}
// Handle edge case where both initialResp and err are nil
if initialResp == nil {
err := fmt.Errorf("unexpected nil response from processUserMessage with no error")
+ a.stateMachine.Transition(ctx, StateError, "Error processing user message: "+err.Error())
+
a.pushToOutbox(ctx, errorMessage(err))
return err
}
@@ -932,14 +949,19 @@
for {
// Check if we are over budget
if err := a.overBudget(ctx); err != nil {
+ a.stateMachine.Transition(ctx, StateBudgetExceeded, "Budget exceeded: "+err.Error())
return err
}
// If the model is not requesting to use a tool, we're done
if resp.StopReason != ant.StopReasonToolUse {
+ a.stateMachine.Transition(ctx, StateEndOfTurn, "LLM completed response, ending turn")
break
}
+ // Transition to tool use requested state
+ a.stateMachine.Transition(ctx, StateToolUseRequested, "LLM requested tool use")
+
// Handle tool execution
continueConversation, toolResp := a.handleToolExecution(ctx, resp)
if !continueConversation {
@@ -958,6 +980,7 @@
// Wait for at least one message from the user
msgs, err := a.GatherMessages(ctx, true)
if err != nil { // e.g. the context was canceled while blocking in GatherMessages
+ a.stateMachine.Transition(ctx, StateError, "Error gathering messages: "+err.Error())
return nil, err
}
@@ -966,13 +989,20 @@
Content: msgs,
}
+ // Transition to sending to LLM state
+ a.stateMachine.Transition(ctx, StateSendingToLLM, "Sending user message to LLM")
+
// Send message to the model
resp, err := a.convo.SendMessage(userMessage)
if err != nil {
+ a.stateMachine.Transition(ctx, StateError, "Error sending to LLM: "+err.Error())
a.pushToOutbox(ctx, errorMessage(err))
return nil, err
}
+ // Transition to processing LLM response state
+ a.stateMachine.Transition(ctx, StateProcessingLLMResponse, "Processing LLM response")
+
return resp, nil
}
@@ -981,6 +1011,9 @@
var results []ant.Content
cancelled := false
+ // Transition to checking for cancellation state
+ a.stateMachine.Transition(ctx, StateCheckingForCancellation, "Checking if user requested cancellation")
+
// Check if the operation was cancelled by the user
select {
case <-ctx.Done():
@@ -989,10 +1022,15 @@
var err error
results, err = a.convo.ToolResultCancelContents(resp)
if err != nil {
+ a.stateMachine.Transition(ctx, StateError, "Error creating cancellation response: "+err.Error())
a.pushToOutbox(ctx, errorMessage(err))
}
cancelled = true
+ a.stateMachine.Transition(ctx, StateCancelled, "Operation cancelled by user")
default:
+ // Transition to running tool state
+ a.stateMachine.Transition(ctx, StateRunningTool, "Executing requested tool")
+
// Add working directory to context for tool execution
ctx = claudetool.WithWorkingDir(ctx, a.workingDir)
@@ -1001,16 +1039,21 @@
results, err = a.convo.ToolResultContents(ctx, resp)
if ctx.Err() != nil { // e.g. the user canceled the operation
cancelled = true
+ a.stateMachine.Transition(ctx, StateCancelled, "Operation cancelled during tool execution")
} else if err != nil {
+ a.stateMachine.Transition(ctx, StateError, "Error executing tool: "+err.Error())
a.pushToOutbox(ctx, errorMessage(err))
}
}
// Process git commits that may have occurred during tool execution
+ a.stateMachine.Transition(ctx, StateCheckingGitCommits, "Checking for git commits")
autoqualityMessages := a.processGitChanges(ctx)
// Check budget again after tool execution
+ a.stateMachine.Transition(ctx, StateCheckingBudget, "Checking budget after tool execution")
if err := a.overBudget(ctx); err != nil {
+ a.stateMachine.Transition(ctx, StateBudgetExceeded, "Budget exceeded after tool execution: "+err.Error())
return false, nil
}
@@ -1031,6 +1074,7 @@
// Run autoformatters if there was exactly one new commit
var autoqualityMessages []string
if len(newCommits) == 1 {
+ a.stateMachine.Transition(ctx, StateRunningAutoformatters, "Running autoformatters on new commit")
formatted := a.codereview.Autoformat(ctx)
if len(formatted) > 0 {
msg := fmt.Sprintf(`
@@ -1056,8 +1100,10 @@
// continueTurnWithToolResults continues the conversation with tool results
func (a *Agent) continueTurnWithToolResults(ctx context.Context, results []ant.Content, autoqualityMessages []string, cancelled bool) (bool, *ant.MessageResponse) {
// Get any messages the user sent while tools were executing
+ a.stateMachine.Transition(ctx, StateGatheringAdditionalMessages, "Gathering additional user messages")
msgs, err := a.GatherMessages(ctx, false)
if err != nil {
+ a.stateMachine.Transition(ctx, StateError, "Error gathering additional messages: "+err.Error())
return false, nil
}
@@ -1083,15 +1129,20 @@
results = append(results, msgs...)
// Send the combined message to continue the conversation
+ a.stateMachine.Transition(ctx, StateSendingToolResults, "Sending tool results back to LLM")
resp, err := a.convo.SendMessage(ant.Message{
Role: "user",
Content: results,
})
if err != nil {
+ a.stateMachine.Transition(ctx, StateError, "Error sending tool results: "+err.Error())
a.pushToOutbox(ctx, errorMessage(fmt.Errorf("error: failed to continue conversation: %s", err.Error())))
return true, nil // Return true to continue the conversation, but with no response
}
+ // Transition back to processing LLM response
+ a.stateMachine.Transition(ctx, StateProcessingLLMResponse, "Processing LLM response to tool results")
+
if cancelled {
return false, nil
}
@@ -1101,6 +1152,7 @@
func (a *Agent) overBudget(ctx context.Context) error {
if err := a.convo.OverBudget(); err != nil {
+ a.stateMachine.Transition(ctx, StateBudgetExceeded, "Budget exceeded: "+err.Error())
m := budgetMessage(err)
m.Content = m.Content + "\n\nBudget reset."
a.pushToOutbox(ctx, budgetMessage(err))
diff --git a/loop/agent_test.go b/loop/agent_test.go
index bde0d20..c62fd21 100644
--- a/loop/agent_test.go
+++ b/loop/agent_test.go
@@ -5,6 +5,7 @@
"fmt"
"net/http"
"os"
+ "slices"
"strings"
"testing"
"time"
@@ -158,6 +159,7 @@
agent := &Agent{
outstandingLLMCalls: make(map[string]struct{}),
outstandingToolCalls: make(map[string]string),
+ stateMachine: NewStateMachine(),
}
// Check initial state
@@ -392,3 +394,292 @@
t.Error("Timed out waiting for error message in outbox")
}
}
+
+func TestAgentStateMachine(t *testing.T) {
+ // Create a simplified test for the state machine functionality
+ agent := &Agent{
+ stateMachine: NewStateMachine(),
+ }
+
+ // Initially the state should be Ready
+ if state := agent.CurrentState(); state != StateReady {
+ t.Errorf("Expected initial state to be StateReady, got %s", state)
+ }
+
+ // Test manual transitions to verify state tracking
+ ctx := context.Background()
+
+ // Track transitions
+ var transitions []State
+ agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
+ transitions = append(transitions, to)
+ t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
+ })
+
+ // Perform a valid sequence of transitions (based on the state machine rules)
+ expectedStates := []State{
+ StateWaitingForUserInput,
+ StateSendingToLLM,
+ StateProcessingLLMResponse,
+ StateToolUseRequested,
+ StateCheckingForCancellation,
+ StateRunningTool,
+ StateCheckingGitCommits,
+ StateRunningAutoformatters,
+ StateCheckingBudget,
+ StateGatheringAdditionalMessages,
+ StateSendingToolResults,
+ StateProcessingLLMResponse,
+ StateEndOfTurn,
+ }
+
+ // Manually perform each transition
+ for _, state := range expectedStates {
+ err := agent.stateMachine.Transition(ctx, state, "Test transition to "+state.String())
+ if err != nil {
+ t.Errorf("Failed to transition to %s: %v", state, err)
+ }
+ }
+
+ // Check if we recorded the right number of transitions
+ if len(transitions) != len(expectedStates) {
+ t.Errorf("Expected %d state transitions, got %d", len(expectedStates), len(transitions))
+ }
+
+ // Check each transition matched what we expected
+ for i, expected := range expectedStates {
+ if i < len(transitions) {
+ if transitions[i] != expected {
+ t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
+ }
+ }
+ }
+
+ // Verify the current state is the last one we transitioned to
+ if state := agent.CurrentState(); state != expectedStates[len(expectedStates)-1] {
+ t.Errorf("Expected current state to be %s, got %s", expectedStates[len(expectedStates)-1], state)
+ }
+
+ // Test force transition
+ agent.stateMachine.ForceTransition(ctx, StateCancelled, "Testing force transition")
+
+ // Verify current state was updated
+ if state := agent.CurrentState(); state != StateCancelled {
+ t.Errorf("Expected forced state to be StateCancelled, got %s", state)
+ }
+}
+
+// mockConvoInterface is a mock implementation of ConvoInterface for testing
+type mockConvoInterface struct {
+ SendMessageFunc func(message ant.Message) (*ant.MessageResponse, error)
+ ToolResultContentsFunc func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
+}
+
+func (c *mockConvoInterface) GetID() string {
+ return "mockConvoInterface-id"
+}
+
+func (c *mockConvoInterface) SubConvoWithHistory() *ant.Convo {
+ return nil
+}
+
+func (m *mockConvoInterface) CumulativeUsage() ant.CumulativeUsage {
+ return ant.CumulativeUsage{}
+}
+
+func (m *mockConvoInterface) ResetBudget(ant.Budget) {}
+
+func (m *mockConvoInterface) OverBudget() error {
+ return nil
+}
+
+func (m *mockConvoInterface) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
+ if m.SendMessageFunc != nil {
+ return m.SendMessageFunc(message)
+ }
+ return &ant.MessageResponse{StopReason: ant.StopReasonEndTurn}, nil
+}
+
+func (m *mockConvoInterface) SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
+ return m.SendMessage(ant.Message{Role: "user", Content: []ant.Content{{Type: "text", Text: s}}})
+}
+
+func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
+ if m.ToolResultContentsFunc != nil {
+ return m.ToolResultContentsFunc(ctx, resp)
+ }
+ return []ant.Content{}, nil
+}
+
+func (m *mockConvoInterface) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
+ return []ant.Content{{Type: "text", Text: "Tool use cancelled"}}, nil
+}
+
+func (m *mockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
+ return nil
+}
+
+func TestAgentProcessTurnStateTransitions(t *testing.T) {
+ // Create a mock ConvoInterface for testing
+ mockConvo := &mockConvoInterface{}
+
+ // Use the testing context
+ ctx := t.Context()
+
+ // Create an agent with the state machine
+ agent := &Agent{
+ convo: mockConvo,
+ config: AgentConfig{Context: ctx},
+ inbox: make(chan string, 10),
+ outbox: make(chan AgentMessage, 10),
+ ready: make(chan struct{}),
+ seenCommits: make(map[string]bool),
+ outstandingLLMCalls: make(map[string]struct{}),
+ outstandingToolCalls: make(map[string]string),
+ stateMachine: NewStateMachine(),
+ startOfTurn: time.Now(),
+ }
+
+ // Verify initial state
+ if state := agent.CurrentState(); state != StateReady {
+ t.Errorf("Expected initial state to be StateReady, got %s", state)
+ }
+
+ // Add a message to the inbox so we don't block in GatherMessages
+ agent.inbox <- "Test message"
+
+ // Setup the mock to simulate a model response with end of turn
+ mockConvo.SendMessageFunc = func(message ant.Message) (*ant.MessageResponse, error) {
+ return &ant.MessageResponse{
+ StopReason: ant.StopReasonEndTurn,
+ Content: []ant.Content{
+ {Type: "text", Text: "This is a test response"},
+ },
+ }, nil
+ }
+
+ // Track state transitions
+ var transitions []State
+ agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
+ transitions = append(transitions, to)
+ t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
+ })
+
+ // Process a turn, which should trigger state transitions
+ agent.processTurn(ctx)
+
+ // The minimum expected states for a simple end-of-turn response
+ minExpectedStates := []State{
+ StateWaitingForUserInput,
+ StateSendingToLLM,
+ StateProcessingLLMResponse,
+ StateEndOfTurn,
+ }
+
+ // Verify we have at least the minimum expected states
+ if len(transitions) < len(minExpectedStates) {
+ t.Errorf("Expected at least %d state transitions, got %d", len(minExpectedStates), len(transitions))
+ }
+
+ // Check that the transitions follow the expected sequence
+ for i, expected := range minExpectedStates {
+ if i < len(transitions) {
+ if transitions[i] != expected {
+ t.Errorf("Transition %d: expected %s, got %s", i, expected, transitions[i])
+ }
+ }
+ }
+
+ // Verify the final state is EndOfTurn
+ if state := agent.CurrentState(); state != StateEndOfTurn {
+ t.Errorf("Expected final state to be StateEndOfTurn, got %s", state)
+ }
+}
+
+func TestAgentProcessTurnWithToolUse(t *testing.T) {
+ // Create a mock ConvoInterface for testing
+ mockConvo := &mockConvoInterface{}
+
+ // Setup a test context
+ ctx := context.Background()
+
+ // Create an agent with the state machine
+ agent := &Agent{
+ convo: mockConvo,
+ config: AgentConfig{Context: ctx},
+ inbox: make(chan string, 10),
+ outbox: make(chan AgentMessage, 10),
+ ready: make(chan struct{}),
+ seenCommits: make(map[string]bool),
+ outstandingLLMCalls: make(map[string]struct{}),
+ outstandingToolCalls: make(map[string]string),
+ stateMachine: NewStateMachine(),
+ startOfTurn: time.Now(),
+ }
+
+ // Add a message to the inbox so we don't block in GatherMessages
+ agent.inbox <- "Test message"
+
+ // First response requests a tool
+ firstResponseDone := false
+ mockConvo.SendMessageFunc = func(message ant.Message) (*ant.MessageResponse, error) {
+ if !firstResponseDone {
+ firstResponseDone = true
+ return &ant.MessageResponse{
+ StopReason: ant.StopReasonToolUse,
+ Content: []ant.Content{
+ {Type: "text", Text: "I'll use a tool"},
+ {Type: "tool_use", ToolName: "test_tool", ToolInput: []byte("{}"), ID: "test_id"},
+ },
+ }, nil
+ }
+ // Second response ends the turn
+ return &ant.MessageResponse{
+ StopReason: ant.StopReasonEndTurn,
+ Content: []ant.Content{
+ {Type: "text", Text: "Finished using the tool"},
+ },
+ }, nil
+ }
+
+ // Tool result content handler
+ mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
+ return []ant.Content{{Type: "text", Text: "Tool executed successfully"}}, nil
+ }
+
+ // Track state transitions
+ var transitions []State
+ agent.stateMachine.SetTransitionCallback(func(ctx context.Context, from, to State, event TransitionEvent) {
+ transitions = append(transitions, to)
+ t.Logf("State transition: %s -> %s (%s)", from, to, event.Description)
+ })
+
+ // Process a turn with tool use
+ agent.processTurn(ctx)
+
+ // Define expected states for a tool use flow
+ expectedToolStates := []State{
+ StateWaitingForUserInput,
+ StateSendingToLLM,
+ StateProcessingLLMResponse,
+ StateToolUseRequested,
+ StateCheckingForCancellation,
+ StateRunningTool,
+ }
+
+ // Verify that these states are present in order
+ for i, expectedState := range expectedToolStates {
+ if i >= len(transitions) {
+ t.Errorf("Missing expected transition to %s; only got %d transitions", expectedState, len(transitions))
+ continue
+ }
+ if transitions[i] != expectedState {
+ t.Errorf("Expected transition %d to be %s, got %s", i, expectedState, transitions[i])
+ }
+ }
+
+ // Also verify we eventually reached EndOfTurn
+ if !slices.Contains(transitions, StateEndOfTurn) {
+ t.Errorf("Expected to eventually reach StateEndOfTurn, but never did")
+ }
+}
diff --git a/loop/statemachine.go b/loop/statemachine.go
new file mode 100644
index 0000000..9a520f6
--- /dev/null
+++ b/loop/statemachine.go
@@ -0,0 +1,513 @@
+package loop
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "sync"
+ "time"
+)
+
+// State represents the possible states of the Agent state machine
+type State int
+
+const (
+ // StateUnknown is the default state
+ StateUnknown State = iota
+ // StateReady is the initial state when the agent is initialized and ready to operate
+ StateReady
+ // StateWaitingForUserInput occurs when the agent is waiting for a user message to start a turn
+ StateWaitingForUserInput
+ // StateSendingToLLM occurs when the agent is sending message(s) to the LLM
+ StateSendingToLLM
+ // StateProcessingLLMResponse occurs when the agent is processing a response from the LLM
+ StateProcessingLLMResponse
+ // StateEndOfTurn occurs when processing is completed without tool use, and the turn ends
+ StateEndOfTurn
+ // StateToolUseRequested occurs when the LLM has requested to use a tool
+ StateToolUseRequested
+ // StateCheckingForCancellation occurs when the agent checks if user requested cancellation
+ StateCheckingForCancellation
+ // StateRunningTool occurs when the agent is executing the requested tool
+ StateRunningTool
+ // StateCheckingGitCommits occurs when the agent checks for new git commits after tool execution
+ StateCheckingGitCommits
+ // StateRunningAutoformatters occurs when the agent runs code formatters on new commits
+ StateRunningAutoformatters
+ // StateCheckingBudget occurs when the agent verifies if budget limits are exceeded
+ StateCheckingBudget
+ // StateGatheringAdditionalMessages occurs when the agent collects user messages that arrived during tool execution
+ StateGatheringAdditionalMessages
+ // StateSendingToolResults occurs when the agent sends tool results back to the LLM
+ StateSendingToolResults
+ // StateCancelled occurs when an operation was cancelled by the user
+ StateCancelled
+ // StateBudgetExceeded occurs when the budget limit was reached
+ StateBudgetExceeded
+ // StateError occurs when an error occurred during processing
+ StateError
+)
+
+// String returns a string representation of the State for logging and debugging
+func (s State) String() string {
+ switch s {
+ case StateUnknown:
+ return "Unknown"
+ case StateReady:
+ return "Ready"
+ case StateWaitingForUserInput:
+ return "WaitingForUserInput"
+ case StateSendingToLLM:
+ return "SendingToLLM"
+ case StateProcessingLLMResponse:
+ return "ProcessingLLMResponse"
+ case StateEndOfTurn:
+ return "EndOfTurn"
+ case StateToolUseRequested:
+ return "ToolUseRequested"
+ case StateCheckingForCancellation:
+ return "CheckingForCancellation"
+ case StateRunningTool:
+ return "RunningTool"
+ case StateCheckingGitCommits:
+ return "CheckingGitCommits"
+ case StateRunningAutoformatters:
+ return "RunningAutoformatters"
+ case StateCheckingBudget:
+ return "CheckingBudget"
+ case StateGatheringAdditionalMessages:
+ return "GatheringAdditionalMessages"
+ case StateSendingToolResults:
+ return "SendingToolResults"
+ case StateCancelled:
+ return "Cancelled"
+ case StateBudgetExceeded:
+ return "BudgetExceeded"
+ case StateError:
+ return "Error"
+ default:
+ return fmt.Sprintf("Unknown(%d)", int(s))
+ }
+}
+
+// TransitionEvent represents an event that causes a state transition
+type TransitionEvent struct {
+ // Description provides a human-readable description of the event
+ Description string
+ // Data can hold any additional information about the event
+ Data interface{}
+ // Timestamp is when the event occurred
+ Timestamp time.Time
+}
+
+// StateTransition represents a transition from one state to another
+type StateTransition struct {
+ From State
+ To State
+ Event TransitionEvent
+}
+
+// StateMachine manages the Agent's states and transitions
+type StateMachine struct {
+ // mu protects all fields of the StateMachine from concurrent access
+ mu sync.RWMutex
+ // currentState is the current state of the state machine
+ currentState State
+ // previousState is the previous state of the state machine
+ previousState State
+ // stateEnteredAt is when the current state was entered
+ stateEnteredAt time.Time
+ // transitions maps from states to the states they can transition to
+ transitions map[State]map[State]bool
+ // history records the history of state transitions
+ history []StateTransition
+ // maxHistorySize limits the number of transitions to keep in history
+ maxHistorySize int
+ // eventListeners are notified when state transitions occur
+ eventListeners []chan<- StateTransition
+ // onTransition is a callback function that's called when a transition occurs
+ onTransition func(ctx context.Context, from, to State, event TransitionEvent)
+}
+
+// NewStateMachine creates a new state machine initialized to StateReady
+func NewStateMachine() *StateMachine {
+ sm := &StateMachine{
+ currentState: StateReady,
+ previousState: StateUnknown,
+ stateEnteredAt: time.Now(),
+ transitions: make(map[State]map[State]bool),
+ maxHistorySize: 100,
+ eventListeners: make([]chan<- StateTransition, 0),
+ }
+
+ // Initialize valid transitions
+ sm.initTransitions()
+
+ return sm
+}
+
+// SetMaxHistorySize sets the maximum number of transitions to keep in history
+func (sm *StateMachine) SetMaxHistorySize(size int) {
+ if size < 1 {
+ size = 1 // Ensure we keep at least one entry
+ }
+
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+
+ sm.maxHistorySize = size
+
+ // Trim history if needed
+ if len(sm.history) > sm.maxHistorySize {
+ sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
+ }
+}
+
+// AddTransitionListener adds a listener channel that will be notified of state transitions
+// Returns a function that can be called to remove the listener
+func (sm *StateMachine) AddTransitionListener(listener chan<- StateTransition) func() {
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+
+ sm.eventListeners = append(sm.eventListeners, listener)
+
+ // Return a function to remove this listener
+ return func() {
+ sm.RemoveTransitionListener(listener)
+ }
+}
+
+// RemoveTransitionListener removes a previously added listener
+func (sm *StateMachine) RemoveTransitionListener(listener chan<- StateTransition) {
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+
+ for i, l := range sm.eventListeners {
+ if l == listener {
+ // Remove by swapping with the last element and then truncating
+ lastIdx := len(sm.eventListeners) - 1
+ sm.eventListeners[i] = sm.eventListeners[lastIdx]
+ sm.eventListeners = sm.eventListeners[:lastIdx]
+ break
+ }
+ }
+}
+
+// SetTransitionCallback sets a function to be called on every state transition
+func (sm *StateMachine) SetTransitionCallback(callback func(ctx context.Context, from, to State, event TransitionEvent)) {
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+
+ sm.onTransition = callback
+}
+
+// ClearTransitionCallback removes any previously set transition callback
+func (sm *StateMachine) ClearTransitionCallback() {
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+
+ sm.onTransition = nil
+}
+
+// initTransitions initializes the map of valid state transitions
+func (sm *StateMachine) initTransitions() {
+ // Helper function to add transitions
+ addTransition := func(from State, to ...State) {
+ // Initialize the map for this state if it doesn't exist
+ if _, exists := sm.transitions[from]; !exists {
+ sm.transitions[from] = make(map[State]bool)
+ }
+
+ // Add all of the 'to' states
+ for _, toState := range to {
+ sm.transitions[from][toState] = true
+ }
+ }
+
+ // Define valid transitions based on the state machine diagram
+
+ // Initial state
+ addTransition(StateReady, StateWaitingForUserInput)
+
+ // Main flow
+ addTransition(StateWaitingForUserInput, StateSendingToLLM, StateError)
+ addTransition(StateSendingToLLM, StateProcessingLLMResponse, StateError)
+ addTransition(StateProcessingLLMResponse, StateEndOfTurn, StateToolUseRequested, StateError)
+ addTransition(StateEndOfTurn, StateWaitingForUserInput)
+
+ // Tool use flow
+ addTransition(StateToolUseRequested, StateCheckingForCancellation)
+ addTransition(StateCheckingForCancellation, StateRunningTool, StateCancelled)
+ addTransition(StateRunningTool, StateCheckingGitCommits, StateError)
+ addTransition(StateCheckingGitCommits, StateRunningAutoformatters, StateCheckingBudget)
+ addTransition(StateRunningAutoformatters, StateCheckingBudget)
+ addTransition(StateCheckingBudget, StateGatheringAdditionalMessages, StateBudgetExceeded)
+ addTransition(StateGatheringAdditionalMessages, StateSendingToolResults, StateError)
+ addTransition(StateSendingToolResults, StateProcessingLLMResponse, StateError)
+
+ // Terminal states to new turn
+ addTransition(StateCancelled, StateWaitingForUserInput)
+ addTransition(StateBudgetExceeded, StateWaitingForUserInput)
+ addTransition(StateError, StateWaitingForUserInput)
+}
+
+// Transition attempts to transition from the current state to the given state
+func (sm *StateMachine) Transition(ctx context.Context, newState State, event string) error {
+ if sm == nil {
+ return fmt.Errorf("nil StateMachine pointer")
+ }
+ transitionEvent := TransitionEvent{
+ Description: event,
+ Timestamp: time.Now(),
+ }
+ return sm.TransitionWithEvent(ctx, newState, transitionEvent)
+}
+
+// TransitionWithEvent attempts to transition from the current state to the given state
+// with the provided event information
+func (sm *StateMachine) TransitionWithEvent(ctx context.Context, newState State, event TransitionEvent) error {
+ // First check if the transition is valid without holding the write lock
+ sm.mu.RLock()
+ currentState := sm.currentState
+ canTransition := false
+ if validToStates, exists := sm.transitions[currentState]; exists {
+ canTransition = validToStates[newState]
+ }
+ sm.mu.RUnlock()
+
+ if !canTransition {
+ return fmt.Errorf("invalid transition from %s to %s", currentState, newState)
+ }
+
+ // Acquire write lock for the actual transition
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+
+ // Double-check that the state hasn't changed since we checked
+ if sm.currentState != currentState {
+ // State changed between our check and lock acquisition
+ // Re-check if the transition is still valid
+ if validToStates, exists := sm.transitions[sm.currentState]; !exists || !validToStates[newState] {
+ return fmt.Errorf("concurrent state change detected: invalid transition from current %s to %s",
+ sm.currentState, newState)
+ }
+ }
+
+ // Calculate duration in current state
+ duration := time.Since(sm.stateEnteredAt)
+
+ // Record the transition
+ transition := StateTransition{
+ From: sm.currentState,
+ To: newState,
+ Event: event,
+ }
+
+ // Update state
+ sm.previousState = sm.currentState
+ sm.currentState = newState
+ sm.stateEnteredAt = time.Now()
+
+ // Add to history
+ sm.history = append(sm.history, transition)
+
+ // Trim history if it exceeds maximum size
+ if len(sm.history) > sm.maxHistorySize {
+ sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
+ }
+
+ // Make a local copy of any callback functions to invoke outside the lock
+ var onTransition func(ctx context.Context, from, to State, event TransitionEvent)
+ var eventListenersCopy []chan<- StateTransition
+ if sm.onTransition != nil {
+ onTransition = sm.onTransition
+ }
+ if len(sm.eventListeners) > 0 {
+ eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners))
+ copy(eventListenersCopy, sm.eventListeners)
+ }
+
+ // Log the transition
+ slog.InfoContext(ctx, "State transition",
+ "from", sm.previousState,
+ "to", sm.currentState,
+ "event", event.Description,
+ "duration", duration)
+
+ // Release the lock before notifying listeners to avoid deadlocks
+ sm.mu.Unlock()
+
+ // Notify listeners if any
+ if onTransition != nil {
+ onTransition(ctx, sm.previousState, sm.currentState, event)
+ }
+
+ for _, ch := range eventListenersCopy {
+ select {
+ case ch <- transition:
+ // Successfully sent
+ default:
+ // Channel buffer full or no receiver, log and continue
+ slog.WarnContext(ctx, "Failed to notify state transition listener",
+ "from", sm.previousState, "to", sm.currentState)
+ }
+ }
+
+ // Re-acquire the lock that we explicitly released above
+ sm.mu.Lock()
+ return nil
+}
+
+// CurrentState returns the current state
+func (sm *StateMachine) CurrentState() State {
+ sm.mu.RLock()
+ defer sm.mu.RUnlock()
+ return sm.currentState
+}
+
+// PreviousState returns the previous state
+func (sm *StateMachine) PreviousState() State {
+ sm.mu.RLock()
+ defer sm.mu.RUnlock()
+ return sm.previousState
+}
+
+// TimeInState returns how long the machine has been in the current state
+func (sm *StateMachine) TimeInState() time.Duration {
+ sm.mu.RLock()
+ enteredAt := sm.stateEnteredAt
+ sm.mu.RUnlock()
+ return time.Since(enteredAt)
+}
+
+// CanTransition returns true if a transition from the from state to the to state is valid
+func (sm *StateMachine) CanTransition(from, to State) bool {
+ sm.mu.RLock()
+ defer sm.mu.RUnlock()
+ if validToStates, exists := sm.transitions[from]; exists {
+ return validToStates[to]
+ }
+ return false
+}
+
+// History returns the transition history of the state machine
+func (sm *StateMachine) History() []StateTransition {
+ sm.mu.RLock()
+ defer sm.mu.RUnlock()
+
+ // Return a copy to prevent modification
+ historyCopy := make([]StateTransition, len(sm.history))
+ copy(historyCopy, sm.history)
+ return historyCopy
+}
+
+// Reset resets the state machine to the initial ready state
+func (sm *StateMachine) Reset() {
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+
+ sm.currentState = StateReady
+ sm.previousState = StateUnknown
+ sm.stateEnteredAt = time.Now()
+}
+
+// IsInTerminalState returns whether the current state is a terminal state
+func (sm *StateMachine) IsInTerminalState() bool {
+ sm.mu.RLock()
+ defer sm.mu.RUnlock()
+
+ switch sm.currentState {
+ case StateEndOfTurn, StateCancelled, StateBudgetExceeded, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
+// IsInErrorState returns whether the current state is an error state
+func (sm *StateMachine) IsInErrorState() bool {
+ sm.mu.RLock()
+ defer sm.mu.RUnlock()
+
+ switch sm.currentState {
+ case StateError, StateCancelled, StateBudgetExceeded:
+ return true
+ default:
+ return false
+ }
+}
+
+// ForceTransition forces a transition regardless of whether it's valid according to the state machine rules
+// This should be used only in critical situations like cancellation or error recovery
+func (sm *StateMachine) ForceTransition(ctx context.Context, newState State, reason string) {
+ event := TransitionEvent{
+ Description: fmt.Sprintf("Forced transition: %s", reason),
+ Timestamp: time.Now(),
+ }
+
+ sm.mu.Lock()
+
+ // Calculate duration in current state
+ duration := time.Since(sm.stateEnteredAt)
+
+ // Record the transition
+ transition := StateTransition{
+ From: sm.currentState,
+ To: newState,
+ Event: event,
+ }
+
+ // Update state
+ sm.previousState = sm.currentState
+ sm.currentState = newState
+ sm.stateEnteredAt = time.Now()
+
+ // Add to history
+ sm.history = append(sm.history, transition)
+
+ // Trim history if it exceeds maximum size
+ if len(sm.history) > sm.maxHistorySize {
+ sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
+ }
+
+ // Make a local copy of any callback functions to invoke outside the lock
+ var onTransition func(ctx context.Context, from, to State, event TransitionEvent)
+ var eventListenersCopy []chan<- StateTransition
+ if sm.onTransition != nil {
+ onTransition = sm.onTransition
+ }
+ if len(sm.eventListeners) > 0 {
+ eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners))
+ copy(eventListenersCopy, sm.eventListeners)
+ }
+
+ // Log the transition
+ slog.WarnContext(ctx, "Forced state transition",
+ "from", sm.previousState,
+ "to", sm.currentState,
+ "reason", reason,
+ "duration", duration)
+
+ // Release the lock before notifying listeners to avoid deadlocks
+ sm.mu.Unlock()
+
+ // Notify listeners if any
+ if onTransition != nil {
+ onTransition(ctx, sm.previousState, sm.currentState, event)
+ }
+
+ for _, ch := range eventListenersCopy {
+ select {
+ case ch <- transition:
+ // Successfully sent
+ default:
+ // Channel buffer full or no receiver, log and continue
+ slog.WarnContext(ctx, "Failed to notify state transition listener for forced transition",
+ "from", sm.previousState, "to", sm.currentState)
+ }
+ }
+
+ // Re-acquire the lock
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+}
diff --git a/loop/statemachine_diagram.md b/loop/statemachine_diagram.md
new file mode 100644
index 0000000..82a160c
--- /dev/null
+++ b/loop/statemachine_diagram.md
@@ -0,0 +1,80 @@
+# Agent State Machine Diagram
+
+```mermaid
+stateDiagram-v2
+ [*] --> StateReady
+
+ StateReady --> StateWaitingForUserInput
+
+ StateWaitingForUserInput --> StateSendingToLLM
+ StateWaitingForUserInput --> StateError
+
+ StateSendingToLLM --> StateProcessingLLMResponse
+ StateSendingToLLM --> StateError
+
+ StateProcessingLLMResponse --> StateEndOfTurn
+ StateProcessingLLMResponse --> StateToolUseRequested
+ StateProcessingLLMResponse --> StateError
+
+ StateEndOfTurn --> StateWaitingForUserInput
+
+ StateToolUseRequested --> StateCheckingForCancellation
+
+ StateCheckingForCancellation --> StateRunningTool
+ StateCheckingForCancellation --> StateCancelled
+
+ StateRunningTool --> StateCheckingGitCommits
+ StateRunningTool --> StateError
+
+ StateCheckingGitCommits --> StateRunningAutoformatters
+ StateCheckingGitCommits --> StateCheckingBudget
+
+ StateRunningAutoformatters --> StateCheckingBudget
+
+ StateCheckingBudget --> StateGatheringAdditionalMessages
+ StateCheckingBudget --> StateBudgetExceeded
+
+ StateGatheringAdditionalMessages --> StateSendingToolResults
+ StateGatheringAdditionalMessages --> StateError
+
+ StateSendingToolResults --> StateProcessingLLMResponse
+ StateSendingToolResults --> StateError
+
+ StateError --> StateWaitingForUserInput
+ StateCancelled --> StateWaitingForUserInput
+ StateBudgetExceeded --> StateWaitingForUserInput
+```
+
+## State Descriptions
+
+| State | Description |
+|-------|-------------|
+| StateReady | Initial state when the agent is initialized and ready to operate |
+| StateWaitingForUserInput | Agent is waiting for a user message to start a turn |
+| StateSendingToLLM | Agent is sending message(s) to the LLM |
+| StateProcessingLLMResponse | Agent is processing a response from the LLM |
+| StateEndOfTurn | Processing completed without tool use, turn ends |
+| StateToolUseRequested | LLM has requested to use a tool |
+| StateCheckingForCancellation | Agent checks if user requested cancellation |
+| StateRunningTool | Agent is executing the requested tool |
+| StateCheckingGitCommits | Agent checks for new git commits after tool execution |
+| StateRunningAutoformatters | Agent runs code formatters on new commits |
+| StateCheckingBudget | Agent verifies if budget limits are exceeded |
+| StateGatheringAdditionalMessages | Agent collects user messages that arrived during tool execution |
+| StateSendingToolResults | Agent sends tool results back to the LLM |
+| StateCancelled | Operation was cancelled by the user |
+| StateBudgetExceeded | Budget limit was reached |
+| StateError | An error occurred during processing |
+
+## Implementation Details
+
+This state machine is implemented in `statemachine.go` and follows the State pattern design. Key features include:
+
+1. Explicit state enumeration for all possible states
+2. Validation of state transitions
+3. History tracking for debugging
+4. Event recording for each transition
+5. Timing information for performance analysis
+6. Error state detection
+
+The `AgentWithStateMachine` in `statemachine_example.go` demonstrates how this state machine could be integrated with the existing Agent implementation.
diff --git a/loop/statemachine_test.go b/loop/statemachine_test.go
new file mode 100644
index 0000000..70b5c9b
--- /dev/null
+++ b/loop/statemachine_test.go
@@ -0,0 +1,309 @@
+package loop
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestStateMachine(t *testing.T) {
+ ctx := context.Background()
+ sm := NewStateMachine()
+
+ // Check initial state
+ if sm.CurrentState() != StateReady {
+ t.Errorf("Initial state should be StateReady, got %s", sm.CurrentState())
+ }
+
+ // Test valid transition
+ err := sm.Transition(ctx, StateWaitingForUserInput, "Starting inner loop")
+ if err != nil {
+ t.Errorf("Error transitioning to StateWaitingForUserInput: %v", err)
+ }
+ if sm.CurrentState() != StateWaitingForUserInput {
+ t.Errorf("Current state should be StateWaitingForUserInput, got %s", sm.CurrentState())
+ }
+ if sm.PreviousState() != StateReady {
+ t.Errorf("Previous state should be StateReady, got %s", sm.PreviousState())
+ }
+
+ // Test invalid transition
+ err = sm.Transition(ctx, StateRunningAutoformatters, "Invalid transition")
+ if err == nil {
+ t.Error("Expected error for invalid transition but got nil")
+ }
+
+ // Verify state didn't change after invalid transition
+ if sm.CurrentState() != StateWaitingForUserInput {
+ t.Errorf("State should not have changed after invalid transition, got %s", sm.CurrentState())
+ }
+
+ // Test complete flow
+ transitions := []struct {
+ state State
+ event string
+ }{
+ {StateSendingToLLM, "Sending user message to LLM"},
+ {StateProcessingLLMResponse, "Processing LLM response"},
+ {StateToolUseRequested, "LLM requested tool use"},
+ {StateCheckingForCancellation, "Checking for user cancellation"},
+ {StateRunningTool, "Running tool"},
+ {StateCheckingGitCommits, "Checking for git commits"},
+ {StateCheckingBudget, "Checking budget"},
+ {StateGatheringAdditionalMessages, "Gathering additional messages"},
+ {StateSendingToolResults, "Sending tool results"},
+ {StateProcessingLLMResponse, "Processing LLM response"},
+ {StateEndOfTurn, "End of turn"},
+ {StateWaitingForUserInput, "Waiting for next user input"},
+ }
+
+ for i, tt := range transitions {
+ err := sm.Transition(ctx, tt.state, tt.event)
+ if err != nil {
+ t.Errorf("[%d] Error transitioning to %s: %v", i, tt.state, err)
+ }
+ if sm.CurrentState() != tt.state {
+ t.Errorf("[%d] Current state should be %s, got %s", i, tt.state, sm.CurrentState())
+ }
+ }
+
+ // Check if history was recorded correctly
+ history := sm.History()
+ expectedHistoryLen := len(transitions) + 1 // +1 for the initial transition
+ if len(history) != expectedHistoryLen {
+ t.Errorf("Expected history length %d, got %d", expectedHistoryLen, len(history))
+ }
+
+ // Check error state detection
+ err = sm.Transition(ctx, StateError, "An error occurred")
+ if err != nil {
+ t.Errorf("Error transitioning to StateError: %v", err)
+ }
+ if !sm.IsInErrorState() {
+ t.Error("IsInErrorState() should return true when in StateError")
+ }
+ if !sm.IsInTerminalState() {
+ t.Error("IsInTerminalState() should return true when in StateError")
+ }
+
+ // Test reset
+ sm.Reset()
+ if sm.CurrentState() != StateReady {
+ t.Errorf("After reset, state should be StateReady, got %s", sm.CurrentState())
+ }
+}
+
+func TestTimeInState(t *testing.T) {
+ sm := NewStateMachine()
+
+ // Ensure time in state increases
+ time.Sleep(50 * time.Millisecond)
+ timeInState := sm.TimeInState()
+ if timeInState < 50*time.Millisecond {
+ t.Errorf("Expected TimeInState() > 50ms, got %v", timeInState)
+ }
+}
+
+func TestTransitionEvent(t *testing.T) {
+ ctx := context.Background()
+ sm := NewStateMachine()
+
+ // Test transition with custom event
+ event := TransitionEvent{
+ Description: "Test event",
+ Data: map[string]string{"key": "value"},
+ Timestamp: time.Now(),
+ }
+
+ err := sm.TransitionWithEvent(ctx, StateWaitingForUserInput, event)
+ if err != nil {
+ t.Errorf("Error in TransitionWithEvent: %v", err)
+ }
+
+ // Check the event was recorded in history
+ history := sm.History()
+ if len(history) != 1 {
+ t.Fatalf("Expected history length 1, got %d", len(history))
+ }
+ if history[0].Event.Description != "Test event" {
+ t.Errorf("Expected event description 'Test event', got '%s'", history[0].Event.Description)
+ }
+}
+
+func TestConcurrentTransitions(t *testing.T) {
+ sm := NewStateMachine()
+ ctx := context.Background()
+
+ // Start with waiting for user input
+ sm.Transition(ctx, StateWaitingForUserInput, "Initial state")
+
+ // Set up a channel to receive transition events
+ events := make(chan StateTransition, 100)
+ removeListener := sm.AddTransitionListener(events)
+ defer removeListener()
+
+ // Launch goroutines to perform concurrent transitions
+ done := make(chan struct{})
+ var wg sync.WaitGroup
+ wg.Add(10)
+
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+
+ // Launch 10 goroutines that attempt to transition the state machine
+ for i := 0; i < 10; i++ {
+ go func(idx int) {
+ defer wg.Done()
+
+ // Each goroutine tries to make a valid transition from the current state
+ for j := 0; j < 10; j++ {
+ currentState := sm.CurrentState()
+ var nextState State
+
+ // Choose a valid next state based on current state
+ switch currentState {
+ case StateWaitingForUserInput:
+ nextState = StateSendingToLLM
+ case StateSendingToLLM:
+ nextState = StateProcessingLLMResponse
+ case StateProcessingLLMResponse:
+ nextState = StateToolUseRequested
+ case StateToolUseRequested:
+ nextState = StateCheckingForCancellation
+ case StateCheckingForCancellation:
+ nextState = StateRunningTool
+ case StateRunningTool:
+ nextState = StateCheckingGitCommits
+ case StateCheckingGitCommits:
+ nextState = StateCheckingBudget
+ case StateCheckingBudget:
+ nextState = StateGatheringAdditionalMessages
+ case StateGatheringAdditionalMessages:
+ nextState = StateSendingToolResults
+ case StateSendingToolResults:
+ nextState = StateProcessingLLMResponse
+ default:
+ // If in a state we don't know how to handle, reset to a known state
+ sm.ForceTransition(ctx, StateWaitingForUserInput, "Reset for test")
+ continue
+ }
+
+ // Try to transition and record success/failure
+ err := sm.Transition(ctx, nextState, fmt.Sprintf("Transition from goroutine %d", idx))
+ if err != nil {
+ // This is expected in concurrent scenarios - another goroutine might have
+ // changed the state between our check and transition attempt
+ time.Sleep(5 * time.Millisecond) // Back off a bit
+ }
+ }
+ }(i)
+ }
+
+ // Collect events until all goroutines are done
+ transitions := make([]StateTransition, 0)
+loop:
+ for {
+ select {
+ case evt := <-events:
+ transitions = append(transitions, evt)
+ case <-done:
+ // Collect any remaining events
+ for len(events) > 0 {
+ transitions = append(transitions, <-events)
+ }
+ break loop
+ }
+ }
+
+ // Get final history from state machine
+ history := sm.History()
+
+ // We may have missed some events due to channel buffer size and race conditions
+ // That's okay for this test - the main point is to verify thread safety
+ t.Logf("Collected %d events, history contains %d transitions",
+ len(transitions), len(history))
+
+ // Verify that all transitions in history are valid
+ for i := 1; i < len(history); i++ {
+ prev := history[i-1]
+ curr := history[i]
+
+ // Skip validating transitions if they're forced
+ if strings.HasPrefix(curr.Event.Description, "Forced transition") {
+ continue
+ }
+
+ if prev.To != curr.From {
+ t.Errorf("Invalid transition chain at index %d: %s->%s followed by %s->%s",
+ i, prev.From, prev.To, curr.From, curr.To)
+ }
+ }
+}
+
+func TestForceTransition(t *testing.T) {
+ sm := NewStateMachine()
+ ctx := context.Background()
+
+ // Set to a regular state
+ sm.Transition(ctx, StateWaitingForUserInput, "Initial state")
+
+ // Force transition to a state that would normally be invalid
+ sm.ForceTransition(ctx, StateError, "Testing force transition")
+
+ // Check that the transition happened despite being invalid
+ if sm.CurrentState() != StateError {
+ t.Errorf("Force transition failed, state is %s instead of %s",
+ sm.CurrentState(), StateError)
+ }
+
+ // Check that it was recorded in history
+ history := sm.History()
+ lastTransition := history[len(history)-1]
+
+ if lastTransition.From != StateWaitingForUserInput || lastTransition.To != StateError {
+ t.Errorf("Force transition not properly recorded in history: %v", lastTransition)
+ }
+}
+
+func TestTransitionListeners(t *testing.T) {
+ sm := NewStateMachine()
+ ctx := context.Background()
+
+ // Create a channel to receive transitions
+ events := make(chan StateTransition, 10)
+
+ // Add a listener
+ removeListener := sm.AddTransitionListener(events)
+
+ // Make a transition
+ sm.Transition(ctx, StateWaitingForUserInput, "Testing listeners")
+
+ // Check that the event was received
+ select {
+ case evt := <-events:
+ if evt.To != StateWaitingForUserInput {
+ t.Errorf("Received wrong transition: %v", evt)
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Error("Timeout waiting for transition event")
+ }
+
+ // Remove the listener
+ removeListener()
+
+ // Make another transition
+ sm.Transition(ctx, StateSendingToLLM, "After removing listener")
+
+ // Verify no event was received
+ select {
+ case evt := <-events:
+ t.Errorf("Received transition after removing listener: %v", evt)
+ case <-time.After(100 * time.Millisecond):
+ // This is expected
+ }
+}