claudetool: add "related files" detection to codereview tool
diff --git a/claudetool/differential.go b/claudetool/differential.go
index 014fc13..710a09d 100644
--- a/claudetool/differential.go
+++ b/claudetool/differential.go
@@ -2,6 +2,7 @@
import (
"bytes"
+ "cmp"
"context"
"encoding/json"
"fmt"
@@ -12,8 +13,10 @@
"os"
"os/exec"
"path/filepath"
+ "runtime"
"slices"
"strings"
+ "sync"
"time"
"golang.org/x/tools/go/packages"
@@ -78,7 +81,20 @@
}
allPkgList := slices.Collect(maps.Keys(allPkgs))
- var msgs []string
+ var errorMessages []string // problems we want the model to address
+ var infoMessages []string // info the model should consider
+
+ // Find potentially related files that should also be considered
+ // TODO: add some caching here, since this depends only on the initial commit and the changed files, not the details of the changes
+ relatedFiles, err := r.findRelatedFiles(ctx, changedFiles)
+ if err != nil {
+ slog.DebugContext(ctx, "CodeReviewer.Run: failed to find related files", "err", err)
+ } else {
+ relatedMsg := r.formatRelatedFiles(relatedFiles)
+ if relatedMsg != "" {
+ infoMessages = append(infoMessages, relatedMsg)
+ }
+ }
testMsg, err := r.checkTests(ctx, allPkgList)
if err != nil {
@@ -86,7 +102,7 @@
return "", err
}
if testMsg != "" {
- msgs = append(msgs, testMsg)
+ errorMessages = append(errorMessages, testMsg)
}
goplsMsg, err := r.checkGopls(ctx, changedFiles) // includes vet checks
@@ -95,17 +111,24 @@
return "", err
}
if goplsMsg != "" {
- msgs = append(msgs, goplsMsg)
+ errorMessages = append(errorMessages, goplsMsg)
}
- if len(msgs) == 0 {
- slog.DebugContext(ctx, "CodeReviewer.Run: no issues found")
- return "OK", nil
+ buf := new(strings.Builder)
+ if len(infoMessages) > 0 {
+ buf.WriteString("# Info\n\n")
+ buf.WriteString(strings.Join(infoMessages, "\n\n"))
+ buf.WriteString("\n\n")
}
-
- msgs = append(msgs, "Please fix before proceeding.")
- slog.DebugContext(ctx, "CodeReviewer.Run: found issues", "issues", msgs)
- return strings.Join(msgs, "\n\n"), nil
+ if len(errorMessages) > 0 {
+ buf.WriteString("# Errors\n\n")
+ buf.WriteString(strings.Join(errorMessages, "\n\n"))
+ buf.WriteString("\n\nPlease fix before proceeding.\n")
+ }
+ if buf.Len() == 0 {
+ buf.WriteString("OK")
+ }
+ return buf.String(), nil
}
func (r *CodeReviewer) initializeInitialCommitWorktree(ctx context.Context) error {
@@ -845,3 +868,161 @@
return buf.String()
}
+
+// RelatedFile represents a file historically related to the changed files
+type RelatedFile struct {
+ Path string // Path to the file
+ Correlation float64 // Correlation score (0.0-1.0)
+}
+
+// findRelatedFiles identifies files that are historically related to the changed files
+// by analyzing git commit history for co-occurrences.
+func (r *CodeReviewer) findRelatedFiles(ctx context.Context, changedFiles []string) ([]RelatedFile, error) {
+ commits, err := r.getCommitsTouchingFiles(ctx, changedFiles)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get commits touching files: %w", err)
+ }
+ if len(commits) == 0 {
+ return nil, nil
+ }
+
+ relChanged := make(map[string]bool, len(changedFiles))
+ for _, file := range changedFiles {
+ rel, err := filepath.Rel(r.repoRoot, file)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get relative path for %s: %w", file, err)
+ }
+ relChanged[rel] = true
+ }
+
+ historyFiles := make(map[string]int)
+ var historyMu sync.Mutex
+
+ maxWorkers := runtime.GOMAXPROCS(0)
+ semaphore := make(chan bool, maxWorkers)
+ var wg sync.WaitGroup
+
+ for _, commit := range commits {
+ wg.Add(1)
+ semaphore <- true // acquire
+
+ go func(commit string) {
+ defer wg.Done()
+ defer func() { <-semaphore }() // release
+ commitFiles, err := r.getFilesInCommit(ctx, commit)
+ if err != nil {
+ slog.WarnContext(ctx, "Failed to get files in commit", "commit", commit, "err", err)
+ return
+ }
+ incr := 0
+ for _, file := range commitFiles {
+ if relChanged[file] {
+ incr++
+ }
+ }
+ if incr == 0 {
+ return
+ }
+ historyMu.Lock()
+ defer historyMu.Unlock()
+ for _, file := range commitFiles {
+ historyFiles[file] += incr
+ }
+ }(commit)
+ }
+ wg.Wait()
+
+ // normalize
+ maxCount := 0
+ for _, count := range historyFiles {
+ maxCount = max(maxCount, count)
+ }
+ if maxCount == 0 {
+ return nil, nil
+ }
+
+ var relatedFiles []RelatedFile
+ for file, count := range historyFiles {
+ if relChanged[file] {
+ // Don't include inputs in the output.
+ continue
+ }
+ correlation := float64(count) / float64(maxCount)
+ // Require min correlation to avoid noise
+ if correlation >= 0.1 {
+ relatedFiles = append(relatedFiles, RelatedFile{Path: file, Correlation: correlation})
+ }
+ }
+
+ // Highest correlation first
+ slices.SortFunc(relatedFiles, func(a, b RelatedFile) int {
+ return -1 * cmp.Compare(a.Correlation, b.Correlation)
+ })
+
+ // Limit to 1 correlated file per input file.
+ // (Arbitrary limit, to be adjusted.)
+ maxFiles := len(changedFiles)
+ if len(relatedFiles) > maxFiles {
+ relatedFiles = relatedFiles[:maxFiles]
+ }
+
+ // TODO: add an LLM in the mix here (like the keyword search tool) to do a filtering pass,
+ // and then increase the strength of the wording in the relatedFiles message.
+
+ return relatedFiles, nil
+}
+
+// getCommitsTouchingFiles returns all commits that touch any of the specified files
+func (r *CodeReviewer) getCommitsTouchingFiles(ctx context.Context, files []string) ([]string, error) {
+ if len(files) == 0 {
+ return nil, nil
+ }
+ fileArgs := append([]string{"rev-list", "--all", "--"}, files...)
+ cmd := exec.CommandContext(ctx, "git", fileArgs...)
+ cmd.Dir = r.repoRoot
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get commits: %w\n%s", err, out)
+ }
+ return nonEmptyTrimmedLines(out), nil
+}
+
+// getFilesInCommit returns all files changed in a specific commit
+func (r *CodeReviewer) getFilesInCommit(ctx context.Context, commit string) ([]string, error) {
+ cmd := exec.CommandContext(ctx, "git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit)
+ cmd.Dir = r.repoRoot
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get files in commit: %w\n%s", err, out)
+ }
+ return nonEmptyTrimmedLines(out), nil
+}
+
+func nonEmptyTrimmedLines(b []byte) []string {
+ var lines []string
+ for line := range strings.Lines(string(b)) {
+ line = strings.TrimSpace(line)
+ if line != "" {
+ lines = append(lines, line)
+ }
+ }
+ return lines
+}
+
+// formatRelatedFiles formats the related files list into a human-readable message
+func (r *CodeReviewer) formatRelatedFiles(files []RelatedFile) string {
+ if len(files) == 0 {
+ return ""
+ }
+
+ buf := new(strings.Builder)
+
+ fmt.Fprintf(buf, "Potentially related files:\n\n")
+
+ for _, file := range files {
+ fmt.Fprintf(buf, "- %s (%0.0f%%)\n", file.Path, 100*file.Correlation)
+ }
+
+ fmt.Fprintf(buf, "\nThese files have historically changed with the files you have modified. Consider whether they require updates as well.\n")
+ return buf.String()
+}