blob: 70b5c9b4d735aa480b33940f4670eedb8eac26e1 [file] [log] [blame]
Sean McCullough96b60dd2025-04-30 09:49:10 -07001package loop
2
3import (
4 "context"
5 "fmt"
6 "strings"
7 "sync"
8 "testing"
9 "time"
10)
11
12func TestStateMachine(t *testing.T) {
13 ctx := context.Background()
14 sm := NewStateMachine()
15
16 // Check initial state
17 if sm.CurrentState() != StateReady {
18 t.Errorf("Initial state should be StateReady, got %s", sm.CurrentState())
19 }
20
21 // Test valid transition
22 err := sm.Transition(ctx, StateWaitingForUserInput, "Starting inner loop")
23 if err != nil {
24 t.Errorf("Error transitioning to StateWaitingForUserInput: %v", err)
25 }
26 if sm.CurrentState() != StateWaitingForUserInput {
27 t.Errorf("Current state should be StateWaitingForUserInput, got %s", sm.CurrentState())
28 }
29 if sm.PreviousState() != StateReady {
30 t.Errorf("Previous state should be StateReady, got %s", sm.PreviousState())
31 }
32
33 // Test invalid transition
34 err = sm.Transition(ctx, StateRunningAutoformatters, "Invalid transition")
35 if err == nil {
36 t.Error("Expected error for invalid transition but got nil")
37 }
38
39 // Verify state didn't change after invalid transition
40 if sm.CurrentState() != StateWaitingForUserInput {
41 t.Errorf("State should not have changed after invalid transition, got %s", sm.CurrentState())
42 }
43
44 // Test complete flow
45 transitions := []struct {
46 state State
47 event string
48 }{
49 {StateSendingToLLM, "Sending user message to LLM"},
50 {StateProcessingLLMResponse, "Processing LLM response"},
51 {StateToolUseRequested, "LLM requested tool use"},
52 {StateCheckingForCancellation, "Checking for user cancellation"},
53 {StateRunningTool, "Running tool"},
54 {StateCheckingGitCommits, "Checking for git commits"},
55 {StateCheckingBudget, "Checking budget"},
56 {StateGatheringAdditionalMessages, "Gathering additional messages"},
57 {StateSendingToolResults, "Sending tool results"},
58 {StateProcessingLLMResponse, "Processing LLM response"},
59 {StateEndOfTurn, "End of turn"},
60 {StateWaitingForUserInput, "Waiting for next user input"},
61 }
62
63 for i, tt := range transitions {
64 err := sm.Transition(ctx, tt.state, tt.event)
65 if err != nil {
66 t.Errorf("[%d] Error transitioning to %s: %v", i, tt.state, err)
67 }
68 if sm.CurrentState() != tt.state {
69 t.Errorf("[%d] Current state should be %s, got %s", i, tt.state, sm.CurrentState())
70 }
71 }
72
73 // Check if history was recorded correctly
74 history := sm.History()
75 expectedHistoryLen := len(transitions) + 1 // +1 for the initial transition
76 if len(history) != expectedHistoryLen {
77 t.Errorf("Expected history length %d, got %d", expectedHistoryLen, len(history))
78 }
79
80 // Check error state detection
81 err = sm.Transition(ctx, StateError, "An error occurred")
82 if err != nil {
83 t.Errorf("Error transitioning to StateError: %v", err)
84 }
85 if !sm.IsInErrorState() {
86 t.Error("IsInErrorState() should return true when in StateError")
87 }
88 if !sm.IsInTerminalState() {
89 t.Error("IsInTerminalState() should return true when in StateError")
90 }
91
92 // Test reset
93 sm.Reset()
94 if sm.CurrentState() != StateReady {
95 t.Errorf("After reset, state should be StateReady, got %s", sm.CurrentState())
96 }
97}
98
99func TestTimeInState(t *testing.T) {
100 sm := NewStateMachine()
101
102 // Ensure time in state increases
103 time.Sleep(50 * time.Millisecond)
104 timeInState := sm.TimeInState()
105 if timeInState < 50*time.Millisecond {
106 t.Errorf("Expected TimeInState() > 50ms, got %v", timeInState)
107 }
108}
109
110func TestTransitionEvent(t *testing.T) {
111 ctx := context.Background()
112 sm := NewStateMachine()
113
114 // Test transition with custom event
115 event := TransitionEvent{
116 Description: "Test event",
117 Data: map[string]string{"key": "value"},
118 Timestamp: time.Now(),
119 }
120
121 err := sm.TransitionWithEvent(ctx, StateWaitingForUserInput, event)
122 if err != nil {
123 t.Errorf("Error in TransitionWithEvent: %v", err)
124 }
125
126 // Check the event was recorded in history
127 history := sm.History()
128 if len(history) != 1 {
129 t.Fatalf("Expected history length 1, got %d", len(history))
130 }
131 if history[0].Event.Description != "Test event" {
132 t.Errorf("Expected event description 'Test event', got '%s'", history[0].Event.Description)
133 }
134}
135
136func TestConcurrentTransitions(t *testing.T) {
137 sm := NewStateMachine()
138 ctx := context.Background()
139
140 // Start with waiting for user input
141 sm.Transition(ctx, StateWaitingForUserInput, "Initial state")
142
143 // Set up a channel to receive transition events
144 events := make(chan StateTransition, 100)
145 removeListener := sm.AddTransitionListener(events)
146 defer removeListener()
147
148 // Launch goroutines to perform concurrent transitions
149 done := make(chan struct{})
150 var wg sync.WaitGroup
151 wg.Add(10)
152
153 go func() {
154 wg.Wait()
155 close(done)
156 }()
157
158 // Launch 10 goroutines that attempt to transition the state machine
159 for i := 0; i < 10; i++ {
160 go func(idx int) {
161 defer wg.Done()
162
163 // Each goroutine tries to make a valid transition from the current state
164 for j := 0; j < 10; j++ {
165 currentState := sm.CurrentState()
166 var nextState State
167
168 // Choose a valid next state based on current state
169 switch currentState {
170 case StateWaitingForUserInput:
171 nextState = StateSendingToLLM
172 case StateSendingToLLM:
173 nextState = StateProcessingLLMResponse
174 case StateProcessingLLMResponse:
175 nextState = StateToolUseRequested
176 case StateToolUseRequested:
177 nextState = StateCheckingForCancellation
178 case StateCheckingForCancellation:
179 nextState = StateRunningTool
180 case StateRunningTool:
181 nextState = StateCheckingGitCommits
182 case StateCheckingGitCommits:
183 nextState = StateCheckingBudget
184 case StateCheckingBudget:
185 nextState = StateGatheringAdditionalMessages
186 case StateGatheringAdditionalMessages:
187 nextState = StateSendingToolResults
188 case StateSendingToolResults:
189 nextState = StateProcessingLLMResponse
190 default:
191 // If in a state we don't know how to handle, reset to a known state
192 sm.ForceTransition(ctx, StateWaitingForUserInput, "Reset for test")
193 continue
194 }
195
196 // Try to transition and record success/failure
197 err := sm.Transition(ctx, nextState, fmt.Sprintf("Transition from goroutine %d", idx))
198 if err != nil {
199 // This is expected in concurrent scenarios - another goroutine might have
200 // changed the state between our check and transition attempt
201 time.Sleep(5 * time.Millisecond) // Back off a bit
202 }
203 }
204 }(i)
205 }
206
207 // Collect events until all goroutines are done
208 transitions := make([]StateTransition, 0)
209loop:
210 for {
211 select {
212 case evt := <-events:
213 transitions = append(transitions, evt)
214 case <-done:
215 // Collect any remaining events
216 for len(events) > 0 {
217 transitions = append(transitions, <-events)
218 }
219 break loop
220 }
221 }
222
223 // Get final history from state machine
224 history := sm.History()
225
226 // We may have missed some events due to channel buffer size and race conditions
227 // That's okay for this test - the main point is to verify thread safety
228 t.Logf("Collected %d events, history contains %d transitions",
229 len(transitions), len(history))
230
231 // Verify that all transitions in history are valid
232 for i := 1; i < len(history); i++ {
233 prev := history[i-1]
234 curr := history[i]
235
236 // Skip validating transitions if they're forced
237 if strings.HasPrefix(curr.Event.Description, "Forced transition") {
238 continue
239 }
240
241 if prev.To != curr.From {
242 t.Errorf("Invalid transition chain at index %d: %s->%s followed by %s->%s",
243 i, prev.From, prev.To, curr.From, curr.To)
244 }
245 }
246}
247
248func TestForceTransition(t *testing.T) {
249 sm := NewStateMachine()
250 ctx := context.Background()
251
252 // Set to a regular state
253 sm.Transition(ctx, StateWaitingForUserInput, "Initial state")
254
255 // Force transition to a state that would normally be invalid
256 sm.ForceTransition(ctx, StateError, "Testing force transition")
257
258 // Check that the transition happened despite being invalid
259 if sm.CurrentState() != StateError {
260 t.Errorf("Force transition failed, state is %s instead of %s",
261 sm.CurrentState(), StateError)
262 }
263
264 // Check that it was recorded in history
265 history := sm.History()
266 lastTransition := history[len(history)-1]
267
268 if lastTransition.From != StateWaitingForUserInput || lastTransition.To != StateError {
269 t.Errorf("Force transition not properly recorded in history: %v", lastTransition)
270 }
271}
272
273func TestTransitionListeners(t *testing.T) {
274 sm := NewStateMachine()
275 ctx := context.Background()
276
277 // Create a channel to receive transitions
278 events := make(chan StateTransition, 10)
279
280 // Add a listener
281 removeListener := sm.AddTransitionListener(events)
282
283 // Make a transition
284 sm.Transition(ctx, StateWaitingForUserInput, "Testing listeners")
285
286 // Check that the event was received
287 select {
288 case evt := <-events:
289 if evt.To != StateWaitingForUserInput {
290 t.Errorf("Received wrong transition: %v", evt)
291 }
292 case <-time.After(100 * time.Millisecond):
293 t.Error("Timeout waiting for transition event")
294 }
295
296 // Remove the listener
297 removeListener()
298
299 // Make another transition
300 sm.Transition(ctx, StateSendingToLLM, "After removing listener")
301
302 // Verify no event was received
303 select {
304 case evt := <-events:
305 t.Errorf("Received transition after removing listener: %v", evt)
306 case <-time.After(100 * time.Millisecond):
307 // This is expected
308 }
309}