claudetool: make it easier to parameterize patch tool
diff --git a/claudetool/bash.go b/claudetool/bash.go
index 5c1eee0..72c4ecc 100644
--- a/claudetool/bash.go
+++ b/claudetool/bash.go
@@ -25,7 +25,7 @@
// PermissionCallback is a function type for checking if a command is allowed to run
type PermissionCallback func(command string) error
-// BashTool specifies a llm.Tool for executing shell commands.
+// BashTool specifies an llm.Tool for executing shell commands.
type BashTool struct {
// CheckPermission is called before running any command, if set
CheckPermission PermissionCallback
diff --git a/claudetool/patch.go b/claudetool/patch.go
index 04ea942..b40396a 100644
--- a/claudetool/patch.go
+++ b/claudetool/patch.go
@@ -25,8 +25,13 @@
// and returns a new, possibly altered tool output.
type PatchCallback func(input PatchInput, output llm.ToolOut) llm.ToolOut
-// Patch creates a patch tool. The callback may be nil.
-func Patch(callback PatchCallback) *llm.Tool {
+// PatchTool specifies an llm.Tool for patching files.
+type PatchTool struct {
+ Callback PatchCallback // may be nil
+}
+
+// Tool returns an llm.Tool based on p.
+func (p *PatchTool) Tool() *llm.Tool {
return &llm.Tool{
Name: PatchName,
Description: strings.TrimSpace(PatchDescription),
@@ -34,8 +39,8 @@
Run: func(ctx context.Context, m json.RawMessage) llm.ToolOut {
var input PatchInput
output := patchRun(ctx, m, &input)
- if callback != nil {
- return callback(input, output)
+ if p.Callback != nil {
+ return p.Callback(input, output)
}
return output
},
diff --git a/loop/agent.go b/loop/agent.go
index 0877a5a..0953a56 100644
--- a/loop/agent.go
+++ b/loop/agent.go
@@ -1390,6 +1390,9 @@
Timeouts: a.config.BashTimeouts,
Pwd: a.workingDir,
}
+ patchTool := &claudetool.PatchTool{
+ Callback: a.patchCallback,
+ }
// Register all tools with the conversation
// When adding, removing, or modifying tools here, double-check that the termui tool display
@@ -1411,7 +1414,7 @@
convo.Tools = []*llm.Tool{
bashTool.Tool(),
claudetool.Keyword,
- claudetool.Patch(a.patchCallback),
+ patchTool.Tool(),
claudetool.Think,
claudetool.TodoRead,
claudetool.TodoWrite,