loop: make multiplechoice tool calls end the turn
Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: s8d507faf9c095824sk
diff --git a/llm/conversation/convo.go b/llm/conversation/convo.go
index ba6d2d9..4740f22 100644
--- a/llm/conversation/convo.go
+++ b/llm/conversation/convo.go
@@ -402,17 +402,24 @@
// ToolResultContents runs all tool uses requested by the response and returns their results.
// Cancelling ctx will cancel any running tool calls.
-func (c *Convo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
+// The boolean return value indicates whether any of the executed tools should end the turn.
+func (c *Convo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, bool, error) {
if resp.StopReason != llm.StopReasonToolUse {
- return nil, nil
+ return nil, false, nil
}
// Extract all tool calls from the response, call the tools, and gather the results.
var wg sync.WaitGroup
toolResultC := make(chan llm.Content, len(resp.Content))
+
+ endsTurn := false
for _, part := range resp.Content {
if part.Type != llm.ContentTypeToolUse {
continue
}
+ tool, err := c.findTool(part.ToolName)
+ if err == nil && tool.EndsTurn {
+ endsTurn = true
+ }
c.incrementToolUse(part.ToolName)
startTime := time.Now()
@@ -492,9 +499,9 @@
toolResults = append(toolResults, toolResult)
}
if ctx.Err() != nil {
- return nil, ctx.Err()
+ return nil, false, ctx.Err()
}
- return toolResults, nil
+ return toolResults, endsTurn, nil
}
func (c *Convo) incrementToolUse(name string) {