loop: make multiplechoice tool calls end the turn
Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: s8d507faf9c095824sk
diff --git a/loop/agent.go b/loop/agent.go
index 05179eb..be5c8a2 100644
--- a/loop/agent.go
+++ b/loop/agent.go
@@ -286,7 +286,7 @@
SendMessage(message llm.Message) (*llm.Response, error)
SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error)
GetID() string
- ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error)
+ ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error)
ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error)
CancelToolUse(toolUseID string, cause error) error
SubConvoWithHistory() *conversation.Convo
@@ -1339,6 +1339,7 @@
func (a *Agent) handleToolExecution(ctx context.Context, resp *llm.Response) (bool, *llm.Response) {
var results []llm.Content
cancelled := false
+ toolEndsTurn := false
// Transition to checking for cancellation state
a.stateMachine.Transition(ctx, StateCheckingForCancellation, "Checking if user requested cancellation")
@@ -1365,7 +1366,7 @@
// Execute the tools
var err error
- results, err = a.convo.ToolResultContents(ctx, resp)
+ results, toolEndsTurn, err = a.convo.ToolResultContents(ctx, resp)
if ctx.Err() != nil { // e.g. the user canceled the operation
cancelled = true
a.stateMachine.Transition(ctx, StateCancelled, "Operation cancelled during tool execution")
@@ -1387,7 +1388,8 @@
}
// Continue the conversation with tool results and any user messages
- return a.continueTurnWithToolResults(ctx, results, autoqualityMessages, cancelled)
+ shouldContinue, resp := a.continueTurnWithToolResults(ctx, results, autoqualityMessages, cancelled)
+ return shouldContinue && !toolEndsTurn, resp
}
// processGitChanges checks for new git commits and runs autoformatters if needed
diff --git a/loop/agent_test.go b/loop/agent_test.go
index ce44352..31e9664 100644
--- a/loop/agent_test.go
+++ b/loop/agent_test.go
@@ -256,7 +256,7 @@
type MockConvoInterface struct {
sendMessageFunc func(message llm.Message) (*llm.Response, error)
sendUserTextMessageFunc func(s string, otherContents ...llm.Content) (*llm.Response, error)
- toolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, error)
+ toolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error)
toolResultCancelContentsFunc func(resp *llm.Response) ([]llm.Content, error)
cancelToolUseFunc func(toolUseID string, cause error) error
cumulativeUsageFunc func() conversation.CumulativeUsage
@@ -280,11 +280,11 @@
return nil, nil
}
-func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
+func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
if m.toolResultContentsFunc != nil {
return m.toolResultContentsFunc(ctx, resp)
}
- return nil, nil
+ return nil, false, nil
}
func (m *MockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
@@ -469,7 +469,7 @@
// mockConvoInterface is a mock implementation of ConvoInterface for testing
type mockConvoInterface struct {
SendMessageFunc func(message llm.Message) (*llm.Response, error)
- ToolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, error)
+ ToolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error)
}
func (c *mockConvoInterface) GetID() string {
@@ -501,11 +501,11 @@
return m.SendMessage(llm.UserStringMessage(s))
}
-func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
+func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
if m.ToolResultContentsFunc != nil {
return m.ToolResultContentsFunc(ctx, resp)
}
- return []llm.Content{}, nil
+ return []llm.Content{}, false, nil
}
func (m *mockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
@@ -640,8 +640,8 @@
}
// Tool result content handler
- mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
- return []llm.Content{llm.StringContent("Tool executed successfully")}, nil
+ mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
+ return []llm.Content{llm.StringContent("Tool executed successfully")}, false, nil
}
// Track state transitions
diff --git a/loop/agent_user_cancel_test.go b/loop/agent_user_cancel_test.go
index ed7df1b..32e6241 100644
--- a/loop/agent_user_cancel_test.go
+++ b/loop/agent_user_cancel_test.go
@@ -104,7 +104,7 @@
// Set up the mock response for tool results
mockConvo.ExpectCall("SendMessage", userMsg).Return(userMsgResponse, nil)
- mockConvo.ExpectCall("ToolResultContents", userMsgResponse).Return(toolUseContents, nil)
+ mockConvo.ExpectCall("ToolResultContents", userMsgResponse).Return(toolUseContents, false, nil)
mockConvo.ExpectCall("SendMessage", toolUseResultsMsg).Return(toolUseResponse, nil)
ctx, cancel := context.WithCancel(context.Background())
@@ -458,7 +458,7 @@
defer cancel()
// Setting up the mock response for tool results
- mockConvo.ExpectCall("ToolResultContents", initialResponse).Return(toolUseContents, nil)
+ mockConvo.ExpectCall("ToolResultContents", initialResponse).Return(toolUseContents, false, nil)
mockConvo.ExpectCall("SendMessage", nil).Return(toolUseResponse, nil)
// mockConvo, as a mock, isn't able to run the loop in conversation.Convo that makes this agent.OnToolResult callback.
// So we "mock" it out here by calling it explicitly, in order to make sure it calls .pushToOutbox with this message.
diff --git a/loop/mocks.go b/loop/mocks.go
index 014cd46..016c021 100644
--- a/loop/mocks.go
+++ b/loop/mocks.go
@@ -149,7 +149,7 @@
return exp.result[0].(*llm.Response), retErr
}
-func (m *MockConvo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
+func (m *MockConvo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
m.recordCall("ToolResultContents", resp)
exp, ok := m.findMatchingExpectation("ToolResultContents", resp)
if !ok {
@@ -163,7 +163,7 @@
retErr = err
}
- return exp.result[0].([]llm.Content), retErr
+ return exp.result[0].([]llm.Content), false, retErr
}
func (m *MockConvo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {