blob: 284c13ff86d69c40664e7b84ddc56c004c3ff15c [file] [log] [blame]
Sean McCullough96b60dd2025-04-30 09:49:10 -07001package loop
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "sync"
8 "time"
9)
10
11// State represents the possible states of the Agent state machine
12type State int
13
Josh Bleecher Snyder4d4e8072025-05-05 15:00:59 -070014//go:generate go tool golang.org/x/tools/cmd/stringer -type=State -trimprefix=State
Sean McCullough96b60dd2025-04-30 09:49:10 -070015const (
16 // StateUnknown is the default state
17 StateUnknown State = iota
18 // StateReady is the initial state when the agent is initialized and ready to operate
19 StateReady
20 // StateWaitingForUserInput occurs when the agent is waiting for a user message to start a turn
21 StateWaitingForUserInput
22 // StateSendingToLLM occurs when the agent is sending message(s) to the LLM
23 StateSendingToLLM
24 // StateProcessingLLMResponse occurs when the agent is processing a response from the LLM
25 StateProcessingLLMResponse
26 // StateEndOfTurn occurs when processing is completed without tool use, and the turn ends
27 StateEndOfTurn
28 // StateToolUseRequested occurs when the LLM has requested to use a tool
29 StateToolUseRequested
30 // StateCheckingForCancellation occurs when the agent checks if user requested cancellation
31 StateCheckingForCancellation
32 // StateRunningTool occurs when the agent is executing the requested tool
33 StateRunningTool
34 // StateCheckingGitCommits occurs when the agent checks for new git commits after tool execution
35 StateCheckingGitCommits
36 // StateRunningAutoformatters occurs when the agent runs code formatters on new commits
37 StateRunningAutoformatters
38 // StateCheckingBudget occurs when the agent verifies if budget limits are exceeded
39 StateCheckingBudget
40 // StateGatheringAdditionalMessages occurs when the agent collects user messages that arrived during tool execution
41 StateGatheringAdditionalMessages
42 // StateSendingToolResults occurs when the agent sends tool results back to the LLM
43 StateSendingToolResults
44 // StateCancelled occurs when an operation was cancelled by the user
45 StateCancelled
46 // StateBudgetExceeded occurs when the budget limit was reached
47 StateBudgetExceeded
48 // StateError occurs when an error occurred during processing
49 StateError
50)
51
Sean McCullough96b60dd2025-04-30 09:49:10 -070052// TransitionEvent represents an event that causes a state transition
53type TransitionEvent struct {
54 // Description provides a human-readable description of the event
55 Description string
56 // Data can hold any additional information about the event
57 Data interface{}
58 // Timestamp is when the event occurred
59 Timestamp time.Time
60}
61
62// StateTransition represents a transition from one state to another
63type StateTransition struct {
64 From State
65 To State
66 Event TransitionEvent
67}
68
69// StateMachine manages the Agent's states and transitions
70type StateMachine struct {
71 // mu protects all fields of the StateMachine from concurrent access
72 mu sync.RWMutex
73 // currentState is the current state of the state machine
74 currentState State
75 // previousState is the previous state of the state machine
76 previousState State
77 // stateEnteredAt is when the current state was entered
78 stateEnteredAt time.Time
79 // transitions maps from states to the states they can transition to
80 transitions map[State]map[State]bool
81 // history records the history of state transitions
82 history []StateTransition
83 // maxHistorySize limits the number of transitions to keep in history
84 maxHistorySize int
85 // eventListeners are notified when state transitions occur
86 eventListeners []chan<- StateTransition
87 // onTransition is a callback function that's called when a transition occurs
88 onTransition func(ctx context.Context, from, to State, event TransitionEvent)
89}
90
91// NewStateMachine creates a new state machine initialized to StateReady
92func NewStateMachine() *StateMachine {
93 sm := &StateMachine{
94 currentState: StateReady,
95 previousState: StateUnknown,
96 stateEnteredAt: time.Now(),
97 transitions: make(map[State]map[State]bool),
98 maxHistorySize: 100,
99 eventListeners: make([]chan<- StateTransition, 0),
100 }
101
102 // Initialize valid transitions
103 sm.initTransitions()
104
105 return sm
106}
107
108// SetMaxHistorySize sets the maximum number of transitions to keep in history
109func (sm *StateMachine) SetMaxHistorySize(size int) {
110 if size < 1 {
111 size = 1 // Ensure we keep at least one entry
112 }
113
114 sm.mu.Lock()
115 defer sm.mu.Unlock()
116
117 sm.maxHistorySize = size
118
119 // Trim history if needed
120 if len(sm.history) > sm.maxHistorySize {
121 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
122 }
123}
124
125// AddTransitionListener adds a listener channel that will be notified of state transitions
126// Returns a function that can be called to remove the listener
127func (sm *StateMachine) AddTransitionListener(listener chan<- StateTransition) func() {
128 sm.mu.Lock()
129 defer sm.mu.Unlock()
130
131 sm.eventListeners = append(sm.eventListeners, listener)
132
133 // Return a function to remove this listener
134 return func() {
135 sm.RemoveTransitionListener(listener)
136 }
137}
138
139// RemoveTransitionListener removes a previously added listener
140func (sm *StateMachine) RemoveTransitionListener(listener chan<- StateTransition) {
141 sm.mu.Lock()
142 defer sm.mu.Unlock()
143
144 for i, l := range sm.eventListeners {
145 if l == listener {
146 // Remove by swapping with the last element and then truncating
147 lastIdx := len(sm.eventListeners) - 1
148 sm.eventListeners[i] = sm.eventListeners[lastIdx]
149 sm.eventListeners = sm.eventListeners[:lastIdx]
150 break
151 }
152 }
153}
154
155// SetTransitionCallback sets a function to be called on every state transition
156func (sm *StateMachine) SetTransitionCallback(callback func(ctx context.Context, from, to State, event TransitionEvent)) {
157 sm.mu.Lock()
158 defer sm.mu.Unlock()
159
160 sm.onTransition = callback
161}
162
163// ClearTransitionCallback removes any previously set transition callback
164func (sm *StateMachine) ClearTransitionCallback() {
165 sm.mu.Lock()
166 defer sm.mu.Unlock()
167
168 sm.onTransition = nil
169}
170
171// initTransitions initializes the map of valid state transitions
172func (sm *StateMachine) initTransitions() {
173 // Helper function to add transitions
174 addTransition := func(from State, to ...State) {
175 // Initialize the map for this state if it doesn't exist
176 if _, exists := sm.transitions[from]; !exists {
177 sm.transitions[from] = make(map[State]bool)
178 }
179
180 // Add all of the 'to' states
181 for _, toState := range to {
182 sm.transitions[from][toState] = true
183 }
184 }
185
186 // Define valid transitions based on the state machine diagram
187
188 // Initial state
189 addTransition(StateReady, StateWaitingForUserInput)
190
191 // Main flow
192 addTransition(StateWaitingForUserInput, StateSendingToLLM, StateError)
193 addTransition(StateSendingToLLM, StateProcessingLLMResponse, StateError)
194 addTransition(StateProcessingLLMResponse, StateEndOfTurn, StateToolUseRequested, StateError)
195 addTransition(StateEndOfTurn, StateWaitingForUserInput)
196
197 // Tool use flow
198 addTransition(StateToolUseRequested, StateCheckingForCancellation)
199 addTransition(StateCheckingForCancellation, StateRunningTool, StateCancelled)
200 addTransition(StateRunningTool, StateCheckingGitCommits, StateError)
201 addTransition(StateCheckingGitCommits, StateRunningAutoformatters, StateCheckingBudget)
202 addTransition(StateRunningAutoformatters, StateCheckingBudget)
203 addTransition(StateCheckingBudget, StateGatheringAdditionalMessages, StateBudgetExceeded)
204 addTransition(StateGatheringAdditionalMessages, StateSendingToolResults, StateError)
205 addTransition(StateSendingToolResults, StateProcessingLLMResponse, StateError)
206
207 // Terminal states to new turn
208 addTransition(StateCancelled, StateWaitingForUserInput)
209 addTransition(StateBudgetExceeded, StateWaitingForUserInput)
210 addTransition(StateError, StateWaitingForUserInput)
211}
212
213// Transition attempts to transition from the current state to the given state
214func (sm *StateMachine) Transition(ctx context.Context, newState State, event string) error {
215 if sm == nil {
216 return fmt.Errorf("nil StateMachine pointer")
217 }
218 transitionEvent := TransitionEvent{
219 Description: event,
220 Timestamp: time.Now(),
221 }
222 return sm.TransitionWithEvent(ctx, newState, transitionEvent)
223}
224
225// TransitionWithEvent attempts to transition from the current state to the given state
226// with the provided event information
227func (sm *StateMachine) TransitionWithEvent(ctx context.Context, newState State, event TransitionEvent) error {
228 // First check if the transition is valid without holding the write lock
229 sm.mu.RLock()
230 currentState := sm.currentState
231 canTransition := false
232 if validToStates, exists := sm.transitions[currentState]; exists {
233 canTransition = validToStates[newState]
234 }
235 sm.mu.RUnlock()
236
237 if !canTransition {
238 return fmt.Errorf("invalid transition from %s to %s", currentState, newState)
239 }
240
241 // Acquire write lock for the actual transition
242 sm.mu.Lock()
243 defer sm.mu.Unlock()
244
245 // Double-check that the state hasn't changed since we checked
246 if sm.currentState != currentState {
247 // State changed between our check and lock acquisition
248 // Re-check if the transition is still valid
249 if validToStates, exists := sm.transitions[sm.currentState]; !exists || !validToStates[newState] {
250 return fmt.Errorf("concurrent state change detected: invalid transition from current %s to %s",
251 sm.currentState, newState)
252 }
253 }
254
255 // Calculate duration in current state
256 duration := time.Since(sm.stateEnteredAt)
257
258 // Record the transition
259 transition := StateTransition{
260 From: sm.currentState,
261 To: newState,
262 Event: event,
263 }
264
265 // Update state
266 sm.previousState = sm.currentState
267 sm.currentState = newState
268 sm.stateEnteredAt = time.Now()
269
270 // Add to history
271 sm.history = append(sm.history, transition)
272
273 // Trim history if it exceeds maximum size
274 if len(sm.history) > sm.maxHistorySize {
275 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
276 }
277
278 // Make a local copy of any callback functions to invoke outside the lock
279 var onTransition func(ctx context.Context, from, to State, event TransitionEvent)
280 var eventListenersCopy []chan<- StateTransition
281 if sm.onTransition != nil {
282 onTransition = sm.onTransition
283 }
284 if len(sm.eventListeners) > 0 {
285 eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners))
286 copy(eventListenersCopy, sm.eventListeners)
287 }
288
289 // Log the transition
290 slog.InfoContext(ctx, "State transition",
Josh Bleecher Snyder4d4e8072025-05-05 15:00:59 -0700291 "from", sm.previousState.String(),
292 "to", sm.currentState.String(),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700293 "event", event.Description,
294 "duration", duration)
295
296 // Release the lock before notifying listeners to avoid deadlocks
297 sm.mu.Unlock()
298
299 // Notify listeners if any
300 if onTransition != nil {
301 onTransition(ctx, sm.previousState, sm.currentState, event)
302 }
303
304 for _, ch := range eventListenersCopy {
305 select {
306 case ch <- transition:
307 // Successfully sent
308 default:
309 // Channel buffer full or no receiver, log and continue
310 slog.WarnContext(ctx, "Failed to notify state transition listener",
311 "from", sm.previousState, "to", sm.currentState)
312 }
313 }
314
315 // Re-acquire the lock that we explicitly released above
316 sm.mu.Lock()
317 return nil
318}
319
320// CurrentState returns the current state
321func (sm *StateMachine) CurrentState() State {
322 sm.mu.RLock()
323 defer sm.mu.RUnlock()
324 return sm.currentState
325}
326
327// PreviousState returns the previous state
328func (sm *StateMachine) PreviousState() State {
329 sm.mu.RLock()
330 defer sm.mu.RUnlock()
331 return sm.previousState
332}
333
334// TimeInState returns how long the machine has been in the current state
335func (sm *StateMachine) TimeInState() time.Duration {
336 sm.mu.RLock()
337 enteredAt := sm.stateEnteredAt
338 sm.mu.RUnlock()
339 return time.Since(enteredAt)
340}
341
342// CanTransition returns true if a transition from the from state to the to state is valid
343func (sm *StateMachine) CanTransition(from, to State) bool {
344 sm.mu.RLock()
345 defer sm.mu.RUnlock()
346 if validToStates, exists := sm.transitions[from]; exists {
347 return validToStates[to]
348 }
349 return false
350}
351
352// History returns the transition history of the state machine
353func (sm *StateMachine) History() []StateTransition {
354 sm.mu.RLock()
355 defer sm.mu.RUnlock()
356
357 // Return a copy to prevent modification
358 historyCopy := make([]StateTransition, len(sm.history))
359 copy(historyCopy, sm.history)
360 return historyCopy
361}
362
363// Reset resets the state machine to the initial ready state
364func (sm *StateMachine) Reset() {
365 sm.mu.Lock()
366 defer sm.mu.Unlock()
367
368 sm.currentState = StateReady
369 sm.previousState = StateUnknown
370 sm.stateEnteredAt = time.Now()
371}
372
373// IsInTerminalState returns whether the current state is a terminal state
374func (sm *StateMachine) IsInTerminalState() bool {
375 sm.mu.RLock()
376 defer sm.mu.RUnlock()
377
378 switch sm.currentState {
379 case StateEndOfTurn, StateCancelled, StateBudgetExceeded, StateError:
380 return true
381 default:
382 return false
383 }
384}
385
386// IsInErrorState returns whether the current state is an error state
387func (sm *StateMachine) IsInErrorState() bool {
388 sm.mu.RLock()
389 defer sm.mu.RUnlock()
390
391 switch sm.currentState {
392 case StateError, StateCancelled, StateBudgetExceeded:
393 return true
394 default:
395 return false
396 }
397}
398
399// ForceTransition forces a transition regardless of whether it's valid according to the state machine rules
400// This should be used only in critical situations like cancellation or error recovery
401func (sm *StateMachine) ForceTransition(ctx context.Context, newState State, reason string) {
402 event := TransitionEvent{
403 Description: fmt.Sprintf("Forced transition: %s", reason),
404 Timestamp: time.Now(),
405 }
406
407 sm.mu.Lock()
408
409 // Calculate duration in current state
410 duration := time.Since(sm.stateEnteredAt)
411
412 // Record the transition
413 transition := StateTransition{
414 From: sm.currentState,
415 To: newState,
416 Event: event,
417 }
418
419 // Update state
420 sm.previousState = sm.currentState
421 sm.currentState = newState
422 sm.stateEnteredAt = time.Now()
423
424 // Add to history
425 sm.history = append(sm.history, transition)
426
427 // Trim history if it exceeds maximum size
428 if len(sm.history) > sm.maxHistorySize {
429 sm.history = sm.history[len(sm.history)-sm.maxHistorySize:]
430 }
431
432 // Make a local copy of any callback functions to invoke outside the lock
433 var onTransition func(ctx context.Context, from, to State, event TransitionEvent)
434 var eventListenersCopy []chan<- StateTransition
435 if sm.onTransition != nil {
436 onTransition = sm.onTransition
437 }
438 if len(sm.eventListeners) > 0 {
439 eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners))
440 copy(eventListenersCopy, sm.eventListeners)
441 }
442
443 // Log the transition
444 slog.WarnContext(ctx, "Forced state transition",
Josh Bleecher Snyder4d4e8072025-05-05 15:00:59 -0700445 "from", sm.previousState.String(),
446 "to", sm.currentState.String(),
Sean McCullough96b60dd2025-04-30 09:49:10 -0700447 "reason", reason,
448 "duration", duration)
449
450 // Release the lock before notifying listeners to avoid deadlocks
451 sm.mu.Unlock()
452
453 // Notify listeners if any
454 if onTransition != nil {
455 onTransition(ctx, sm.previousState, sm.currentState, event)
456 }
457
458 for _, ch := range eventListenersCopy {
459 select {
460 case ch <- transition:
461 // Successfully sent
462 default:
463 // Channel buffer full or no receiver, log and continue
464 slog.WarnContext(ctx, "Failed to notify state transition listener for forced transition",
465 "from", sm.previousState, "to", sm.currentState)
466 }
467 }
468
469 // Re-acquire the lock
470 sm.mu.Lock()
471 defer sm.mu.Unlock()
472}