blob: fec5dd7fc0d8335bdd9dd87f3ede4de2f1de9d85 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "os"
10 "os/exec"
11 "runtime/debug"
12 "slices"
13 "strings"
14 "sync"
15 "time"
16
17 "sketch.dev/ant"
18 "sketch.dev/claudetool"
19)
20
21const (
22 userCancelMessage = "user requested agent to stop handling responses"
23)
24
25type CodingAgent interface {
26 // Init initializes an agent inside a docker container.
27 Init(AgentInit) error
28
29 // Ready returns a channel closed after Init successfully called.
30 Ready() <-chan struct{}
31
32 // URL reports the HTTP URL of this agent.
33 URL() string
34
35 // UserMessage enqueues a message to the agent and returns immediately.
36 UserMessage(ctx context.Context, msg string)
37
38 // WaitForMessage blocks until the agent has a response to give.
39 // Use AgentMessage.EndOfTurn to help determine if you want to
40 // drain the agent.
41 WaitForMessage(ctx context.Context) AgentMessage
42
43 // Loop begins the agent loop returns only when ctx is cancelled.
44 Loop(ctx context.Context)
45
46 CancelInnerLoop(cause error)
47
48 CancelToolUse(toolUseID string, cause error) error
49
50 // Returns a subset of the agent's message history.
51 Messages(start int, end int) []AgentMessage
52
53 // Returns the current number of messages in the history
54 MessageCount() int
55
56 TotalUsage() ant.CumulativeUsage
57 OriginalBudget() ant.Budget
58
59 // WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
60 WaitForMessageCount(ctx context.Context, greaterThan int)
61
62 WorkingDir() string
63
64 // Diff returns a unified diff of changes made since the agent was instantiated.
65 // If commit is non-nil, it shows the diff for just that specific commit.
66 Diff(commit *string) (string, error)
67
68 // InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
69 InitialCommit() string
70
71 // Title returns the current title of the conversation.
72 Title() string
73
74 // OS returns the operating system of the client.
75 OS() string
Philip Zeyliger99a9a022025-04-27 15:15:25 +000076
77 // OutstandingLLMCallCount returns the number of outstanding LLM calls.
78 OutstandingLLMCallCount() int
79
80 // OutstandingToolCalls returns the names of outstanding tool calls.
81 OutstandingToolCalls() []string
Philip Zeyliger18532b22025-04-23 21:11:46 +000082 OutsideOS() string
83 OutsideHostname() string
84 OutsideWorkingDir() string
Philip Zeyligerd1402952025-04-23 03:54:37 +000085 GitOrigin() string
Earl Lee2e463fb2025-04-17 11:22:22 -070086}
87
88type CodingAgentMessageType string
89
90const (
91 UserMessageType CodingAgentMessageType = "user"
92 AgentMessageType CodingAgentMessageType = "agent"
93 ErrorMessageType CodingAgentMessageType = "error"
94 BudgetMessageType CodingAgentMessageType = "budget" // dedicated for "out of budget" errors
95 ToolUseMessageType CodingAgentMessageType = "tool"
96 CommitMessageType CodingAgentMessageType = "commit" // for displaying git commits
97 AutoMessageType CodingAgentMessageType = "auto" // for automated notifications like autoformatting
98
99 cancelToolUseMessage = "Stop responding to my previous message. Wait for me to ask you something else before attempting to use any more tools."
100)
101
102type AgentMessage struct {
103 Type CodingAgentMessageType `json:"type"`
104 // EndOfTurn indicates that the AI is done working and is ready for the next user input.
105 EndOfTurn bool `json:"end_of_turn"`
106
107 Content string `json:"content"`
108 ToolName string `json:"tool_name,omitempty"`
109 ToolInput string `json:"input,omitempty"`
110 ToolResult string `json:"tool_result,omitempty"`
111 ToolError bool `json:"tool_error,omitempty"`
112 ToolCallId string `json:"tool_call_id,omitempty"`
113
114 // ToolCalls is a list of all tool calls requested in this message (name and input pairs)
115 ToolCalls []ToolCall `json:"tool_calls,omitempty"`
116
Sean McCulloughd9f13372025-04-21 15:08:49 -0700117 // ToolResponses is a list of all responses to tool calls requested in this message (name and input pairs)
118 ToolResponses []AgentMessage `json:"toolResponses,omitempty"`
119
Earl Lee2e463fb2025-04-17 11:22:22 -0700120 // Commits is a list of git commits for a commit message
121 Commits []*GitCommit `json:"commits,omitempty"`
122
123 Timestamp time.Time `json:"timestamp"`
124 ConversationID string `json:"conversation_id"`
125 ParentConversationID *string `json:"parent_conversation_id,omitempty"`
126 Usage *ant.Usage `json:"usage,omitempty"`
127
128 // Message timing information
129 StartTime *time.Time `json:"start_time,omitempty"`
130 EndTime *time.Time `json:"end_time,omitempty"`
131 Elapsed *time.Duration `json:"elapsed,omitempty"`
132
133 // Turn duration - the time taken for a complete agent turn
134 TurnDuration *time.Duration `json:"turnDuration,omitempty"`
135
136 Idx int `json:"idx"`
137}
138
139// GitCommit represents a single git commit for a commit message
140type GitCommit struct {
141 Hash string `json:"hash"` // Full commit hash
142 Subject string `json:"subject"` // Commit subject line
143 Body string `json:"body"` // Full commit message body
144 PushedBranch string `json:"pushed_branch,omitempty"` // If set, this commit was pushed to this branch
145}
146
147// ToolCall represents a single tool call within an agent message
148type ToolCall struct {
Sean McCulloughd9f13372025-04-21 15:08:49 -0700149 Name string `json:"name"`
150 Input string `json:"input"`
151 ToolCallId string `json:"tool_call_id"`
152 ResultMessage *AgentMessage `json:"result_message,omitempty"`
153 Args string `json:"args,omitempty"`
154 Result string `json:"result,omitempty"`
Earl Lee2e463fb2025-04-17 11:22:22 -0700155}
156
157func (a *AgentMessage) Attr() slog.Attr {
158 var attrs []any = []any{
159 slog.String("type", string(a.Type)),
160 }
161 if a.EndOfTurn {
162 attrs = append(attrs, slog.Bool("end_of_turn", a.EndOfTurn))
163 }
164 if a.Content != "" {
165 attrs = append(attrs, slog.String("content", a.Content))
166 }
167 if a.ToolName != "" {
168 attrs = append(attrs, slog.String("tool_name", a.ToolName))
169 }
170 if a.ToolInput != "" {
171 attrs = append(attrs, slog.String("tool_input", a.ToolInput))
172 }
173 if a.Elapsed != nil {
174 attrs = append(attrs, slog.Int64("elapsed", a.Elapsed.Nanoseconds()))
175 }
176 if a.TurnDuration != nil {
177 attrs = append(attrs, slog.Int64("turnDuration", a.TurnDuration.Nanoseconds()))
178 }
179 if a.ToolResult != "" {
180 attrs = append(attrs, slog.String("tool_result", a.ToolResult))
181 }
182 if a.ToolError {
183 attrs = append(attrs, slog.Bool("tool_error", a.ToolError))
184 }
185 if len(a.ToolCalls) > 0 {
186 toolCallAttrs := make([]any, 0, len(a.ToolCalls))
187 for i, tc := range a.ToolCalls {
188 toolCallAttrs = append(toolCallAttrs, slog.Group(
189 fmt.Sprintf("tool_call_%d", i),
190 slog.String("name", tc.Name),
191 slog.String("input", tc.Input),
192 ))
193 }
194 attrs = append(attrs, slog.Group("tool_calls", toolCallAttrs...))
195 }
196 if a.ConversationID != "" {
197 attrs = append(attrs, slog.String("convo_id", a.ConversationID))
198 }
199 if a.ParentConversationID != nil {
200 attrs = append(attrs, slog.String("parent_convo_id", *a.ParentConversationID))
201 }
202 if a.Usage != nil && !a.Usage.IsZero() {
203 attrs = append(attrs, a.Usage.Attr())
204 }
205 // TODO: timestamp, convo ids, idx?
206 return slog.Group("agent_message", attrs...)
207}
208
209func errorMessage(err error) AgentMessage {
210 // It's somewhat unknowable whether error messages are "end of turn" or not, but it seems like the best approach.
211 if os.Getenv(("DEBUG")) == "1" {
212 return AgentMessage{Type: ErrorMessageType, Content: err.Error() + " Stacktrace: " + string(debug.Stack()), EndOfTurn: true}
213 }
214
215 return AgentMessage{Type: ErrorMessageType, Content: err.Error(), EndOfTurn: true}
216}
217
218func budgetMessage(err error) AgentMessage {
219 return AgentMessage{Type: BudgetMessageType, Content: err.Error(), EndOfTurn: true}
220}
221
222// ConvoInterface defines the interface for conversation interactions
223type ConvoInterface interface {
224 CumulativeUsage() ant.CumulativeUsage
225 ResetBudget(ant.Budget)
226 OverBudget() error
227 SendMessage(message ant.Message) (*ant.MessageResponse, error)
228 SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
229 ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
230 ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error)
231 CancelToolUse(toolUseID string, cause error) error
232}
233
234type Agent struct {
235 convo ConvoInterface
236 config AgentConfig // config for this agent
237 workingDir string
238 repoRoot string // workingDir may be a subdir of repoRoot
239 url string
240 lastHEAD string // hash of the last HEAD that was pushed to the host (only when under docker)
241 initialCommit string // hash of the Git HEAD when the agent was instantiated or Init()
242 gitRemoteAddr string // HTTP URL of the host git repo (only when under docker)
243 ready chan struct{} // closed when the agent is initialized (only when under docker)
244 startedAt time.Time
245 originalBudget ant.Budget
246 title string
247 codereview *claudetool.CodeReviewer
Philip Zeyliger18532b22025-04-23 21:11:46 +0000248 // Outside information
249 outsideHostname string
250 outsideOS string
251 outsideWorkingDir string
Philip Zeyligerd1402952025-04-23 03:54:37 +0000252 // URL of the git remote 'origin' if it exists
253 gitOrigin string
Earl Lee2e463fb2025-04-17 11:22:22 -0700254
255 // Time when the current turn started (reset at the beginning of InnerLoop)
256 startOfTurn time.Time
257
258 // Inbox - for messages from the user to the agent.
259 // sent on by UserMessage
260 // . e.g. when user types into the chat textarea
261 // read from by GatherMessages
262 inbox chan string
263
264 // Outbox
265 // sent on by pushToOutbox
266 // via OnToolResult and OnResponse callbacks
267 // read from by WaitForMessage
268 // called by termui inside its repl loop.
269 outbox chan AgentMessage
270
271 // protects cancelInnerLoop
272 cancelInnerLoopMu sync.Mutex
273 // cancels potentially long-running tool_use calls or chains of them
274 cancelInnerLoop context.CancelCauseFunc
275
276 // protects following
277 mu sync.Mutex
278
279 // Stores all messages for this agent
280 history []AgentMessage
281
282 listeners []chan struct{}
283
284 // Track git commits we've already seen (by hash)
285 seenCommits map[string]bool
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000286
287 // Track outstanding LLM call IDs
288 outstandingLLMCalls map[string]struct{}
289
290 // Track outstanding tool calls by ID with their names
291 outstandingToolCalls map[string]string
Earl Lee2e463fb2025-04-17 11:22:22 -0700292}
293
294func (a *Agent) URL() string { return a.url }
295
296// Title returns the current title of the conversation.
297// If no title has been set, returns an empty string.
298func (a *Agent) Title() string {
299 a.mu.Lock()
300 defer a.mu.Unlock()
301 return a.title
302}
303
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000304// OutstandingLLMCallCount returns the number of outstanding LLM calls.
305func (a *Agent) OutstandingLLMCallCount() int {
306 a.mu.Lock()
307 defer a.mu.Unlock()
308 return len(a.outstandingLLMCalls)
309}
310
311// OutstandingToolCalls returns the names of outstanding tool calls.
312func (a *Agent) OutstandingToolCalls() []string {
313 a.mu.Lock()
314 defer a.mu.Unlock()
315
316 tools := make([]string, 0, len(a.outstandingToolCalls))
317 for _, toolName := range a.outstandingToolCalls {
318 tools = append(tools, toolName)
319 }
320 return tools
321}
322
Earl Lee2e463fb2025-04-17 11:22:22 -0700323// OS returns the operating system of the client.
324func (a *Agent) OS() string {
325 return a.config.ClientGOOS
326}
327
Philip Zeyliger18532b22025-04-23 21:11:46 +0000328// OutsideOS returns the operating system of the outside system.
329func (a *Agent) OutsideOS() string {
330 return a.outsideOS
Philip Zeyligerd1402952025-04-23 03:54:37 +0000331}
332
Philip Zeyliger18532b22025-04-23 21:11:46 +0000333// OutsideHostname returns the hostname of the outside system.
334func (a *Agent) OutsideHostname() string {
335 return a.outsideHostname
Philip Zeyligerd1402952025-04-23 03:54:37 +0000336}
337
Philip Zeyliger18532b22025-04-23 21:11:46 +0000338// OutsideWorkingDir returns the working directory on the outside system.
339func (a *Agent) OutsideWorkingDir() string {
340 return a.outsideWorkingDir
Philip Zeyligerd1402952025-04-23 03:54:37 +0000341}
342
343// GitOrigin returns the URL of the git remote 'origin' if it exists.
344func (a *Agent) GitOrigin() string {
345 return a.gitOrigin
346}
347
Earl Lee2e463fb2025-04-17 11:22:22 -0700348// SetTitle sets the title of the conversation.
349func (a *Agent) SetTitle(title string) {
350 a.mu.Lock()
351 defer a.mu.Unlock()
352 a.title = title
353 // Notify all listeners that the state has changed
354 for _, ch := range a.listeners {
355 close(ch)
356 }
357 a.listeners = a.listeners[:0]
358}
359
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000360// OnToolCall implements ant.Listener and tracks the start of a tool call.
361func (a *Agent) OnToolCall(ctx context.Context, convo *ant.Convo, id string, toolName string, toolInput json.RawMessage, content ant.Content) {
362 // Track the tool call
363 a.mu.Lock()
364 a.outstandingToolCalls[id] = toolName
365 a.mu.Unlock()
366}
367
Earl Lee2e463fb2025-04-17 11:22:22 -0700368// OnToolResult implements ant.Listener.
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000369func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolID string, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
370 // Remove the tool call from outstanding calls
371 a.mu.Lock()
372 delete(a.outstandingToolCalls, toolID)
373 a.mu.Unlock()
374
Earl Lee2e463fb2025-04-17 11:22:22 -0700375 m := AgentMessage{
376 Type: ToolUseMessageType,
377 Content: content.Text,
378 ToolResult: content.ToolResult,
379 ToolError: content.ToolError,
380 ToolName: toolName,
381 ToolInput: string(toolInput),
382 ToolCallId: content.ToolUseID,
383 StartTime: content.StartTime,
384 EndTime: content.EndTime,
385 }
386
387 // Calculate the elapsed time if both start and end times are set
388 if content.StartTime != nil && content.EndTime != nil {
389 elapsed := content.EndTime.Sub(*content.StartTime)
390 m.Elapsed = &elapsed
391 }
392
393 m.ConversationID = convo.ID
394 if convo.Parent != nil {
395 m.ParentConversationID = &convo.Parent.ID
396 }
397 a.pushToOutbox(ctx, m)
398}
399
400// OnRequest implements ant.Listener.
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000401func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, id string, msg *ant.Message) {
402 a.mu.Lock()
403 defer a.mu.Unlock()
404 a.outstandingLLMCalls[id] = struct{}{}
Earl Lee2e463fb2025-04-17 11:22:22 -0700405 // We already get tool results from the above. We send user messages to the outbox in the agent loop.
406}
407
408// OnResponse implements ant.Listener. Responses contain messages from the LLM
409// that need to be displayed (as well as tool calls that we send along when
410// they're done). (It would be reasonable to also mention tool calls when they're
411// started, but we don't do that yet.)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000412func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, id string, resp *ant.MessageResponse) {
413 // Remove the LLM call from outstanding calls
414 a.mu.Lock()
415 delete(a.outstandingLLMCalls, id)
416 a.mu.Unlock()
417
Earl Lee2e463fb2025-04-17 11:22:22 -0700418 endOfTurn := false
419 if resp.StopReason != ant.StopReasonToolUse {
420 endOfTurn = true
421 }
422 m := AgentMessage{
423 Type: AgentMessageType,
424 Content: collectTextContent(resp),
425 EndOfTurn: endOfTurn,
426 Usage: &resp.Usage,
427 StartTime: resp.StartTime,
428 EndTime: resp.EndTime,
429 }
430
431 // Extract any tool calls from the response
432 if resp.StopReason == ant.StopReasonToolUse {
433 var toolCalls []ToolCall
434 for _, part := range resp.Content {
435 if part.Type == "tool_use" {
436 toolCalls = append(toolCalls, ToolCall{
437 Name: part.ToolName,
438 Input: string(part.ToolInput),
439 ToolCallId: part.ID,
440 })
441 }
442 }
443 m.ToolCalls = toolCalls
444 }
445
446 // Calculate the elapsed time if both start and end times are set
447 if resp.StartTime != nil && resp.EndTime != nil {
448 elapsed := resp.EndTime.Sub(*resp.StartTime)
449 m.Elapsed = &elapsed
450 }
451
452 m.ConversationID = convo.ID
453 if convo.Parent != nil {
454 m.ParentConversationID = &convo.Parent.ID
455 }
456 a.pushToOutbox(ctx, m)
457}
458
459// WorkingDir implements CodingAgent.
460func (a *Agent) WorkingDir() string {
461 return a.workingDir
462}
463
464// MessageCount implements CodingAgent.
465func (a *Agent) MessageCount() int {
466 a.mu.Lock()
467 defer a.mu.Unlock()
468 return len(a.history)
469}
470
471// Messages implements CodingAgent.
472func (a *Agent) Messages(start int, end int) []AgentMessage {
473 a.mu.Lock()
474 defer a.mu.Unlock()
475 return slices.Clone(a.history[start:end])
476}
477
478func (a *Agent) OriginalBudget() ant.Budget {
479 return a.originalBudget
480}
481
482// AgentConfig contains configuration for creating a new Agent.
483type AgentConfig struct {
484 Context context.Context
485 AntURL string
486 APIKey string
487 HTTPC *http.Client
488 Budget ant.Budget
489 GitUsername string
490 GitEmail string
491 SessionID string
492 ClientGOOS string
493 ClientGOARCH string
494 UseAnthropicEdit bool
Philip Zeyliger18532b22025-04-23 21:11:46 +0000495 // Outside information
496 OutsideHostname string
497 OutsideOS string
498 OutsideWorkingDir string
Earl Lee2e463fb2025-04-17 11:22:22 -0700499}
500
501// NewAgent creates a new Agent.
502// It is not usable until Init() is called.
503func NewAgent(config AgentConfig) *Agent {
504 agent := &Agent{
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000505 config: config,
506 ready: make(chan struct{}),
507 inbox: make(chan string, 100),
508 outbox: make(chan AgentMessage, 100),
509 startedAt: time.Now(),
510 originalBudget: config.Budget,
511 seenCommits: make(map[string]bool),
512 outsideHostname: config.OutsideHostname,
513 outsideOS: config.OutsideOS,
514 outsideWorkingDir: config.OutsideWorkingDir,
515 outstandingLLMCalls: make(map[string]struct{}),
516 outstandingToolCalls: make(map[string]string),
Earl Lee2e463fb2025-04-17 11:22:22 -0700517 }
518 return agent
519}
520
521type AgentInit struct {
522 WorkingDir string
523 NoGit bool // only for testing
524
525 InDocker bool
526 Commit string
527 GitRemoteAddr string
528 HostAddr string
529}
530
531func (a *Agent) Init(ini AgentInit) error {
532 ctx := a.config.Context
533 if ini.InDocker {
534 cmd := exec.CommandContext(ctx, "git", "stash")
535 cmd.Dir = ini.WorkingDir
536 if out, err := cmd.CombinedOutput(); err != nil {
537 return fmt.Errorf("git stash: %s: %v", out, err)
538 }
Philip Zeyligerd0ac1ea2025-04-21 20:04:19 -0700539 cmd = exec.CommandContext(ctx, "git", "remote", "add", "sketch-host", ini.GitRemoteAddr)
540 cmd.Dir = ini.WorkingDir
541 if out, err := cmd.CombinedOutput(); err != nil {
542 return fmt.Errorf("git remote add: %s: %v", out, err)
543 }
544 cmd = exec.CommandContext(ctx, "git", "fetch", "sketch-host")
Earl Lee2e463fb2025-04-17 11:22:22 -0700545 cmd.Dir = ini.WorkingDir
546 if out, err := cmd.CombinedOutput(); err != nil {
547 return fmt.Errorf("git fetch: %s: %w", out, err)
548 }
549 cmd = exec.CommandContext(ctx, "git", "checkout", "-f", ini.Commit)
550 cmd.Dir = ini.WorkingDir
551 if out, err := cmd.CombinedOutput(); err != nil {
552 return fmt.Errorf("git checkout %s: %s: %w", ini.Commit, out, err)
553 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700554 a.lastHEAD = ini.Commit
555 a.gitRemoteAddr = ini.GitRemoteAddr
556 a.initialCommit = ini.Commit
557 if ini.HostAddr != "" {
558 a.url = "http://" + ini.HostAddr
559 }
560 }
561 a.workingDir = ini.WorkingDir
562
563 if !ini.NoGit {
564 repoRoot, err := repoRoot(ctx, a.workingDir)
565 if err != nil {
566 return fmt.Errorf("repoRoot: %w", err)
567 }
568 a.repoRoot = repoRoot
569
570 commitHash, err := resolveRef(ctx, a.repoRoot, "HEAD")
571 if err != nil {
572 return fmt.Errorf("resolveRef: %w", err)
573 }
574 a.initialCommit = commitHash
575
576 codereview, err := claudetool.NewCodeReviewer(ctx, a.repoRoot, a.initialCommit)
577 if err != nil {
578 return fmt.Errorf("Agent.Init: claudetool.NewCodeReviewer: %w", err)
579 }
580 a.codereview = codereview
Philip Zeyligerd1402952025-04-23 03:54:37 +0000581
582 a.gitOrigin = getGitOrigin(ctx, ini.WorkingDir)
Earl Lee2e463fb2025-04-17 11:22:22 -0700583 }
584 a.lastHEAD = a.initialCommit
585 a.convo = a.initConvo()
586 close(a.ready)
587 return nil
588}
589
590// initConvo initializes the conversation.
591// It must not be called until all agent fields are initialized,
592// particularly workingDir and git.
593func (a *Agent) initConvo() *ant.Convo {
594 ctx := a.config.Context
595 convo := ant.NewConvo(ctx, a.config.APIKey)
596 if a.config.HTTPC != nil {
597 convo.HTTPC = a.config.HTTPC
598 }
599 if a.config.AntURL != "" {
600 convo.URL = a.config.AntURL
601 }
602 convo.PromptCaching = true
603 convo.Budget = a.config.Budget
604
605 var editPrompt string
606 if a.config.UseAnthropicEdit {
607 editPrompt = "Then use the str_replace_editor tool to make those edits. For short complete file replacements, you may use the bash tool with cat and heredoc stdin."
608 } else {
609 editPrompt = "Then use the patch tool to make those edits. Combine all edits to any given file into a single patch tool call."
610 }
611
612 convo.SystemPrompt = fmt.Sprintf(`
613You are an expert coding assistant and architect, with a specialty in Go.
614You are assisting the user to achieve their goals.
615
616Start by asking concise clarifying questions as needed.
617Once the intent is clear, work autonomously.
618
619Call the title tool early in the conversation to provide a brief summary of
620what the chat is about.
621
622Break down the overall goal into a series of smaller steps.
623(The first step is often: "Make a plan.")
624Then execute each step using tools.
625Update the plan if you have encountered problems or learned new information.
626
627When in doubt about a step, follow this broad workflow:
628
629- Think about how the current step fits into the overall plan.
630- Do research. Good tool choices: bash, think, keyword_search
631- Make edits.
632- Repeat.
633
634To make edits reliably and efficiently, first think about the intent of the edit,
635and what set of patches will achieve that intent.
636%s
637
638For renames or refactors, consider invoking gopls (via bash).
639
640The done tool provides a checklist of items you MUST verify and
641review before declaring that you are done. Before executing
642the done tool, run all the tools the done tool checklist asks
643for, including creating a git commit. Do not forget to run tests.
644
645<platform>
646%s/%s
647</platform>
648<pwd>
649%v
650</pwd>
651<git_root>
652%v
653</git_root>
Josh Bleecher Snyder833a0f82025-04-24 18:39:36 +0000654<HEAD>
655%v
656</HEAD>
657`, editPrompt, a.config.ClientGOOS, a.config.ClientGOARCH, a.workingDir, a.repoRoot, a.initialCommit)
Earl Lee2e463fb2025-04-17 11:22:22 -0700658
659 // Register all tools with the conversation
660 // When adding, removing, or modifying tools here, double-check that the termui tool display
661 // template in termui/termui.go has pretty-printing support for all tools.
662 convo.Tools = []*ant.Tool{
663 claudetool.Bash, claudetool.Keyword,
664 claudetool.Think, a.titleTool(), makeDoneTool(a.codereview, a.config.GitUsername, a.config.GitEmail),
665 a.codereview.Tool(),
666 }
667 if a.config.UseAnthropicEdit {
668 convo.Tools = append(convo.Tools, claudetool.AnthropicEditTool)
669 } else {
670 convo.Tools = append(convo.Tools, claudetool.Patch)
671 }
672 convo.Listener = a
673 return convo
674}
675
676func (a *Agent) titleTool() *ant.Tool {
677 // titleTool creates the title tool that sets the conversation title.
678 title := &ant.Tool{
679 Name: "title",
680 Description: `Use this tool early in the conversation, BEFORE MAKING ANY GIT COMMITS, to summarize what the chat is about briefly.`,
681 InputSchema: json.RawMessage(`{
682 "type": "object",
683 "properties": {
684 "title": {
685 "type": "string",
686 "description": "A brief title summarizing what this chat is about"
687 }
688 },
689 "required": ["title"]
690}`),
691 Run: func(ctx context.Context, input json.RawMessage) (string, error) {
692 var params struct {
693 Title string `json:"title"`
694 }
695 if err := json.Unmarshal(input, &params); err != nil {
696 return "", err
697 }
698 a.SetTitle(params.Title)
699 return fmt.Sprintf("Title set to: %s", params.Title), nil
700 },
701 }
702 return title
703}
704
705func (a *Agent) Ready() <-chan struct{} {
706 return a.ready
707}
708
709func (a *Agent) UserMessage(ctx context.Context, msg string) {
710 a.pushToOutbox(ctx, AgentMessage{Type: UserMessageType, Content: msg})
711 a.inbox <- msg
712}
713
714func (a *Agent) WaitForMessage(ctx context.Context) AgentMessage {
715 // TODO: Should this drain any outbox messages in case there are multiple?
716 select {
717 case msg := <-a.outbox:
718 return msg
719 case <-ctx.Done():
720 return errorMessage(ctx.Err())
721 }
722}
723
724func (a *Agent) CancelToolUse(toolUseID string, cause error) error {
725 return a.convo.CancelToolUse(toolUseID, cause)
726}
727
728func (a *Agent) CancelInnerLoop(cause error) {
729 a.cancelInnerLoopMu.Lock()
730 defer a.cancelInnerLoopMu.Unlock()
731 if a.cancelInnerLoop != nil {
732 a.cancelInnerLoop(cause)
733 }
734}
735
736func (a *Agent) Loop(ctxOuter context.Context) {
737 for {
738 select {
739 case <-ctxOuter.Done():
740 return
741 default:
742 ctxInner, cancel := context.WithCancelCause(ctxOuter)
743 a.cancelInnerLoopMu.Lock()
744 // Set .cancelInnerLoop so the user can cancel whatever is happening
745 // inside InnerLoop(ctxInner) without canceling this outer Loop execution.
746 // This CancelInnerLoop func is intended be called from other goroutines,
747 // hence the mutex.
748 a.cancelInnerLoop = cancel
749 a.cancelInnerLoopMu.Unlock()
750 a.InnerLoop(ctxInner)
751 cancel(nil)
752 }
753 }
754}
755
756func (a *Agent) pushToOutbox(ctx context.Context, m AgentMessage) {
757 if m.Timestamp.IsZero() {
758 m.Timestamp = time.Now()
759 }
760
761 // If this is an end-of-turn message, calculate the turn duration and add it to the message
762 if m.EndOfTurn && m.Type == AgentMessageType {
763 turnDuration := time.Since(a.startOfTurn)
764 m.TurnDuration = &turnDuration
765 slog.InfoContext(ctx, "Turn completed", "turnDuration", turnDuration)
766 }
767
768 slog.InfoContext(ctx, "agent message", m.Attr())
769
770 a.mu.Lock()
771 defer a.mu.Unlock()
772 m.Idx = len(a.history)
773 a.history = append(a.history, m)
774 a.outbox <- m
775
776 // Notify all listeners:
777 for _, ch := range a.listeners {
778 close(ch)
779 }
780 a.listeners = a.listeners[:0]
781}
782
783func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
784 var m []ant.Content
785 if block {
786 select {
787 case <-ctx.Done():
788 return m, ctx.Err()
789 case msg := <-a.inbox:
790 m = append(m, ant.Content{Type: "text", Text: msg})
791 }
792 }
793 for {
794 select {
795 case msg := <-a.inbox:
796 m = append(m, ant.Content{Type: "text", Text: msg})
797 default:
798 return m, nil
799 }
800 }
801}
802
803func (a *Agent) InnerLoop(ctx context.Context) {
804 // Reset the start of turn time
805 a.startOfTurn = time.Now()
806
807 // Wait for at least one message from the user.
808 msgs, err := a.GatherMessages(ctx, true)
809 if err != nil { // e.g. the context was canceled while blocking in GatherMessages
810 return
811 }
812 // We do this as we go, but let's also do it at the end of the turn
813 defer func() {
814 if _, err := a.handleGitCommits(ctx); err != nil {
815 // Just log the error, don't stop execution
816 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
817 }
818 }()
819
820 userMessage := ant.Message{
821 Role: "user",
822 Content: msgs,
823 }
824 // convo.SendMessage does the actual network call to send this to anthropic. This blocks until the response is ready.
825 // TODO: pass ctx to SendMessage, and figure out how to square that ctx with convo's own .Ctx. Who owns the scope of this call?
826 resp, err := a.convo.SendMessage(userMessage)
827 if err != nil {
828 a.pushToOutbox(ctx, errorMessage(err))
829 return
830 }
831 for {
832 // TODO: here and below where we check the budget,
833 // we should review the UX: is it clear what happened?
834 // is it clear how to resume?
835 // should we let the user set a new budget?
836 if err := a.overBudget(ctx); err != nil {
837 return
838 }
839 if resp.StopReason != ant.StopReasonToolUse {
840 break
841 }
842 var results []ant.Content
843 cancelled := false
844 select {
845 case <-ctx.Done():
846 // Don't actually run any of the tools, but rather build a response
847 // for each tool_use message letting the LLM know that user canceled it.
848 results, err = a.convo.ToolResultCancelContents(resp)
849 if err != nil {
850 a.pushToOutbox(ctx, errorMessage(err))
851 }
852 cancelled = true
853 default:
854 ctx = claudetool.WithWorkingDir(ctx, a.workingDir)
855 // fall-through, when the user has not canceled the inner loop:
856 results, err = a.convo.ToolResultContents(ctx, resp)
857 if ctx.Err() != nil { // e.g. the user canceled the operation
858 cancelled = true
859 } else if err != nil {
860 a.pushToOutbox(ctx, errorMessage(err))
861 }
862 }
863
864 // Check for git commits. Currently we do this here, after we collect
865 // tool results, since that's when we know commits could have happened.
866 // We could instead do this when the turn ends, but I think it makes sense
867 // to do this as we go.
868 newCommits, err := a.handleGitCommits(ctx)
869 if err != nil {
870 // Just log the error, don't stop execution
871 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
872 }
873 var autoqualityMessages []string
874 if len(newCommits) == 1 {
875 formatted := a.codereview.Autoformat(ctx)
876 if len(formatted) > 0 {
877 msg := fmt.Sprintf(`
878I ran autoformatters and they updated these files:
879
880%s
881
882Please amend your latest git commit with these changes and then continue with what you were doing.`,
883 strings.Join(formatted, "\n"),
884 )[1:]
885 a.pushToOutbox(ctx, AgentMessage{
886 Type: AutoMessageType,
887 Content: msg,
888 Timestamp: time.Now(),
889 })
890 autoqualityMessages = append(autoqualityMessages, msg)
891 }
892 }
893
894 if err := a.overBudget(ctx); err != nil {
895 return
896 }
897
898 // Include, along with the tool results (which must go first for whatever reason),
899 // any messages that the user has sent along while the tool_use was executing concurrently.
900 msgs, err = a.GatherMessages(ctx, false)
901 if err != nil {
902 return
903 }
904 // Inject any auto-generated messages from quality checks.
905 for _, msg := range autoqualityMessages {
906 msgs = append(msgs, ant.Content{Type: "text", Text: msg})
907 }
908 if cancelled {
909 msgs = append(msgs, ant.Content{Type: "text", Text: cancelToolUseMessage})
910 // EndOfTurn is false here so that the client of this agent keeps processing
911 // messages from WaitForMessage() and gets the response from the LLM (usually
912 // something like "okay, I'll wait further instructions", but the user should
913 // be made aware of it regardless).
914 a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
915 } else if err := a.convo.OverBudget(); err != nil {
916 budgetMsg := "We've exceeded our budget. Please ask the user to confirm before continuing by ending the turn."
917 msgs = append(msgs, ant.Content{Type: "text", Text: budgetMsg})
918 a.pushToOutbox(ctx, budgetMessage(fmt.Errorf("warning: %w (ask to keep trying, if you'd like)", err)))
919 }
920 results = append(results, msgs...)
921 resp, err = a.convo.SendMessage(ant.Message{
922 Role: "user",
923 Content: results,
924 })
925 if err != nil {
926 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("error: failed to continue conversation: %s", err.Error())))
927 break
928 }
929 if cancelled {
930 return
931 }
932 }
933}
934
935func (a *Agent) overBudget(ctx context.Context) error {
936 if err := a.convo.OverBudget(); err != nil {
937 m := budgetMessage(err)
938 m.Content = m.Content + "\n\nBudget reset."
939 a.pushToOutbox(ctx, budgetMessage(err))
940 a.convo.ResetBudget(a.originalBudget)
941 return err
942 }
943 return nil
944}
945
946func collectTextContent(msg *ant.MessageResponse) string {
947 // Collect all text content
948 var allText strings.Builder
949 for _, content := range msg.Content {
950 if content.Type == "text" && content.Text != "" {
951 if allText.Len() > 0 {
952 allText.WriteString("\n\n")
953 }
954 allText.WriteString(content.Text)
955 }
956 }
957 return allText.String()
958}
959
960func (a *Agent) TotalUsage() ant.CumulativeUsage {
961 a.mu.Lock()
962 defer a.mu.Unlock()
963 return a.convo.CumulativeUsage()
964}
965
966// WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
967func (a *Agent) WaitForMessageCount(ctx context.Context, greaterThan int) {
968 for a.MessageCount() <= greaterThan {
969 a.mu.Lock()
970 ch := make(chan struct{})
971 // Deletion happens when we notify.
972 a.listeners = append(a.listeners, ch)
973 a.mu.Unlock()
974
975 select {
976 case <-ctx.Done():
977 return
978 case <-ch:
979 continue
980 }
981 }
982}
983
984// Diff returns a unified diff of changes made since the agent was instantiated.
985func (a *Agent) Diff(commit *string) (string, error) {
986 if a.initialCommit == "" {
987 return "", fmt.Errorf("no initial commit reference available")
988 }
989
990 // Find the repository root
991 ctx := context.Background()
992
993 // If a specific commit hash is provided, show just that commit's changes
994 if commit != nil && *commit != "" {
995 // Validate that the commit looks like a valid git SHA
996 if !isValidGitSHA(*commit) {
997 return "", fmt.Errorf("invalid git commit SHA format: %s", *commit)
998 }
999
1000 // Get the diff for just this commit
1001 cmd := exec.CommandContext(ctx, "git", "show", "--unified=10", *commit)
1002 cmd.Dir = a.repoRoot
1003 output, err := cmd.CombinedOutput()
1004 if err != nil {
1005 return "", fmt.Errorf("failed to get diff for commit %s: %w - %s", *commit, err, string(output))
1006 }
1007 return string(output), nil
1008 }
1009
1010 // Otherwise, get the diff between the initial commit and the current state using exec.Command
1011 cmd := exec.CommandContext(ctx, "git", "diff", "--unified=10", a.initialCommit)
1012 cmd.Dir = a.repoRoot
1013 output, err := cmd.CombinedOutput()
1014 if err != nil {
1015 return "", fmt.Errorf("failed to get diff: %w - %s", err, string(output))
1016 }
1017
1018 return string(output), nil
1019}
1020
1021// InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
1022func (a *Agent) InitialCommit() string {
1023 return a.initialCommit
1024}
1025
1026// handleGitCommits() highlights new commits to the user. When running
1027// under docker, new HEADs are pushed to a branch according to the title.
1028func (a *Agent) handleGitCommits(ctx context.Context) ([]*GitCommit, error) {
1029 if a.repoRoot == "" {
1030 return nil, nil
1031 }
1032
1033 head, err := resolveRef(ctx, a.repoRoot, "HEAD")
1034 if err != nil {
1035 return nil, err
1036 }
1037 if head == a.lastHEAD {
1038 return nil, nil // nothing to do
1039 }
1040 defer func() {
1041 a.lastHEAD = head
1042 }()
1043
1044 // Get new commits. Because it's possible that the agent does rebases, fixups, and
1045 // so forth, we use, as our fixed point, the "initialCommit", and we limit ourselves
1046 // to the last 100 commits.
1047 var commits []*GitCommit
1048
1049 // Get commits since the initial commit
1050 // Format: <hash>\0<subject>\0<body>\0
1051 // This uses NULL bytes as separators to avoid issues with newlines in commit messages
1052 // Limit to 100 commits to avoid overwhelming the user
1053 cmd := exec.CommandContext(ctx, "git", "log", "-n", "100", "--pretty=format:%H%x00%s%x00%b%x00", "^"+a.initialCommit, head)
1054 cmd.Dir = a.repoRoot
1055 output, err := cmd.Output()
1056 if err != nil {
1057 return nil, fmt.Errorf("failed to get git log: %w", err)
1058 }
1059
1060 // Parse git log output and filter out already seen commits
1061 parsedCommits := parseGitLog(string(output))
1062
1063 var headCommit *GitCommit
1064
1065 // Filter out commits we've already seen
1066 for _, commit := range parsedCommits {
1067 if commit.Hash == head {
1068 headCommit = &commit
1069 }
1070
1071 // Skip if we've seen this commit before. If our head has changed, always include that.
1072 if a.seenCommits[commit.Hash] && commit.Hash != head {
1073 continue
1074 }
1075
1076 // Mark this commit as seen
1077 a.seenCommits[commit.Hash] = true
1078
1079 // Add to our list of new commits
1080 commits = append(commits, &commit)
1081 }
1082
1083 if a.gitRemoteAddr != "" {
1084 if headCommit == nil {
1085 // I think this can only happen if we have a bug or if there's a race.
1086 headCommit = &GitCommit{}
1087 headCommit.Hash = head
1088 headCommit.Subject = "unknown"
1089 commits = append(commits, headCommit)
1090 }
1091
1092 cleanTitle := titleToBranch(a.title)
1093 if cleanTitle == "" {
1094 cleanTitle = a.config.SessionID
1095 }
1096 branch := "sketch/" + cleanTitle
1097
1098 // TODO: I don't love the force push here. We could see if the push is a fast-forward, and,
1099 // if it's not, we could make a backup with a unique name (perhaps append a timestamp) and
1100 // then use push with lease to replace.
1101 cmd = exec.Command("git", "push", "--force", a.gitRemoteAddr, "HEAD:refs/heads/"+branch)
1102 cmd.Dir = a.workingDir
1103 if out, err := cmd.CombinedOutput(); err != nil {
1104 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("git push to host: %s: %v", out, err)))
1105 } else {
1106 headCommit.PushedBranch = branch
1107 }
1108 }
1109
1110 // If we found new commits, create a message
1111 if len(commits) > 0 {
1112 msg := AgentMessage{
1113 Type: CommitMessageType,
1114 Timestamp: time.Now(),
1115 Commits: commits,
1116 }
1117 a.pushToOutbox(ctx, msg)
1118 }
1119 return commits, nil
1120}
1121
1122func titleToBranch(s string) string {
1123 // Convert to lowercase
1124 s = strings.ToLower(s)
1125
1126 // Replace spaces with hyphens
1127 s = strings.ReplaceAll(s, " ", "-")
1128
1129 // Remove any character that isn't a-z or hyphen
1130 var result strings.Builder
1131 for _, r := range s {
1132 if (r >= 'a' && r <= 'z') || r == '-' {
1133 result.WriteRune(r)
1134 }
1135 }
1136 return result.String()
1137}
1138
1139// parseGitLog parses the output of git log with format '%H%x00%s%x00%b%x00'
1140// and returns an array of GitCommit structs.
1141func parseGitLog(output string) []GitCommit {
1142 var commits []GitCommit
1143
1144 // No output means no commits
1145 if len(output) == 0 {
1146 return commits
1147 }
1148
1149 // Split by NULL byte
1150 parts := strings.Split(output, "\x00")
1151
1152 // Process in triplets (hash, subject, body)
1153 for i := 0; i < len(parts); i++ {
1154 // Skip empty parts
1155 if parts[i] == "" {
1156 continue
1157 }
1158
1159 // This should be a hash
1160 hash := strings.TrimSpace(parts[i])
1161
1162 // Make sure we have at least a subject part available
1163 if i+1 >= len(parts) {
1164 break // No more parts available
1165 }
1166
1167 // Get the subject
1168 subject := strings.TrimSpace(parts[i+1])
1169
1170 // Get the body if available
1171 body := ""
1172 if i+2 < len(parts) {
1173 body = strings.TrimSpace(parts[i+2])
1174 }
1175
1176 // Skip to the next triplet
1177 i += 2
1178
1179 commits = append(commits, GitCommit{
1180 Hash: hash,
1181 Subject: subject,
1182 Body: body,
1183 })
1184 }
1185
1186 return commits
1187}
1188
1189func repoRoot(ctx context.Context, dir string) (string, error) {
1190 cmd := exec.CommandContext(ctx, "git", "rev-parse", "--show-toplevel")
1191 stderr := new(strings.Builder)
1192 cmd.Stderr = stderr
1193 cmd.Dir = dir
1194 out, err := cmd.Output()
1195 if err != nil {
1196 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1197 }
1198 return strings.TrimSpace(string(out)), nil
1199}
1200
1201func resolveRef(ctx context.Context, dir, refName string) (string, error) {
1202 cmd := exec.CommandContext(ctx, "git", "rev-parse", refName)
1203 stderr := new(strings.Builder)
1204 cmd.Stderr = stderr
1205 cmd.Dir = dir
1206 out, err := cmd.Output()
1207 if err != nil {
1208 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1209 }
1210 // TODO: validate that out is valid hex
1211 return strings.TrimSpace(string(out)), nil
1212}
1213
1214// isValidGitSHA validates if a string looks like a valid git SHA hash.
1215// Git SHAs are hexadecimal strings of at least 4 characters but typically 7, 8, or 40 characters.
1216func isValidGitSHA(sha string) bool {
1217 // Git SHA must be a hexadecimal string with at least 4 characters
1218 if len(sha) < 4 || len(sha) > 40 {
1219 return false
1220 }
1221
1222 // Check if the string only contains hexadecimal characters
1223 for _, char := range sha {
1224 if !(char >= '0' && char <= '9') && !(char >= 'a' && char <= 'f') && !(char >= 'A' && char <= 'F') {
1225 return false
1226 }
1227 }
1228
1229 return true
1230}
Philip Zeyligerd1402952025-04-23 03:54:37 +00001231
1232// getGitOrigin returns the URL of the git remote 'origin' if it exists
1233func getGitOrigin(ctx context.Context, dir string) string {
1234 cmd := exec.CommandContext(ctx, "git", "config", "--get", "remote.origin.url")
1235 cmd.Dir = dir
1236 stderr := new(strings.Builder)
1237 cmd.Stderr = stderr
1238 out, err := cmd.Output()
1239 if err != nil {
1240 return ""
1241 }
1242 return strings.TrimSpace(string(out))
1243}