Initial commit
diff --git a/claudetool/bash.go b/claudetool/bash.go
new file mode 100644
index 0000000..d76d7f1
--- /dev/null
+++ b/claudetool/bash.go
@@ -0,0 +1,163 @@
+package claudetool
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "math"
+ "os/exec"
+ "strings"
+ "syscall"
+ "time"
+
+ "sketch.dev/ant"
+ "sketch.dev/claudetool/bashkit"
+)
+
+// The Bash tool executes shell commands with bash -c and optional timeout
+var Bash = &ant.Tool{
+ Name: bashName,
+ Description: strings.TrimSpace(bashDescription),
+ InputSchema: ant.MustSchema(bashInputSchema),
+ Run: BashRun,
+}
+
+const (
+ bashName = "bash"
+ bashDescription = `
+Executes a shell command using bash -c with an optional timeout, returning combined stdout and stderr.
+
+Executables pre-installed in this environment include:
+- standard unix tools
+- go
+- git
+- rg
+- jq
+- gopls
+- sqlite
+- fzf
+- gh
+- python3
+`
+ // If you modify this, update the termui template for prettier rendering.
+ bashInputSchema = `
+{
+ "type": "object",
+ "required": ["command"],
+ "properties": {
+ "command": {
+ "type": "string",
+ "description": "Shell script to execute"
+ },
+ "timeout": {
+ "type": "string",
+ "description": "Timeout as a Go duration string, defaults to '1m'"
+ }
+ }
+}
+`
+)
+
+type bashInput struct {
+ Command string `json:"command"`
+ Timeout string `json:"timeout,omitempty"`
+}
+
+func (i *bashInput) timeout() time.Duration {
+ dur, err := time.ParseDuration(i.Timeout)
+ if err != nil {
+ return 1 * time.Minute
+ }
+ return dur
+}
+
+func BashRun(ctx context.Context, m json.RawMessage) (string, error) {
+ var req bashInput
+ if err := json.Unmarshal(m, &req); err != nil {
+ return "", fmt.Errorf("failed to unmarshal bash command input: %w", err)
+ }
+ // do a quick permissions check (NOT a security barrier)
+ err := bashkit.Check(req.Command)
+ if err != nil {
+ return "", err
+ }
+ out, execErr := executeBash(ctx, req)
+ if execErr == nil {
+ return out, nil
+ }
+ return "", execErr
+}
+
+const maxBashOutputLength = 131072
+
+func executeBash(ctx context.Context, req bashInput) (string, error) {
+ execCtx, cancel := context.WithTimeout(ctx, req.timeout())
+ defer cancel()
+
+ // Can't do the simple thing and call CombinedOutput because of the need to kill the process group.
+ cmd := exec.CommandContext(execCtx, "bash", "-c", req.Command)
+ cmd.Dir = WorkingDir(ctx)
+ cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+
+ var output bytes.Buffer
+ cmd.Stdin = nil
+ cmd.Stdout = &output
+ cmd.Stderr = &output
+ if err := cmd.Start(); err != nil {
+ return "", fmt.Errorf("command failed: %w", err)
+ }
+ proc := cmd.Process
+ done := make(chan struct{})
+ go func() {
+ select {
+ case <-execCtx.Done():
+ if execCtx.Err() == context.DeadlineExceeded && proc != nil {
+ // Kill the entire process group.
+ syscall.Kill(-proc.Pid, syscall.SIGKILL)
+ }
+ case <-done:
+ }
+ }()
+
+ err := cmd.Wait()
+ close(done)
+
+ if execCtx.Err() == context.DeadlineExceeded {
+ return "", fmt.Errorf("command timed out after %s", req.timeout())
+ }
+ longOutput := output.Len() > maxBashOutputLength
+ var outstr string
+ if longOutput {
+ outstr = fmt.Sprintf("output too long: got %v, max is %v\ninitial bytes of output:\n%s",
+ humanizeBytes(output.Len()), humanizeBytes(maxBashOutputLength),
+ output.Bytes()[:1024],
+ )
+ } else {
+ outstr = output.String()
+ }
+
+ if err != nil {
+ return "", fmt.Errorf("command failed: %w\n%s", err, outstr)
+ }
+
+ if longOutput {
+ return "", fmt.Errorf("%s", outstr)
+ }
+
+ return output.String(), nil
+}
+
+func humanizeBytes(bytes int) string {
+ switch {
+ case bytes < 4*1024:
+ return fmt.Sprintf("%dB", bytes)
+ case bytes < 1024*1024:
+ kb := int(math.Round(float64(bytes) / 1024.0))
+ return fmt.Sprintf("%dkB", kb)
+ case bytes < 1024*1024*1024:
+ mb := int(math.Round(float64(bytes) / (1024.0 * 1024.0)))
+ return fmt.Sprintf("%dMB", mb)
+ }
+ return "more than 1GB"
+}
diff --git a/claudetool/bash_test.go b/claudetool/bash_test.go
new file mode 100644
index 0000000..8fe4b5c
--- /dev/null
+++ b/claudetool/bash_test.go
@@ -0,0 +1,186 @@
+package claudetool
+
+import (
+ "context"
+ "encoding/json"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestBashRun(t *testing.T) {
+ // Test basic functionality
+ t.Run("Basic Command", func(t *testing.T) {
+ input := json.RawMessage(`{"command":"echo 'Hello, world!'"}`)
+
+ result, err := BashRun(context.Background(), input)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ expected := "Hello, world!\n"
+ if result != expected {
+ t.Errorf("Expected %q, got %q", expected, result)
+ }
+ })
+
+ // Test with arguments
+ t.Run("Command With Arguments", func(t *testing.T) {
+ input := json.RawMessage(`{"command":"echo -n foo && echo -n bar"}`)
+
+ result, err := BashRun(context.Background(), input)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ expected := "foobar"
+ if result != expected {
+ t.Errorf("Expected %q, got %q", expected, result)
+ }
+ })
+
+ // Test with timeout parameter
+ t.Run("With Timeout", func(t *testing.T) {
+ inputObj := struct {
+ Command string `json:"command"`
+ Timeout string `json:"timeout"`
+ }{
+ Command: "sleep 0.1 && echo 'Completed'",
+ Timeout: "5s",
+ }
+ inputJSON, err := json.Marshal(inputObj)
+ if err != nil {
+ t.Fatalf("Failed to marshal input: %v", err)
+ }
+
+ result, err := BashRun(context.Background(), inputJSON)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ expected := "Completed\n"
+ if result != expected {
+ t.Errorf("Expected %q, got %q", expected, result)
+ }
+ })
+
+ // Test command timeout
+ t.Run("Command Timeout", func(t *testing.T) {
+ inputObj := struct {
+ Command string `json:"command"`
+ Timeout string `json:"timeout"`
+ }{
+ Command: "sleep 0.5 && echo 'Should not see this'",
+ Timeout: "100ms",
+ }
+ inputJSON, err := json.Marshal(inputObj)
+ if err != nil {
+ t.Fatalf("Failed to marshal input: %v", err)
+ }
+
+ _, err = BashRun(context.Background(), inputJSON)
+ if err == nil {
+ t.Errorf("Expected timeout error, got none")
+ } else if !strings.Contains(err.Error(), "timed out") {
+ t.Errorf("Expected timeout error, got: %v", err)
+ }
+ })
+
+ // Test command that fails
+ t.Run("Failed Command", func(t *testing.T) {
+ input := json.RawMessage(`{"command":"exit 1"}`)
+
+ _, err := BashRun(context.Background(), input)
+ if err == nil {
+ t.Errorf("Expected error for failed command, got none")
+ }
+ })
+
+ // Test invalid input
+ t.Run("Invalid JSON Input", func(t *testing.T) {
+ input := json.RawMessage(`{"command":123}`) // Invalid JSON (command must be string)
+
+ _, err := BashRun(context.Background(), input)
+ if err == nil {
+ t.Errorf("Expected error for invalid input, got none")
+ }
+ })
+}
+
+func TestExecuteBash(t *testing.T) {
+ ctx := context.Background()
+
+ // Test successful command
+ t.Run("Successful Command", func(t *testing.T) {
+ req := bashInput{
+ Command: "echo 'Success'",
+ Timeout: "5s",
+ }
+
+ output, err := executeBash(ctx, req)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ want := "Success\n"
+ if output != want {
+ t.Errorf("Expected %q, got %q", want, output)
+ }
+ })
+
+ // Test command with output to stderr
+ t.Run("Command with stderr", func(t *testing.T) {
+ req := bashInput{
+ Command: "echo 'Error message' >&2 && echo 'Success'",
+ Timeout: "5s",
+ }
+
+ output, err := executeBash(ctx, req)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ want := "Error message\nSuccess\n"
+ if output != want {
+ t.Errorf("Expected %q, got %q", want, output)
+ }
+ })
+
+ // Test command that fails with stderr
+ t.Run("Failed Command with stderr", func(t *testing.T) {
+ req := bashInput{
+ Command: "echo 'Error message' >&2 && exit 1",
+ Timeout: "5s",
+ }
+
+ _, err := executeBash(ctx, req)
+ if err == nil {
+ t.Errorf("Expected error for failed command, got none")
+ } else if !strings.Contains(err.Error(), "Error message") {
+ t.Errorf("Expected stderr in error message, got: %v", err)
+ }
+ })
+
+ // Test timeout
+ t.Run("Command Timeout", func(t *testing.T) {
+ req := bashInput{
+ Command: "sleep 1 && echo 'Should not see this'",
+ Timeout: "100ms",
+ }
+
+ start := time.Now()
+ _, err := executeBash(ctx, req)
+ elapsed := time.Since(start)
+
+ // Command should time out after ~100ms, not wait for full 1 second
+ if elapsed >= 1*time.Second {
+ t.Errorf("Command did not respect timeout, took %v", elapsed)
+ }
+
+ if err == nil {
+ t.Errorf("Expected timeout error, got none")
+ } else if !strings.Contains(err.Error(), "timed out") {
+ t.Errorf("Expected timeout error, got: %v", err)
+ }
+ })
+}
diff --git a/claudetool/bashkit/bashkit.go b/claudetool/bashkit/bashkit.go
new file mode 100644
index 0000000..a56eef0
--- /dev/null
+++ b/claudetool/bashkit/bashkit.go
@@ -0,0 +1,97 @@
+package bashkit
+
+import (
+ "fmt"
+ "strings"
+
+ "mvdan.cc/sh/v3/syntax"
+)
+
+var checks = []func(*syntax.CallExpr) error{
+ noGitConfigUsernameEmailChanges,
+}
+
+// Check inspects bashScript and returns an error if it ought not be executed.
+// Check DOES NOT PROVIDE SECURITY against malicious actors.
+// It is intended to catch straightforward mistakes in which a model
+// does things despite having been instructed not to do them.
+func Check(bashScript string) error {
+ r := strings.NewReader(bashScript)
+ parser := syntax.NewParser()
+ file, err := parser.Parse(r, "")
+ if err != nil {
+ // Execution will fail, but we'll get a better error message from bash.
+ // Note that if this were security load bearing, this would be a terrible idea:
+ // You could smuggle stuff past Check by exploiting differences in what is considered syntactically valid.
+ // But it is not.
+ return nil
+ }
+
+ syntax.Walk(file, func(node syntax.Node) bool {
+ if err != nil {
+ return false
+ }
+ callExpr, ok := node.(*syntax.CallExpr)
+ if !ok {
+ return true
+ }
+ for _, check := range checks {
+ err = check(callExpr)
+ if err != nil {
+ return false
+ }
+ }
+ return true
+ })
+
+ return err
+}
+
+// noGitConfigUsernameEmailChanges checks for git config username/email changes.
+// It uses simple heuristics, and has both false positives and false negatives.
+func noGitConfigUsernameEmailChanges(cmd *syntax.CallExpr) error {
+ if hasGitConfigUsernameEmailChanges(cmd) {
+ return fmt.Errorf("permission denied: changing git config username/email is not allowed, use env vars instead")
+ }
+ return nil
+}
+
+func hasGitConfigUsernameEmailChanges(cmd *syntax.CallExpr) bool {
+ if len(cmd.Args) < 3 {
+ return false
+ }
+ if cmd.Args[0].Lit() != "git" {
+ return false
+ }
+
+ configIndex := -1
+ for i, arg := range cmd.Args {
+ if arg.Lit() == "config" {
+ configIndex = i
+ break
+ }
+ }
+
+ if configIndex < 0 || configIndex == len(cmd.Args)-1 {
+ return false
+ }
+
+ // check for user.name or user.email
+ keyIndex := -1
+ for i, arg := range cmd.Args {
+ if i < configIndex {
+ continue
+ }
+ if arg.Lit() == "user.name" || arg.Lit() == "user.email" {
+ keyIndex = i
+ break
+ }
+ }
+
+ if keyIndex < 0 || keyIndex == len(cmd.Args)-1 {
+ return false
+ }
+
+ // user.name/user.email is followed by a value
+ return true
+}
diff --git a/claudetool/bashkit/bashkit_test.go b/claudetool/bashkit/bashkit_test.go
new file mode 100644
index 0000000..8bcdd8f
--- /dev/null
+++ b/claudetool/bashkit/bashkit_test.go
@@ -0,0 +1,109 @@
+package bashkit
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestCheck(t *testing.T) {
+ tests := []struct {
+ name string
+ script string
+ wantErr bool
+ errMatch string // string to match in error message, if wantErr is true
+ }{
+ {
+ name: "valid script",
+ script: "echo hello world",
+ wantErr: false,
+ errMatch: "",
+ },
+ {
+ name: "invalid syntax",
+ script: "echo 'unterminated string",
+ wantErr: false, // As per implementation, syntax errors are not flagged
+ errMatch: "",
+ },
+ {
+ name: "git config user.name",
+ script: "git config user.name 'John Doe'",
+ wantErr: true,
+ errMatch: "changing git config username/email is not allowed",
+ },
+ {
+ name: "git config user.email",
+ script: "git config user.email 'john@example.com'",
+ wantErr: true,
+ errMatch: "changing git config username/email is not allowed",
+ },
+ {
+ name: "git config with flag user.name",
+ script: "git config --global user.name 'John Doe'",
+ wantErr: true,
+ errMatch: "changing git config username/email is not allowed",
+ },
+ {
+ name: "git config with other setting",
+ script: "git config core.editor vim",
+ wantErr: false,
+ errMatch: "",
+ },
+ {
+ name: "git without config",
+ script: "git commit -m 'Add feature'",
+ wantErr: false,
+ errMatch: "",
+ },
+ {
+ name: "multiline script with proper escaped newlines",
+ script: "echo 'Setting up git...' && git config user.name 'John Doe' && echo 'Done!'",
+ wantErr: true,
+ errMatch: "changing git config username/email is not allowed",
+ },
+ {
+ name: "multiline script with backticks",
+ script: `echo 'Setting up git...'
+git config user.name 'John Doe'
+echo 'Done!'`,
+ wantErr: true,
+ errMatch: "changing git config username/email is not allowed",
+ },
+ {
+ name: "git config with variable",
+ script: "NAME='John Doe'\ngit config user.name $NAME",
+ wantErr: true,
+ errMatch: "changing git config username/email is not allowed",
+ },
+ {
+ name: "only git command",
+ script: "git",
+ wantErr: false,
+ errMatch: "",
+ },
+ {
+ name: "read git config",
+ script: "git config user.name",
+ wantErr: false,
+ errMatch: "",
+ },
+ {
+ name: "commented git config",
+ script: "# git config user.name 'John Doe'",
+ wantErr: false,
+ errMatch: "",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ err := Check(tc.script)
+ if (err != nil) != tc.wantErr {
+ t.Errorf("Check() error = %v, wantErr %v", err, tc.wantErr)
+ return
+ }
+ if tc.wantErr && err != nil && !strings.Contains(err.Error(), tc.errMatch) {
+ t.Errorf("Check() error message = %v, want containing %v", err, tc.errMatch)
+ }
+ })
+ }
+}
diff --git a/claudetool/codereview.go b/claudetool/codereview.go
new file mode 100644
index 0000000..f305b4b
--- /dev/null
+++ b/claudetool/codereview.go
@@ -0,0 +1,347 @@
+package claudetool
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "log/slog"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+)
+
+// A CodeReviewer manages quality checks.
+type CodeReviewer struct {
+ repoRoot string
+ initialCommit string
+ initialStatus []fileStatus // git status of files at initial commit, absolute paths
+ reviewed []string // history of all commits which have been reviewed
+ initialWorktree string // git worktree at initial commit, absolute path
+}
+
+func NewCodeReviewer(ctx context.Context, repoRoot, initialCommit string) (*CodeReviewer, error) {
+ r := &CodeReviewer{
+ repoRoot: repoRoot,
+ initialCommit: initialCommit,
+ }
+ if r.repoRoot == "" {
+ return nil, fmt.Errorf("NewCodeReviewer: repoRoot must be non-empty")
+ }
+ if r.initialCommit == "" {
+ return nil, fmt.Errorf("NewCodeReviewer: initialCommit must be non-empty")
+ }
+ // Confirm that root is in fact the git repo root.
+ root, err := findRepoRoot(r.repoRoot)
+ if err != nil {
+ return nil, err
+ }
+ if root != r.repoRoot {
+ return nil, fmt.Errorf("NewCodeReviewer: repoRoot=%q but git repo root is %q", r.repoRoot, root)
+ }
+
+ // Get an initial list of dirty and untracked files.
+ // We'll filter them out later when deciding whether the worktree is clean.
+ status, err := r.repoStatus(ctx)
+ if err != nil {
+ return nil, err
+ }
+ r.initialStatus = status
+ return r, nil
+}
+
+// Autoformat formats all files changed in HEAD.
+// It returns a list of all files that were formatted.
+// It is best-effort only.
+func (r *CodeReviewer) Autoformat(ctx context.Context) []string {
+ // Refuse to format if HEAD == r.InitialCommit
+ head, err := r.CurrentCommit(ctx)
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat unable to get current commit", "err", err)
+ return nil
+ }
+ parent, err := r.ResolveCommit(ctx, "HEAD^1")
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat unable to get parent commit", "err", err)
+ return nil
+ }
+ if head == r.initialCommit {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat refusing to format because HEAD == InitialCommit")
+ return nil
+ }
+ // Retrieve a list of all files changed
+ // TODO: instead of one git diff --name-only and then N --name-status, do one --name-status.
+ changedFiles, err := r.changedFiles(ctx, r.initialCommit, head)
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat unable to get changed files", "err", err)
+ return nil
+ }
+
+ // General strategy: For all changed files,
+ // run the strictest formatter that passes on the original version.
+ // TODO: add non-Go formatters?
+ // TODO: at a minimum, for common file types, ensure trailing newlines and maybe trim trailing whitespace per line?
+ var fmtFiles []string
+ for _, file := range changedFiles {
+ if !strings.HasSuffix(file, ".go") {
+ continue
+ }
+ fileStatus, err := r.gitFileStatus(ctx, file)
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat unable to get file status", "file", file, "err", err)
+ continue
+ }
+ if fileStatus == "D" { // deleted, nothing to format
+ continue
+ }
+ code, err := r.getFileContentAtCommit(ctx, file, head)
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat unable to get file content at head", "file", file, "err", err)
+ continue
+ }
+ if isAutogeneratedGoFile(code) { // leave autogenerated files alone
+ continue
+ }
+ onDisk, err := os.ReadFile(file)
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat unable to read file", "file", file, "err", err)
+ continue
+ }
+ if !bytes.Equal(code, onDisk) { // file has been modified since HEAD
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat file modified since HEAD", "file", file, "err", err)
+ continue
+ }
+ var formatterToUse string
+ if fileStatus == "A" {
+ formatterToUse = "gofumpt" // newly added, so we can format how we please: use gofumpt
+ } else {
+ prev, err := r.getFileContentAtCommit(ctx, file, parent)
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat unable to get file content at parent", "file", file, "err", err)
+ continue
+ }
+ formatterToUse = r.pickFormatter(ctx, prev) // pick the strictest formatter that passes on the original version
+ }
+
+ // Apply the chosen formatter to the current file
+ newCode := r.runFormatter(ctx, formatterToUse, code)
+ if newCode == nil { // no changes made
+ continue
+ }
+ // write to disk
+ if err := os.WriteFile(file, newCode, 0o600); err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.Autoformat unable to write formatted file", "file", file, "err", err)
+ continue
+ }
+ fmtFiles = append(fmtFiles, file)
+ }
+ return fmtFiles
+}
+
+// RequireNormalGitState checks that the git repo state is pretty normal.
+func (r *CodeReviewer) RequireNormalGitState(_ context.Context) error {
+ rebaseDirs := []string{"rebase-merge", "rebase-apply"}
+ for _, dir := range rebaseDirs {
+ _, err := os.Stat(filepath.Join(r.repoRoot, dir))
+ if err == nil {
+ return fmt.Errorf("git repo is not clean: rebase in progress")
+ }
+ }
+ filesReason := map[string]string{
+ "MERGE_HEAD": "merge is in progress",
+ "CHERRY_PICK_HEAD": "cherry-pick is in progress",
+ "REVERT_HEAD": "revert is in progress",
+ "BISECT_LOG": "bisect is in progress",
+ }
+ for file, reason := range filesReason {
+ _, err := os.Stat(filepath.Join(r.repoRoot, file))
+ if err == nil {
+ return fmt.Errorf("git repo is not clean: %s", reason)
+ }
+ }
+ return nil
+}
+
+func (r *CodeReviewer) RequireNoUncommittedChanges(ctx context.Context) error {
+ // Check that there are no uncommitted changes, whether staged or not.
+ // (Changes in r.initialStatus are OK, no other changes are.)
+ statuses, err := r.repoStatus(ctx)
+ if err != nil {
+ return fmt.Errorf("unable to get repo status: %w", err)
+ }
+ uncommitted := new(strings.Builder)
+ for _, status := range statuses {
+ if !r.initialStatusesContainFile(status.Path) {
+ fmt.Fprintf(uncommitted, "%s %s\n", status.Path, status.RawStatus)
+ }
+ }
+ if uncommitted.Len() > 0 {
+ return fmt.Errorf("uncommitted changes in repo, please commit or revert:\n%s", uncommitted.String())
+ }
+ return nil
+}
+
+func (r *CodeReviewer) initialStatusesContainFile(file string) bool {
+ for _, s := range r.initialStatus {
+ if s.Path == file {
+ return true
+ }
+ }
+ return false
+}
+
+type fileStatus struct {
+ Path string
+ RawStatus string // always 2 characters
+}
+
+func (r *CodeReviewer) repoStatus(ctx context.Context) ([]fileStatus, error) {
+ // Run git status --porcelain, split into lines
+ cmd := exec.CommandContext(ctx, "git", "status", "--porcelain")
+ cmd.Dir = r.repoRoot
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return nil, fmt.Errorf("failed to run git status: %w\n%s", err, out)
+ }
+ var statuses []fileStatus
+ for line := range strings.Lines(string(out)) {
+ if len(line) == 0 {
+ continue
+ }
+ if len(line) < 3 {
+ return nil, fmt.Errorf("invalid status line: %s", line)
+ }
+ path := line[3:]
+ status := line[:2]
+ absPath := r.absPath(path)
+ statuses = append(statuses, fileStatus{Path: absPath, RawStatus: status})
+ }
+ return statuses, nil
+}
+
+// CurrentCommit retrieves the current git commit hash
+func (r *CodeReviewer) CurrentCommit(ctx context.Context) (string, error) {
+ return r.ResolveCommit(ctx, "HEAD")
+}
+
+func (r *CodeReviewer) ResolveCommit(ctx context.Context, ref string) (string, error) {
+ cmd := exec.CommandContext(ctx, "git", "rev-parse", ref)
+ cmd.Dir = r.repoRoot
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return "", fmt.Errorf("failed to get current commit hash: %w\n%s", err, out)
+ }
+ return strings.TrimSpace(string(out)), nil
+}
+
+func (r *CodeReviewer) absPath(relPath string) string {
+ return filepath.Clean(filepath.Join(r.repoRoot, relPath))
+}
+
+// gitFileStatus returns the status of a file (A for added, M for modified, D for deleted, etc.)
+func (r *CodeReviewer) gitFileStatus(ctx context.Context, file string) (string, error) {
+ cmd := exec.CommandContext(ctx, "git", "diff", "--name-status", r.initialCommit, "HEAD", "--", file)
+ cmd.Dir = r.repoRoot
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return "", fmt.Errorf("failed to get file status: %w\n%s", err, out)
+ }
+ status := strings.TrimSpace(string(out))
+ if status == "" {
+ return "", fmt.Errorf("no status found for file: %s", file)
+ }
+ return string(status[0]), nil
+}
+
+// getFileContentAtCommit retrieves file content at a specific commit
+func (r *CodeReviewer) getFileContentAtCommit(ctx context.Context, file, commit string) ([]byte, error) {
+ relFile, err := filepath.Rel(r.repoRoot, file)
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.getFileContentAtCommit: failed to get relative path", "repo_root", r.repoRoot, "file", file, "err", err)
+ file = relFile
+ }
+ cmd := exec.CommandContext(ctx, "git", "show", fmt.Sprintf("%s:%s", commit, relFile))
+ cmd.Dir = r.repoRoot
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get file content at commit %s: %w\n%s", commit, err, out)
+ }
+ return out, nil
+}
+
+// runFormatter runs the specified formatter on a file and returns the results.
+// A nil result indicates that the file is unchanged, or that an error occurred.
+func (r *CodeReviewer) runFormatter(ctx context.Context, formatter string, content []byte) []byte {
+ if formatter == "" {
+ return nil // no formatter
+ }
+ // Run the formatter and capture the output
+ cmd := exec.CommandContext(ctx, formatter)
+ cmd.Dir = r.repoRoot
+ cmd.Stdin = bytes.NewReader(content)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ // probably a parse error, err on the side of safety
+ return nil
+ }
+ if bytes.Equal(content, out) {
+ return nil // no changes
+ }
+ return out
+}
+
+// formatterWouldChange reports whether a formatter would make changes to the content.
+// If the contents are invalid, it returns false.
+// It works by piping the content to the formatter with the -l flag.
+func (r *CodeReviewer) formatterWouldChange(ctx context.Context, formatter string, content []byte) bool {
+ cmd := exec.CommandContext(ctx, formatter, "-l")
+ cmd.Dir = r.repoRoot
+ cmd.Stdin = bytes.NewReader(content)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ // probably a parse error, err on the side of safety
+ return false
+ }
+
+ // If the output is empty, the file passes the formatter
+ // If the output contains "<standard input>", the file would be changed
+ return len(bytes.TrimSpace(out)) > 0
+}
+
+// pickFormatter picks a formatter to use for code.
+// If something goes wrong, it recommends no formatter (empty string).
+func (r *CodeReviewer) pickFormatter(ctx context.Context, code []byte) string {
+ // Test each formatter from strictest to least strict.
+ // Keep the first one that doesn't make changes.
+ formatters := []string{"gofumpt", "goimports", "gofmt"}
+ for _, formatter := range formatters {
+ if r.formatterWouldChange(ctx, formatter, code) {
+ continue
+ }
+ return formatter
+ }
+ return "" // no safe formatter found
+}
+
+// changedFiles retrieves a list of all files changed between two commits
+func (r *CodeReviewer) changedFiles(ctx context.Context, fromCommit, toCommit string) ([]string, error) {
+ cmd := exec.CommandContext(ctx, "git", "diff", "--name-only", fromCommit, toCommit)
+ cmd.Dir = r.repoRoot
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get changed files: %w\n%s", err, out)
+ }
+ var files []string
+ for line := range strings.Lines(string(out)) {
+ line = strings.TrimSpace(line)
+ if len(line) == 0 {
+ continue
+ }
+ path := r.absPath(line)
+ if r.initialStatusesContainFile(path) {
+ continue
+ }
+ files = append(files, path)
+ }
+ return files, nil
+}
diff --git a/claudetool/differential.go b/claudetool/differential.go
new file mode 100644
index 0000000..14dce04
--- /dev/null
+++ b/claudetool/differential.go
@@ -0,0 +1,1125 @@
+package claudetool
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "maps"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "slices"
+ "strings"
+ "time"
+
+ "golang.org/x/tools/go/packages"
+ "sketch.dev/ant"
+)
+
+// This file does differential quality analysis of a commit relative to a base commit.
+
+// Tool returns a tool spec for a CodeReview tool backed by r.
+func (r *CodeReviewer) Tool() *ant.Tool {
+ spec := &ant.Tool{
+ Name: "codereview",
+ Description: `Run an automated code review.`,
+ // If you modify this, update the termui template for prettier rendering.
+ InputSchema: ant.MustSchema(`{"type": "object"}`),
+ Run: r.Run,
+ }
+ return spec
+}
+
+func (r *CodeReviewer) Run(ctx context.Context, m json.RawMessage) (string, error) {
+ if err := r.RequireNormalGitState(ctx); err != nil {
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to check for normal git state", "err", err)
+ return "", err
+ }
+ if err := r.RequireNoUncommittedChanges(ctx); err != nil {
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to check for uncommitted changes", "err", err)
+ return "", err
+ }
+
+ // Check that the current commit is not the initial commit
+ currentCommit, err := r.CurrentCommit(ctx)
+ if err != nil {
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to get current commit", "err", err)
+ return "", err
+ }
+ if r.IsInitialCommit(currentCommit) {
+ slog.DebugContext(ctx, "CodeReviewer.Run: current commit is initial commit, nothing to review")
+ return "", fmt.Errorf("no new commits have been added, nothing to review")
+ }
+
+ // No matter what failures happen from here out, we will declare this to have been reviewed.
+ // This should help avoid the model getting blocked by a broken code review tool.
+ r.reviewed = append(r.reviewed, currentCommit)
+
+ changedFiles, err := r.changedFiles(ctx, r.initialCommit, currentCommit)
+ if err != nil {
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to get changed files", "err", err)
+ return "", err
+ }
+
+ // Prepare to analyze before/after for the impacted files.
+ // We use the current commit to determine what packages exist and are impacted.
+ // The packages in the initial commit may be different.
+ // Good enough for now.
+ // TODO: do better
+ directPkgs, allPkgs, err := r.packagesForFiles(ctx, changedFiles)
+ if err != nil {
+ // TODO: log and skip to stuff that doesn't require packages
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to get packages for files", "err", err)
+ return "", err
+ }
+ allPkgList := slices.Collect(maps.Keys(allPkgs))
+ directPkgList := slices.Collect(maps.Keys(directPkgs))
+
+ var msgs []string
+
+ testMsg, err := r.checkTests(ctx, allPkgList)
+ if err != nil {
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to check tests", "err", err)
+ return "", err
+ }
+ if testMsg != "" {
+ msgs = append(msgs, testMsg)
+ }
+
+ vetMsg, err := r.checkVet(ctx, directPkgList)
+ if err != nil {
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to check vet", "err", err)
+ return "", err
+ }
+ if vetMsg != "" {
+ msgs = append(msgs, vetMsg)
+ }
+
+ goplsMsg, err := r.checkGopls(ctx, changedFiles)
+ if err != nil {
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to check gopls", "err", err)
+ return "", err
+ }
+ if goplsMsg != "" {
+ msgs = append(msgs, goplsMsg)
+ }
+
+ if len(msgs) == 0 {
+ slog.DebugContext(ctx, "CodeReviewer.Run: no issues found")
+ return "OK", nil
+ }
+ slog.DebugContext(ctx, "CodeReviewer.Run: found issues", "issues", msgs)
+ return strings.Join(msgs, "\n\n"), nil
+}
+
+func (r *CodeReviewer) initializeInitialCommitWorktree(ctx context.Context) error {
+ if r.initialWorktree != "" {
+ return nil
+ }
+ tmpDir, err := os.MkdirTemp("", "sketch-codereview-worktree")
+ if err != nil {
+ return err
+ }
+ worktreeCmd := exec.CommandContext(ctx, "git", "worktree", "add", "--detach", tmpDir, r.initialCommit)
+ worktreeCmd.Dir = r.repoRoot
+ out, err := worktreeCmd.CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("unable to create worktree for initial commit: %w\n%s", err, out)
+ }
+ r.initialWorktree = tmpDir
+ return nil
+}
+
+func (r *CodeReviewer) checkTests(ctx context.Context, pkgList []string) (string, error) {
+ goTestArgs := []string{"test", "-json", "-v"}
+ goTestArgs = append(goTestArgs, pkgList...)
+
+ afterTestCmd := exec.CommandContext(ctx, "go", goTestArgs...)
+ afterTestCmd.Dir = r.repoRoot
+ afterTestOut, afterTestErr := afterTestCmd.Output()
+ if afterTestErr == nil {
+ return "", nil // all tests pass, we're good!
+ }
+
+ err := r.initializeInitialCommitWorktree(ctx)
+ if err != nil {
+ return "", err
+ }
+
+ beforeTestCmd := exec.CommandContext(ctx, "go", goTestArgs...)
+ beforeTestCmd.Dir = r.initialWorktree
+ beforeTestOut, _ := beforeTestCmd.Output() // ignore error, interesting info is in the output
+
+ // Parse the jsonl test results
+ beforeResults, beforeParseErr := parseTestResults(beforeTestOut)
+ if beforeParseErr != nil {
+ return "", fmt.Errorf("unable to parse test results for initial commit: %w\n%s", beforeParseErr, beforeTestOut)
+ }
+ afterResults, afterParseErr := parseTestResults(afterTestOut)
+ if afterParseErr != nil {
+ return "", fmt.Errorf("unable to parse test results for current commit: %w\n%s", afterParseErr, afterTestOut)
+ }
+
+ testRegressions, err := r.compareTestResults(beforeResults, afterResults)
+ if err != nil {
+ return "", fmt.Errorf("failed to compare test results: %w", err)
+ }
+ // TODO: better output formatting?
+ res := r.formatTestRegressions(testRegressions)
+ return res, nil
+}
+
+// VetIssue represents a single issue found by go vet
+type VetIssue struct {
+ Position string `json:"posn"`
+ Message string `json:"message"`
+ // Ignoring suggested_fixes for now as we don't need them for comparison
+}
+
+// VetResult represents the JSON output of go vet -json for a single package
+type VetResult map[string][]VetIssue // category -> issues
+
+// VetResults represents the full JSON output of go vet -json
+type VetResults map[string]VetResult // package path -> result
+
+// checkVet runs go vet on the provided packages in both the current and initial state,
+// compares the results, and reports any new vet issues introduced in the current state.
+func (r *CodeReviewer) checkVet(ctx context.Context, pkgList []string) (string, error) {
+ if len(pkgList) == 0 {
+ return "", nil // no packages to check
+ }
+
+ // Run vet on the current state with JSON output
+ goVetArgs := []string{"vet", "-json"}
+ goVetArgs = append(goVetArgs, pkgList...)
+
+ afterVetCmd := exec.CommandContext(ctx, "go", goVetArgs...)
+ afterVetCmd.Dir = r.repoRoot
+ afterVetOut, afterVetErr := afterVetCmd.CombinedOutput() // ignore error, we'll parse the output regar
+ if afterVetErr != nil {
+ slog.WarnContext(ctx, "CodeReviewer.checkVet: (after) go vet failed", "err", afterVetErr, "output", string(afterVetOut))
+ return "", nil // nothing more we can do here
+ }
+
+ // Parse the JSON output (even if vet returned an error, as it does when issues are found)
+ afterVetResults, err := parseVetJSON(afterVetOut)
+ if err != nil {
+ return "", fmt.Errorf("failed to parse vet output for current state: %w", err)
+ }
+
+ // If no issues were found, we're done
+ if len(afterVetResults) == 0 || !vetResultsHaveIssues(afterVetResults) {
+ return "", nil
+ }
+
+ // Vet detected issues in the current state, check if they existed in the initial state
+ err = r.initializeInitialCommitWorktree(ctx)
+ if err != nil {
+ return "", err
+ }
+
+ beforeVetCmd := exec.CommandContext(ctx, "go", goVetArgs...)
+ beforeVetCmd.Dir = r.initialWorktree
+ beforeVetOut, _ := beforeVetCmd.CombinedOutput() // ignore error, we'll parse the output anyway
+
+ // Parse the JSON output for the initial state
+ beforeVetResults, err := parseVetJSON(beforeVetOut)
+ if err != nil {
+ return "", fmt.Errorf("failed to parse vet output for initial state: %w", err)
+ }
+
+ // Find new issues that weren't present in the initial state
+ vetRegressions := findVetRegressions(beforeVetResults, afterVetResults)
+ if !vetResultsHaveIssues(vetRegressions) {
+ return "", nil // no new issues
+ }
+
+ // Format the results
+ return formatVetRegressions(vetRegressions), nil
+}
+
+// parseVetJSON parses the JSON output from go vet -json
+func parseVetJSON(output []byte) (VetResults, error) {
+ // The output contains multiple JSON objects, one per package
+ // We need to parse them separately
+ results := make(VetResults)
+
+ // Process the output by collecting JSON chunks between # comment lines
+ lines := strings.Split(string(output), "\n")
+ currentChunk := strings.Builder{}
+
+ // Helper function to process accumulated JSON chunks
+ processChunk := func() {
+ chunk := strings.TrimSpace(currentChunk.String())
+ if chunk == "" || !strings.HasPrefix(chunk, "{") {
+ return // Skip empty chunks or non-JSON chunks
+ }
+
+ // Try to parse the chunk as JSON
+ var result VetResults
+ if err := json.Unmarshal([]byte(chunk), &result); err != nil {
+ return // Skip invalid JSON
+ }
+
+ // Merge with our results
+ for pkg, issues := range result {
+ results[pkg] = issues
+ }
+
+ // Reset the chunk builder
+ currentChunk.Reset()
+ }
+
+ // Process lines
+ for _, line := range lines {
+ // If we hit a comment line, process the previous chunk and start a new one
+ if strings.HasPrefix(strings.TrimSpace(line), "#") {
+ processChunk()
+ continue
+ }
+
+ // Add the line to the current chunk
+ currentChunk.WriteString(line)
+ currentChunk.WriteString("\n")
+ }
+
+ // Process the final chunk
+ processChunk()
+
+ return results, nil
+}
+
+// vetResultsHaveIssues checks if there are any actual issues in the vet results
+func vetResultsHaveIssues(results VetResults) bool {
+ for _, pkgResult := range results {
+ for _, issues := range pkgResult {
+ if len(issues) > 0 {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// findVetRegressions identifies vet issues that are new in the after state
+func findVetRegressions(before, after VetResults) VetResults {
+ regressions := make(VetResults)
+
+ // Go through all packages in the after state
+ for pkgPath, afterPkgResults := range after {
+ beforePkgResults, pkgExistedBefore := before[pkgPath]
+
+ // Initialize package in regressions if it has issues
+ if !pkgExistedBefore {
+ // If the package didn't exist before, all issues are new
+ regressions[pkgPath] = afterPkgResults
+ continue
+ }
+
+ // Compare issues by category
+ for category, afterIssues := range afterPkgResults {
+ beforeIssues, categoryExistedBefore := beforePkgResults[category]
+
+ if !categoryExistedBefore {
+ // If this category didn't exist before, all issues are new
+ if regressions[pkgPath] == nil {
+ regressions[pkgPath] = make(VetResult)
+ }
+ regressions[pkgPath][category] = afterIssues
+ continue
+ }
+
+ // Compare individual issues
+ var newIssues []VetIssue
+ for _, afterIssue := range afterIssues {
+ if !issueExistsIn(afterIssue, beforeIssues) {
+ newIssues = append(newIssues, afterIssue)
+ }
+ }
+
+ // Add new issues to regressions
+ if len(newIssues) > 0 {
+ if regressions[pkgPath] == nil {
+ regressions[pkgPath] = make(VetResult)
+ }
+ regressions[pkgPath][category] = newIssues
+ }
+ }
+ }
+
+ return regressions
+}
+
+// issueExistsIn checks if an issue already exists in a list of issues
+// using a looser comparison that's resilient to position changes
+func issueExistsIn(issue VetIssue, issues []VetIssue) bool {
+ issueFile := extractFilePath(issue.Position)
+
+ for _, existing := range issues {
+ // Main comparison is by message content, which is likely stable
+ if issue.Message == existing.Message {
+ // If messages match exactly, consider it the same issue even if position changed
+ return true
+ }
+
+ // As a secondary check, if the issue is in the same file and has similar message,
+ // it's likely the same issue that might have been slightly reworded or relocated
+ existingFile := extractFilePath(existing.Position)
+ if issueFile == existingFile && messagesSimilar(issue.Message, existing.Message) {
+ return true
+ }
+ }
+ return false
+}
+
+// extractFilePath gets just the file path from a position string like "/path/to/file.go:10:15"
+func extractFilePath(position string) string {
+ parts := strings.Split(position, ":")
+ if len(parts) >= 1 {
+ return parts[0]
+ }
+ return position // fallback to the full position if we can't parse it
+}
+
+// messagesSimilar checks if two messages are similar enough to be considered the same issue
+// This is a simple implementation that could be enhanced with more sophisticated text comparison
+func messagesSimilar(msg1, msg2 string) bool {
+ // For now, simple similarity check: if one is a substring of the other
+ return strings.Contains(msg1, msg2) || strings.Contains(msg2, msg1)
+}
+
+// formatVetRegressions generates a human-readable summary of vet regressions
+func formatVetRegressions(regressions VetResults) string {
+ if !vetResultsHaveIssues(regressions) {
+ return ""
+ }
+
+ var sb strings.Builder
+ sb.WriteString("Go vet issues detected:\n\n")
+
+ // Get sorted list of packages for deterministic output
+ pkgPaths := make([]string, 0, len(regressions))
+ for pkgPath := range regressions {
+ pkgPaths = append(pkgPaths, pkgPath)
+ }
+ slices.Sort(pkgPaths)
+
+ issueCount := 1
+ for _, pkgPath := range pkgPaths {
+ pkgResult := regressions[pkgPath]
+
+ // Get sorted list of categories
+ categories := make([]string, 0, len(pkgResult))
+ for category := range pkgResult {
+ categories = append(categories, category)
+ }
+ slices.Sort(categories)
+
+ for _, category := range categories {
+ issues := pkgResult[category]
+
+ // Skip empty issue lists (shouldn't happen, but just in case)
+ if len(issues) == 0 {
+ continue
+ }
+
+ // Sort issues by position for deterministic output
+ slices.SortFunc(issues, func(a, b VetIssue) int {
+ return strings.Compare(a.Position, b.Position)
+ })
+
+ // Format each issue
+ for _, issue := range issues {
+ sb.WriteString(fmt.Sprintf("%d. [%s] %s: %s\n",
+ issueCount,
+ category,
+ issue.Position,
+ issue.Message))
+ issueCount++
+ }
+ }
+ }
+
+ sb.WriteString("\nPlease fix these issues before proceeding.")
+ return sb.String()
+}
+
+// GoplsIssue represents a single issue reported by gopls check
+type GoplsIssue struct {
+ Position string // File position in format "file:line:col-range"
+ Message string // Description of the issue
+}
+
+// checkGopls runs gopls check on the provided files in both the current and initial state,
+// compares the results, and reports any new issues introduced in the current state.
+func (r *CodeReviewer) checkGopls(ctx context.Context, changedFiles []string) (string, error) {
+ if len(changedFiles) == 0 {
+ return "", nil // no files to check
+ }
+
+ // Filter out non-Go files as gopls only works on Go files
+ // and verify they still exist (not deleted)
+ var goFiles []string
+ for _, file := range changedFiles {
+ if !strings.HasSuffix(file, ".go") {
+ continue // not a Go file
+ }
+
+ // Check if the file still exists (not deleted)
+ if _, err := os.Stat(file); os.IsNotExist(err) {
+ continue // file doesn't exist anymore (deleted)
+ }
+
+ goFiles = append(goFiles, file)
+ }
+
+ if len(goFiles) == 0 {
+ return "", nil // no Go files to check
+ }
+
+ // Run gopls check on the current state
+ goplsArgs := append([]string{"check"}, goFiles...)
+
+ afterGoplsCmd := exec.CommandContext(ctx, "gopls", goplsArgs...)
+ afterGoplsCmd.Dir = r.repoRoot
+ afterGoplsOut, err := afterGoplsCmd.CombinedOutput() // gopls returns non-zero if it finds issues
+ if err != nil {
+ // Check if the output looks like real gopls issues or if it's just error output
+ if !looksLikeGoplsIssues(afterGoplsOut) {
+ slog.WarnContext(ctx, "CodeReviewer.checkGopls: gopls check failed to run properly", "err", err, "output", string(afterGoplsOut))
+ return "", nil // Skip rather than failing the entire code review
+ }
+ // Otherwise, proceed with parsing - it's likely just the non-zero exit code due to found issues
+ }
+
+ // Parse the output
+ afterIssues := parseGoplsOutput(afterGoplsOut)
+
+ // If no issues were found, we're done
+ if len(afterIssues) == 0 {
+ return "", nil
+ }
+
+ // Gopls detected issues in the current state, check if they existed in the initial state
+ initErr := r.initializeInitialCommitWorktree(ctx)
+ if initErr != nil {
+ return "", err
+ }
+
+ // For each file that exists in the initial commit, run gopls check
+ var initialFilesToCheck []string
+ for _, file := range goFiles {
+ // Get relative path for git operations
+ relFile, err := filepath.Rel(r.repoRoot, file)
+ if err != nil {
+ slog.WarnContext(ctx, "CodeReviewer.checkGopls: failed to get relative path", "repo_root", r.repoRoot, "file", file, "err", err)
+ continue
+ }
+
+ // Check if the file exists in the initial commit
+ checkCmd := exec.CommandContext(ctx, "git", "cat-file", "-e", fmt.Sprintf("%s:%s", r.initialCommit, relFile))
+ checkCmd.Dir = r.repoRoot
+ if err := checkCmd.Run(); err == nil {
+ // File exists in initial commit
+ initialFilePath := filepath.Join(r.initialWorktree, relFile)
+ initialFilesToCheck = append(initialFilesToCheck, initialFilePath)
+ }
+ }
+
+ // Run gopls check on the files that existed in the initial commit
+ beforeIssues := []GoplsIssue{}
+ if len(initialFilesToCheck) > 0 {
+ beforeGoplsArgs := append([]string{"check"}, initialFilesToCheck...)
+ beforeGoplsCmd := exec.CommandContext(ctx, "gopls", beforeGoplsArgs...)
+ beforeGoplsCmd.Dir = r.initialWorktree
+ var beforeGoplsOut []byte
+ var beforeCmdErr error
+ beforeGoplsOut, beforeCmdErr = beforeGoplsCmd.CombinedOutput()
+ if beforeCmdErr != nil && !looksLikeGoplsIssues(beforeGoplsOut) {
+ // If gopls fails to run properly on the initial commit, log a warning and continue
+ // with empty before issues - this will be conservative and report more issues
+ slog.WarnContext(ctx, "CodeReviewer.checkGopls: gopls check failed on initial commit",
+ "err", err, "output", string(beforeGoplsOut))
+ // Continue with empty beforeIssues
+ } else {
+ beforeIssues = parseGoplsOutput(beforeGoplsOut)
+ }
+ }
+
+ // Find new issues that weren't present in the initial state
+ goplsRegressions := findGoplsRegressions(beforeIssues, afterIssues)
+ if len(goplsRegressions) == 0 {
+ return "", nil // no new issues
+ }
+
+ // Format the results
+ return formatGoplsRegressions(goplsRegressions), nil
+}
+
+// parseGoplsOutput parses the text output from gopls check
+// Each line has the format: '/path/to/file.go:448:22-26: unused parameter: path'
+func parseGoplsOutput(output []byte) []GoplsIssue {
+ var issues []GoplsIssue
+ lines := strings.Split(string(output), "\n")
+
+ for _, line := range lines {
+ line = strings.TrimSpace(line)
+ if line == "" {
+ continue
+ }
+
+ // Skip lines that look like error messages rather than gopls issues
+ if strings.HasPrefix(line, "Error:") ||
+ strings.HasPrefix(line, "Failed:") ||
+ strings.HasPrefix(line, "Warning:") ||
+ strings.HasPrefix(line, "gopls:") {
+ continue
+ }
+
+ // Find the first colon that separates the file path from the line number
+ firstColonIdx := strings.Index(line, ":")
+ if firstColonIdx < 0 {
+ continue // Invalid format
+ }
+
+ // Verify the part before the first colon looks like a file path
+ potentialPath := line[:firstColonIdx]
+ if !strings.HasSuffix(potentialPath, ".go") {
+ continue // Not a Go file path
+ }
+
+ // Find the position of the first message separator ': '
+ // This separates the position info from the message
+ messageStart := strings.Index(line, ": ")
+ if messageStart < 0 || messageStart <= firstColonIdx {
+ continue // Invalid format
+ }
+
+ // Extract position and message
+ position := line[:messageStart]
+ message := line[messageStart+2:] // Skip the ': ' separator
+
+ // Verify position has the expected format (at least 2 colons for line:col)
+ colonCount := strings.Count(position, ":")
+ if colonCount < 2 {
+ continue // Not enough position information
+ }
+
+ issues = append(issues, GoplsIssue{
+ Position: position,
+ Message: message,
+ })
+ }
+
+ return issues
+}
+
+// looksLikeGoplsIssues checks if the output appears to be actual gopls issues
+// rather than error messages about gopls itself failing
+func looksLikeGoplsIssues(output []byte) bool {
+ // If output is empty, it's not valid issues
+ if len(output) == 0 {
+ return false
+ }
+
+ // Check if output has at least one line that looks like a gopls issue
+ // A gopls issue looks like: '/path/to/file.go:123:45-67: message'
+ lines := strings.Split(string(output), "\n")
+ for _, line := range lines {
+ line = strings.TrimSpace(line)
+ if line == "" {
+ continue
+ }
+
+ // A gopls issue has at least two colons (file path, line number, column)
+ // and contains a colon followed by a space (separating position from message)
+ colonCount := strings.Count(line, ":")
+ hasSeparator := strings.Contains(line, ": ")
+
+ if colonCount >= 2 && hasSeparator {
+ // Check if it starts with a likely file path (ending in .go)
+ parts := strings.SplitN(line, ":", 2)
+ if strings.HasSuffix(parts[0], ".go") {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// normalizeGoplsPosition extracts just the file path from a position string
+func normalizeGoplsPosition(position string) string {
+ // Extract just the file path by taking everything before the first colon
+ parts := strings.Split(position, ":")
+ if len(parts) < 1 {
+ return position
+ }
+ return parts[0]
+}
+
+// findGoplsRegressions identifies gopls issues that are new in the after state
+func findGoplsRegressions(before, after []GoplsIssue) []GoplsIssue {
+ var regressions []GoplsIssue
+
+ // Build map of before issues for easier lookup
+ beforeIssueMap := make(map[string]map[string]bool) // file -> message -> exists
+ for _, issue := range before {
+ file := normalizeGoplsPosition(issue.Position)
+ if _, exists := beforeIssueMap[file]; !exists {
+ beforeIssueMap[file] = make(map[string]bool)
+ }
+ // Store both the exact message and the general issue type for fuzzy matching
+ beforeIssueMap[file][issue.Message] = true
+
+ // Extract the general issue type (everything before the first ':' in the message)
+ generalIssue := issue.Message
+ if colonIdx := strings.Index(issue.Message, ":"); colonIdx > 0 {
+ generalIssue = issue.Message[:colonIdx]
+ }
+ beforeIssueMap[file][generalIssue] = true
+ }
+
+ // Check each after issue to see if it's new
+ for _, afterIssue := range after {
+ file := normalizeGoplsPosition(afterIssue.Position)
+ isNew := true
+
+ if fileIssues, fileExists := beforeIssueMap[file]; fileExists {
+ // Check for exact message match
+ if fileIssues[afterIssue.Message] {
+ isNew = false
+ } else {
+ // Check for general issue type match
+ generalIssue := afterIssue.Message
+ if colonIdx := strings.Index(afterIssue.Message, ":"); colonIdx > 0 {
+ generalIssue = afterIssue.Message[:colonIdx]
+ }
+ if fileIssues[generalIssue] {
+ isNew = false
+ }
+ }
+ }
+
+ if isNew {
+ regressions = append(regressions, afterIssue)
+ }
+ }
+
+ // Sort regressions for deterministic output
+ slices.SortFunc(regressions, func(a, b GoplsIssue) int {
+ return strings.Compare(a.Position, b.Position)
+ })
+
+ return regressions
+}
+
+// formatGoplsRegressions generates a human-readable summary of gopls check regressions
+func formatGoplsRegressions(regressions []GoplsIssue) string {
+ if len(regressions) == 0 {
+ return ""
+ }
+
+ var sb strings.Builder
+ sb.WriteString("Gopls check issues detected:\n\n")
+
+ // Format each issue
+ for i, issue := range regressions {
+ sb.WriteString(fmt.Sprintf("%d. %s: %s\n", i+1, issue.Position, issue.Message))
+ }
+
+ sb.WriteString("\nIMPORTANT: Only fix new gopls check issues in parts of the code that you have already edited. ")
+ sb.WriteString("Do not change existing code that was not part of your current edits.")
+ return sb.String()
+}
+
+func (r *CodeReviewer) HasReviewed(commit string) bool {
+ return slices.Contains(r.reviewed, commit)
+}
+
+func (r *CodeReviewer) IsInitialCommit(commit string) bool {
+ return commit == r.initialCommit
+}
+
+// packagesForFiles returns maps of packages related to the given files:
+// 1. directPkgs: packages that directly contain the changed files
+// 2. allPkgs: all packages that might be affected, including downstream packages that depend on the direct packages
+// It may include false positives.
+// Files must be absolute paths!
+func (r *CodeReviewer) packagesForFiles(ctx context.Context, files []string) (directPkgs, allPkgs map[string]*packages.Package, err error) {
+ for _, f := range files {
+ if !filepath.IsAbs(f) {
+ return nil, nil, fmt.Errorf("path %q is not absolute", f)
+ }
+ }
+ cfg := &packages.Config{
+ Mode: packages.LoadImports | packages.NeedEmbedFiles,
+ Context: ctx,
+ // Logf: func(msg string, args ...any) {
+ // slog.DebugContext(ctx, "loading go packages", "msg", fmt.Sprintf(msg, args...))
+ // },
+ // TODO: in theory, go.mod might not be in the repo root, and there might be multiple go.mod files.
+ // We can cross that bridge when we get there.
+ Dir: r.repoRoot,
+ Tests: true,
+ }
+ universe, err := packages.Load(cfg, "./...")
+ if err != nil {
+ return nil, nil, err
+ }
+ // Identify packages that directly contain the changed files
+ directPkgs = make(map[string]*packages.Package) // import path -> package
+ for _, pkg := range universe {
+ // fmt.Println("pkg:", pkg.PkgPath)
+ pkgFiles := allFiles(pkg)
+ // fmt.Println("pkgFiles:", pkgFiles)
+ for _, file := range files {
+ if pkgFiles[file] {
+ // prefer test packages, as they contain strictly more files (right?)
+ prev := directPkgs[pkg.PkgPath]
+ if prev == nil || prev.ForTest == "" {
+ directPkgs[pkg.PkgPath] = pkg
+ }
+ }
+ }
+ }
+
+ // Create a copy of directPkgs to expand with dependencies
+ allPkgs = make(map[string]*packages.Package)
+ for k, v := range directPkgs {
+ allPkgs[k] = v
+ }
+
+ // Add packages that depend on the direct packages
+ addDependentPackages(universe, allPkgs)
+ return directPkgs, allPkgs, nil
+}
+
+// allFiles returns all files that might be referenced by the package.
+// It may contain false positives.
+func allFiles(p *packages.Package) map[string]bool {
+ files := make(map[string]bool)
+ add := [][]string{p.GoFiles, p.CompiledGoFiles, p.OtherFiles, p.EmbedFiles, p.IgnoredFiles}
+ for _, extra := range add {
+ for _, file := range extra {
+ files[file] = true
+ }
+ }
+ return files
+}
+
+// addDependentPackages adds to pkgs all packages from universe
+// that directly or indirectly depend on any package already in pkgs.
+func addDependentPackages(universe []*packages.Package, pkgs map[string]*packages.Package) {
+ for {
+ changed := false
+ for _, p := range universe {
+ if _, ok := pkgs[p.PkgPath]; ok {
+ // already in pkgs
+ continue
+ }
+ for importPath := range p.Imports {
+ if _, ok := pkgs[importPath]; ok {
+ // imports a package dependent on pkgs, add it
+ pkgs[p.PkgPath] = p
+ changed = true
+ break
+ }
+ }
+ }
+ if !changed {
+ break
+ }
+ }
+}
+
+// testJSON is a union of BuildEvent and TestEvent
+type testJSON struct {
+ // TestEvent only:
+ // The Time field holds the time the event happened. It is conventionally omitted
+ // for cached test results.
+ Time time.Time `json:"Time"`
+ // BuildEvent only:
+ // The ImportPath field gives the package ID of the package being built.
+ // This matches the Package.ImportPath field of go list -json and the
+ // TestEvent.FailedBuild field of go test -json. Note that it does not
+ // match TestEvent.Package.
+ ImportPath string `json:"ImportPath"` // BuildEvent only
+ // TestEvent only:
+ // The Package field, if present, specifies the package being tested. When the
+ // go command runs parallel tests in -json mode, events from different tests are
+ // interlaced; the Package field allows readers to separate them.
+ Package string `json:"Package"`
+ // Action is used in both BuildEvent and TestEvent.
+ // It is the key to distinguishing between them.
+ // BuildEvent:
+ // build-output or build-fail
+ // TestEvent:
+ // start, run, pause, cont, pass, bench, fail, output, skip
+ Action string `json:"Action"`
+ // TestEvent only:
+ // The Test field, if present, specifies the test, example, or benchmark function
+ // that caused the event. Events for the overall package test do not set Test.
+ Test string `json:"Test"`
+ // TestEvent only:
+ // The Elapsed field is set for "pass" and "fail" events. It gives the time elapsed in seconds
+ // for the specific test or the overall package test that passed or failed.
+ Elapsed float64
+ // TestEvent:
+ // The Output field is set for Action == "output" and is a portion of the
+ // test's output (standard output and standard error merged together). The
+ // output is unmodified except that invalid UTF-8 output from a test is coerced
+ // into valid UTF-8 by use of replacement characters. With that one exception,
+ // the concatenation of the Output fields of all output events is the exact output
+ // of the test execution.
+ // BuildEvent:
+ // The Output field is set for Action == "build-output" and is a portion of
+ // the build's output. The concatenation of the Output fields of all output
+ // events is the exact output of the build. A single event may contain one
+ // or more lines of output and there may be more than one output event for
+ // a given ImportPath. This matches the definition of the TestEvent.Output
+ // field produced by go test -json.
+ Output string `json:"Output"`
+ // TestEvent only:
+ // The FailedBuild field is set for Action == "fail" if the test failure was caused
+ // by a build failure. It contains the package ID of the package that failed to
+ // build. This matches the ImportPath field of the "go list" output, as well as the
+ // BuildEvent.ImportPath field as emitted by "go build -json".
+ FailedBuild string `json:"FailedBuild"`
+}
+
+// parseTestResults converts test output in JSONL format into a slice of testJSON objects
+func parseTestResults(testOutput []byte) ([]testJSON, error) {
+ var results []testJSON
+ dec := json.NewDecoder(bytes.NewReader(testOutput))
+ for {
+ var event testJSON
+ if err := dec.Decode(&event); err != nil {
+ if err == io.EOF {
+ break
+ }
+ return nil, err
+ }
+ results = append(results, event)
+ }
+ return results, nil
+}
+
+// testStatus represents the status of a test in a given commit
+type testStatus int
+
+const (
+ testStatusUnknown testStatus = iota
+ testStatusPass
+ testStatusFail
+ testStatusBuildFail
+ testStatusSkip
+)
+
+// testInfo represents information about a specific test
+type testInfo struct {
+ Package string
+ Test string // empty for package tests
+}
+
+// String returns a human-readable string representation of the test
+func (t testInfo) String() string {
+ if t.Test == "" {
+ return t.Package
+ }
+ return fmt.Sprintf("%s.%s", t.Package, t.Test)
+}
+
+// testRegression represents a test that regressed between commits
+type testRegression struct {
+ Info testInfo
+ BeforeStatus testStatus
+ AfterStatus testStatus
+ Output string // failure output in the after state
+}
+
+// collectTestStatuses processes a slice of test events and returns a map of test statuses
+func collectTestStatuses(results []testJSON) map[testInfo]testStatus {
+ statuses := make(map[testInfo]testStatus)
+ failedBuilds := make(map[string]bool) // track packages with build failures
+ testOutputs := make(map[testInfo][]string) // collect output for failing tests
+
+ // First pass: identify build failures
+ for _, result := range results {
+ if result.Action == "fail" && result.FailedBuild != "" {
+ failedBuilds[result.FailedBuild] = true
+ }
+ }
+
+ // Second pass: collect test statuses
+ for _, result := range results {
+ info := testInfo{Package: result.Package, Test: result.Test}
+
+ // Skip output events for now, we'll process them in a separate pass
+ if result.Action == "output" {
+ if result.Test != "" { // only collect output for actual tests, not package messages
+ testOutputs[info] = append(testOutputs[info], result.Output)
+ }
+ continue
+ }
+
+ // Handle BuildEvent output
+ if result.Action == "build-fail" {
+ // Mark all tests in this package as build failures
+ for ti := range statuses {
+ if ti.Package == result.ImportPath {
+ statuses[ti] = testStatusBuildFail
+ }
+ }
+ continue
+ }
+
+ // Check if the package has a build failure
+ if _, hasBuildFailure := failedBuilds[result.Package]; hasBuildFailure {
+ statuses[info] = testStatusBuildFail
+ continue
+ }
+
+ // Handle test events
+ switch result.Action {
+ case "pass":
+ statuses[info] = testStatusPass
+ case "fail":
+ statuses[info] = testStatusFail
+ case "skip":
+ statuses[info] = testStatusSkip
+ }
+ }
+
+ return statuses
+}
+
+// compareTestResults identifies tests that have regressed between commits
+func (r *CodeReviewer) compareTestResults(beforeResults, afterResults []testJSON) ([]testRegression, error) {
+ beforeStatuses := collectTestStatuses(beforeResults)
+ afterStatuses := collectTestStatuses(afterResults)
+
+ // Collect output for failing tests
+ testOutputMap := make(map[testInfo]string)
+ for _, result := range afterResults {
+ if result.Action == "output" {
+ info := testInfo{Package: result.Package, Test: result.Test}
+ testOutputMap[info] += result.Output
+ }
+ }
+
+ var regressions []testRegression
+
+ // Look for tests that regressed
+ for info, afterStatus := range afterStatuses {
+ // Skip tests that are passing or skipped in the after state
+ if afterStatus == testStatusPass || afterStatus == testStatusSkip {
+ continue
+ }
+
+ // Get the before status (default to unknown if not present)
+ beforeStatus, exists := beforeStatuses[info]
+ if !exists {
+ beforeStatus = testStatusUnknown
+ }
+
+ // Log warning if we encounter unexpected unknown status in the 'after' state
+ if afterStatus == testStatusUnknown {
+ slog.WarnContext(context.Background(), "Unexpected unknown test status encountered",
+ "package", info.Package, "test", info.Test)
+ }
+
+ // Check for regressions
+ if isRegression(beforeStatus, afterStatus) {
+ regressions = append(regressions, testRegression{
+ Info: info,
+ BeforeStatus: beforeStatus,
+ AfterStatus: afterStatus,
+ Output: testOutputMap[info],
+ })
+ }
+ }
+
+ // Sort regressions for consistent output
+ slices.SortFunc(regressions, func(a, b testRegression) int {
+ // First by package
+ if c := strings.Compare(a.Info.Package, b.Info.Package); c != 0 {
+ return c
+ }
+ // Then by test name
+ return strings.Compare(a.Info.Test, b.Info.Test)
+ })
+
+ return regressions, nil
+}
+
+// badnessLevels maps test status to a badness level
+// Higher values indicate worse status (more severe issues)
+var badnessLevels = map[testStatus]int{
+ testStatusBuildFail: 4, // Worst
+ testStatusFail: 3,
+ testStatusSkip: 2,
+ testStatusPass: 1,
+ testStatusUnknown: 0, // Least bad - avoids false positives
+}
+
+// isRegression determines if a test has regressed based on before and after status
+// A regression is defined as an increase in badness level
+func isRegression(before, after testStatus) bool {
+ // Higher badness level means worse status
+ return badnessLevels[after] > badnessLevels[before]
+}
+
+// formatTestRegressions generates a human-readable summary of test regressions
+func (r *CodeReviewer) formatTestRegressions(regressions []testRegression) string {
+ if len(regressions) == 0 {
+ return ""
+ }
+
+ var sb strings.Builder
+ sb.WriteString(fmt.Sprintf("Test regressions detected between initial commit (%s) and HEAD:\n\n", r.initialCommit))
+
+ for i, reg := range regressions {
+ // Describe the regression
+ sb.WriteString(fmt.Sprintf("%d. %s: ", i+1, reg.Info.String()))
+
+ switch {
+ case reg.BeforeStatus == testStatusUnknown && reg.AfterStatus == testStatusFail:
+ sb.WriteString("New test is failing")
+ case reg.BeforeStatus == testStatusUnknown && reg.AfterStatus == testStatusBuildFail:
+ sb.WriteString("New test has build errors")
+ case reg.BeforeStatus == testStatusPass && reg.AfterStatus == testStatusFail:
+ sb.WriteString("Was passing, now failing")
+ case reg.BeforeStatus == testStatusPass && reg.AfterStatus == testStatusBuildFail:
+ sb.WriteString("Was passing, now has build errors")
+ case reg.BeforeStatus == testStatusSkip && reg.AfterStatus == testStatusFail:
+ sb.WriteString("Was skipped, now failing")
+ case reg.BeforeStatus == testStatusSkip && reg.AfterStatus == testStatusBuildFail:
+ sb.WriteString("Was skipped, now has build errors")
+ default:
+ sb.WriteString("Regression detected")
+ }
+ sb.WriteString("\n")
+
+ // Add failure output with indentation for readability
+ if reg.Output != "" {
+ outputLines := strings.Split(strings.TrimSpace(reg.Output), "\n")
+ // Limit output to first 10 lines to avoid overwhelming feedback
+ shownLines := min(len(outputLines), 10)
+
+ sb.WriteString(" Output:\n")
+ for _, line := range outputLines[:shownLines] {
+ sb.WriteString(fmt.Sprintf(" | %s\n", line))
+ }
+ if shownLines < len(outputLines) {
+ sb.WriteString(fmt.Sprintf(" | ... (%d more lines)\n", len(outputLines)-shownLines))
+ }
+ }
+ sb.WriteString("\n")
+ }
+
+ sb.WriteString("Please fix these test failures before proceeding.")
+ return sb.String()
+}
diff --git a/claudetool/edit.go b/claudetool/edit.go
new file mode 100644
index 0000000..df83139
--- /dev/null
+++ b/claudetool/edit.go
@@ -0,0 +1,451 @@
+package claudetool
+
+/*
+
+Note: sketch wrote this based on translating https://raw.githubusercontent.com/anthropics/anthropic-quickstarts/refs/heads/main/computer-use-demo/computer_use_demo/tools/edit.py
+
+## Implementation Notes
+This tool is based on Anthropic's Python implementation of the `text_editor_20250124` tool. It maintains a history of file edits to support the undo functionality, and verifies text uniqueness for the str_replace operation to ensure safe edits.
+
+*/
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+
+ "sketch.dev/ant"
+)
+
+// Constants for the AnthropicEditTool
+const (
+ editName = "str_replace_editor"
+)
+
+// Constants used by the tool
+const (
+ snippetLines = 4
+ maxResponseLen = 16000
+ truncatedMessage = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
+)
+
+// Command represents the type of operation to perform
+type editCommand string
+
+const (
+ viewCommand editCommand = "view"
+ createCommand editCommand = "create"
+ strReplaceCommand editCommand = "str_replace"
+ insertCommand editCommand = "insert"
+ undoEditCommand editCommand = "undo_edit"
+)
+
+// editInput represents the expected input format for the edit tool
+type editInput struct {
+ Command string `json:"command"`
+ Path string `json:"path"`
+ FileText *string `json:"file_text,omitempty"`
+ ViewRange []int `json:"view_range,omitempty"`
+ OldStr *string `json:"old_str,omitempty"`
+ NewStr *string `json:"new_str,omitempty"`
+ InsertLine *int `json:"insert_line,omitempty"`
+}
+
+// fileHistory maintains a history of edits for each file to support undo functionality
+var fileHistory = make(map[string][]string)
+
+// AnthropicEditTool is a tool for viewing, creating, and editing files
+var AnthropicEditTool = &ant.Tool{
+ // Note that Type is model-dependent, and would be different for Claude 3.5, for example.
+ Type: "text_editor_20250124",
+ Name: editName,
+ Run: EditRun,
+}
+
+// EditRun is the implementation of the edit tool
+func EditRun(ctx context.Context, input json.RawMessage) (string, error) {
+ var editRequest editInput
+ if err := json.Unmarshal(input, &editRequest); err != nil {
+ return "", fmt.Errorf("failed to parse edit input: %v", err)
+ }
+
+ // Validate the command
+ cmd := editCommand(editRequest.Command)
+ if !isValidCommand(cmd) {
+ return "", fmt.Errorf("unrecognized command %s. The allowed commands are: view, create, str_replace, insert, undo_edit", cmd)
+ }
+
+ path := editRequest.Path
+
+ // Validate the path
+ if err := validatePath(cmd, path); err != nil {
+ return "", err
+ }
+
+ // Execute the appropriate command
+ switch cmd {
+ case viewCommand:
+ return handleView(ctx, path, editRequest.ViewRange)
+ case createCommand:
+ if editRequest.FileText == nil {
+ return "", fmt.Errorf("parameter file_text is required for command: create")
+ }
+ return handleCreate(path, *editRequest.FileText)
+ case strReplaceCommand:
+ if editRequest.OldStr == nil {
+ return "", fmt.Errorf("parameter old_str is required for command: str_replace")
+ }
+ newStr := ""
+ if editRequest.NewStr != nil {
+ newStr = *editRequest.NewStr
+ }
+ return handleStrReplace(path, *editRequest.OldStr, newStr)
+ case insertCommand:
+ if editRequest.InsertLine == nil {
+ return "", fmt.Errorf("parameter insert_line is required for command: insert")
+ }
+ if editRequest.NewStr == nil {
+ return "", fmt.Errorf("parameter new_str is required for command: insert")
+ }
+ return handleInsert(path, *editRequest.InsertLine, *editRequest.NewStr)
+ case undoEditCommand:
+ return handleUndoEdit(path)
+ default:
+ return "", fmt.Errorf("command %s is not implemented", cmd)
+ }
+}
+
+// Utility function to check if a command is valid
+func isValidCommand(cmd editCommand) bool {
+ switch cmd {
+ case viewCommand, createCommand, strReplaceCommand, insertCommand, undoEditCommand:
+ return true
+ default:
+ return false
+ }
+}
+
+// validatePath checks if the path/command combination is valid
+func validatePath(cmd editCommand, path string) error {
+ // Check if it's an absolute path
+ if !filepath.IsAbs(path) {
+ suggestedPath := "/" + path
+ return fmt.Errorf("the path %s is not an absolute path, it should start with '/'. Maybe you meant %s?", path, suggestedPath)
+ }
+
+ // Get file info
+ info, err := os.Stat(path)
+
+ // Check if path exists (except for create command)
+ if err != nil {
+ if os.IsNotExist(err) && cmd != createCommand {
+ return fmt.Errorf("the path %s does not exist. Please provide a valid path", path)
+ } else if !os.IsNotExist(err) {
+ return fmt.Errorf("error accessing path %s: %v", path, err)
+ }
+ } else {
+ // Path exists, check if it's a directory
+ if info.IsDir() && cmd != viewCommand {
+ return fmt.Errorf("the path %s is a directory and only the 'view' command can be used on directories", path)
+ }
+
+ // For create command, check if file already exists
+ if cmd == createCommand {
+ return fmt.Errorf("file already exists at: %s. Cannot overwrite files using command 'create'", path)
+ }
+ }
+
+ return nil
+}
+
+// handleView implements the view command
+func handleView(ctx context.Context, path string, viewRange []int) (string, error) {
+ info, err := os.Stat(path)
+ if err != nil {
+ return "", fmt.Errorf("error accessing path %s: %v", path, err)
+ }
+
+ // Handle directory view
+ if info.IsDir() {
+ if viewRange != nil {
+ return "", fmt.Errorf("the view_range parameter is not allowed when path points to a directory")
+ }
+
+ // List files in the directory (up to 2 levels deep)
+ return listDirectory(ctx, path)
+ }
+
+ // Handle file view
+ fileContent, err := readFile(path)
+ if err != nil {
+ return "", err
+ }
+
+ initLine := 1
+ if viewRange != nil {
+ if len(viewRange) != 2 {
+ return "", fmt.Errorf("invalid view_range. It should be a list of two integers")
+ }
+
+ fileLines := strings.Split(fileContent, "\n")
+ nLinesFile := len(fileLines)
+ initLine, finalLine := viewRange[0], viewRange[1]
+
+ if initLine < 1 || initLine > nLinesFile {
+ return "", fmt.Errorf("invalid view_range: %v. Its first element %d should be within the range of lines of the file: [1, %d]",
+ viewRange, initLine, nLinesFile)
+ }
+
+ if finalLine != -1 && finalLine < initLine {
+ return "", fmt.Errorf("invalid view_range: %v. Its second element %d should be larger or equal than its first %d",
+ viewRange, finalLine, initLine)
+ }
+
+ if finalLine > nLinesFile {
+ return "", fmt.Errorf("invalid view_range: %v. Its second element %d should be smaller than the number of lines in the file: %d",
+ viewRange, finalLine, nLinesFile)
+ }
+
+ if finalLine == -1 {
+ fileContent = strings.Join(fileLines[initLine-1:], "\n")
+ } else {
+ fileContent = strings.Join(fileLines[initLine-1:finalLine], "\n")
+ }
+ }
+
+ return makeOutput(fileContent, path, initLine), nil
+}
+
+// handleCreate implements the create command
+func handleCreate(path string, fileText string) (string, error) {
+ // Ensure the directory exists
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return "", fmt.Errorf("failed to create directory %s: %v", dir, err)
+ }
+
+ // Write the file
+ if err := writeFile(path, fileText); err != nil {
+ return "", err
+ }
+
+ // Save to history
+ fileHistory[path] = append(fileHistory[path], fileText)
+
+ return fmt.Sprintf("File created successfully at: %s", path), nil
+}
+
+// handleStrReplace implements the str_replace command
+func handleStrReplace(path, oldStr, newStr string) (string, error) {
+ // Read the file content
+ fileContent, err := readFile(path)
+ if err != nil {
+ return "", err
+ }
+
+ // Replace tabs with spaces
+ fileContent = maybeExpandTabs(path, fileContent)
+ oldStr = maybeExpandTabs(path, oldStr)
+ newStr = maybeExpandTabs(path, newStr)
+
+ // Check if oldStr is unique in the file
+ occurrences := strings.Count(fileContent, oldStr)
+ if occurrences == 0 {
+ return "", fmt.Errorf("no replacement was performed, old_str %q did not appear verbatim in %s", oldStr, path)
+ } else if occurrences > 1 {
+ // Find line numbers where oldStr appears
+ fileContentLines := strings.Split(fileContent, "\n")
+ var lines []int
+ for idx, line := range fileContentLines {
+ if strings.Contains(line, oldStr) {
+ lines = append(lines, idx+1)
+ }
+ }
+ return "", fmt.Errorf("no replacement was performed. Multiple occurrences of old_str %q in lines %v. Please ensure it is unique", oldStr, lines)
+ }
+
+ // Save the current content to history
+ fileHistory[path] = append(fileHistory[path], fileContent)
+
+ // Replace oldStr with newStr
+ newFileContent := strings.Replace(fileContent, oldStr, newStr, 1)
+
+ // Write the new content to the file
+ if err := writeFile(path, newFileContent); err != nil {
+ return "", err
+ }
+
+ // Create a snippet of the edited section
+ parts := strings.Split(fileContent, oldStr)
+ if len(parts) == 0 {
+ // This should never happen due to the earlier check, but let's be safe
+ parts = []string{""}
+ }
+ replacementLine := strings.Count(parts[0], "\n")
+ startLine := max(0, replacementLine-snippetLines)
+ endLine := replacementLine + snippetLines + strings.Count(newStr, "\n")
+ fileLines := strings.Split(newFileContent, "\n")
+ if len(fileLines) == 0 {
+ fileLines = []string{""}
+ }
+ endLine = min(endLine+1, len(fileLines))
+ snippet := strings.Join(fileLines[startLine:endLine], "\n")
+
+ // Prepare the success message
+ successMsg := fmt.Sprintf("The file %s has been edited. ", path)
+ successMsg += makeOutput(snippet, fmt.Sprintf("a snippet of %s", path), startLine+1)
+ successMsg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
+
+ return successMsg, nil
+}
+
+// handleInsert implements the insert command
+func handleInsert(path string, insertLine int, newStr string) (string, error) {
+ // Read the file content
+ fileContent, err := readFile(path)
+ if err != nil {
+ return "", err
+ }
+
+ // Replace tabs with spaces
+ fileContent = maybeExpandTabs(path, fileContent)
+ newStr = maybeExpandTabs(path, newStr)
+
+ // Split the file content into lines
+ fileTextLines := strings.Split(fileContent, "\n")
+ nLinesFile := len(fileTextLines)
+
+ // Validate insert line
+ if insertLine < 0 || insertLine > nLinesFile {
+ return "", fmt.Errorf("invalid insert_line parameter: %d. It should be within the range of lines of the file: [0, %d]",
+ insertLine, nLinesFile)
+ }
+
+ // Save the current content to history
+ fileHistory[path] = append(fileHistory[path], fileContent)
+
+ // Split the new string into lines
+ newStrLines := strings.Split(newStr, "\n")
+
+ // Create new content by inserting the new lines
+ newFileTextLines := make([]string, 0, nLinesFile+len(newStrLines))
+ newFileTextLines = append(newFileTextLines, fileTextLines[:insertLine]...)
+ newFileTextLines = append(newFileTextLines, newStrLines...)
+ newFileTextLines = append(newFileTextLines, fileTextLines[insertLine:]...)
+
+ // Create a snippet of the edited section
+ snippetStart := max(0, insertLine-snippetLines)
+ snippetEnd := min(insertLine+snippetLines, nLinesFile)
+
+ snippetLines := make([]string, 0)
+ snippetLines = append(snippetLines, fileTextLines[snippetStart:insertLine]...)
+ snippetLines = append(snippetLines, newStrLines...)
+ snippetLines = append(snippetLines, fileTextLines[insertLine:snippetEnd]...)
+ snippet := strings.Join(snippetLines, "\n")
+
+ // Write the new content to the file
+ newFileText := strings.Join(newFileTextLines, "\n")
+ if err := writeFile(path, newFileText); err != nil {
+ return "", err
+ }
+
+ // Prepare the success message
+ successMsg := fmt.Sprintf("The file %s has been edited. ", path)
+ successMsg += makeOutput(snippet, "a snippet of the edited file", max(1, insertLine-4+1))
+ successMsg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
+
+ return successMsg, nil
+}
+
+// handleUndoEdit implements the undo_edit command
+func handleUndoEdit(path string) (string, error) {
+ history, exists := fileHistory[path]
+ if !exists || len(history) == 0 {
+ return "", fmt.Errorf("no edit history found for %s", path)
+ }
+
+ // Get the last edit and remove it from history
+ lastIdx := len(history) - 1
+ oldText := history[lastIdx]
+ fileHistory[path] = history[:lastIdx]
+
+ // Write the old content back to the file
+ if err := writeFile(path, oldText); err != nil {
+ return "", err
+ }
+
+ return fmt.Sprintf("Last edit to %s undone successfully. %s", path, makeOutput(oldText, path, 1)), nil
+}
+
+// listDirectory lists files and directories up to 2 levels deep
+func listDirectory(ctx context.Context, path string) (string, error) {
+ cmd := fmt.Sprintf("find %s -maxdepth 2 -not -path '*/\\.*'", path)
+ output, err := executeCommand(ctx, cmd)
+ if err != nil {
+ return "", fmt.Errorf("failed to list directory: %v", err)
+ }
+
+ return fmt.Sprintf("Here's the files and directories up to 2 levels deep in %s, excluding hidden items:\n%s\n", path, output), nil
+}
+
+// executeCommand executes a shell command and returns its output
+func executeCommand(ctx context.Context, cmd string) (string, error) {
+ // This is a simplified version without timeouts for now
+ bash := exec.CommandContext(ctx, "bash", "-c", cmd)
+ bash.Dir = WorkingDir(ctx)
+ output, err := bash.CombinedOutput()
+ if err != nil {
+ return "", fmt.Errorf("command execution failed: %v: %s", err, string(output))
+ }
+ return maybetruncate(string(output)), nil
+}
+
+// readFile reads the content of a file
+func readFile(path string) (string, error) {
+ content, err := os.ReadFile(path)
+ if err != nil {
+ return "", fmt.Errorf("failed to read file %s: %v", path, err)
+ }
+ return string(content), nil
+}
+
+// writeFile writes content to a file
+func writeFile(path, content string) error {
+ if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
+ return fmt.Errorf("failed to write to file %s: %v", path, err)
+ }
+ return nil
+}
+
+// makeOutput generates a formatted output for the CLI
+func makeOutput(fileContent, fileDescriptor string, initLine int) string {
+ fileContent = maybetruncate(fileContent)
+ fileContent = maybeExpandTabs(fileDescriptor, fileContent)
+
+ var lines []string
+ for i, line := range strings.Split(fileContent, "\n") {
+ lines = append(lines, fmt.Sprintf("%6d\t%s", i+initLine, line))
+ }
+
+ return fmt.Sprintf("Here's the result of running `cat -n` on %s:\n%s\n", fileDescriptor, strings.Join(lines, "\n"))
+}
+
+// maybetruncate truncates content and appends a notice if content exceeds the specified length
+func maybetruncate(content string) string {
+ if len(content) <= maxResponseLen {
+ return content
+ }
+ return content[:maxResponseLen] + truncatedMessage
+}
+
+// maybeExpandTabs is currently a no-op. The python
+// implementation replaces tabs with spaces, but this strikes
+// me as unwise for our tool.
+func maybeExpandTabs(path, s string) string {
+ // return strings.ReplaceAll(s, "\t", " ")
+ return s
+}
diff --git a/claudetool/edit_regression_test.go b/claudetool/edit_regression_test.go
new file mode 100644
index 0000000..cb859fe
--- /dev/null
+++ b/claudetool/edit_regression_test.go
@@ -0,0 +1,152 @@
+package claudetool
+
+import (
+ "context"
+ "encoding/json"
+ "strings"
+ "testing"
+)
+
+// TestEmptyContentHandling tests handling of empty content in str_replace and related operations
+// This test specifically reproduces conditions that might lead to "index out of range [0]" panic
+func TestEmptyContentHandling(t *testing.T) {
+ // Create a file with empty content
+ emptyFile := setupTestFile(t, "")
+
+ // Test running EditRun directly with empty content
+ // This more closely simulates the actual call flow that led to the panic
+ input := map[string]any{
+ "command": "str_replace",
+ "path": emptyFile,
+ "old_str": "nonexistent text",
+ "new_str": "new content",
+ }
+
+ inputJSON, err := json.Marshal(input)
+ if err != nil {
+ t.Fatalf("Failed to marshal input: %v", err)
+ }
+
+ // This should not panic but return an error
+ _, err = EditRun(context.Background(), inputJSON)
+ if err == nil {
+ t.Fatalf("Expected error for empty file with str_replace but got none")
+ }
+
+ // Make sure the error message is as expected
+ if !strings.Contains(err.Error(), "did not appear verbatim") {
+ t.Errorf("Expected error message to indicate missing string, got: %s", err.Error())
+ }
+}
+
+// TestNilParameterHandling tests error cases with nil parameters
+// This test validates proper error handling when nil or invalid parameters are provided
+func TestNilParameterHandling(t *testing.T) {
+ // Create a test file
+ testFile := setupTestFile(t, "test content")
+
+ // Test case 1: nil old_str in str_replace
+ input1 := map[string]any{
+ "command": "str_replace",
+ "path": testFile,
+ // old_str is deliberately missing
+ "new_str": "replacement",
+ }
+
+ inputJSON1, err := json.Marshal(input1)
+ if err != nil {
+ t.Fatalf("Failed to marshal input: %v", err)
+ }
+
+ _, err = EditRun(context.Background(), inputJSON1)
+ if err == nil {
+ t.Fatalf("Expected error for missing old_str but got none")
+ }
+ if !strings.Contains(err.Error(), "parameter old_str is required") {
+ t.Errorf("Expected error message to indicate missing old_str, got: %s", err.Error())
+ }
+
+ // Test case 2: nil new_str in insert
+ input2 := map[string]any{
+ "command": "insert",
+ "path": testFile,
+ "insert_line": 1,
+ // new_str is deliberately missing
+ }
+
+ inputJSON2, err := json.Marshal(input2)
+ if err != nil {
+ t.Fatalf("Failed to marshal input: %v", err)
+ }
+
+ _, err = EditRun(context.Background(), inputJSON2)
+ if err == nil {
+ t.Fatalf("Expected error for missing new_str but got none")
+ }
+ if !strings.Contains(err.Error(), "parameter new_str is required") {
+ t.Errorf("Expected error message to indicate missing new_str, got: %s", err.Error())
+ }
+
+ // Test case 3: nil view_range in view
+ // This doesn't cause an error, but tests the code path
+ input3 := map[string]any{
+ "command": "view",
+ "path": testFile,
+ // No view_range
+ }
+
+ inputJSON3, err := json.Marshal(input3)
+ if err != nil {
+ t.Fatalf("Failed to marshal input: %v", err)
+ }
+
+ // This should not result in an error
+ _, err = EditRun(context.Background(), inputJSON3)
+ if err != nil {
+ t.Fatalf("Unexpected error for nil view_range: %v", err)
+ }
+}
+
+// TestEmptySplitResult tests the specific scenario where strings.Split might return empty results
+// This directly reproduces conditions that might have led to the "index out of range [0]" panic
+func TestEmptySplitResult(t *testing.T) {
+ // Direct test of strings.Split behavior and our handling of it
+ emptyCases := []struct {
+ content string
+ oldStr string
+ }{
+ {"", "any string"},
+ {"content", "not in string"},
+ {"\n\n", "also not here"},
+ }
+
+ for _, tc := range emptyCases {
+ parts := strings.Split(tc.content, tc.oldStr)
+
+ // Verify that strings.Split with non-matching separator returns a slice with original content
+ if len(parts) != 1 {
+ t.Errorf("Expected strings.Split to return a slice with 1 element when separator isn't found, got %d elements", len(parts))
+ }
+
+ // Double check the content
+ if len(parts) > 0 && parts[0] != tc.content {
+ t.Errorf("Expected parts[0] to be original content %q, got %q", tc.content, parts[0])
+ }
+ }
+
+ // Test the actual unsafe scenario with empty content
+ emptyFile := setupTestFile(t, "")
+
+ // Get the content and simulate the internal string splitting
+ content, _ := readFile(emptyFile)
+ oldStr := "nonexistent"
+ parts := strings.Split(content, oldStr)
+
+ // Validate that the defensive code would work
+ if len(parts) == 0 {
+ parts = []string{""} // This is the fix
+ }
+
+ // This would have panicked without the fix
+ _ = strings.Count(parts[0], "\n")
+}
diff --git a/claudetool/edit_test.go b/claudetool/edit_test.go
new file mode 100644
index 0000000..fe3d66c
--- /dev/null
+++ b/claudetool/edit_test.go
@@ -0,0 +1,399 @@
+package claudetool
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+// setupTestFile creates a temporary file with given content for testing
+func setupTestFile(t *testing.T, content string) string {
+ t.Helper()
+
+ // Create a temporary directory
+ tempDir, err := os.MkdirTemp("", "anthropic_edit_test_*")
+ if err != nil {
+ t.Fatalf("Failed to create temp directory: %v", err)
+ }
+
+ // Create a test file in the temp directory
+ testFile := filepath.Join(tempDir, "test_file.txt")
+ if err := os.WriteFile(testFile, []byte(content), 0o644); err != nil {
+ os.RemoveAll(tempDir)
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ // Register cleanup function
+ t.Cleanup(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ return testFile
+}
+
+// callEditTool is a helper to call the edit tool with specific parameters
+func callEditTool(t *testing.T, input map[string]any) string {
+ t.Helper()
+
+ // Convert input to JSON
+ inputJSON, err := json.Marshal(input)
+ if err != nil {
+ t.Fatalf("Failed to marshal input: %v", err)
+ }
+
+ // Call the tool
+ result, err := EditRun(context.Background(), inputJSON)
+ if err != nil {
+ t.Fatalf("Tool execution failed: %v", err)
+ }
+
+ return result
+}
+
+// TestEditToolView tests the view command functionality
+func TestEditToolView(t *testing.T) {
+ content := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
+ testFile := setupTestFile(t, content)
+
+ // Test the view command
+ result := callEditTool(t, map[string]any{
+ "command": "view",
+ "path": testFile,
+ })
+
+ // Verify results
+ if !strings.Contains(result, "Line 1") {
+ t.Errorf("View result should contain the file content, got: %s", result)
+ }
+
+ // Test view with range
+ result = callEditTool(t, map[string]any{
+ "command": "view",
+ "path": testFile,
+ "view_range": []int{2, 4},
+ })
+
+ // Verify range results
+ if strings.Contains(result, "Line 1") || !strings.Contains(result, "Line 2") {
+ t.Errorf("View with range should show only specified lines, got: %s", result)
+ }
+}
+
+// TestEditToolStrReplace tests the str_replace command functionality
+func TestEditToolStrReplace(t *testing.T) {
+ content := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
+ testFile := setupTestFile(t, content)
+
+ // Test the str_replace command
+ result := callEditTool(t, map[string]any{
+ "command": "str_replace",
+ "path": testFile,
+ "old_str": "Line 3",
+ "new_str": "Modified Line 3",
+ })
+
+ // Verify the file was modified
+ modifiedContent, err := os.ReadFile(testFile)
+ if err != nil {
+ t.Fatalf("Failed to read test file: %v", err)
+ }
+
+ if !strings.Contains(string(modifiedContent), "Modified Line 3") {
+ t.Errorf("File content should be modified, got: %s", string(modifiedContent))
+ }
+
+ // Verify the result contains a snippet
+ if !strings.Contains(result, "Modified Line 3") {
+ t.Errorf("Result should contain the modified content, got: %s", result)
+ }
+}
+
+// TestEditToolInsert tests the insert command functionality
+func TestEditToolInsert(t *testing.T) {
+ content := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
+ testFile := setupTestFile(t, content)
+
+ // Test the insert command
+ result := callEditTool(t, map[string]any{
+ "command": "insert",
+ "path": testFile,
+ "insert_line": 2,
+ "new_str": "Inserted Line",
+ })
+
+ // Verify the file was modified
+ modifiedContent, err := os.ReadFile(testFile)
+ if err != nil {
+ t.Fatalf("Failed to read test file: %v", err)
+ }
+
+ expected := "Line 1\nLine 2\nInserted Line\nLine 3\nLine 4\nLine 5"
+ if string(modifiedContent) != expected {
+ t.Errorf("File content incorrect after insert. Expected:\n%s\nGot:\n%s", expected, string(modifiedContent))
+ }
+
+ // Verify the result contains a snippet
+ if !strings.Contains(result, "Inserted Line") {
+ t.Errorf("Result should contain the inserted content, got: %s", result)
+ }
+}
+
+// TestEditToolCreate tests the create command functionality
+func TestEditToolCreate(t *testing.T) {
+ tempDir, err := os.MkdirTemp("", "anthropic_edit_test_create_*")
+ if err != nil {
+ t.Fatalf("Failed to create temp directory: %v", err)
+ }
+
+ t.Cleanup(func() {
+ os.RemoveAll(tempDir)
+ })
+
+ newFilePath := filepath.Join(tempDir, "new_file.txt")
+ content := "This is a new file\nWith multiple lines"
+
+ // Test the create command
+ result := callEditTool(t, map[string]any{
+ "command": "create",
+ "path": newFilePath,
+ "file_text": content,
+ })
+
+ // Verify the file was created with the right content
+ createdContent, err := os.ReadFile(newFilePath)
+ if err != nil {
+ t.Fatalf("Failed to read created file: %v", err)
+ }
+
+ if string(createdContent) != content {
+ t.Errorf("Created file content incorrect. Expected:\n%s\nGot:\n%s", content, string(createdContent))
+ }
+
+ // Verify the result message
+ if !strings.Contains(result, "File created successfully") {
+ t.Errorf("Result should confirm file creation, got: %s", result)
+ }
+}
+
+// TestEditToolUndoEdit tests the undo_edit command functionality
+func TestEditToolUndoEdit(t *testing.T) {
+ originalContent := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
+ testFile := setupTestFile(t, originalContent)
+
+ // First modify the file
+ callEditTool(t, map[string]any{
+ "command": "str_replace",
+ "path": testFile,
+ "old_str": "Line 3",
+ "new_str": "Modified Line 3",
+ })
+
+ // Then undo the edit
+ result := callEditTool(t, map[string]any{
+ "command": "undo_edit",
+ "path": testFile,
+ })
+
+ // Verify the file was restored to original content
+ restoredContent, err := os.ReadFile(testFile)
+ if err != nil {
+ t.Fatalf("Failed to read test file: %v", err)
+ }
+
+ if string(restoredContent) != originalContent {
+ t.Errorf("File content should be restored to original, got: %s", string(restoredContent))
+ }
+
+ // Verify the result message
+ if !strings.Contains(result, "undone successfully") {
+ t.Errorf("Result should confirm undo operation, got: %s", result)
+ }
+}
+
+// TestEditToolErrors tests various error conditions
+func TestEditToolErrors(t *testing.T) {
+ content := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
+ testFile := setupTestFile(t, content)
+
+ testCases := []struct {
+ name string
+ input map[string]any
+ errMsg string
+ }{
+ {
+ name: "Invalid command",
+ input: map[string]any{
+ "command": "invalid_command",
+ "path": testFile,
+ },
+ errMsg: "unrecognized command",
+ },
+ {
+ name: "Non-existent file",
+ input: map[string]any{
+ "command": "view",
+ "path": "/non/existent/file.txt",
+ },
+ errMsg: "does not exist",
+ },
+ {
+ name: "Missing required parameter",
+ input: map[string]any{
+ "command": "str_replace",
+ "path": testFile,
+ // Missing old_str
+ },
+ errMsg: "parameter old_str is required",
+ },
+ {
+ name: "Multiple occurrences in str_replace",
+ input: map[string]any{
+ "command": "str_replace",
+ "path": testFile,
+ "old_str": "Line", // Appears multiple times
+ "new_str": "Modified Line",
+ },
+ errMsg: "Multiple occurrences",
+ },
+ {
+ name: "Invalid view range",
+ input: map[string]any{
+ "command": "view",
+ "path": testFile,
+ "view_range": []int{10, 20}, // Out of range
+ },
+ errMsg: "invalid view_range",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ inputJSON, err := json.Marshal(tc.input)
+ if err != nil {
+ t.Fatalf("Failed to marshal input: %v", err)
+ }
+
+ _, err = EditRun(context.Background(), inputJSON)
+ if err == nil {
+ t.Fatalf("Expected error but got none")
+ }
+
+ if !strings.Contains(err.Error(), tc.errMsg) {
+ t.Errorf("Error message does not contain expected text. Expected to contain: %q, Got: %q", tc.errMsg, err.Error())
+ }
+ })
+ }
+}
+
+// TestHandleStrReplaceEdgeCases tests the handleStrReplace function specifically for edge cases
+// that could cause panics like "index out of range [0] with length 0"
+func TestHandleStrReplaceEdgeCases(t *testing.T) {
+ // The issue was with strings.Split returning an empty slice when the separator wasn't found
+ // This test directly tests the internal implementation with conditions that might cause this
+
+ // Create a test file with empty content
+ emptyFile := setupTestFile(t, "")
+
+ // Test with empty file content and arbitrary oldStr
+ _, err := handleStrReplace(emptyFile, "some string that doesn't exist", "new content")
+ if err == nil {
+ t.Fatal("Expected error for empty file but got none")
+ }
+ if !strings.Contains(err.Error(), "did not appear verbatim") {
+ t.Errorf("Expected error message to indicate missing string, got: %s", err.Error())
+ }
+
+ // Create a file with content that doesn't match oldStr
+ nonMatchingFile := setupTestFile(t, "This is some content\nthat doesn't contain the target string")
+
+ // Test with content that doesn't contain oldStr
+ _, err = handleStrReplace(nonMatchingFile, "target string not present", "replacement")
+ if err == nil {
+ t.Fatal("Expected error for non-matching content but got none")
+ }
+ if !strings.Contains(err.Error(), "did not appear verbatim") {
+ t.Errorf("Expected error message to indicate missing string, got: %s", err.Error())
+ }
+
+ // Test handling of the edge case that could potentially cause the "index out of range" panic
+ // This directly verifies that the handleStrReplace function properly handles the case where
+ // strings.Split returns an empty or unexpected result
+
+ // Verify that the protection against empty parts slice works
+ fileContent := ""
+ oldStr := "some string"
+ parts := strings.Split(fileContent, oldStr)
+ if len(parts) == 0 {
+ // This should match the protection in the code
+ parts = []string{""}
+ }
+
+ // This should not panic with the fix in place
+ _ = strings.Count(parts[0], "\n") // This line would have panicked without the fix
+}
+
+// TestViewRangeWithStrReplace tests that the view_range parameter works correctly
+// with the str_replace command (tests the full workflow)
+func TestViewRangeWithStrReplace(t *testing.T) {
+ // Create test file with multiple lines
+ content := "Line 1: First line\nLine 2: Second line\nLine 3: Third line\nLine 4: Fourth line\nLine 5: Fifth line"
+ testFile := setupTestFile(t, content)
+
+ // First view a subset of the file using view_range
+ viewResult := callEditTool(t, map[string]any{
+ "command": "view",
+ "path": testFile,
+ "view_range": []int{2, 4}, // Only lines 2-4
+ })
+
+ // Verify that we only see the specified lines
+ if strings.Contains(viewResult, "Line 1:") || strings.Contains(viewResult, "Line 5:") {
+ t.Errorf("View with range should only show lines 2-4, got: %s", viewResult)
+ }
+ if !strings.Contains(viewResult, "Line 2:") || !strings.Contains(viewResult, "Line 4:") {
+ t.Errorf("View with range should show lines 2-4, got: %s", viewResult)
+ }
+
+ // Now perform a str_replace on one of the lines we viewed
+ replaceResult := callEditTool(t, map[string]any{
+ "command": "str_replace",
+ "path": testFile,
+ "old_str": "Line 3: Third line",
+ "new_str": "Line 3: MODIFIED Third line",
+ })
+
+ // Check that the replacement was successful
+ if !strings.Contains(replaceResult, "Line 3: MODIFIED Third line") {
+ t.Errorf("Replace result should contain the modified line, got: %s", replaceResult)
+ }
+
+ // Verify the file content was updated correctly
+ modifiedContent, err := os.ReadFile(testFile)
+ if err != nil {
+ t.Fatalf("Failed to read test file after modification: %v", err)
+ }
+
+ expectedContent := "Line 1: First line\nLine 2: Second line\nLine 3: MODIFIED Third line\nLine 4: Fourth line\nLine 5: Fifth line"
+ if string(modifiedContent) != expectedContent {
+ t.Errorf("File content after replacement is incorrect.\nExpected:\n%s\nGot:\n%s",
+ expectedContent, string(modifiedContent))
+ }
+
+ // View the modified file with a different view_range
+ finalViewResult := callEditTool(t, map[string]any{
+ "command": "view",
+ "path": testFile,
+ "view_range": []int{3, 3}, // Only the modified line
+ })
+
+ // Verify we can see only the modified line
+ if !strings.Contains(finalViewResult, "Line 3: MODIFIED Third line") {
+ t.Errorf("Final view should show the modified line, got: %s", finalViewResult)
+ }
+ if strings.Contains(finalViewResult, "Line 2:") || strings.Contains(finalViewResult, "Line 4:") {
+ t.Errorf("Final view should only show line 3, got: %s", finalViewResult)
+ }
+}
diff --git a/claudetool/editbuf/LICENSE b/claudetool/editbuf/LICENSE
new file mode 100644
index 0000000..ea5ea89
--- /dev/null
+++ b/claudetool/editbuf/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2009 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/claudetool/editbuf/editbuf.go b/claudetool/editbuf/editbuf.go
new file mode 100644
index 0000000..6b04310
--- /dev/null
+++ b/claudetool/editbuf/editbuf.go
@@ -0,0 +1,92 @@
+// Modified from rsc.io/edit
+
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package edit implements buffered position-based editing of byte slices.
+package editbuf
+
+import (
+ "fmt"
+ "sort"
+)
+
+// A Buffer is a queue of edits to apply to a given byte slice.
+type Buffer struct {
+ old []byte
+ q edits
+}
+
+// An edit records a single text modification: change the bytes in [start,end) to new.
+type edit struct {
+ start int
+ end int
+ new string
+}
+
+// An edits is a list of edits that is sortable by start offset, breaking ties by end offset.
+type edits []edit
+
+func (x edits) Len() int { return len(x) }
+func (x edits) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
+func (x edits) Less(i, j int) bool {
+ if x[i].start != x[j].start {
+ return x[i].start < x[j].start
+ }
+ return x[i].end < x[j].end
+}
+
+// NewBuffer returns a new buffer to accumulate changes to an initial data slice.
+// The returned buffer maintains a reference to the data, so the caller must ensure
+// the data is not modified until after the Buffer is done being used.
+func NewBuffer(old []byte) *Buffer {
+ return &Buffer{old: old}
+}
+
+// Insert inserts the new string at old[pos:pos].
+func (b *Buffer) Insert(pos int, new string) {
+ if pos < 0 || pos > len(b.old) {
+ panic("invalid edit position")
+ }
+ b.q = append(b.q, edit{pos, pos, new})
+}
+
+// Delete deletes the text old[start:end].
+func (b *Buffer) Delete(start, end int) {
+ if end < start || start < 0 || end > len(b.old) {
+ panic("invalid edit position")
+ }
+ b.q = append(b.q, edit{start, end, ""})
+}
+
+// Replace replaces old[start:end] with new.
+func (b *Buffer) Replace(start, end int, new string) {
+ if end < start || start < 0 || end > len(b.old) {
+ panic("invalid edit position")
+ }
+ b.q = append(b.q, edit{start, end, new})
+}
+
+// Bytes returns a new byte slice containing the original data
+// with the queued edits applied.
+func (b *Buffer) Bytes() ([]byte, error) {
+ // Sort edits by starting position and then by ending position.
+ // Breaking ties by ending position allows insertions at point x
+ // to be applied before a replacement of the text at [x, y).
+ sort.Stable(b.q)
+
+ var new []byte
+ offset := 0
+ for i, e := range b.q {
+ if e.start < offset {
+ e0 := b.q[i-1]
+ return nil, fmt.Errorf("overlapping edits: [%d,%d)->%q, [%d,%d)->%q", e0.start, e0.end, e0.new, e.start, e.end, e.new)
+ }
+ new = append(new, b.old[offset:e.start]...)
+ offset = e.end
+ new = append(new, e.new...)
+ }
+ new = append(new, b.old[offset:]...)
+ return new, nil
+}
diff --git a/claudetool/keyword.go b/claudetool/keyword.go
new file mode 100644
index 0000000..2438275
--- /dev/null
+++ b/claudetool/keyword.go
@@ -0,0 +1,175 @@
+package claudetool
+
+import (
+ "context"
+ _ "embed"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "os/exec"
+ "strings"
+
+ "sketch.dev/ant"
+)
+
+// The Keyword tool provides keyword search.
+// TODO: use an embedding model + re-ranker or otherwise do something nicer than this kludge.
+// TODO: if we can get this fast enough, do it on the fly while the user is typing their prompt.
+var Keyword = &ant.Tool{
+ Name: keywordName,
+ Description: keywordDescription,
+ InputSchema: ant.MustSchema(keywordInputSchema),
+ Run: keywordRun,
+}
+
+const (
+ keywordName = "keyword_search"
+ keywordDescription = `
+keyword_search locates files with a search-and-filter approach.
+Use when navigating unfamiliar codebases with only conceptual understanding or vague user questions.
+
+Effective use:
+- Provide a detailed query for accurate relevance ranking
+- Include extensive but uncommon keywords to ensure comprehensive results
+- Order keywords by importance (most important first) - less important keywords may be dropped if there are too many results
+
+IMPORTANT: Do NOT use this tool if you have precise information like log lines, error messages, filenames, symbols, or package names. Use direct approaches (grep, cat, go doc, etc.) instead.
+`
+
+ // If you modify this, update the termui template for prettier rendering.
+ keywordInputSchema = `
+{
+ "type": "object",
+ "required": [
+ "query",
+ "keywords"
+ ],
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "A detailed statement of what you're trying to find or learn."
+ },
+ "keywords": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ },
+ "description": "List of keywords in descending order of importance."
+ }
+ }
+}
+`
+)
+
+type keywordInput struct {
+ Query string `json:"query"`
+ Keywords []string `json:"keywords"`
+}
+
+//go:embed keyword_system_prompt.txt
+var keywordSystemPrompt string
+
+// findRepoRoot attempts to find the git repository root from the current directory
+func findRepoRoot(wd string) (string, error) {
+ cmd := exec.Command("git", "rev-parse", "--show-toplevel")
+ cmd.Dir = wd
+ out, err := cmd.Output()
+ // todo: cwd here and throughout
+ if err != nil {
+ return "", fmt.Errorf("failed to find git repository root: %w", err)
+ }
+ return strings.TrimSpace(string(out)), nil
+}
+
+func keywordRun(ctx context.Context, m json.RawMessage) (string, error) {
+ var input keywordInput
+ if err := json.Unmarshal(m, &input); err != nil {
+ return "", err
+ }
+ wd := WorkingDir(ctx)
+ root, err := findRepoRoot(wd)
+ if err == nil {
+ wd = root
+ }
+ slog.InfoContext(ctx, "keyword search input", "query", input.Query, "keywords", input.Keywords, "wd", wd)
+
+ // first remove stopwords
+ var keep []string
+ for _, term := range input.Keywords {
+ out, err := ripgrep(ctx, wd, []string{term})
+ if err != nil {
+ return "", err
+ }
+ if len(out) > 64*1024 {
+ slog.InfoContext(ctx, "keyword search result too large", "term", term, "bytes", len(out))
+ continue
+ }
+ keep = append(keep, term)
+ }
+
+ // peel off keywords until we get a result that fits in the query window
+ var out string
+ for {
+ var err error
+ out, err = ripgrep(ctx, wd, keep)
+ if err != nil {
+ return "", err
+ }
+ if len(out) < 128*1024 {
+ break
+ }
+ keep = keep[:len(keep)-1]
+ }
+
+ info := ant.ToolCallInfoFromContext(ctx)
+ convo := info.Convo.SubConvo()
+ convo.SystemPrompt = strings.TrimSpace(keywordSystemPrompt)
+
+ initialMessage := ant.Message{
+ Role: ant.MessageRoleUser,
+ Content: []ant.Content{
+ ant.StringContent("<pwd>\n" + wd + "\n</pwd>"),
+ ant.StringContent("<ripgrep_results>\n" + out + "\n</ripgrep_results>"),
+ ant.StringContent("<query>\n" + input.Query + "\n</query>"),
+ },
+ }
+
+ resp, err := convo.SendMessage(initialMessage)
+ if err != nil {
+ return "", fmt.Errorf("failed to send relevance filtering message: %w", err)
+ }
+ if len(resp.Content) != 1 {
+ return "", fmt.Errorf("unexpected number of messages in relevance filtering response: %d", len(resp.Content))
+ }
+
+ filtered := resp.Content[0].Text
+
+ slog.InfoContext(ctx, "keyword search results processed",
+ "bytes", len(out),
+ "lines", strings.Count(out, "\n"),
+ "files", strings.Count(out, "\n\n"),
+ "query", input.Query,
+ "filtered", filtered,
+ )
+
+ return resp.Content[0].Text, nil
+}
+
+func ripgrep(ctx context.Context, wd string, terms []string) (string, error) {
+ args := []string{"-C", "10", "-i", "--line-number", "--with-filename"}
+ for _, term := range terms {
+ args = append(args, "-e", term)
+ }
+ cmd := exec.CommandContext(ctx, "rg", args...)
+ cmd.Dir = wd
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ // ripgrep returns exit code 1 when no matches are found, which is not an error for us
+ if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
+ return "no matches found", nil
+ }
+ return "", fmt.Errorf("search failed: %v\n%s", err, out)
+ }
+ outStr := string(out)
+ return outStr, nil
+}
diff --git a/claudetool/keyword_system_prompt.txt b/claudetool/keyword_system_prompt.txt
new file mode 100644
index 0000000..ac37acd
--- /dev/null
+++ b/claudetool/keyword_system_prompt.txt
@@ -0,0 +1,28 @@
+You are a code search relevance evaluator. Your task is to analyze ripgrep results and determine which files are most relevant to the user's query.
+
+INPUT FORMAT:
+- You will receive ripgrep output containing file matches for keywords with 10 lines of context
+- At the end will be "QUERY: <original search query>"
+
+ANALYSIS INSTRUCTIONS:
+1. Examine each file match and its surrounding context
+2. Evaluate relevance to the query based on:
+ - Direct relevance to concepts in the query
+ - Implementation of functionality described in the query
+ - Evidence of patterns or systems related to the query
+3. Exercise strict judgment - only return files that are genuinely relevant
+
+OUTPUT FORMAT:
+Respond with a plain text list of the most relevant files in decreasing order of relevance:
+
+/path/to/most/relevant/file: Concise relevance explanation
+/path/to/second/file: Concise relevance explanation
+...
+
+IMPORTANT:
+- Only include files with meaningful relevance to the query
+- Keep it short, don't blather
+- Do NOT list all files that had keyword matches
+- Focus on quality over quantity
+- If no files are truly relevant, return "No relevant files found"
+- Use absolute file paths
diff --git a/claudetool/patch.go b/claudetool/patch.go
new file mode 100644
index 0000000..9254319
--- /dev/null
+++ b/claudetool/patch.go
@@ -0,0 +1,307 @@
+package claudetool
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "go/parser"
+ "go/token"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "sketch.dev/ant"
+ "sketch.dev/claudetool/editbuf"
+ "sketch.dev/claudetool/patchkit"
+)
+
+// Patch is a tool for precise text modifications in files.
+var Patch = &ant.Tool{
+ Name: PatchName,
+ Description: strings.TrimSpace(PatchDescription),
+ InputSchema: ant.MustSchema(PatchInputSchema),
+ Run: PatchRun,
+}
+
+const (
+ PatchName = "patch"
+ PatchDescription = `
+File modification tool for precise text edits.
+
+Operations:
+- replace: Substitute text with new content
+- append_eof: Append new text at the end of the file
+- prepend_bof: Insert new text at the beginning of the file
+- overwrite: Replace the entire file with new content (automatically creates the file)
+
+Usage notes:
+- All inputs are interpreted literally (no automatic newline or whitespace handling)
+- For replace operations, oldText must appear EXACTLY ONCE in the file
+`
+
+ // If you modify this, update the termui template for prettier rendering.
+ PatchInputSchema = `
+{
+ "type": "object",
+ "required": ["path", "patches"],
+ "properties": {
+ "path": {
+ "type": "string",
+ "description": "Absolute path to the file to patch"
+ },
+ "patches": {
+ "type": "array",
+ "description": "List of patch requests to apply",
+ "items": {
+ "type": "object",
+ "required": ["operation", "newText"],
+ "properties": {
+ "operation": {
+ "type": "string",
+ "enum": ["replace", "append_eof", "prepend_bof", "overwrite"],
+ "description": "Type of operation to perform"
+ },
+ "oldText": {
+ "type": "string",
+ "description": "Text to locate for the operation (must be unique in file, required for replace)"
+ },
+ "newText": {
+ "type": "string",
+ "description": "The new text to use (empty for deletions)"
+ }
+ }
+ }
+ }
+ }
+}
+`
+)
+
+// TODO: maybe rename PatchRequest to PatchOperation or PatchSpec or PatchPart or just Patch?
+
+type patchInput struct {
+ Path string `json:"path"`
+ Patches []patchRequest `json:"patches"`
+}
+
+type patchRequest struct {
+ Operation string `json:"operation"`
+ OldText string `json:"oldText,omitempty"`
+ NewText string `json:"newText,omitempty"`
+}
+
+// PatchRun is the entry point for the user_patch tool.
+func PatchRun(ctx context.Context, m json.RawMessage) (string, error) {
+ var input patchInput
+ if err := json.Unmarshal(m, &input); err != nil {
+ return "", fmt.Errorf("failed to unmarshal user_patch input: %w", err)
+ }
+
+ // Validate the input
+ if !filepath.IsAbs(input.Path) {
+ return "", fmt.Errorf("path %q is not absolute", input.Path)
+ }
+ if len(input.Patches) == 0 {
+ return "", fmt.Errorf("no patches provided")
+ }
+ // TODO: check whether the file is autogenerated, and if so, require a "force" flag to modify it.
+
+ orig, err := os.ReadFile(input.Path)
+ // If the file doesn't exist, we can still apply patches
+ // that don't require finding existing text.
+ switch {
+ case errors.Is(err, os.ErrNotExist):
+ for _, patch := range input.Patches {
+ switch patch.Operation {
+ case "prepend_bof", "append_eof", "overwrite":
+ default:
+ return "", fmt.Errorf("file %q does not exist", input.Path)
+ }
+ }
+ case err != nil:
+ return "", fmt.Errorf("failed to read file %q: %w", input.Path, err)
+ }
+
+ likelyGoFile := strings.HasSuffix(input.Path, ".go")
+
+ autogenerated := likelyGoFile && isAutogeneratedGoFile(orig)
+ parsed := likelyGoFile && parseGo(orig) != nil
+
+ origStr := string(orig)
+ // Process the patches "simultaneously", minimizing them along the way.
+ // Claude generates patches that interact with each other.
+ buf := editbuf.NewBuffer(orig)
+
+ // TODO: is it better to apply the patches that apply cleanly and report on the failures?
+ // or instead have it be all-or-nothing?
+ // For now, it is all-or-nothing.
+ // TODO: when the model gets into a "cannot apply patch" cycle of doom, how do we get it unstuck?
+ // Also: how do we detect that it's in a cycle?
+ var patchErr error
+ for i, patch := range input.Patches {
+ switch patch.Operation {
+ case "prepend_bof":
+ buf.Insert(0, patch.NewText)
+ case "append_eof":
+ buf.Insert(len(orig), patch.NewText)
+ case "overwrite":
+ buf.Replace(0, len(orig), patch.NewText)
+ case "replace":
+ if patch.OldText == "" {
+ return "", fmt.Errorf("patch %d: oldText cannot be empty for %s operation", i, patch.Operation)
+ }
+
+ // Attempt to apply the patch.
+ spec, count := patchkit.Unique(origStr, patch.OldText, patch.NewText)
+ switch count {
+ case 0:
+ // no matches, maybe recoverable, continued below
+ case 1:
+ // exact match, apply
+ slog.DebugContext(ctx, "patch_applied", "method", "unique")
+ spec.ApplyToEditBuf(buf)
+ continue
+ case 2:
+ // multiple matches
+ patchErr = errors.Join(patchErr, fmt.Errorf("old text not unique:\n%s", patch.OldText))
+ default:
+ // TODO: return an error instead of using agentPatch
+ slog.ErrorContext(ctx, "unique returned unexpected count", "count", count)
+ patchErr = errors.Join(patchErr, fmt.Errorf("internal error"))
+ continue
+ }
+
+ // The following recovery mechanisms are heuristic.
+ // They aren't perfect, but they appear safe,
+ // and the cases they cover appear with some regularity.
+
+ // Try adjusting the whitespace prefix.
+ spec, ok := patchkit.UniqueDedent(origStr, patch.OldText, patch.NewText)
+ if ok {
+ slog.DebugContext(ctx, "patch_applied", "method", "unique_dedent")
+ spec.ApplyToEditBuf(buf)
+ continue
+ }
+
+ // Try ignoring leading/trailing whitespace in a semantically safe way.
+ spec, ok = patchkit.UniqueInValidGo(origStr, patch.OldText, patch.NewText)
+ if ok {
+ slog.DebugContext(ctx, "patch_applied", "method", "unique_in_valid_go")
+ spec.ApplyToEditBuf(buf)
+ continue
+ }
+
+ // Try ignoring semantically insignificant whitespace.
+ spec, ok = patchkit.UniqueGoTokens(origStr, patch.OldText, patch.NewText)
+ if ok {
+ slog.DebugContext(ctx, "patch_applied", "method", "unique_go_tokens")
+ spec.ApplyToEditBuf(buf)
+ continue
+ }
+
+ // Try trimming the first line of the patch, if we can do so safely.
+ spec, ok = patchkit.UniqueTrim(origStr, patch.OldText, patch.NewText)
+ if ok {
+ slog.DebugContext(ctx, "patch_applied", "method", "unique_trim")
+ spec.ApplyToEditBuf(buf)
+ continue
+ }
+
+ // No dice.
+ patchErr = errors.Join(patchErr, fmt.Errorf("old text not found:\n%s", patch.OldText))
+ continue
+ default:
+ return "", fmt.Errorf("unrecognized operation %q", patch.Operation)
+ }
+ }
+
+ if patchErr != nil {
+ sendTelemetry(ctx, "patch_error", map[string]any{
+ "orig": origStr,
+ "patches": input.Patches,
+ "errors": patchErr,
+ })
+ return "", patchErr
+ }
+
+ patched, err := buf.Bytes()
+ if err != nil {
+ return "", err
+ }
+ if err := os.MkdirAll(filepath.Dir(input.Path), 0o700); err != nil {
+ return "", fmt.Errorf("failed to create directory %q: %w", filepath.Dir(input.Path), err)
+ }
+ if err := os.WriteFile(input.Path, patched, 0o600); err != nil {
+ return "", fmt.Errorf("failed to write patched contents to file %q: %w", input.Path, err)
+ }
+
+ response := new(strings.Builder)
+ fmt.Fprintf(response, "- Applied all patches\n")
+
+ if parsed {
+ parseErr := parseGo(patched)
+ if parseErr != nil {
+ return "", fmt.Errorf("after applying all patches, the file no longer parses:\n%w", parseErr)
+ }
+ }
+
+ if autogenerated {
+ fmt.Fprintf(response, "- WARNING: %q appears to be autogenerated. Patches were applied anyway.\n", input.Path)
+ }
+
+ // TODO: maybe report the patch result to the model, i.e. some/all of the new code after the patches and formatting.
+ return response.String(), nil
+}
+
+func parseGo(buf []byte) error {
+ fset := token.NewFileSet()
+ _, err := parser.ParseFile(fset, "", buf, parser.SkipObjectResolution)
+ return err
+}
+
+func isAutogeneratedGoFile(buf []byte) bool {
+ for _, sig := range autogeneratedSignals {
+ if bytes.Contains(buf, []byte(sig)) {
+ return true
+ }
+ }
+
+ // https://pkg.go.dev/cmd/go#hdr-Generate_Go_files_by_processing_source
+ // "This line must appear before the first non-comment, non-blank text in the file."
+ // Approximate that by looking for it at the top of the file, before the last of the imports.
+ // (Sometimes people put it after the package declaration, because of course they do.)
+ // At least in the imports region we know it's not part of their actual code;
+ // we don't want to ignore the generator (which also includes these strings!),
+ // just the generated code.
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, "x.go", buf, parser.ImportsOnly|parser.ParseComments)
+ if err == nil {
+ for _, cg := range f.Comments {
+ t := strings.ToLower(cg.Text())
+ for _, sig := range autogeneratedHeaderSignals {
+ if strings.Contains(t, sig) {
+ return true
+ }
+ }
+ }
+ }
+
+ return false
+}
+
+// autogeneratedSignals are signals that a file is autogenerated, when present anywhere in the file.
+var autogeneratedSignals = [][]byte{
+ []byte("\nfunc bindataRead("), // pre-embed bindata packed file
+}
+
+// autogeneratedHeaderSignals are signals that a file is autogenerated, when present at the top of the file.
+var autogeneratedHeaderSignals = []string{
+ // canonical would be `(?m)^// Code generated .* DO NOT EDIT\.$`
+ // but people screw it up, a lot, so be more lenient
+ strings.ToLower("generate"),
+ strings.ToLower("DO NOT EDIT"),
+ strings.ToLower("export by"),
+}
diff --git a/claudetool/patchkit/patchkit.go b/claudetool/patchkit/patchkit.go
new file mode 100644
index 0000000..c7235e4
--- /dev/null
+++ b/claudetool/patchkit/patchkit.go
@@ -0,0 +1,415 @@
+package patchkit
+
+import (
+ "fmt"
+ "go/scanner"
+ "go/token"
+ "slices"
+ "strings"
+ "unicode"
+
+ "sketch.dev/claudetool/editbuf"
+)
+
+// A Spec specifies a single patch.
+type Spec struct {
+ Off int // Byte offset to apply the replacement
+ Len int // Length of the replacement
+ Src string // Original string (for debugging)
+ Old string // Search string
+ New string // Replacement string
+}
+
+// Unique generates a patch spec to apply op, given a unique occurrence of needle in haystack and replacement text replace.
+// It reports the number of matches found for needle in haystack: 0, 1, or 2 (for any value > 1).
+func Unique(haystack, needle, replace string) (*Spec, int) {
+ prefix, rest, ok := strings.Cut(haystack, needle)
+ if !ok {
+ return nil, 0
+ }
+ if strings.Contains(rest, needle) {
+ return nil, 2
+ }
+ s := &Spec{
+ Off: len(prefix),
+ Len: len(needle),
+ Src: haystack,
+ Old: needle,
+ New: replace,
+ }
+ return s, 1
+}
+
+// minimize reduces the size of the patch by removing any shared prefix and suffix.
+func (s *Spec) minimize() {
+ pre := commonPrefixLen(s.Old, s.New)
+ s.Off += pre
+ s.Len -= pre
+ s.Old = s.Old[pre:]
+ s.New = s.New[pre:]
+ suf := commonSuffixLen(s.Old, s.New)
+ s.Len -= suf
+ s.Old = s.Old[:len(s.Old)-suf]
+ s.New = s.New[:len(s.New)-suf]
+}
+
+// ApplyToEditBuf applies the patch to the given edit buffer.
+func (s *Spec) ApplyToEditBuf(buf *editbuf.Buffer) {
+ s.minimize()
+ buf.Replace(s.Off, s.Off+s.Len, s.New)
+}
+
+// UniqueDedent is Unique, but with flexibility around consistent whitespace prefix changes.
+// Unlike Unique, which returns a count of matches,
+// UniqueDedent returns a boolean indicating whether a unique match was found.
+// It is for LLMs that have a hard time reliably reproducing uniform whitespace prefixes.
+// For example, they may generate 8 spaces instead of 6 for all relevant lines.
+// UniqueDedent adjusts the needle's whitespace prefix to match the haystack's
+// and then replaces the unique instance of needle in haystack with replacement.
+func UniqueDedent(haystack, needle, replace string) (*Spec, bool) {
+ // TODO: this all definitely admits of some optimization
+ haystackLines := slices.Collect(strings.Lines(haystack))
+ needleLines := slices.Collect(strings.Lines(needle))
+ match := uniqueTrimmedLineMatch(haystackLines, needleLines)
+ if match == -1 {
+ return nil, false
+ }
+ // We now systematically adjust needle's whitespace prefix to match haystack.
+ // The first line gets special treatment, because its leading whitespace is irrelevant,
+ // and models often skip past it (or part of it).
+ if len(needleLines) == 0 {
+ return nil, false
+ }
+ // First line: cut leading whitespace and make corresponding fixes to replacement.
+ // The leading whitespace will come out in the wash in Unique.
+ // We need to make corresponding fixes to the replacement.
+ nl0 := needleLines[0]
+ noWS := strings.TrimLeftFunc(nl0, unicode.IsSpace)
+ ws0, _ := strings.CutSuffix(nl0, noWS) // can't fail
+ rest, ok := strings.CutPrefix(replace, ws0)
+ if ok {
+ // Adjust needle and replacement in tandem.
+ nl0 = noWS
+ replace = rest
+ }
+ // Calculate common whitespace prefixes for the rest.
+ haystackPrefix := commonWhitespacePrefix(haystackLines[match : match+len(needleLines)])
+ needlePrefix := commonWhitespacePrefix(needleLines[1:])
+ nbuf := new(strings.Builder)
+ for i, line := range needleLines {
+ if i == 0 {
+ nbuf.WriteString(nl0)
+ continue
+ }
+ // Allow empty (newline-only) lines not to be prefixed.
+ if strings.TrimRight(line, "\n\r") == "" {
+ nbuf.WriteString(line)
+ continue
+ }
+ // Swap in haystackPrefix for needlePrefix.
+ nbuf.WriteString(haystackPrefix)
+ nbuf.WriteString(line[len(needlePrefix):])
+ }
+ // Do a replacement with our new-and-improved needle.
+ needle = nbuf.String()
+ spec, count := Unique(haystack, needle, replace)
+ if count != 1 {
+ return nil, false
+ }
+ return spec, true
+}
+
+type tok struct {
+ pos token.Position
+ tok token.Token
+ lit string
+}
+
+func (t tok) String() string {
+ if t.lit == "" {
+ return fmt.Sprintf("%s", t.tok)
+ }
+ return fmt.Sprintf("%s(%q)", t.tok, t.lit)
+}
+
+func tokenize(code string) ([]tok, bool) {
+ var s scanner.Scanner
+ fset := token.NewFileSet()
+ file := fset.AddFile("", fset.Base(), len(code))
+ s.Init(file, []byte(code), nil, scanner.ScanComments)
+ var tokens []tok
+ for {
+ pos, t, lit := s.Scan()
+ if s.ErrorCount > 0 {
+ return nil, false // invalid Go code (or not Go code at all)
+ }
+ if t == token.EOF {
+ return tokens, true
+ }
+ tokens = append(tokens, tok{pos: fset.PositionFor(pos, false), tok: t, lit: lit})
+ }
+}
+
+func tokensEqual(a, b []tok) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ at, bt := a[i], b[i]
+ // positions are expected to differ
+ if at.tok != bt.tok || at.lit != bt.lit {
+ return false
+ }
+ }
+ return true
+}
+
+func tokensUniqueMatch(haystack, needle []tok) int {
+ // TODO: optimize
+ match := -1
+ for i := range haystack {
+ rest := haystack[i:]
+ if len(rest) < len(needle) {
+ break
+ }
+ rest = rest[:len(needle)]
+ if !tokensEqual(rest, needle) {
+ continue
+ }
+ if match != -1 {
+ return -1 // multiple matches
+ }
+ match = i
+ }
+ return match
+}
+
+// UniqueGoTokens is Unique, but with flexibility around all insignificant whitespace.
+// Like UniqueDedent, it returns a boolean indicating whether a unique match was found.
+// It is safe (enough) because it ensures that the needle alterations occurs only in places
+// where whitespace is not semantically significant.
+// In practice, this appears safe.
+func UniqueGoTokens(haystack, needle, replace string) (*Spec, bool) {
+ nt, ok := tokenize(needle)
+ if !ok {
+ return nil, false
+ }
+ ht, ok := tokenize(haystack)
+ if !ok {
+ return nil, false
+ }
+ match := tokensUniqueMatch(ht, nt)
+ if match == -1 {
+ return nil, false
+ }
+ matchEnd := match + len(nt) - 1
+ start := ht[match].pos.Offset
+ needle = haystack[start:]
+ if matchEnd+1 < len(ht) {
+ // todo: handle match at very end of file
+ end := ht[matchEnd+1].pos.Offset
+ needle = needle[:end-start]
+ }
+ // OK, declare this very fuzzy match to be our new needle.
+ spec, count := Unique(haystack, needle, replace)
+ if count != 1 {
+ return nil, false
+ }
+ return spec, true
+}
+
+// UniqueInValidGo is Unique, but with flexibility around all leading and trailing whitespace.
+// Like UniqueDedent, it returns a boolean indicating whether a unique match was found.
+// It is safe (enough) because it ensures that the needle alterations occurs only in places
+// where whitespace is not semantically significant.
+// In practice, this appears safe.
+func UniqueInValidGo(haystack, needle, replace string) (*Spec, bool) {
+ haystackLines := slices.Collect(strings.Lines(haystack))
+ needleLines := slices.Collect(strings.Lines(needle))
+ matchStart := uniqueTrimmedLineMatch(haystackLines, needleLines)
+ if matchStart == -1 {
+ return nil, false
+ }
+ needle, replace = improveNeedle(haystack, needle, replace, matchStart)
+ matchEnd := matchStart + strings.Count(needle, "\n")
+ // Ensure that none of the lines that we fuzzy-matched involve a multiline comment or string literal.
+ var s scanner.Scanner
+ fset := token.NewFileSet()
+ file := fset.AddFile("", fset.Base(), len(haystack))
+ s.Init(file, []byte(haystack), nil, scanner.ScanComments)
+ for {
+ pos, tok, lit := s.Scan()
+ if s.ErrorCount > 0 {
+ return nil, false // invalid Go code (or not Go code at all)
+ }
+ if tok == token.EOF {
+ break
+ }
+ if tok == token.SEMICOLON || !strings.Contains(lit, "\n") {
+ continue
+ }
+ // In a token that spans multiple lines,
+ // so not perfectly matching whitespace might be unsafe.
+ p := fset.Position(pos)
+ tokenStart := p.Line - 1 // 1-based to 0-based
+ tokenEnd := tokenStart + strings.Count(lit, "\n")
+ // Check whether [matchStart, matchEnd] overlaps [tokenStart, tokenEnd]
+ // TODO: think more about edge conditions here. Any off-by-one errors?
+ // For example, leading whitespace and trailing whitespace
+ // on this token's lines are not semantically significant.
+ if tokenStart <= matchEnd && matchStart <= tokenEnd {
+ // if tokenStart <= matchStart && tokenEnd <= tokenEnd {}
+ return nil, false // this token overlaps the range we're replacing, not safe
+ }
+ }
+
+ // TODO: restore this sanity check? it's mildly expensive and i've never seen it fail.
+ // replaced := strings.Join(haystackLines[:matchStart], "") + replacement + strings.Join(haystackLines[matchEnd:], "")
+ // _, err := format.Source([]byte(replaced))
+ // if err != nil {
+ // return nil, false
+ // }
+
+ // OK, declare this very fuzzy match to be our new needle.
+ needle = strings.Join(haystackLines[matchStart:matchEnd], "")
+ spec, count := Unique(haystack, needle, replace)
+ if count != 1 {
+ return nil, false
+ }
+ return spec, true
+}
+
+// UniqueTrim is Unique, but with flexibility to shrink old/replace in tandem.
+func UniqueTrim(haystack, needle, replace string) (*Spec, bool) {
+ // LLMs appear to particularly struggle with the first line of a patch.
+ // If that first line is replicated in replace,
+ // and removing it yields a unique match,
+ // we can remove that line entirely from both.
+ n0, nRest, nOK := strings.Cut(needle, "\n")
+ r0, rRest, rOK := strings.Cut(replace, "\n")
+ if !nOK || !rOK || n0 != r0 {
+ return nil, false
+ }
+ spec, count := Unique(haystack, nRest, rRest)
+ if count != 1 {
+ return nil, false
+ }
+ return spec, true
+}
+
+// uniqueTrimmedLineMatch returns the index of the first line in haystack that matches needle,
+// when ignoring leading and trailing whitespace.
+// uniqueTrimmedLineMatch returns -1 if there is no unique match.
+func uniqueTrimmedLineMatch(haystackLines, needleLines []string) int {
+ // TODO: optimize
+ trimmedHaystackLines := trimSpaceAll(haystackLines)
+ trimmedNeedleLines := trimSpaceAll(needleLines)
+ match := -1
+ for i := range trimmedHaystackLines {
+ rest := trimmedHaystackLines[i:]
+ if len(rest) < len(trimmedNeedleLines) {
+ break
+ }
+ rest = rest[:len(trimmedNeedleLines)]
+ if !slices.Equal(rest, trimmedNeedleLines) {
+ continue
+ }
+ if match != -1 {
+ return -1 // multiple matches
+ }
+ match = i
+ }
+ return match
+}
+
+func trimSpaceAll(x []string) []string {
+ trimmed := make([]string, len(x))
+ for i, s := range x {
+ trimmed[i] = strings.TrimSpace(s)
+ }
+ return trimmed
+}
+
+// improveNeedle adjusts both needle and replacement in tandem to better match haystack.
+// Note that we adjust search and replace together.
+func improveNeedle(haystack string, needle, replacement string, matchLine int) (string, string) {
+ // TODO: we make new slices too much
+ needleLines := slices.Collect(strings.Lines(needle))
+ if len(needleLines) == 0 {
+ return needle, replacement
+ }
+ haystackLines := slices.Collect(strings.Lines(haystack))
+ if matchLine+len(needleLines) > len(haystackLines) {
+ // should be impossible, but just in case
+ return needle, replacement
+ }
+ // Add trailing last-line newline if needed to better match haystack.
+ if !strings.HasSuffix(needle, "\n") && strings.HasSuffix(haystackLines[matchLine+len(needleLines)-1], "\n") {
+ needle += "\n"
+ replacement += "\n"
+ }
+ // Add leading first-line prefix if needed to better match haystack.
+ rest, ok := strings.CutSuffix(haystackLines[matchLine], needleLines[0])
+ if ok {
+ needle = rest + needle
+ replacement = rest + replacement
+ }
+ return needle, replacement
+}
+
+func isNonSpace(r rune) bool {
+ return !unicode.IsSpace(r)
+}
+
+func whitespacePrefix(s string) string {
+ firstNonSpace := strings.IndexFunc(s, isNonSpace)
+ return s[:max(0, firstNonSpace)] // map -1 for "not found" onto 0
+}
+
+// commonWhitespacePrefix returns the longest common whitespace prefix of the elements of x, somewhat flexibly.
+func commonWhitespacePrefix(x []string) string {
+ var pre string
+ for i, s := range x {
+ if i == 0 {
+ pre = s
+ continue
+ }
+ // ignore line endings for the moment
+ // (this is just for prefixes)
+ s = strings.TrimRight(s, "\n\r")
+ if s == "" {
+ continue
+ }
+ n := commonPrefixLen(pre, whitespacePrefix(s))
+ if n == 0 {
+ return ""
+ }
+ pre = pre[:n]
+ }
+ pre = strings.TrimRightFunc(pre, isNonSpace)
+ return pre
+}
+
+// commonPrefixLen returns the length of the common prefix of two strings.
+// TODO: optimize, see e.g. https://go-review.googlesource.com/c/go/+/408116
+func commonPrefixLen(a, b string) int {
+ shortest := min(len(a), len(b))
+ for i := range shortest {
+ if a[i] != b[i] {
+ return i
+ }
+ }
+ return shortest
+}
+
+// commonSuffixLen returns the length of the common suffix of two strings.
+// TODO: optimize
+func commonSuffixLen(a, b string) int {
+ shortest := min(len(a), len(b))
+ for i := 0; i < shortest; i++ {
+ if a[len(a)-i-1] != b[len(b)-i-1] {
+ return i
+ }
+ }
+ return shortest
+}
diff --git a/claudetool/shared.go b/claudetool/shared.go
new file mode 100644
index 0000000..83048b9
--- /dev/null
+++ b/claudetool/shared.go
@@ -0,0 +1,72 @@
+// Package claudetool provides tools for Claude AI models.
+//
+// When adding, removing, or modifying tools in this package,
+// remember to update the tool display template in termui/termui.go
+// to ensure proper tool output formatting.
+package claudetool
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "os"
+ "time"
+)
+
+type workingDirCtxKeyType string
+
+const workingDirCtxKey workingDirCtxKeyType = "workingDir"
+
+func WithWorkingDir(ctx context.Context, wd string) context.Context {
+ return context.WithValue(ctx, workingDirCtxKey, wd)
+}
+
+func WorkingDir(ctx context.Context) string {
+ // If cmd.Dir is empty, it uses the current working directory,
+ // so we can use that as a fallback.
+ wd, _ := ctx.Value(workingDirCtxKey).(string)
+ return wd
+}
+
+// sendTelemetry posts debug data to an internal logging server.
+// It is meant for use by people developing sketch and is disabled by default.
+// This is a best-effort operation; errors are logged but not returned.
+func sendTelemetry(ctx context.Context, typ string, data any) {
+ telemetryEndpoint := os.Getenv("SKETCH_TELEMETRY_ENDPOINT")
+ if telemetryEndpoint == "" {
+ return
+ }
+ err := doPostTelemetry(ctx, telemetryEndpoint, typ, data)
+ if err != nil {
+ slog.DebugContext(ctx, "failed to send JSON to server", "type", typ, "error", err)
+ }
+}
+
+func doPostTelemetry(ctx context.Context, telemetryEndpoint, typ string, data any) error {
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal %#v as JSON: %w", data, err)
+ }
+ timestamp := time.Now().Unix()
+ url := fmt.Sprintf(telemetryEndpoint+"/%s_%d.json", typ, timestamp)
+ req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return fmt.Errorf("failed to create HTTP request for %s: %w", typ, err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send %s JSON to server: %w", typ, err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode/100 != 2 {
+ return fmt.Errorf("server returned non-success status for %s: %d", typ, resp.StatusCode)
+ }
+ slog.DebugContext(ctx, "successfully sent JSON to server", "file_type", typ, "url", url)
+ return nil
+}
diff --git a/claudetool/think.go b/claudetool/think.go
new file mode 100644
index 0000000..293cc0b
--- /dev/null
+++ b/claudetool/think.go
@@ -0,0 +1,39 @@
+package claudetool
+
+import (
+ "context"
+ "encoding/json"
+
+ "sketch.dev/ant"
+)
+
+// The Think tool provides space to think.
+var Think = &ant.Tool{
+ Name: thinkName,
+ Description: thinkDescription,
+ InputSchema: ant.MustSchema(thinkInputSchema),
+ Run: thinkRun,
+}
+
+const (
+ thinkName = "think"
+ thinkDescription = `Think out loud, take notes, form plans. Has no external effects.`
+
+ // If you modify this, update the termui template for prettier rendering.
+ thinkInputSchema = `
+{
+ "type": "object",
+ "required": ["thoughts"],
+ "properties": {
+ "thoughts": {
+ "type": "string",
+ "description": "The thoughts, notes, or plans to record"
+ }
+ }
+}
+`
+)
+
+func thinkRun(ctx context.Context, m json.RawMessage) (string, error) {
+ return "recorded", nil
+}