agent plumbing: convert outbox to subscribers and an iterator
WaitForMessage() could only work for one thread, because it was using a
singular channel for outboxes. This was fine when we only had one user,
but WaitForMessageCount() was kinda similar, and had its own thing, and
I want to rework how polling works and need another user.
Anyway, this one is hand-coded, because Sketch really struggled
with getting the iterator convincingly safe. In a follow-up commit,
I'll try to get Sketch to write some tests.
diff --git a/loop/agent.go b/loop/agent.go
index f5b12f7..adbdf51 100644
--- a/loop/agent.go
+++ b/loop/agent.go
@@ -27,6 +27,13 @@
userCancelMessage = "user requested agent to stop handling responses"
)
+type MessageIterator interface {
+ // Next blocks until the next message is available. It may
+ // return nil if the underlying iterator context is done.
+ Next() *AgentMessage
+ Close()
+}
+
type CodingAgent interface {
// Init initializes an agent inside a docker container.
Init(AgentInit) error
@@ -40,10 +47,9 @@
// UserMessage enqueues a message to the agent and returns immediately.
UserMessage(ctx context.Context, msg string)
- // WaitForMessage blocks until the agent has a response to give.
- // Use AgentMessage.EndOfTurn to help determine if you want to
- // drain the agent.
- WaitForMessage(ctx context.Context) AgentMessage
+ // Returns an iterator that finishes when the context is done and
+ // starts with the given message index.
+ NewIterator(ctx context.Context, nextMessageIdx int) MessageIterator
// Loop begins the agent loop returns only when ctx is cancelled.
Loop(ctx context.Context)
@@ -61,9 +67,6 @@
TotalUsage() ant.CumulativeUsage
OriginalBudget() ant.Budget
- // WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
- WaitForMessageCount(ctx context.Context, greaterThan int)
-
WorkingDir() string
// Diff returns a unified diff of changes made since the agent was instantiated.
@@ -195,6 +198,7 @@
var attrs []any = []any{
slog.String("type", string(a.Type)),
}
+ attrs = append(attrs, slog.Int("idx", a.Idx))
if a.EndOfTurn {
attrs = append(attrs, slog.Bool("end_of_turn", a.EndOfTurn))
}
@@ -305,13 +309,6 @@
// read from by GatherMessages
inbox chan string
- // Outbox
- // sent on by pushToOutbox
- // via OnToolResult and OnResponse callbacks
- // read from by WaitForMessage
- // called by termui inside its repl loop.
- outbox chan AgentMessage
-
// protects cancelTurn
cancelTurnMu sync.Mutex
// cancels potentially long-running tool_use calls or chains of them
@@ -323,7 +320,8 @@
// Stores all messages for this agent
history []AgentMessage
- listeners []chan struct{}
+ // Iterators add themselves here when they're ready to be notified of new messages.
+ subscribers []chan *AgentMessage
// Track git commits we've already seen (by hash)
seenCommits map[string]bool
@@ -335,6 +333,80 @@
outstandingToolCalls map[string]string
}
+// NewIterator implements CodingAgent.
+func (a *Agent) NewIterator(ctx context.Context, nextMessageIdx int) MessageIterator {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ return &MessageIteratorImpl{
+ agent: a,
+ ctx: ctx,
+ nextMessageIdx: nextMessageIdx,
+ ch: make(chan *AgentMessage, 100),
+ }
+}
+
+type MessageIteratorImpl struct {
+ agent *Agent
+ ctx context.Context
+ nextMessageIdx int
+ ch chan *AgentMessage
+ subscribed bool
+}
+
+func (m *MessageIteratorImpl) Close() {
+ m.agent.mu.Lock()
+ defer m.agent.mu.Unlock()
+ // Delete ourselves from the subscribers list
+ m.agent.subscribers = slices.DeleteFunc(m.agent.subscribers, func(x chan *AgentMessage) bool {
+ return x == m.ch
+ })
+ close(m.ch)
+}
+
+func (m *MessageIteratorImpl) Next() *AgentMessage {
+ // We avoid subscription at creation to let ourselves catch up to "current state"
+ // before subscribing.
+ if !m.subscribed {
+ m.agent.mu.Lock()
+ if m.nextMessageIdx < len(m.agent.history) {
+ msg := &m.agent.history[m.nextMessageIdx]
+ m.nextMessageIdx++
+ m.agent.mu.Unlock()
+ return msg
+ }
+ // The next message doesn't exist yet, so let's subscribe
+ m.agent.subscribers = append(m.agent.subscribers, m.ch)
+ m.subscribed = true
+ m.agent.mu.Unlock()
+ }
+
+ for {
+ select {
+ case <-m.ctx.Done():
+ m.agent.mu.Lock()
+ // Delete ourselves from the subscribers list
+ m.agent.subscribers = slices.DeleteFunc(m.agent.subscribers, func(x chan *AgentMessage) bool {
+ return x == m.ch
+ })
+ m.subscribed = false
+ m.agent.mu.Unlock()
+ return nil
+ case msg, ok := <-m.ch:
+ if !ok {
+ // Close may have been called
+ return nil
+ }
+ if msg.Idx == m.nextMessageIdx {
+ m.nextMessageIdx++
+ return msg
+ }
+ slog.Debug("Out of order messages", "expected", m.nextMessageIdx, "got", msg.Idx, "m", msg.Content)
+ panic("out of order message")
+ }
+ }
+}
+
// Assert that Agent satisfies the CodingAgent interface.
var _ CodingAgent = &Agent{}
@@ -453,11 +525,9 @@
defer a.mu.Unlock()
a.title = title
a.branchName = branchName
- // Notify all listeners that the state has changed
- for _, ch := range a.listeners {
- close(ch)
- }
- a.listeners = a.listeners[:0]
+
+ // TODO: We could potentially notify listeners of a state change, but,
+ // realistically, a new message will be sent for the tool result as well.
}
// OnToolCall implements ant.Listener and tracks the start of a tool call.
@@ -614,7 +684,7 @@
config: config,
ready: make(chan struct{}),
inbox: make(chan string, 100),
- outbox: make(chan AgentMessage, 100),
+ subscribers: make([]chan *AgentMessage, 0),
startedAt: time.Now(),
originalBudget: config.Budget,
seenCommits: make(map[string]bool),
@@ -858,16 +928,6 @@
a.inbox <- msg
}
-func (a *Agent) WaitForMessage(ctx context.Context) AgentMessage {
- // TODO: Should this drain any outbox messages in case there are multiple?
- select {
- case msg := <-a.outbox:
- return msg
- case <-ctx.Done():
- return errorMessage(ctx.Err())
- }
-}
-
func (a *Agent) CancelToolUse(toolUseID string, cause error) error {
return a.convo.CancelToolUse(toolUseID, cause)
}
@@ -918,19 +978,16 @@
slog.InfoContext(ctx, "Turn completed", "turnDuration", turnDuration)
}
- slog.InfoContext(ctx, "agent message", m.Attr())
-
a.mu.Lock()
defer a.mu.Unlock()
m.Idx = len(a.history)
+ slog.InfoContext(ctx, "agent message", m.Attr())
a.history = append(a.history, m)
- a.outbox <- m
- // Notify all listeners:
- for _, ch := range a.listeners {
- close(ch)
+ // Notify all subscribers
+ for _, ch := range a.subscribers {
+ ch <- &m
}
- a.listeners = a.listeners[:0]
}
func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
@@ -1161,7 +1218,7 @@
if cancelled {
msgs = append(msgs, ant.Content{Type: "text", Text: cancelToolUseMessage})
// EndOfTurn is false here so that the client of this agent keeps processing
- // messages from WaitForMessage() and gets the response from the LLM
+ // further messages; the conversation is not over.
a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
} else if err := a.convo.OverBudget(); err != nil {
// Handle budget issues by appending a message about it
@@ -1227,24 +1284,6 @@
return a.convo.CumulativeUsage()
}
-// WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
-func (a *Agent) WaitForMessageCount(ctx context.Context, greaterThan int) {
- for a.MessageCount() <= greaterThan {
- a.mu.Lock()
- ch := make(chan struct{})
- // Deletion happens when we notify.
- a.listeners = append(a.listeners, ch)
- a.mu.Unlock()
-
- select {
- case <-ctx.Done():
- return
- case <-ch:
- continue
- }
- }
-}
-
// Diff returns a unified diff of changes made since the agent was instantiated.
func (a *Agent) Diff(commit *string) (string, error) {
if a.initialCommit == "" {
diff --git a/loop/agent_git_test.go b/loop/agent_git_test.go
index 399943b..053b8af 100644
--- a/loop/agent_git_test.go
+++ b/loop/agent_git_test.go
@@ -8,7 +8,6 @@
"path/filepath"
"strings"
"testing"
- "time"
)
// TestGitCommitTracking tests the git commit tracking functionality
@@ -67,7 +66,6 @@
agent := &Agent{
workingDir: tempDir,
repoRoot: tempDir, // Set repoRoot to same as workingDir for this test
- outbox: make(chan AgentMessage, 100),
seenCommits: make(map[string]bool),
initialCommit: initialCommit,
}
@@ -97,13 +95,7 @@
}
// Check if we received a commit message
- var commitMsg AgentMessage
- select {
- case commitMsg = <-agent.outbox:
- // We got a message
- case <-time.After(500 * time.Millisecond):
- t.Fatal("Timed out waiting for commit message")
- }
+ var commitMsg AgentMessage = agent.history[len(agent.history)-1]
// Verify the commit message
if commitMsg.Type != CommitMessageType {
@@ -162,7 +154,6 @@
}
// Reset the outbox channel and seen commits map
- agent.outbox = make(chan AgentMessage, 100)
agent.seenCommits = make(map[string]bool)
// Call handleGitCommits again - it should still work but only show at most 100 commits
@@ -172,12 +163,7 @@
}
// Check if we received a commit message
- select {
- case commitMsg = <-agent.outbox:
- // We got a message
- case <-time.After(500 * time.Millisecond):
- t.Fatal("Timed out waiting for commit message")
- }
+ commitMsg = agent.history[len(agent.history)-1]
// Should have at most 100 commits due to the -n 100 limit in git log
if len(commitMsg.Commits) > 100 {
diff --git a/loop/agent_test.go b/loop/agent_test.go
index c62fd21..9663e26 100644
--- a/loop/agent_test.go
+++ b/loop/agent_test.go
@@ -94,26 +94,16 @@
// Collect responses with a timeout
var responses []AgentMessage
- timeout := time.After(10 * time.Second)
+ ctx2, _ := context.WithDeadline(ctx, time.Now().Add(10*time.Second))
done := false
+ it := agent.NewIterator(ctx2, 0)
for !done {
- select {
- case <-timeout:
- t.Log("Timeout reached while waiting for agent responses")
+ msg := it.Next()
+ t.Logf("Received message: Type=%s, EndOfTurn=%v, Content=%q", msg.Type, msg.EndOfTurn, msg.Content)
+ responses = append(responses, *msg)
+ if msg.EndOfTurn {
done = true
- default:
- select {
- case msg := <-agent.outbox:
- t.Logf("Received message: Type=%s, EndOfTurn=%v, Content=%q", msg.Type, msg.EndOfTurn, msg.Content)
- responses = append(responses, msg)
- if msg.EndOfTurn {
- done = true
- }
- default:
- // No more messages available right now
- time.Sleep(100 * time.Millisecond)
- }
}
}
diff --git a/loop/server/loophttp.go b/loop/server/loophttp.go
index f61ed28..4a415c8 100644
--- a/loop/server/loophttp.go
+++ b/loop/server/loophttp.go
@@ -349,9 +349,10 @@
if pollParam == "true" {
ch := make(chan string)
go func() {
- // This is your blocking operation
- agent.WaitForMessageCount(r.Context(), clientMessageCount)
+ it := agent.NewIterator(r.Context(), clientMessageCount)
+ it.Next()
close(ch)
+ it.Close()
}()
select {
case <-r.Context().Done():