blob: e81299e4f8cf90d1c5590bb26c77adbd725ebb6b [file] [log] [blame]
Sean McCullough96b60dd2025-04-30 09:49:10 -07001package loop
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "sync"
8 "time"
9)
10
11// State represents the possible states of the Agent state machine
12type State int
13
Josh Bleecher Snyder4d4e8072025-05-05 15:00:59 -070014//go:generate go tool golang.org/x/tools/cmd/stringer -type=State -trimprefix=State
Sean McCullough96b60dd2025-04-30 09:49:10 -070015const (
16 // StateUnknown is the default state
17 StateUnknown State = iota
18 // StateReady is the initial state when the agent is initialized and ready to operate
19 StateReady
20 // StateWaitingForUserInput occurs when the agent is waiting for a user message to start a turn
21 StateWaitingForUserInput
22 // StateSendingToLLM occurs when the agent is sending message(s) to the LLM
23 StateSendingToLLM
24 // StateProcessingLLMResponse occurs when the agent is processing a response from the LLM
25 StateProcessingLLMResponse
26 // StateEndOfTurn occurs when processing is completed without tool use, and the turn ends
27 StateEndOfTurn
28 // StateToolUseRequested occurs when the LLM has requested to use a tool
29 StateToolUseRequested
30 // StateCheckingForCancellation occurs when the agent checks if user requested cancellation
31 StateCheckingForCancellation
32 // StateRunningTool occurs when the agent is executing the requested tool
33 StateRunningTool
34 // StateCheckingGitCommits occurs when the agent checks for new git commits after tool execution
35 StateCheckingGitCommits
36 // StateRunningAutoformatters occurs when the agent runs code formatters on new commits
37 StateRunningAutoformatters
38 // StateCheckingBudget occurs when the agent verifies if budget limits are exceeded
39 StateCheckingBudget
40 // StateGatheringAdditionalMessages occurs when the agent collects user messages that arrived during tool execution
41 StateGatheringAdditionalMessages
42 // StateSendingToolResults occurs when the agent sends tool results back to the LLM
43 StateSendingToolResults
44 // StateCancelled occurs when an operation was cancelled by the user
45 StateCancelled
46 // StateBudgetExceeded occurs when the budget limit was reached
47 StateBudgetExceeded
48 // StateError occurs when an error occurred during processing
49 StateError
Philip Zeyligerb8a8f352025-06-02 07:39:37 -070050 // StateCompacting occurs when the agent is compacting the conversation
51 StateCompacting
Sean McCullough96b60dd2025-04-30 09:49:10 -070052)
53
Sean McCullough96b60dd2025-04-30 09:49:10 -070054// TransitionEvent represents an event that causes a state transition
55type TransitionEvent struct {
56 // Description provides a human-readable description of the event
57 Description string
58 // Data can hold any additional information about the event
59 Data interface{}
60 // Timestamp is when the event occurred
61 Timestamp time.Time
62}
63
64// StateTransition represents a transition from one state to another
65type StateTransition struct {
66 From State
67 To State
68 Event TransitionEvent
69}
70
71// StateMachine manages the Agent's states and transitions
72type StateMachine struct {
73 // mu protects all fields of the StateMachine from concurrent access
74 mu sync.RWMutex
75 // currentState is the current state of the state machine
76 currentState State
77 // previousState is the previous state of the state machine
78 previousState State
79 // stateEnteredAt is when the current state was entered
80 stateEnteredAt time.Time
81 // transitions maps from states to the states they can transition to
82 transitions map[State]map[State]bool
83 // history records the history of state transitions
84 history []StateTransition
85 // maxHistorySize limits the number of transitions to keep in history
86 maxHistorySize int
87 // eventListeners are notified when state transitions occur
88 eventListeners []chan<- StateTransition
89 // onTransition is a callback function that's called when a transition occurs
90 onTransition func(ctx context.Context, from, to State, event TransitionEvent)
91}
92
93// NewStateMachine creates a new state machine initialized to StateReady
94func NewStateMachine() *StateMachine {
95 sm := &StateMachine{
96 currentState: StateReady,
97 previousState: StateUnknown,
98 stateEnteredAt: time.Now(),
99 transitions: make(map[State]map[State]bool),
100 maxHistorySize: 100,
101 eventListeners: make([]chan<- StateTransition, 0),
102 }
103
104 // Initialize valid transitions
105 sm.initTransitions()
106
107 return sm
108}
109
110// SetMaxHistorySize sets the maximum number of transitions to keep in history
111func (sm *StateMachine) SetMaxHistorySize(size int) {
112 if size < 1 {
113 size = 1 // Ensure we keep at least one entry
114 }
115
116 sm.mu.Lock()
117 defer sm.mu.Unlock()
118
119 sm.maxHistorySize = size
120
121 // Trim history if needed
122 if len(sm.history) > sm.maxHistorySize {
123 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
124 }
125}
126
127// AddTransitionListener adds a listener channel that will be notified of state transitions
128// Returns a function that can be called to remove the listener
129func (sm *StateMachine) AddTransitionListener(listener chan<- StateTransition) func() {
130 sm.mu.Lock()
131 defer sm.mu.Unlock()
132
133 sm.eventListeners = append(sm.eventListeners, listener)
134
135 // Return a function to remove this listener
136 return func() {
137 sm.RemoveTransitionListener(listener)
138 }
139}
140
141// RemoveTransitionListener removes a previously added listener
142func (sm *StateMachine) RemoveTransitionListener(listener chan<- StateTransition) {
143 sm.mu.Lock()
144 defer sm.mu.Unlock()
145
146 for i, l := range sm.eventListeners {
147 if l == listener {
148 // Remove by swapping with the last element and then truncating
149 lastIdx := len(sm.eventListeners) - 1
150 sm.eventListeners[i] = sm.eventListeners[lastIdx]
151 sm.eventListeners = sm.eventListeners[:lastIdx]
152 break
153 }
154 }
155}
156
157// SetTransitionCallback sets a function to be called on every state transition
158func (sm *StateMachine) SetTransitionCallback(callback func(ctx context.Context, from, to State, event TransitionEvent)) {
159 sm.mu.Lock()
160 defer sm.mu.Unlock()
161
162 sm.onTransition = callback
163}
164
165// ClearTransitionCallback removes any previously set transition callback
166func (sm *StateMachine) ClearTransitionCallback() {
167 sm.mu.Lock()
168 defer sm.mu.Unlock()
169
170 sm.onTransition = nil
171}
172
173// initTransitions initializes the map of valid state transitions
174func (sm *StateMachine) initTransitions() {
175 // Helper function to add transitions
176 addTransition := func(from State, to ...State) {
177 // Initialize the map for this state if it doesn't exist
178 if _, exists := sm.transitions[from]; !exists {
179 sm.transitions[from] = make(map[State]bool)
180 }
181
182 // Add all of the 'to' states
183 for _, toState := range to {
184 sm.transitions[from][toState] = true
185 }
186 }
187
188 // Define valid transitions based on the state machine diagram
189
190 // Initial state
191 addTransition(StateReady, StateWaitingForUserInput)
192
193 // Main flow
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700194 addTransition(StateWaitingForUserInput, StateSendingToLLM, StateCompacting, StateError)
Sean McCullough96b60dd2025-04-30 09:49:10 -0700195 addTransition(StateSendingToLLM, StateProcessingLLMResponse, StateError)
196 addTransition(StateProcessingLLMResponse, StateEndOfTurn, StateToolUseRequested, StateError)
197 addTransition(StateEndOfTurn, StateWaitingForUserInput)
198
199 // Tool use flow
200 addTransition(StateToolUseRequested, StateCheckingForCancellation)
201 addTransition(StateCheckingForCancellation, StateRunningTool, StateCancelled)
202 addTransition(StateRunningTool, StateCheckingGitCommits, StateError)
203 addTransition(StateCheckingGitCommits, StateRunningAutoformatters, StateCheckingBudget)
204 addTransition(StateRunningAutoformatters, StateCheckingBudget)
205 addTransition(StateCheckingBudget, StateGatheringAdditionalMessages, StateBudgetExceeded)
206 addTransition(StateGatheringAdditionalMessages, StateSendingToolResults, StateError)
207 addTransition(StateSendingToolResults, StateProcessingLLMResponse, StateError)
208
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700209 // Compaction flow
210 addTransition(StateCompacting, StateWaitingForUserInput, StateError)
211
Sean McCullough96b60dd2025-04-30 09:49:10 -0700212 // Terminal states to new turn
213 addTransition(StateCancelled, StateWaitingForUserInput)
214 addTransition(StateBudgetExceeded, StateWaitingForUserInput)
215 addTransition(StateError, StateWaitingForUserInput)
216}
217
218// Transition attempts to transition from the current state to the given state
219func (sm *StateMachine) Transition(ctx context.Context, newState State, event string) error {
220 if sm == nil {
221 return fmt.Errorf("nil StateMachine pointer")
222 }
223 transitionEvent := TransitionEvent{
224 Description: event,
225 Timestamp: time.Now(),
226 }
227 return sm.TransitionWithEvent(ctx, newState, transitionEvent)
228}
229
230// TransitionWithEvent attempts to transition from the current state to the given state
231// with the provided event information
232func (sm *StateMachine) TransitionWithEvent(ctx context.Context, newState State, event TransitionEvent) error {
233 // First check if the transition is valid without holding the write lock
234 sm.mu.RLock()
235 currentState := sm.currentState
236 canTransition := false
237 if validToStates, exists := sm.transitions[currentState]; exists {
238 canTransition = validToStates[newState]
239 }
240 sm.mu.RUnlock()
241
242 if !canTransition {
243 return fmt.Errorf("invalid transition from %s to %s", currentState, newState)
244 }
245
246 // Acquire write lock for the actual transition
247 sm.mu.Lock()
248 defer sm.mu.Unlock()
249
250 // Double-check that the state hasn't changed since we checked
251 if sm.currentState != currentState {
252 // State changed between our check and lock acquisition
253 // Re-check if the transition is still valid
254 if validToStates, exists := sm.transitions[sm.currentState]; !exists || !validToStates[newState] {
255 return fmt.Errorf("concurrent state change detected: invalid transition from current %s to %s",
256 sm.currentState, newState)
257 }
258 }
259
260 // Calculate duration in current state
261 duration := time.Since(sm.stateEnteredAt)
262
263 // Record the transition
264 transition := StateTransition{
265 From: sm.currentState,
266 To: newState,
267 Event: event,
268 }
269
270 // Update state
271 sm.previousState = sm.currentState
272 sm.currentState = newState
273 sm.stateEnteredAt = time.Now()
274
275 // Add to history
276 sm.history = append(sm.history, transition)
277
278 // Trim history if it exceeds maximum size
279 if len(sm.history) > sm.maxHistorySize {
280 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
281 }
282
283 // Make a local copy of any callback functions to invoke outside the lock
284 var onTransition func(ctx context.Context, from, to State, event TransitionEvent)
285 var eventListenersCopy []chan<- StateTransition
286 if sm.onTransition != nil {
287 onTransition = sm.onTransition
288 }
289 if len(sm.eventListeners) > 0 {
290 eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners))
291 copy(eventListenersCopy, sm.eventListeners)
292 }
293
294 // Log the transition
Philip Zeyliger0df32aa2025-06-04 16:47:30 +0000295 slog.DebugContext(ctx, "State transition",
Josh Bleecher Snyder4d4e8072025-05-05 15:00:59 -0700296 "from", sm.previousState.String(),
297 "to", sm.currentState.String(),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700298 "event", event.Description,
299 "duration", duration)
300
301 // Release the lock before notifying listeners to avoid deadlocks
302 sm.mu.Unlock()
303
304 // Notify listeners if any
305 if onTransition != nil {
306 onTransition(ctx, sm.previousState, sm.currentState, event)
307 }
308
309 for _, ch := range eventListenersCopy {
310 select {
311 case ch <- transition:
312 // Successfully sent
313 default:
314 // Channel buffer full or no receiver, log and continue
315 slog.WarnContext(ctx, "Failed to notify state transition listener",
316 "from", sm.previousState, "to", sm.currentState)
317 }
318 }
319
320 // Re-acquire the lock that we explicitly released above
321 sm.mu.Lock()
322 return nil
323}
324
325// CurrentState returns the current state
326func (sm *StateMachine) CurrentState() State {
327 sm.mu.RLock()
328 defer sm.mu.RUnlock()
329 return sm.currentState
330}
331
332// PreviousState returns the previous state
333func (sm *StateMachine) PreviousState() State {
334 sm.mu.RLock()
335 defer sm.mu.RUnlock()
336 return sm.previousState
337}
338
339// TimeInState returns how long the machine has been in the current state
340func (sm *StateMachine) TimeInState() time.Duration {
341 sm.mu.RLock()
342 enteredAt := sm.stateEnteredAt
343 sm.mu.RUnlock()
344 return time.Since(enteredAt)
345}
346
347// CanTransition returns true if a transition from the from state to the to state is valid
348func (sm *StateMachine) CanTransition(from, to State) bool {
349 sm.mu.RLock()
350 defer sm.mu.RUnlock()
351 if validToStates, exists := sm.transitions[from]; exists {
352 return validToStates[to]
353 }
354 return false
355}
356
357// History returns the transition history of the state machine
358func (sm *StateMachine) History() []StateTransition {
359 sm.mu.RLock()
360 defer sm.mu.RUnlock()
361
362 // Return a copy to prevent modification
363 historyCopy := make([]StateTransition, len(sm.history))
364 copy(historyCopy, sm.history)
365 return historyCopy
366}
367
368// Reset resets the state machine to the initial ready state
369func (sm *StateMachine) Reset() {
370 sm.mu.Lock()
371 defer sm.mu.Unlock()
372
373 sm.currentState = StateReady
374 sm.previousState = StateUnknown
375 sm.stateEnteredAt = time.Now()
376}
377
378// IsInTerminalState returns whether the current state is a terminal state
379func (sm *StateMachine) IsInTerminalState() bool {
380 sm.mu.RLock()
381 defer sm.mu.RUnlock()
382
383 switch sm.currentState {
384 case StateEndOfTurn, StateCancelled, StateBudgetExceeded, StateError:
385 return true
386 default:
387 return false
388 }
389}
390
391// IsInErrorState returns whether the current state is an error state
392func (sm *StateMachine) IsInErrorState() bool {
393 sm.mu.RLock()
394 defer sm.mu.RUnlock()
395
396 switch sm.currentState {
397 case StateError, StateCancelled, StateBudgetExceeded:
398 return true
399 default:
400 return false
401 }
402}
403
404// ForceTransition forces a transition regardless of whether it's valid according to the state machine rules
405// This should be used only in critical situations like cancellation or error recovery
406func (sm *StateMachine) ForceTransition(ctx context.Context, newState State, reason string) {
407 event := TransitionEvent{
408 Description: fmt.Sprintf("Forced transition: %s", reason),
409 Timestamp: time.Now(),
410 }
411
412 sm.mu.Lock()
413
414 // Calculate duration in current state
415 duration := time.Since(sm.stateEnteredAt)
416
417 // Record the transition
418 transition := StateTransition{
419 From: sm.currentState,
420 To: newState,
421 Event: event,
422 }
423
424 // Update state
425 sm.previousState = sm.currentState
426 sm.currentState = newState
427 sm.stateEnteredAt = time.Now()
428
429 // Add to history
430 sm.history = append(sm.history, transition)
431
432 // Trim history if it exceeds maximum size
433 if len(sm.history) > sm.maxHistorySize {
434 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
435 }
436
437 // Make a local copy of any callback functions to invoke outside the lock
438 var onTransition func(ctx context.Context, from, to State, event TransitionEvent)
439 var eventListenersCopy []chan<- StateTransition
440 if sm.onTransition != nil {
441 onTransition = sm.onTransition
442 }
443 if len(sm.eventListeners) > 0 {
444 eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners))
445 copy(eventListenersCopy, sm.eventListeners)
446 }
447
448 // Log the transition
449 slog.WarnContext(ctx, "Forced state transition",
Josh Bleecher Snyder4d4e8072025-05-05 15:00:59 -0700450 "from", sm.previousState.String(),
451 "to", sm.currentState.String(),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700452 "reason", reason,
453 "duration", duration)
454
455 // Release the lock before notifying listeners to avoid deadlocks
456 sm.mu.Unlock()
457
458 // Notify listeners if any
459 if onTransition != nil {
460 onTransition(ctx, sm.previousState, sm.currentState, event)
461 }
462
463 for _, ch := range eventListenersCopy {
464 select {
465 case ch <- transition:
466 // Successfully sent
467 default:
468 // Channel buffer full or no receiver, log and continue
469 slog.WarnContext(ctx, "Failed to notify state transition listener for forced transition",
470 "from", sm.previousState, "to", sm.currentState)
471 }
472 }
473
474 // Re-acquire the lock
475 sm.mu.Lock()
476 defer sm.mu.Unlock()
477}