blob: 243827522943b0332832a5ca9ee754162a0cff2d [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package claudetool
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "os/exec"
10 "strings"
11
12 "sketch.dev/ant"
13)
14
15// The Keyword tool provides keyword search.
16// TODO: use an embedding model + re-ranker or otherwise do something nicer than this kludge.
17// TODO: if we can get this fast enough, do it on the fly while the user is typing their prompt.
18var Keyword = &ant.Tool{
19 Name: keywordName,
20 Description: keywordDescription,
21 InputSchema: ant.MustSchema(keywordInputSchema),
22 Run: keywordRun,
23}
24
25const (
26 keywordName = "keyword_search"
27 keywordDescription = `
28keyword_search locates files with a search-and-filter approach.
29Use when navigating unfamiliar codebases with only conceptual understanding or vague user questions.
30
31Effective use:
32- Provide a detailed query for accurate relevance ranking
33- Include extensive but uncommon keywords to ensure comprehensive results
34- Order keywords by importance (most important first) - less important keywords may be dropped if there are too many results
35
36IMPORTANT: Do NOT use this tool if you have precise information like log lines, error messages, filenames, symbols, or package names. Use direct approaches (grep, cat, go doc, etc.) instead.
37`
38
39 // If you modify this, update the termui template for prettier rendering.
40 keywordInputSchema = `
41{
42 "type": "object",
43 "required": [
44 "query",
45 "keywords"
46 ],
47 "properties": {
48 "query": {
49 "type": "string",
50 "description": "A detailed statement of what you're trying to find or learn."
51 },
52 "keywords": {
53 "type": "array",
54 "items": {
55 "type": "string"
56 },
57 "description": "List of keywords in descending order of importance."
58 }
59 }
60}
61`
62)
63
64type keywordInput struct {
65 Query string `json:"query"`
66 Keywords []string `json:"keywords"`
67}
68
69//go:embed keyword_system_prompt.txt
70var keywordSystemPrompt string
71
72// findRepoRoot attempts to find the git repository root from the current directory
73func findRepoRoot(wd string) (string, error) {
74 cmd := exec.Command("git", "rev-parse", "--show-toplevel")
75 cmd.Dir = wd
76 out, err := cmd.Output()
77 // todo: cwd here and throughout
78 if err != nil {
79 return "", fmt.Errorf("failed to find git repository root: %w", err)
80 }
81 return strings.TrimSpace(string(out)), nil
82}
83
84func keywordRun(ctx context.Context, m json.RawMessage) (string, error) {
85 var input keywordInput
86 if err := json.Unmarshal(m, &input); err != nil {
87 return "", err
88 }
89 wd := WorkingDir(ctx)
90 root, err := findRepoRoot(wd)
91 if err == nil {
92 wd = root
93 }
94 slog.InfoContext(ctx, "keyword search input", "query", input.Query, "keywords", input.Keywords, "wd", wd)
95
96 // first remove stopwords
97 var keep []string
98 for _, term := range input.Keywords {
99 out, err := ripgrep(ctx, wd, []string{term})
100 if err != nil {
101 return "", err
102 }
103 if len(out) > 64*1024 {
104 slog.InfoContext(ctx, "keyword search result too large", "term", term, "bytes", len(out))
105 continue
106 }
107 keep = append(keep, term)
108 }
109
110 // peel off keywords until we get a result that fits in the query window
111 var out string
112 for {
113 var err error
114 out, err = ripgrep(ctx, wd, keep)
115 if err != nil {
116 return "", err
117 }
118 if len(out) < 128*1024 {
119 break
120 }
121 keep = keep[:len(keep)-1]
122 }
123
124 info := ant.ToolCallInfoFromContext(ctx)
125 convo := info.Convo.SubConvo()
126 convo.SystemPrompt = strings.TrimSpace(keywordSystemPrompt)
127
128 initialMessage := ant.Message{
129 Role: ant.MessageRoleUser,
130 Content: []ant.Content{
131 ant.StringContent("<pwd>\n" + wd + "\n</pwd>"),
132 ant.StringContent("<ripgrep_results>\n" + out + "\n</ripgrep_results>"),
133 ant.StringContent("<query>\n" + input.Query + "\n</query>"),
134 },
135 }
136
137 resp, err := convo.SendMessage(initialMessage)
138 if err != nil {
139 return "", fmt.Errorf("failed to send relevance filtering message: %w", err)
140 }
141 if len(resp.Content) != 1 {
142 return "", fmt.Errorf("unexpected number of messages in relevance filtering response: %d", len(resp.Content))
143 }
144
145 filtered := resp.Content[0].Text
146
147 slog.InfoContext(ctx, "keyword search results processed",
148 "bytes", len(out),
149 "lines", strings.Count(out, "\n"),
150 "files", strings.Count(out, "\n\n"),
151 "query", input.Query,
152 "filtered", filtered,
153 )
154
155 return resp.Content[0].Text, nil
156}
157
158func ripgrep(ctx context.Context, wd string, terms []string) (string, error) {
159 args := []string{"-C", "10", "-i", "--line-number", "--with-filename"}
160 for _, term := range terms {
161 args = append(args, "-e", term)
162 }
163 cmd := exec.CommandContext(ctx, "rg", args...)
164 cmd.Dir = wd
165 out, err := cmd.CombinedOutput()
166 if err != nil {
167 // ripgrep returns exit code 1 when no matches are found, which is not an error for us
168 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
169 return "no matches found", nil
170 }
171 return "", fmt.Errorf("search failed: %v\n%s", err, out)
172 }
173 outStr := string(out)
174 return outStr, nil
175}