| package loop |
| |
| import ( |
| "context" |
| "fmt" |
| "log/slog" |
| "sync" |
| "time" |
| ) |
| |
| // State represents the possible states of the Agent state machine |
| type State int |
| |
| //go:generate go tool golang.org/x/tools/cmd/stringer -type=State -trimprefix=State |
| const ( |
| // StateUnknown is the default state |
| StateUnknown State = iota |
| // StateReady is the initial state when the agent is initialized and ready to operate |
| StateReady |
| // StateWaitingForUserInput occurs when the agent is waiting for a user message to start a turn |
| StateWaitingForUserInput |
| // StateSendingToLLM occurs when the agent is sending message(s) to the LLM |
| StateSendingToLLM |
| // StateProcessingLLMResponse occurs when the agent is processing a response from the LLM |
| StateProcessingLLMResponse |
| // StateEndOfTurn occurs when processing is completed without tool use, and the turn ends |
| StateEndOfTurn |
| // StateToolUseRequested occurs when the LLM has requested to use a tool |
| StateToolUseRequested |
| // StateCheckingForCancellation occurs when the agent checks if user requested cancellation |
| StateCheckingForCancellation |
| // StateRunningTool occurs when the agent is executing the requested tool |
| StateRunningTool |
| // StateCheckingGitCommits occurs when the agent checks for new git commits after tool execution |
| StateCheckingGitCommits |
| // StateRunningAutoformatters occurs when the agent runs code formatters on new commits |
| StateRunningAutoformatters |
| // StateCheckingBudget occurs when the agent verifies if budget limits are exceeded |
| StateCheckingBudget |
| // StateGatheringAdditionalMessages occurs when the agent collects user messages that arrived during tool execution |
| StateGatheringAdditionalMessages |
| // StateSendingToolResults occurs when the agent sends tool results back to the LLM |
| StateSendingToolResults |
| // StateCancelled occurs when an operation was cancelled by the user |
| StateCancelled |
| // StateBudgetExceeded occurs when the budget limit was reached |
| StateBudgetExceeded |
| // StateError occurs when an error occurred during processing |
| StateError |
| // StateCompacting occurs when the agent is compacting the conversation |
| StateCompacting |
| ) |
| |
| // TransitionEvent represents an event that causes a state transition |
| type TransitionEvent struct { |
| // Description provides a human-readable description of the event |
| Description string |
| // Data can hold any additional information about the event |
| Data interface{} |
| // Timestamp is when the event occurred |
| Timestamp time.Time |
| } |
| |
| // StateTransition represents a transition from one state to another |
| type StateTransition struct { |
| From State |
| To State |
| Event TransitionEvent |
| } |
| |
| // StateMachine manages the Agent's states and transitions |
| type StateMachine struct { |
| // mu protects all fields of the StateMachine from concurrent access |
| mu sync.RWMutex |
| // currentState is the current state of the state machine |
| currentState State |
| // previousState is the previous state of the state machine |
| previousState State |
| // stateEnteredAt is when the current state was entered |
| stateEnteredAt time.Time |
| // transitions maps from states to the states they can transition to |
| transitions map[State]map[State]bool |
| // history records the history of state transitions |
| history []StateTransition |
| // maxHistorySize limits the number of transitions to keep in history |
| maxHistorySize int |
| // eventListeners are notified when state transitions occur |
| eventListeners []chan<- StateTransition |
| // onTransition is a callback function that's called when a transition occurs |
| onTransition func(ctx context.Context, from, to State, event TransitionEvent) |
| } |
| |
| // NewStateMachine creates a new state machine initialized to StateReady |
| func NewStateMachine() *StateMachine { |
| sm := &StateMachine{ |
| currentState: StateReady, |
| previousState: StateUnknown, |
| stateEnteredAt: time.Now(), |
| transitions: make(map[State]map[State]bool), |
| maxHistorySize: 100, |
| eventListeners: make([]chan<- StateTransition, 0), |
| } |
| |
| // Initialize valid transitions |
| sm.initTransitions() |
| |
| return sm |
| } |
| |
| // SetMaxHistorySize sets the maximum number of transitions to keep in history |
| func (sm *StateMachine) SetMaxHistorySize(size int) { |
| if size < 1 { |
| size = 1 // Ensure we keep at least one entry |
| } |
| |
| sm.mu.Lock() |
| defer sm.mu.Unlock() |
| |
| sm.maxHistorySize = size |
| |
| // Trim history if needed |
| if len(sm.history) > sm.maxHistorySize { |
| sm.history = sm.history[len(sm.history)-sm.maxHistorySize:] |
| } |
| } |
| |
| // AddTransitionListener adds a listener channel that will be notified of state transitions |
| // Returns a function that can be called to remove the listener |
| func (sm *StateMachine) AddTransitionListener(listener chan<- StateTransition) func() { |
| sm.mu.Lock() |
| defer sm.mu.Unlock() |
| |
| sm.eventListeners = append(sm.eventListeners, listener) |
| |
| // Return a function to remove this listener |
| return func() { |
| sm.RemoveTransitionListener(listener) |
| } |
| } |
| |
| // RemoveTransitionListener removes a previously added listener |
| func (sm *StateMachine) RemoveTransitionListener(listener chan<- StateTransition) { |
| sm.mu.Lock() |
| defer sm.mu.Unlock() |
| |
| for i, l := range sm.eventListeners { |
| if l == listener { |
| // Remove by swapping with the last element and then truncating |
| lastIdx := len(sm.eventListeners) - 1 |
| sm.eventListeners[i] = sm.eventListeners[lastIdx] |
| sm.eventListeners = sm.eventListeners[:lastIdx] |
| break |
| } |
| } |
| } |
| |
| // SetTransitionCallback sets a function to be called on every state transition |
| func (sm *StateMachine) SetTransitionCallback(callback func(ctx context.Context, from, to State, event TransitionEvent)) { |
| sm.mu.Lock() |
| defer sm.mu.Unlock() |
| |
| sm.onTransition = callback |
| } |
| |
| // ClearTransitionCallback removes any previously set transition callback |
| func (sm *StateMachine) ClearTransitionCallback() { |
| sm.mu.Lock() |
| defer sm.mu.Unlock() |
| |
| sm.onTransition = nil |
| } |
| |
| // initTransitions initializes the map of valid state transitions |
| func (sm *StateMachine) initTransitions() { |
| // Helper function to add transitions |
| addTransition := func(from State, to ...State) { |
| // Initialize the map for this state if it doesn't exist |
| if _, exists := sm.transitions[from]; !exists { |
| sm.transitions[from] = make(map[State]bool) |
| } |
| |
| // Add all of the 'to' states |
| for _, toState := range to { |
| sm.transitions[from][toState] = true |
| } |
| } |
| |
| // Define valid transitions based on the state machine diagram |
| |
| // Initial state |
| addTransition(StateReady, StateWaitingForUserInput) |
| |
| // Main flow |
| addTransition(StateWaitingForUserInput, StateSendingToLLM, StateCompacting, StateError) |
| addTransition(StateSendingToLLM, StateProcessingLLMResponse, StateError) |
| addTransition(StateProcessingLLMResponse, StateEndOfTurn, StateToolUseRequested, StateError) |
| addTransition(StateEndOfTurn, StateWaitingForUserInput) |
| |
| // Tool use flow |
| addTransition(StateToolUseRequested, StateCheckingForCancellation) |
| addTransition(StateCheckingForCancellation, StateRunningTool, StateCancelled) |
| addTransition(StateRunningTool, StateCheckingGitCommits, StateError) |
| addTransition(StateCheckingGitCommits, StateRunningAutoformatters, StateCheckingBudget) |
| addTransition(StateRunningAutoformatters, StateCheckingBudget) |
| addTransition(StateCheckingBudget, StateGatheringAdditionalMessages, StateBudgetExceeded) |
| addTransition(StateGatheringAdditionalMessages, StateSendingToolResults, StateError) |
| addTransition(StateSendingToolResults, StateProcessingLLMResponse, StateError) |
| |
| // Compaction flow |
| addTransition(StateCompacting, StateWaitingForUserInput, StateError) |
| |
| // Terminal states to new turn |
| addTransition(StateCancelled, StateWaitingForUserInput) |
| addTransition(StateBudgetExceeded, StateWaitingForUserInput) |
| addTransition(StateError, StateWaitingForUserInput) |
| } |
| |
| // Transition attempts to transition from the current state to the given state |
| func (sm *StateMachine) Transition(ctx context.Context, newState State, event string) error { |
| if sm == nil { |
| return fmt.Errorf("nil StateMachine pointer") |
| } |
| transitionEvent := TransitionEvent{ |
| Description: event, |
| Timestamp: time.Now(), |
| } |
| return sm.TransitionWithEvent(ctx, newState, transitionEvent) |
| } |
| |
| // TransitionWithEvent attempts to transition from the current state to the given state |
| // with the provided event information |
| func (sm *StateMachine) TransitionWithEvent(ctx context.Context, newState State, event TransitionEvent) error { |
| // First check if the transition is valid without holding the write lock |
| sm.mu.RLock() |
| currentState := sm.currentState |
| canTransition := false |
| if validToStates, exists := sm.transitions[currentState]; exists { |
| canTransition = validToStates[newState] |
| } |
| sm.mu.RUnlock() |
| |
| if !canTransition { |
| return fmt.Errorf("invalid transition from %s to %s", currentState, newState) |
| } |
| |
| // Acquire write lock for the actual transition |
| sm.mu.Lock() |
| defer sm.mu.Unlock() |
| |
| // Double-check that the state hasn't changed since we checked |
| if sm.currentState != currentState { |
| // State changed between our check and lock acquisition |
| // Re-check if the transition is still valid |
| if validToStates, exists := sm.transitions[sm.currentState]; !exists || !validToStates[newState] { |
| return fmt.Errorf("concurrent state change detected: invalid transition from current %s to %s", |
| sm.currentState, newState) |
| } |
| } |
| |
| // Calculate duration in current state |
| duration := time.Since(sm.stateEnteredAt) |
| |
| // Record the transition |
| transition := StateTransition{ |
| From: sm.currentState, |
| To: newState, |
| Event: event, |
| } |
| |
| // Update state |
| sm.previousState = sm.currentState |
| sm.currentState = newState |
| sm.stateEnteredAt = time.Now() |
| |
| // Add to history |
| sm.history = append(sm.history, transition) |
| |
| // Trim history if it exceeds maximum size |
| if len(sm.history) > sm.maxHistorySize { |
| sm.history = sm.history[len(sm.history)-sm.maxHistorySize:] |
| } |
| |
| // Make a local copy of any callback functions to invoke outside the lock |
| var onTransition func(ctx context.Context, from, to State, event TransitionEvent) |
| var eventListenersCopy []chan<- StateTransition |
| if sm.onTransition != nil { |
| onTransition = sm.onTransition |
| } |
| if len(sm.eventListeners) > 0 { |
| eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners)) |
| copy(eventListenersCopy, sm.eventListeners) |
| } |
| |
| // Log the transition |
| slog.DebugContext(ctx, "State transition", |
| "from", sm.previousState.String(), |
| "to", sm.currentState.String(), |
| "event", event.Description, |
| "duration", duration) |
| |
| // Release the lock before notifying listeners to avoid deadlocks |
| sm.mu.Unlock() |
| |
| // Notify listeners if any |
| if onTransition != nil { |
| onTransition(ctx, sm.previousState, sm.currentState, event) |
| } |
| |
| for _, ch := range eventListenersCopy { |
| select { |
| case ch <- transition: |
| // Successfully sent |
| default: |
| // Channel buffer full or no receiver, log and continue |
| slog.WarnContext(ctx, "Failed to notify state transition listener", |
| "from", sm.previousState, "to", sm.currentState) |
| } |
| } |
| |
| // Re-acquire the lock that we explicitly released above |
| sm.mu.Lock() |
| return nil |
| } |
| |
| // CurrentState returns the current state |
| func (sm *StateMachine) CurrentState() State { |
| sm.mu.RLock() |
| defer sm.mu.RUnlock() |
| return sm.currentState |
| } |
| |
| // PreviousState returns the previous state |
| func (sm *StateMachine) PreviousState() State { |
| sm.mu.RLock() |
| defer sm.mu.RUnlock() |
| return sm.previousState |
| } |
| |
| // TimeInState returns how long the machine has been in the current state |
| func (sm *StateMachine) TimeInState() time.Duration { |
| sm.mu.RLock() |
| enteredAt := sm.stateEnteredAt |
| sm.mu.RUnlock() |
| return time.Since(enteredAt) |
| } |
| |
| // CanTransition returns true if a transition from the from state to the to state is valid |
| func (sm *StateMachine) CanTransition(from, to State) bool { |
| sm.mu.RLock() |
| defer sm.mu.RUnlock() |
| if validToStates, exists := sm.transitions[from]; exists { |
| return validToStates[to] |
| } |
| return false |
| } |
| |
| // History returns the transition history of the state machine |
| func (sm *StateMachine) History() []StateTransition { |
| sm.mu.RLock() |
| defer sm.mu.RUnlock() |
| |
| // Return a copy to prevent modification |
| historyCopy := make([]StateTransition, len(sm.history)) |
| copy(historyCopy, sm.history) |
| return historyCopy |
| } |
| |
| // Reset resets the state machine to the initial ready state |
| func (sm *StateMachine) Reset() { |
| sm.mu.Lock() |
| defer sm.mu.Unlock() |
| |
| sm.currentState = StateReady |
| sm.previousState = StateUnknown |
| sm.stateEnteredAt = time.Now() |
| } |
| |
| // IsInTerminalState returns whether the current state is a terminal state |
| func (sm *StateMachine) IsInTerminalState() bool { |
| sm.mu.RLock() |
| defer sm.mu.RUnlock() |
| |
| switch sm.currentState { |
| case StateEndOfTurn, StateCancelled, StateBudgetExceeded, StateError: |
| return true |
| default: |
| return false |
| } |
| } |
| |
| // IsInErrorState returns whether the current state is an error state |
| func (sm *StateMachine) IsInErrorState() bool { |
| sm.mu.RLock() |
| defer sm.mu.RUnlock() |
| |
| switch sm.currentState { |
| case StateError, StateCancelled, StateBudgetExceeded: |
| return true |
| default: |
| return false |
| } |
| } |
| |
| // ForceTransition forces a transition regardless of whether it's valid according to the state machine rules |
| // This should be used only in critical situations like cancellation or error recovery |
| func (sm *StateMachine) ForceTransition(ctx context.Context, newState State, reason string) { |
| event := TransitionEvent{ |
| Description: fmt.Sprintf("Forced transition: %s", reason), |
| Timestamp: time.Now(), |
| } |
| |
| sm.mu.Lock() |
| |
| // Calculate duration in current state |
| duration := time.Since(sm.stateEnteredAt) |
| |
| // Record the transition |
| transition := StateTransition{ |
| From: sm.currentState, |
| To: newState, |
| Event: event, |
| } |
| |
| // Update state |
| sm.previousState = sm.currentState |
| sm.currentState = newState |
| sm.stateEnteredAt = time.Now() |
| |
| // Add to history |
| sm.history = append(sm.history, transition) |
| |
| // Trim history if it exceeds maximum size |
| if len(sm.history) > sm.maxHistorySize { |
| sm.history = sm.history[len(sm.history)-sm.maxHistorySize:] |
| } |
| |
| // Make a local copy of any callback functions to invoke outside the lock |
| var onTransition func(ctx context.Context, from, to State, event TransitionEvent) |
| var eventListenersCopy []chan<- StateTransition |
| if sm.onTransition != nil { |
| onTransition = sm.onTransition |
| } |
| if len(sm.eventListeners) > 0 { |
| eventListenersCopy = make([]chan<- StateTransition, len(sm.eventListeners)) |
| copy(eventListenersCopy, sm.eventListeners) |
| } |
| |
| // Log the transition |
| slog.WarnContext(ctx, "Forced state transition", |
| "from", sm.previousState.String(), |
| "to", sm.currentState.String(), |
| "reason", reason, |
| "duration", duration) |
| |
| // Release the lock before notifying listeners to avoid deadlocks |
| sm.mu.Unlock() |
| |
| // Notify listeners if any |
| if onTransition != nil { |
| onTransition(ctx, sm.previousState, sm.currentState, event) |
| } |
| |
| for _, ch := range eventListenersCopy { |
| select { |
| case ch <- transition: |
| // Successfully sent |
| default: |
| // Channel buffer full or no receiver, log and continue |
| slog.WarnContext(ctx, "Failed to notify state transition listener for forced transition", |
| "from", sm.previousState, "to", sm.currentState) |
| } |
| } |
| |
| // Re-acquire the lock |
| sm.mu.Lock() |
| defer sm.mu.Unlock() |
| } |