blob: 9a520f670a1c7dc3dd31e15f7f50e3b329c6e1a4 [file] [log] [blame]
Sean McCullough96b60dd2025-04-30 09:49:10 -07001package loop
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "sync"
8 "time"
9)
10
11// State represents the possible states of the Agent state machine
12type State int
13
14const (
15 // StateUnknown is the default state
16 StateUnknown State = iota
17 // StateReady is the initial state when the agent is initialized and ready to operate
18 StateReady
19 // StateWaitingForUserInput occurs when the agent is waiting for a user message to start a turn
20 StateWaitingForUserInput
21 // StateSendingToLLM occurs when the agent is sending message(s) to the LLM
22 StateSendingToLLM
23 // StateProcessingLLMResponse occurs when the agent is processing a response from the LLM
24 StateProcessingLLMResponse
25 // StateEndOfTurn occurs when processing is completed without tool use, and the turn ends
26 StateEndOfTurn
27 // StateToolUseRequested occurs when the LLM has requested to use a tool
28 StateToolUseRequested
29 // StateCheckingForCancellation occurs when the agent checks if user requested cancellation
30 StateCheckingForCancellation
31 // StateRunningTool occurs when the agent is executing the requested tool
32 StateRunningTool
33 // StateCheckingGitCommits occurs when the agent checks for new git commits after tool execution
34 StateCheckingGitCommits
35 // StateRunningAutoformatters occurs when the agent runs code formatters on new commits
36 StateRunningAutoformatters
37 // StateCheckingBudget occurs when the agent verifies if budget limits are exceeded
38 StateCheckingBudget
39 // StateGatheringAdditionalMessages occurs when the agent collects user messages that arrived during tool execution
40 StateGatheringAdditionalMessages
41 // StateSendingToolResults occurs when the agent sends tool results back to the LLM
42 StateSendingToolResults
43 // StateCancelled occurs when an operation was cancelled by the user
44 StateCancelled
45 // StateBudgetExceeded occurs when the budget limit was reached
46 StateBudgetExceeded
47 // StateError occurs when an error occurred during processing
48 StateError
49)
50
51// String returns a string representation of the State for logging and debugging
52func (s State) String() string {
53 switch s {
54 case StateUnknown:
55 return "Unknown"
56 case StateReady:
57 return "Ready"
58 case StateWaitingForUserInput:
59 return "WaitingForUserInput"
60 case StateSendingToLLM:
61 return "SendingToLLM"
62 case StateProcessingLLMResponse:
63 return "ProcessingLLMResponse"
64 case StateEndOfTurn:
65 return "EndOfTurn"
66 case StateToolUseRequested:
67 return "ToolUseRequested"
68 case StateCheckingForCancellation:
69 return "CheckingForCancellation"
70 case StateRunningTool:
71 return "RunningTool"
72 case StateCheckingGitCommits:
73 return "CheckingGitCommits"
74 case StateRunningAutoformatters:
75 return "RunningAutoformatters"
76 case StateCheckingBudget:
77 return "CheckingBudget"
78 case StateGatheringAdditionalMessages:
79 return "GatheringAdditionalMessages"
80 case StateSendingToolResults:
81 return "SendingToolResults"
82 case StateCancelled:
83 return "Cancelled"
84 case StateBudgetExceeded:
85 return "BudgetExceeded"
86 case StateError:
87 return "Error"
88 default:
89 return fmt.Sprintf("Unknown(%d)", int(s))
90 }
91}
92
93// TransitionEvent represents an event that causes a state transition
94type TransitionEvent struct {
95 // Description provides a human-readable description of the event
96 Description string
97 // Data can hold any additional information about the event
98 Data interface{}
99 // Timestamp is when the event occurred
100 Timestamp time.Time
101}
102
103// StateTransition represents a transition from one state to another
104type StateTransition struct {
105 From State
106 To State
107 Event TransitionEvent
108}
109
110// StateMachine manages the Agent's states and transitions
111type StateMachine struct {
112 // mu protects all fields of the StateMachine from concurrent access
113 mu sync.RWMutex
114 // currentState is the current state of the state machine
115 currentState State
116 // previousState is the previous state of the state machine
117 previousState State
118 // stateEnteredAt is when the current state was entered
119 stateEnteredAt time.Time
120 // transitions maps from states to the states they can transition to
121 transitions map[State]map[State]bool
122 // history records the history of state transitions
123 history []StateTransition
124 // maxHistorySize limits the number of transitions to keep in history
125 maxHistorySize int
126 // eventListeners are notified when state transitions occur
127 eventListeners []chan<- StateTransition
128 // onTransition is a callback function that's called when a transition occurs
129 onTransition func(ctx context.Context, from, to State, event TransitionEvent)
130}
131
132// NewStateMachine creates a new state machine initialized to StateReady
133func NewStateMachine() *StateMachine {
134 sm := &StateMachine{
135 currentState: StateReady,
136 previousState: StateUnknown,
137 stateEnteredAt: time.Now(),
138 transitions: make(map[State]map[State]bool),
139 maxHistorySize: 100,
140 eventListeners: make([]chan<- StateTransition, 0),
141 }
142
143 // Initialize valid transitions
144 sm.initTransitions()
145
146 return sm
147}
148
149// SetMaxHistorySize sets the maximum number of transitions to keep in history
150func (sm *StateMachine) SetMaxHistorySize(size int) {
151 if size < 1 {
152 size = 1 // Ensure we keep at least one entry
153 }
154
155 sm.mu.Lock()
156 defer sm.mu.Unlock()
157
158 sm.maxHistorySize = size
159
160 // Trim history if needed
161 if len(sm.history) > sm.maxHistorySize {
162 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
163 }
164}
165
166// AddTransitionListener adds a listener channel that will be notified of state transitions
167// Returns a function that can be called to remove the listener
168func (sm *StateMachine) AddTransitionListener(listener chan<- StateTransition) func() {
169 sm.mu.Lock()
170 defer sm.mu.Unlock()
171
172 sm.eventListeners = append(sm.eventListeners, listener)
173
174 // Return a function to remove this listener
175 return func() {
176 sm.RemoveTransitionListener(listener)
177 }
178}
179
180// RemoveTransitionListener removes a previously added listener
181func (sm *StateMachine) RemoveTransitionListener(listener chan<- StateTransition) {
182 sm.mu.Lock()
183 defer sm.mu.Unlock()
184
185 for i, l := range sm.eventListeners {
186 if l == listener {
187 // Remove by swapping with the last element and then truncating
188 lastIdx := len(sm.eventListeners) - 1
189 sm.eventListeners[i] = sm.eventListeners[lastIdx]
190 sm.eventListeners = sm.eventListeners[:lastIdx]
191 break
192 }
193 }
194}
195
196// SetTransitionCallback sets a function to be called on every state transition
197func (sm *StateMachine) SetTransitionCallback(callback func(ctx context.Context, from, to State, event TransitionEvent)) {
198 sm.mu.Lock()
199 defer sm.mu.Unlock()
200
201 sm.onTransition = callback
202}
203
204// ClearTransitionCallback removes any previously set transition callback
205func (sm *StateMachine) ClearTransitionCallback() {
206 sm.mu.Lock()
207 defer sm.mu.Unlock()
208
209 sm.onTransition = nil
210}
211
212// initTransitions initializes the map of valid state transitions
213func (sm *StateMachine) initTransitions() {
214 // Helper function to add transitions
215 addTransition := func(from State, to ...State) {
216 // Initialize the map for this state if it doesn't exist
217 if _, exists := sm.transitions[from]; !exists {
218 sm.transitions[from] = make(map[State]bool)
219 }
220
221 // Add all of the 'to' states
222 for _, toState := range to {
223 sm.transitions[from][toState] = true
224 }
225 }
226
227 // Define valid transitions based on the state machine diagram
228
229 // Initial state
230 addTransition(StateReady, StateWaitingForUserInput)
231
232 // Main flow
233 addTransition(StateWaitingForUserInput, StateSendingToLLM, StateError)
234 addTransition(StateSendingToLLM, StateProcessingLLMResponse, StateError)
235 addTransition(StateProcessingLLMResponse, StateEndOfTurn, StateToolUseRequested, StateError)
236 addTransition(StateEndOfTurn, StateWaitingForUserInput)
237
238 // Tool use flow
239 addTransition(StateToolUseRequested, StateCheckingForCancellation)
240 addTransition(StateCheckingForCancellation, StateRunningTool, StateCancelled)
241 addTransition(StateRunningTool, StateCheckingGitCommits, StateError)
242 addTransition(StateCheckingGitCommits, StateRunningAutoformatters, StateCheckingBudget)
243 addTransition(StateRunningAutoformatters, StateCheckingBudget)
244 addTransition(StateCheckingBudget, StateGatheringAdditionalMessages, StateBudgetExceeded)
245 addTransition(StateGatheringAdditionalMessages, StateSendingToolResults, StateError)
246 addTransition(StateSendingToolResults, StateProcessingLLMResponse, StateError)
247
248 // Terminal states to new turn
249 addTransition(StateCancelled, StateWaitingForUserInput)
250 addTransition(StateBudgetExceeded, StateWaitingForUserInput)
251 addTransition(StateError, StateWaitingForUserInput)
252}
253
254// Transition attempts to transition from the current state to the given state
255func (sm *StateMachine) Transition(ctx context.Context, newState State, event string) error {
256 if sm == nil {
257 return fmt.Errorf("nil StateMachine pointer")
258 }
259 transitionEvent := TransitionEvent{
260 Description: event,
261 Timestamp: time.Now(),
262 }
263 return sm.TransitionWithEvent(ctx, newState, transitionEvent)
264}
265
266// TransitionWithEvent attempts to transition from the current state to the given state
267// with the provided event information
268func (sm *StateMachine) TransitionWithEvent(ctx context.Context, newState State, event TransitionEvent) error {
269 // First check if the transition is valid without holding the write lock
270 sm.mu.RLock()
271 currentState := sm.currentState
272 canTransition := false
273 if validToStates, exists := sm.transitions[currentState]; exists {
274 canTransition = validToStates[newState]
275 }
276 sm.mu.RUnlock()
277
278 if !canTransition {
279 return fmt.Errorf("invalid transition from %s to %s", currentState, newState)
280 }
281
282 // Acquire write lock for the actual transition
283 sm.mu.Lock()
284 defer sm.mu.Unlock()
285
286 // Double-check that the state hasn't changed since we checked
287 if sm.currentState != currentState {
288 // State changed between our check and lock acquisition
289 // Re-check if the transition is still valid
290 if validToStates, exists := sm.transitions[sm.currentState]; !exists || !validToStates[newState] {
291 return fmt.Errorf("concurrent state change detected: invalid transition from current %s to %s",
292 sm.currentState, newState)
293 }
294 }
295
296 // Calculate duration in current state
297 duration := time.Since(sm.stateEnteredAt)
298
299 // Record the transition
300 transition := StateTransition{
301 From: sm.currentState,
302 To: newState,
303 Event: event,
304 }
305
306 // Update state
307 sm.previousState = sm.currentState
308 sm.currentState = newState
309 sm.stateEnteredAt = time.Now()
310
311 // Add to history
312 sm.history = append(sm.history, transition)
313
314 // Trim history if it exceeds maximum size
315 if len(sm.history) > sm.maxHistorySize {
316 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
317 }
318
319 // Make a local copy of any callback functions to invoke outside the lock
320 var onTransition func(ctx context.Context, from, to State, event TransitionEvent)
321 var eventListenersCopy []chan<- StateTransition
322 if sm.onTransition != nil {
323 onTransition = sm.onTransition
324 }
325 if len(sm.eventListeners) > 0 {
326 eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners))
327 copy(eventListenersCopy, sm.eventListeners)
328 }
329
330 // Log the transition
331 slog.InfoContext(ctx, "State transition",
332 "from", sm.previousState,
333 "to", sm.currentState,
334 "event", event.Description,
335 "duration", duration)
336
337 // Release the lock before notifying listeners to avoid deadlocks
338 sm.mu.Unlock()
339
340 // Notify listeners if any
341 if onTransition != nil {
342 onTransition(ctx, sm.previousState, sm.currentState, event)
343 }
344
345 for _, ch := range eventListenersCopy {
346 select {
347 case ch <- transition:
348 // Successfully sent
349 default:
350 // Channel buffer full or no receiver, log and continue
351 slog.WarnContext(ctx, "Failed to notify state transition listener",
352 "from", sm.previousState, "to", sm.currentState)
353 }
354 }
355
356 // Re-acquire the lock that we explicitly released above
357 sm.mu.Lock()
358 return nil
359}
360
361// CurrentState returns the current state
362func (sm *StateMachine) CurrentState() State {
363 sm.mu.RLock()
364 defer sm.mu.RUnlock()
365 return sm.currentState
366}
367
368// PreviousState returns the previous state
369func (sm *StateMachine) PreviousState() State {
370 sm.mu.RLock()
371 defer sm.mu.RUnlock()
372 return sm.previousState
373}
374
375// TimeInState returns how long the machine has been in the current state
376func (sm *StateMachine) TimeInState() time.Duration {
377 sm.mu.RLock()
378 enteredAt := sm.stateEnteredAt
379 sm.mu.RUnlock()
380 return time.Since(enteredAt)
381}
382
383// CanTransition returns true if a transition from the from state to the to state is valid
384func (sm *StateMachine) CanTransition(from, to State) bool {
385 sm.mu.RLock()
386 defer sm.mu.RUnlock()
387 if validToStates, exists := sm.transitions[from]; exists {
388 return validToStates[to]
389 }
390 return false
391}
392
393// History returns the transition history of the state machine
394func (sm *StateMachine) History() []StateTransition {
395 sm.mu.RLock()
396 defer sm.mu.RUnlock()
397
398 // Return a copy to prevent modification
399 historyCopy := make([]StateTransition, len(sm.history))
400 copy(historyCopy, sm.history)
401 return historyCopy
402}
403
404// Reset resets the state machine to the initial ready state
405func (sm *StateMachine) Reset() {
406 sm.mu.Lock()
407 defer sm.mu.Unlock()
408
409 sm.currentState = StateReady
410 sm.previousState = StateUnknown
411 sm.stateEnteredAt = time.Now()
412}
413
414// IsInTerminalState returns whether the current state is a terminal state
415func (sm *StateMachine) IsInTerminalState() bool {
416 sm.mu.RLock()
417 defer sm.mu.RUnlock()
418
419 switch sm.currentState {
420 case StateEndOfTurn, StateCancelled, StateBudgetExceeded, StateError:
421 return true
422 default:
423 return false
424 }
425}
426
427// IsInErrorState returns whether the current state is an error state
428func (sm *StateMachine) IsInErrorState() bool {
429 sm.mu.RLock()
430 defer sm.mu.RUnlock()
431
432 switch sm.currentState {
433 case StateError, StateCancelled, StateBudgetExceeded:
434 return true
435 default:
436 return false
437 }
438}
439
440// ForceTransition forces a transition regardless of whether it's valid according to the state machine rules
441// This should be used only in critical situations like cancellation or error recovery
442func (sm *StateMachine) ForceTransition(ctx context.Context, newState State, reason string) {
443 event := TransitionEvent{
444 Description: fmt.Sprintf("Forced transition: %s", reason),
445 Timestamp: time.Now(),
446 }
447
448 sm.mu.Lock()
449
450 // Calculate duration in current state
451 duration := time.Since(sm.stateEnteredAt)
452
453 // Record the transition
454 transition := StateTransition{
455 From: sm.currentState,
456 To: newState,
457 Event: event,
458 }
459
460 // Update state
461 sm.previousState = sm.currentState
462 sm.currentState = newState
463 sm.stateEnteredAt = time.Now()
464
465 // Add to history
466 sm.history = append(sm.history, transition)
467
468 // Trim history if it exceeds maximum size
469 if len(sm.history) > sm.maxHistorySize {
470 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
471 }
472
473 // Make a local copy of any callback functions to invoke outside the lock
474 var onTransition func(ctx context.Context, from, to State, event TransitionEvent)
475 var eventListenersCopy []chan<- StateTransition
476 if sm.onTransition != nil {
477 onTransition = sm.onTransition
478 }
479 if len(sm.eventListeners) > 0 {
480 eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners))
481 copy(eventListenersCopy, sm.eventListeners)
482 }
483
484 // Log the transition
485 slog.WarnContext(ctx, "Forced state transition",
486 "from", sm.previousState,
487 "to", sm.currentState,
488 "reason", reason,
489 "duration", duration)
490
491 // Release the lock before notifying listeners to avoid deadlocks
492 sm.mu.Unlock()
493
494 // Notify listeners if any
495 if onTransition != nil {
496 onTransition(ctx, sm.previousState, sm.currentState, event)
497 }
498
499 for _, ch := range eventListenersCopy {
500 select {
501 case ch <- transition:
502 // Successfully sent
503 default:
504 // Channel buffer full or no receiver, log and continue
505 slog.WarnContext(ctx, "Failed to notify state transition listener for forced transition",
506 "from", sm.previousState, "to", sm.currentState)
507 }
508 }
509
510 // Re-acquire the lock
511 sm.mu.Lock()
512 defer sm.mu.Unlock()
513}