blob: f7e70338bf5de38b4014f6aed0127edc49291c85 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "os"
10 "os/exec"
11 "runtime/debug"
12 "slices"
13 "strings"
14 "sync"
15 "time"
16
17 "sketch.dev/ant"
18 "sketch.dev/claudetool"
19)
20
21const (
22 userCancelMessage = "user requested agent to stop handling responses"
23)
24
25type CodingAgent interface {
26 // Init initializes an agent inside a docker container.
27 Init(AgentInit) error
28
29 // Ready returns a channel closed after Init successfully called.
30 Ready() <-chan struct{}
31
32 // URL reports the HTTP URL of this agent.
33 URL() string
34
35 // UserMessage enqueues a message to the agent and returns immediately.
36 UserMessage(ctx context.Context, msg string)
37
38 // WaitForMessage blocks until the agent has a response to give.
39 // Use AgentMessage.EndOfTurn to help determine if you want to
40 // drain the agent.
41 WaitForMessage(ctx context.Context) AgentMessage
42
43 // Loop begins the agent loop returns only when ctx is cancelled.
44 Loop(ctx context.Context)
45
46 CancelInnerLoop(cause error)
47
48 CancelToolUse(toolUseID string, cause error) error
49
50 // Returns a subset of the agent's message history.
51 Messages(start int, end int) []AgentMessage
52
53 // Returns the current number of messages in the history
54 MessageCount() int
55
56 TotalUsage() ant.CumulativeUsage
57 OriginalBudget() ant.Budget
58
59 // WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
60 WaitForMessageCount(ctx context.Context, greaterThan int)
61
62 WorkingDir() string
63
64 // Diff returns a unified diff of changes made since the agent was instantiated.
65 // If commit is non-nil, it shows the diff for just that specific commit.
66 Diff(commit *string) (string, error)
67
68 // InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
69 InitialCommit() string
70
71 // Title returns the current title of the conversation.
72 Title() string
73
74 // OS returns the operating system of the client.
75 OS() string
Philip Zeyliger99a9a022025-04-27 15:15:25 +000076
77 // OutstandingLLMCallCount returns the number of outstanding LLM calls.
78 OutstandingLLMCallCount() int
79
80 // OutstandingToolCalls returns the names of outstanding tool calls.
81 OutstandingToolCalls() []string
Philip Zeyliger18532b22025-04-23 21:11:46 +000082 OutsideOS() string
83 OutsideHostname() string
84 OutsideWorkingDir() string
Philip Zeyligerd1402952025-04-23 03:54:37 +000085 GitOrigin() string
Earl Lee2e463fb2025-04-17 11:22:22 -070086}
87
88type CodingAgentMessageType string
89
90const (
91 UserMessageType CodingAgentMessageType = "user"
92 AgentMessageType CodingAgentMessageType = "agent"
93 ErrorMessageType CodingAgentMessageType = "error"
94 BudgetMessageType CodingAgentMessageType = "budget" // dedicated for "out of budget" errors
95 ToolUseMessageType CodingAgentMessageType = "tool"
96 CommitMessageType CodingAgentMessageType = "commit" // for displaying git commits
97 AutoMessageType CodingAgentMessageType = "auto" // for automated notifications like autoformatting
98
99 cancelToolUseMessage = "Stop responding to my previous message. Wait for me to ask you something else before attempting to use any more tools."
100)
101
102type AgentMessage struct {
103 Type CodingAgentMessageType `json:"type"`
104 // EndOfTurn indicates that the AI is done working and is ready for the next user input.
105 EndOfTurn bool `json:"end_of_turn"`
106
107 Content string `json:"content"`
108 ToolName string `json:"tool_name,omitempty"`
109 ToolInput string `json:"input,omitempty"`
110 ToolResult string `json:"tool_result,omitempty"`
111 ToolError bool `json:"tool_error,omitempty"`
112 ToolCallId string `json:"tool_call_id,omitempty"`
113
114 // ToolCalls is a list of all tool calls requested in this message (name and input pairs)
115 ToolCalls []ToolCall `json:"tool_calls,omitempty"`
116
Sean McCulloughd9f13372025-04-21 15:08:49 -0700117 // ToolResponses is a list of all responses to tool calls requested in this message (name and input pairs)
118 ToolResponses []AgentMessage `json:"toolResponses,omitempty"`
119
Earl Lee2e463fb2025-04-17 11:22:22 -0700120 // Commits is a list of git commits for a commit message
121 Commits []*GitCommit `json:"commits,omitempty"`
122
123 Timestamp time.Time `json:"timestamp"`
124 ConversationID string `json:"conversation_id"`
125 ParentConversationID *string `json:"parent_conversation_id,omitempty"`
126 Usage *ant.Usage `json:"usage,omitempty"`
127
128 // Message timing information
129 StartTime *time.Time `json:"start_time,omitempty"`
130 EndTime *time.Time `json:"end_time,omitempty"`
131 Elapsed *time.Duration `json:"elapsed,omitempty"`
132
133 // Turn duration - the time taken for a complete agent turn
134 TurnDuration *time.Duration `json:"turnDuration,omitempty"`
135
136 Idx int `json:"idx"`
137}
138
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700139// SetConvo sets m.ConversationID and m.ParentConversationID based on convo.
140func (m *AgentMessage) SetConvo(convo *ant.Convo) {
141 if convo == nil {
142 m.ConversationID = ""
143 m.ParentConversationID = nil
144 return
145 }
146 m.ConversationID = convo.ID
147 if convo.Parent != nil {
148 m.ParentConversationID = &convo.Parent.ID
149 }
150}
151
Earl Lee2e463fb2025-04-17 11:22:22 -0700152// GitCommit represents a single git commit for a commit message
153type GitCommit struct {
154 Hash string `json:"hash"` // Full commit hash
155 Subject string `json:"subject"` // Commit subject line
156 Body string `json:"body"` // Full commit message body
157 PushedBranch string `json:"pushed_branch,omitempty"` // If set, this commit was pushed to this branch
158}
159
160// ToolCall represents a single tool call within an agent message
161type ToolCall struct {
Sean McCulloughd9f13372025-04-21 15:08:49 -0700162 Name string `json:"name"`
163 Input string `json:"input"`
164 ToolCallId string `json:"tool_call_id"`
165 ResultMessage *AgentMessage `json:"result_message,omitempty"`
166 Args string `json:"args,omitempty"`
167 Result string `json:"result,omitempty"`
Earl Lee2e463fb2025-04-17 11:22:22 -0700168}
169
170func (a *AgentMessage) Attr() slog.Attr {
171 var attrs []any = []any{
172 slog.String("type", string(a.Type)),
173 }
174 if a.EndOfTurn {
175 attrs = append(attrs, slog.Bool("end_of_turn", a.EndOfTurn))
176 }
177 if a.Content != "" {
178 attrs = append(attrs, slog.String("content", a.Content))
179 }
180 if a.ToolName != "" {
181 attrs = append(attrs, slog.String("tool_name", a.ToolName))
182 }
183 if a.ToolInput != "" {
184 attrs = append(attrs, slog.String("tool_input", a.ToolInput))
185 }
186 if a.Elapsed != nil {
187 attrs = append(attrs, slog.Int64("elapsed", a.Elapsed.Nanoseconds()))
188 }
189 if a.TurnDuration != nil {
190 attrs = append(attrs, slog.Int64("turnDuration", a.TurnDuration.Nanoseconds()))
191 }
192 if a.ToolResult != "" {
193 attrs = append(attrs, slog.String("tool_result", a.ToolResult))
194 }
195 if a.ToolError {
196 attrs = append(attrs, slog.Bool("tool_error", a.ToolError))
197 }
198 if len(a.ToolCalls) > 0 {
199 toolCallAttrs := make([]any, 0, len(a.ToolCalls))
200 for i, tc := range a.ToolCalls {
201 toolCallAttrs = append(toolCallAttrs, slog.Group(
202 fmt.Sprintf("tool_call_%d", i),
203 slog.String("name", tc.Name),
204 slog.String("input", tc.Input),
205 ))
206 }
207 attrs = append(attrs, slog.Group("tool_calls", toolCallAttrs...))
208 }
209 if a.ConversationID != "" {
210 attrs = append(attrs, slog.String("convo_id", a.ConversationID))
211 }
212 if a.ParentConversationID != nil {
213 attrs = append(attrs, slog.String("parent_convo_id", *a.ParentConversationID))
214 }
215 if a.Usage != nil && !a.Usage.IsZero() {
216 attrs = append(attrs, a.Usage.Attr())
217 }
218 // TODO: timestamp, convo ids, idx?
219 return slog.Group("agent_message", attrs...)
220}
221
222func errorMessage(err error) AgentMessage {
223 // It's somewhat unknowable whether error messages are "end of turn" or not, but it seems like the best approach.
224 if os.Getenv(("DEBUG")) == "1" {
225 return AgentMessage{Type: ErrorMessageType, Content: err.Error() + " Stacktrace: " + string(debug.Stack()), EndOfTurn: true}
226 }
227
228 return AgentMessage{Type: ErrorMessageType, Content: err.Error(), EndOfTurn: true}
229}
230
231func budgetMessage(err error) AgentMessage {
232 return AgentMessage{Type: BudgetMessageType, Content: err.Error(), EndOfTurn: true}
233}
234
235// ConvoInterface defines the interface for conversation interactions
236type ConvoInterface interface {
237 CumulativeUsage() ant.CumulativeUsage
238 ResetBudget(ant.Budget)
239 OverBudget() error
240 SendMessage(message ant.Message) (*ant.MessageResponse, error)
241 SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
242 ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
243 ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error)
244 CancelToolUse(toolUseID string, cause error) error
245}
246
247type Agent struct {
248 convo ConvoInterface
249 config AgentConfig // config for this agent
250 workingDir string
251 repoRoot string // workingDir may be a subdir of repoRoot
252 url string
253 lastHEAD string // hash of the last HEAD that was pushed to the host (only when under docker)
254 initialCommit string // hash of the Git HEAD when the agent was instantiated or Init()
255 gitRemoteAddr string // HTTP URL of the host git repo (only when under docker)
256 ready chan struct{} // closed when the agent is initialized (only when under docker)
257 startedAt time.Time
258 originalBudget ant.Budget
259 title string
260 codereview *claudetool.CodeReviewer
Philip Zeyliger18532b22025-04-23 21:11:46 +0000261 // Outside information
262 outsideHostname string
263 outsideOS string
264 outsideWorkingDir string
Philip Zeyligerd1402952025-04-23 03:54:37 +0000265 // URL of the git remote 'origin' if it exists
266 gitOrigin string
Earl Lee2e463fb2025-04-17 11:22:22 -0700267
268 // Time when the current turn started (reset at the beginning of InnerLoop)
269 startOfTurn time.Time
270
271 // Inbox - for messages from the user to the agent.
272 // sent on by UserMessage
273 // . e.g. when user types into the chat textarea
274 // read from by GatherMessages
275 inbox chan string
276
277 // Outbox
278 // sent on by pushToOutbox
279 // via OnToolResult and OnResponse callbacks
280 // read from by WaitForMessage
281 // called by termui inside its repl loop.
282 outbox chan AgentMessage
283
284 // protects cancelInnerLoop
285 cancelInnerLoopMu sync.Mutex
286 // cancels potentially long-running tool_use calls or chains of them
287 cancelInnerLoop context.CancelCauseFunc
288
289 // protects following
290 mu sync.Mutex
291
292 // Stores all messages for this agent
293 history []AgentMessage
294
295 listeners []chan struct{}
296
297 // Track git commits we've already seen (by hash)
298 seenCommits map[string]bool
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000299
300 // Track outstanding LLM call IDs
301 outstandingLLMCalls map[string]struct{}
302
303 // Track outstanding tool calls by ID with their names
304 outstandingToolCalls map[string]string
Earl Lee2e463fb2025-04-17 11:22:22 -0700305}
306
307func (a *Agent) URL() string { return a.url }
308
309// Title returns the current title of the conversation.
310// If no title has been set, returns an empty string.
311func (a *Agent) Title() string {
312 a.mu.Lock()
313 defer a.mu.Unlock()
314 return a.title
315}
316
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000317// OutstandingLLMCallCount returns the number of outstanding LLM calls.
318func (a *Agent) OutstandingLLMCallCount() int {
319 a.mu.Lock()
320 defer a.mu.Unlock()
321 return len(a.outstandingLLMCalls)
322}
323
324// OutstandingToolCalls returns the names of outstanding tool calls.
325func (a *Agent) OutstandingToolCalls() []string {
326 a.mu.Lock()
327 defer a.mu.Unlock()
328
329 tools := make([]string, 0, len(a.outstandingToolCalls))
330 for _, toolName := range a.outstandingToolCalls {
331 tools = append(tools, toolName)
332 }
333 return tools
334}
335
Earl Lee2e463fb2025-04-17 11:22:22 -0700336// OS returns the operating system of the client.
337func (a *Agent) OS() string {
338 return a.config.ClientGOOS
339}
340
Philip Zeyliger18532b22025-04-23 21:11:46 +0000341// OutsideOS returns the operating system of the outside system.
342func (a *Agent) OutsideOS() string {
343 return a.outsideOS
Philip Zeyligerd1402952025-04-23 03:54:37 +0000344}
345
Philip Zeyliger18532b22025-04-23 21:11:46 +0000346// OutsideHostname returns the hostname of the outside system.
347func (a *Agent) OutsideHostname() string {
348 return a.outsideHostname
Philip Zeyligerd1402952025-04-23 03:54:37 +0000349}
350
Philip Zeyliger18532b22025-04-23 21:11:46 +0000351// OutsideWorkingDir returns the working directory on the outside system.
352func (a *Agent) OutsideWorkingDir() string {
353 return a.outsideWorkingDir
Philip Zeyligerd1402952025-04-23 03:54:37 +0000354}
355
356// GitOrigin returns the URL of the git remote 'origin' if it exists.
357func (a *Agent) GitOrigin() string {
358 return a.gitOrigin
359}
360
Earl Lee2e463fb2025-04-17 11:22:22 -0700361// SetTitle sets the title of the conversation.
362func (a *Agent) SetTitle(title string) {
363 a.mu.Lock()
364 defer a.mu.Unlock()
365 a.title = title
366 // Notify all listeners that the state has changed
367 for _, ch := range a.listeners {
368 close(ch)
369 }
370 a.listeners = a.listeners[:0]
371}
372
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000373// OnToolCall implements ant.Listener and tracks the start of a tool call.
374func (a *Agent) OnToolCall(ctx context.Context, convo *ant.Convo, id string, toolName string, toolInput json.RawMessage, content ant.Content) {
375 // Track the tool call
376 a.mu.Lock()
377 a.outstandingToolCalls[id] = toolName
378 a.mu.Unlock()
379}
380
Earl Lee2e463fb2025-04-17 11:22:22 -0700381// OnToolResult implements ant.Listener.
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000382func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolID string, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
383 // Remove the tool call from outstanding calls
384 a.mu.Lock()
385 delete(a.outstandingToolCalls, toolID)
386 a.mu.Unlock()
387
Earl Lee2e463fb2025-04-17 11:22:22 -0700388 m := AgentMessage{
389 Type: ToolUseMessageType,
390 Content: content.Text,
391 ToolResult: content.ToolResult,
392 ToolError: content.ToolError,
393 ToolName: toolName,
394 ToolInput: string(toolInput),
395 ToolCallId: content.ToolUseID,
396 StartTime: content.StartTime,
397 EndTime: content.EndTime,
398 }
399
400 // Calculate the elapsed time if both start and end times are set
401 if content.StartTime != nil && content.EndTime != nil {
402 elapsed := content.EndTime.Sub(*content.StartTime)
403 m.Elapsed = &elapsed
404 }
405
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700406 m.SetConvo(convo)
Earl Lee2e463fb2025-04-17 11:22:22 -0700407 a.pushToOutbox(ctx, m)
408}
409
410// OnRequest implements ant.Listener.
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000411func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, id string, msg *ant.Message) {
412 a.mu.Lock()
413 defer a.mu.Unlock()
414 a.outstandingLLMCalls[id] = struct{}{}
Earl Lee2e463fb2025-04-17 11:22:22 -0700415 // We already get tool results from the above. We send user messages to the outbox in the agent loop.
416}
417
418// OnResponse implements ant.Listener. Responses contain messages from the LLM
419// that need to be displayed (as well as tool calls that we send along when
420// they're done). (It would be reasonable to also mention tool calls when they're
421// started, but we don't do that yet.)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000422func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, id string, resp *ant.MessageResponse) {
423 // Remove the LLM call from outstanding calls
424 a.mu.Lock()
425 delete(a.outstandingLLMCalls, id)
426 a.mu.Unlock()
427
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700428 if resp == nil {
429 // LLM API call failed
430 m := AgentMessage{
431 Type: ErrorMessageType,
432 Content: "API call failed, type 'continue' to try again",
433 }
434 m.SetConvo(convo)
435 a.pushToOutbox(ctx, m)
436 return
437 }
438
Earl Lee2e463fb2025-04-17 11:22:22 -0700439 endOfTurn := false
440 if resp.StopReason != ant.StopReasonToolUse {
441 endOfTurn = true
442 }
443 m := AgentMessage{
444 Type: AgentMessageType,
445 Content: collectTextContent(resp),
446 EndOfTurn: endOfTurn,
447 Usage: &resp.Usage,
448 StartTime: resp.StartTime,
449 EndTime: resp.EndTime,
450 }
451
452 // Extract any tool calls from the response
453 if resp.StopReason == ant.StopReasonToolUse {
454 var toolCalls []ToolCall
455 for _, part := range resp.Content {
456 if part.Type == "tool_use" {
457 toolCalls = append(toolCalls, ToolCall{
458 Name: part.ToolName,
459 Input: string(part.ToolInput),
460 ToolCallId: part.ID,
461 })
462 }
463 }
464 m.ToolCalls = toolCalls
465 }
466
467 // Calculate the elapsed time if both start and end times are set
468 if resp.StartTime != nil && resp.EndTime != nil {
469 elapsed := resp.EndTime.Sub(*resp.StartTime)
470 m.Elapsed = &elapsed
471 }
472
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700473 m.SetConvo(convo)
Earl Lee2e463fb2025-04-17 11:22:22 -0700474 a.pushToOutbox(ctx, m)
475}
476
477// WorkingDir implements CodingAgent.
478func (a *Agent) WorkingDir() string {
479 return a.workingDir
480}
481
482// MessageCount implements CodingAgent.
483func (a *Agent) MessageCount() int {
484 a.mu.Lock()
485 defer a.mu.Unlock()
486 return len(a.history)
487}
488
489// Messages implements CodingAgent.
490func (a *Agent) Messages(start int, end int) []AgentMessage {
491 a.mu.Lock()
492 defer a.mu.Unlock()
493 return slices.Clone(a.history[start:end])
494}
495
496func (a *Agent) OriginalBudget() ant.Budget {
497 return a.originalBudget
498}
499
500// AgentConfig contains configuration for creating a new Agent.
501type AgentConfig struct {
502 Context context.Context
503 AntURL string
504 APIKey string
505 HTTPC *http.Client
506 Budget ant.Budget
507 GitUsername string
508 GitEmail string
509 SessionID string
510 ClientGOOS string
511 ClientGOARCH string
512 UseAnthropicEdit bool
Philip Zeyliger18532b22025-04-23 21:11:46 +0000513 // Outside information
514 OutsideHostname string
515 OutsideOS string
516 OutsideWorkingDir string
Earl Lee2e463fb2025-04-17 11:22:22 -0700517}
518
519// NewAgent creates a new Agent.
520// It is not usable until Init() is called.
521func NewAgent(config AgentConfig) *Agent {
522 agent := &Agent{
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000523 config: config,
524 ready: make(chan struct{}),
525 inbox: make(chan string, 100),
526 outbox: make(chan AgentMessage, 100),
527 startedAt: time.Now(),
528 originalBudget: config.Budget,
529 seenCommits: make(map[string]bool),
530 outsideHostname: config.OutsideHostname,
531 outsideOS: config.OutsideOS,
532 outsideWorkingDir: config.OutsideWorkingDir,
533 outstandingLLMCalls: make(map[string]struct{}),
534 outstandingToolCalls: make(map[string]string),
Earl Lee2e463fb2025-04-17 11:22:22 -0700535 }
536 return agent
537}
538
539type AgentInit struct {
540 WorkingDir string
541 NoGit bool // only for testing
542
543 InDocker bool
544 Commit string
545 GitRemoteAddr string
546 HostAddr string
547}
548
549func (a *Agent) Init(ini AgentInit) error {
Josh Bleecher Snyder9c07e1d2025-04-28 19:25:37 -0700550 if a.convo != nil {
551 return fmt.Errorf("Agent.Init: already initialized")
552 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700553 ctx := a.config.Context
554 if ini.InDocker {
555 cmd := exec.CommandContext(ctx, "git", "stash")
556 cmd.Dir = ini.WorkingDir
557 if out, err := cmd.CombinedOutput(); err != nil {
558 return fmt.Errorf("git stash: %s: %v", out, err)
559 }
Philip Zeyligerd0ac1ea2025-04-21 20:04:19 -0700560 cmd = exec.CommandContext(ctx, "git", "remote", "add", "sketch-host", ini.GitRemoteAddr)
561 cmd.Dir = ini.WorkingDir
562 if out, err := cmd.CombinedOutput(); err != nil {
563 return fmt.Errorf("git remote add: %s: %v", out, err)
564 }
565 cmd = exec.CommandContext(ctx, "git", "fetch", "sketch-host")
Earl Lee2e463fb2025-04-17 11:22:22 -0700566 cmd.Dir = ini.WorkingDir
567 if out, err := cmd.CombinedOutput(); err != nil {
568 return fmt.Errorf("git fetch: %s: %w", out, err)
569 }
570 cmd = exec.CommandContext(ctx, "git", "checkout", "-f", ini.Commit)
571 cmd.Dir = ini.WorkingDir
572 if out, err := cmd.CombinedOutput(); err != nil {
573 return fmt.Errorf("git checkout %s: %s: %w", ini.Commit, out, err)
574 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700575 a.lastHEAD = ini.Commit
576 a.gitRemoteAddr = ini.GitRemoteAddr
577 a.initialCommit = ini.Commit
578 if ini.HostAddr != "" {
579 a.url = "http://" + ini.HostAddr
580 }
581 }
582 a.workingDir = ini.WorkingDir
583
584 if !ini.NoGit {
585 repoRoot, err := repoRoot(ctx, a.workingDir)
586 if err != nil {
587 return fmt.Errorf("repoRoot: %w", err)
588 }
589 a.repoRoot = repoRoot
590
591 commitHash, err := resolveRef(ctx, a.repoRoot, "HEAD")
592 if err != nil {
593 return fmt.Errorf("resolveRef: %w", err)
594 }
595 a.initialCommit = commitHash
596
597 codereview, err := claudetool.NewCodeReviewer(ctx, a.repoRoot, a.initialCommit)
598 if err != nil {
599 return fmt.Errorf("Agent.Init: claudetool.NewCodeReviewer: %w", err)
600 }
601 a.codereview = codereview
Philip Zeyligerd1402952025-04-23 03:54:37 +0000602
603 a.gitOrigin = getGitOrigin(ctx, ini.WorkingDir)
Earl Lee2e463fb2025-04-17 11:22:22 -0700604 }
605 a.lastHEAD = a.initialCommit
606 a.convo = a.initConvo()
607 close(a.ready)
608 return nil
609}
610
611// initConvo initializes the conversation.
612// It must not be called until all agent fields are initialized,
613// particularly workingDir and git.
614func (a *Agent) initConvo() *ant.Convo {
615 ctx := a.config.Context
616 convo := ant.NewConvo(ctx, a.config.APIKey)
617 if a.config.HTTPC != nil {
618 convo.HTTPC = a.config.HTTPC
619 }
620 if a.config.AntURL != "" {
621 convo.URL = a.config.AntURL
622 }
623 convo.PromptCaching = true
624 convo.Budget = a.config.Budget
625
626 var editPrompt string
627 if a.config.UseAnthropicEdit {
628 editPrompt = "Then use the str_replace_editor tool to make those edits. For short complete file replacements, you may use the bash tool with cat and heredoc stdin."
629 } else {
630 editPrompt = "Then use the patch tool to make those edits. Combine all edits to any given file into a single patch tool call."
631 }
632
633 convo.SystemPrompt = fmt.Sprintf(`
634You are an expert coding assistant and architect, with a specialty in Go.
635You are assisting the user to achieve their goals.
636
637Start by asking concise clarifying questions as needed.
638Once the intent is clear, work autonomously.
639
640Call the title tool early in the conversation to provide a brief summary of
641what the chat is about.
642
643Break down the overall goal into a series of smaller steps.
644(The first step is often: "Make a plan.")
645Then execute each step using tools.
646Update the plan if you have encountered problems or learned new information.
647
648When in doubt about a step, follow this broad workflow:
649
650- Think about how the current step fits into the overall plan.
651- Do research. Good tool choices: bash, think, keyword_search
652- Make edits.
653- Repeat.
654
655To make edits reliably and efficiently, first think about the intent of the edit,
656and what set of patches will achieve that intent.
657%s
658
659For renames or refactors, consider invoking gopls (via bash).
660
661The done tool provides a checklist of items you MUST verify and
662review before declaring that you are done. Before executing
663the done tool, run all the tools the done tool checklist asks
664for, including creating a git commit. Do not forget to run tests.
665
666<platform>
667%s/%s
668</platform>
669<pwd>
670%v
671</pwd>
672<git_root>
673%v
674</git_root>
Josh Bleecher Snyder833a0f82025-04-24 18:39:36 +0000675<HEAD>
676%v
677</HEAD>
678`, editPrompt, a.config.ClientGOOS, a.config.ClientGOARCH, a.workingDir, a.repoRoot, a.initialCommit)
Earl Lee2e463fb2025-04-17 11:22:22 -0700679
680 // Register all tools with the conversation
681 // When adding, removing, or modifying tools here, double-check that the termui tool display
682 // template in termui/termui.go has pretty-printing support for all tools.
683 convo.Tools = []*ant.Tool{
684 claudetool.Bash, claudetool.Keyword,
685 claudetool.Think, a.titleTool(), makeDoneTool(a.codereview, a.config.GitUsername, a.config.GitEmail),
686 a.codereview.Tool(),
687 }
688 if a.config.UseAnthropicEdit {
689 convo.Tools = append(convo.Tools, claudetool.AnthropicEditTool)
690 } else {
691 convo.Tools = append(convo.Tools, claudetool.Patch)
692 }
693 convo.Listener = a
694 return convo
695}
696
697func (a *Agent) titleTool() *ant.Tool {
698 // titleTool creates the title tool that sets the conversation title.
699 title := &ant.Tool{
700 Name: "title",
701 Description: `Use this tool early in the conversation, BEFORE MAKING ANY GIT COMMITS, to summarize what the chat is about briefly.`,
702 InputSchema: json.RawMessage(`{
703 "type": "object",
704 "properties": {
705 "title": {
706 "type": "string",
707 "description": "A brief title summarizing what this chat is about"
708 }
709 },
710 "required": ["title"]
711}`),
712 Run: func(ctx context.Context, input json.RawMessage) (string, error) {
713 var params struct {
714 Title string `json:"title"`
715 }
716 if err := json.Unmarshal(input, &params); err != nil {
717 return "", err
718 }
719 a.SetTitle(params.Title)
720 return fmt.Sprintf("Title set to: %s", params.Title), nil
721 },
722 }
723 return title
724}
725
726func (a *Agent) Ready() <-chan struct{} {
727 return a.ready
728}
729
730func (a *Agent) UserMessage(ctx context.Context, msg string) {
731 a.pushToOutbox(ctx, AgentMessage{Type: UserMessageType, Content: msg})
732 a.inbox <- msg
733}
734
735func (a *Agent) WaitForMessage(ctx context.Context) AgentMessage {
736 // TODO: Should this drain any outbox messages in case there are multiple?
737 select {
738 case msg := <-a.outbox:
739 return msg
740 case <-ctx.Done():
741 return errorMessage(ctx.Err())
742 }
743}
744
745func (a *Agent) CancelToolUse(toolUseID string, cause error) error {
746 return a.convo.CancelToolUse(toolUseID, cause)
747}
748
749func (a *Agent) CancelInnerLoop(cause error) {
750 a.cancelInnerLoopMu.Lock()
751 defer a.cancelInnerLoopMu.Unlock()
752 if a.cancelInnerLoop != nil {
753 a.cancelInnerLoop(cause)
754 }
755}
756
757func (a *Agent) Loop(ctxOuter context.Context) {
758 for {
759 select {
760 case <-ctxOuter.Done():
761 return
762 default:
763 ctxInner, cancel := context.WithCancelCause(ctxOuter)
764 a.cancelInnerLoopMu.Lock()
765 // Set .cancelInnerLoop so the user can cancel whatever is happening
766 // inside InnerLoop(ctxInner) without canceling this outer Loop execution.
767 // This CancelInnerLoop func is intended be called from other goroutines,
768 // hence the mutex.
769 a.cancelInnerLoop = cancel
770 a.cancelInnerLoopMu.Unlock()
771 a.InnerLoop(ctxInner)
772 cancel(nil)
773 }
774 }
775}
776
777func (a *Agent) pushToOutbox(ctx context.Context, m AgentMessage) {
778 if m.Timestamp.IsZero() {
779 m.Timestamp = time.Now()
780 }
781
782 // If this is an end-of-turn message, calculate the turn duration and add it to the message
783 if m.EndOfTurn && m.Type == AgentMessageType {
784 turnDuration := time.Since(a.startOfTurn)
785 m.TurnDuration = &turnDuration
786 slog.InfoContext(ctx, "Turn completed", "turnDuration", turnDuration)
787 }
788
789 slog.InfoContext(ctx, "agent message", m.Attr())
790
791 a.mu.Lock()
792 defer a.mu.Unlock()
793 m.Idx = len(a.history)
794 a.history = append(a.history, m)
795 a.outbox <- m
796
797 // Notify all listeners:
798 for _, ch := range a.listeners {
799 close(ch)
800 }
801 a.listeners = a.listeners[:0]
802}
803
804func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
805 var m []ant.Content
806 if block {
807 select {
808 case <-ctx.Done():
809 return m, ctx.Err()
810 case msg := <-a.inbox:
811 m = append(m, ant.Content{Type: "text", Text: msg})
812 }
813 }
814 for {
815 select {
816 case msg := <-a.inbox:
817 m = append(m, ant.Content{Type: "text", Text: msg})
818 default:
819 return m, nil
820 }
821 }
822}
823
824func (a *Agent) InnerLoop(ctx context.Context) {
825 // Reset the start of turn time
826 a.startOfTurn = time.Now()
827
828 // Wait for at least one message from the user.
829 msgs, err := a.GatherMessages(ctx, true)
830 if err != nil { // e.g. the context was canceled while blocking in GatherMessages
831 return
832 }
833 // We do this as we go, but let's also do it at the end of the turn
834 defer func() {
835 if _, err := a.handleGitCommits(ctx); err != nil {
836 // Just log the error, don't stop execution
837 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
838 }
839 }()
840
841 userMessage := ant.Message{
842 Role: "user",
843 Content: msgs,
844 }
845 // convo.SendMessage does the actual network call to send this to anthropic. This blocks until the response is ready.
846 // TODO: pass ctx to SendMessage, and figure out how to square that ctx with convo's own .Ctx. Who owns the scope of this call?
847 resp, err := a.convo.SendMessage(userMessage)
848 if err != nil {
849 a.pushToOutbox(ctx, errorMessage(err))
850 return
851 }
852 for {
853 // TODO: here and below where we check the budget,
854 // we should review the UX: is it clear what happened?
855 // is it clear how to resume?
856 // should we let the user set a new budget?
857 if err := a.overBudget(ctx); err != nil {
858 return
859 }
860 if resp.StopReason != ant.StopReasonToolUse {
861 break
862 }
863 var results []ant.Content
864 cancelled := false
865 select {
866 case <-ctx.Done():
867 // Don't actually run any of the tools, but rather build a response
868 // for each tool_use message letting the LLM know that user canceled it.
869 results, err = a.convo.ToolResultCancelContents(resp)
870 if err != nil {
871 a.pushToOutbox(ctx, errorMessage(err))
872 }
873 cancelled = true
874 default:
875 ctx = claudetool.WithWorkingDir(ctx, a.workingDir)
876 // fall-through, when the user has not canceled the inner loop:
877 results, err = a.convo.ToolResultContents(ctx, resp)
878 if ctx.Err() != nil { // e.g. the user canceled the operation
879 cancelled = true
880 } else if err != nil {
881 a.pushToOutbox(ctx, errorMessage(err))
882 }
883 }
884
885 // Check for git commits. Currently we do this here, after we collect
886 // tool results, since that's when we know commits could have happened.
887 // We could instead do this when the turn ends, but I think it makes sense
888 // to do this as we go.
889 newCommits, err := a.handleGitCommits(ctx)
890 if err != nil {
891 // Just log the error, don't stop execution
892 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
893 }
894 var autoqualityMessages []string
895 if len(newCommits) == 1 {
896 formatted := a.codereview.Autoformat(ctx)
897 if len(formatted) > 0 {
898 msg := fmt.Sprintf(`
899I ran autoformatters and they updated these files:
900
901%s
902
903Please amend your latest git commit with these changes and then continue with what you were doing.`,
904 strings.Join(formatted, "\n"),
905 )[1:]
906 a.pushToOutbox(ctx, AgentMessage{
907 Type: AutoMessageType,
908 Content: msg,
909 Timestamp: time.Now(),
910 })
911 autoqualityMessages = append(autoqualityMessages, msg)
912 }
913 }
914
915 if err := a.overBudget(ctx); err != nil {
916 return
917 }
918
919 // Include, along with the tool results (which must go first for whatever reason),
920 // any messages that the user has sent along while the tool_use was executing concurrently.
921 msgs, err = a.GatherMessages(ctx, false)
922 if err != nil {
923 return
924 }
925 // Inject any auto-generated messages from quality checks.
926 for _, msg := range autoqualityMessages {
927 msgs = append(msgs, ant.Content{Type: "text", Text: msg})
928 }
929 if cancelled {
930 msgs = append(msgs, ant.Content{Type: "text", Text: cancelToolUseMessage})
931 // EndOfTurn is false here so that the client of this agent keeps processing
932 // messages from WaitForMessage() and gets the response from the LLM (usually
933 // something like "okay, I'll wait further instructions", but the user should
934 // be made aware of it regardless).
935 a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
936 } else if err := a.convo.OverBudget(); err != nil {
937 budgetMsg := "We've exceeded our budget. Please ask the user to confirm before continuing by ending the turn."
938 msgs = append(msgs, ant.Content{Type: "text", Text: budgetMsg})
939 a.pushToOutbox(ctx, budgetMessage(fmt.Errorf("warning: %w (ask to keep trying, if you'd like)", err)))
940 }
941 results = append(results, msgs...)
942 resp, err = a.convo.SendMessage(ant.Message{
943 Role: "user",
944 Content: results,
945 })
946 if err != nil {
947 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("error: failed to continue conversation: %s", err.Error())))
948 break
949 }
950 if cancelled {
951 return
952 }
953 }
954}
955
956func (a *Agent) overBudget(ctx context.Context) error {
957 if err := a.convo.OverBudget(); err != nil {
958 m := budgetMessage(err)
959 m.Content = m.Content + "\n\nBudget reset."
960 a.pushToOutbox(ctx, budgetMessage(err))
961 a.convo.ResetBudget(a.originalBudget)
962 return err
963 }
964 return nil
965}
966
967func collectTextContent(msg *ant.MessageResponse) string {
968 // Collect all text content
969 var allText strings.Builder
970 for _, content := range msg.Content {
971 if content.Type == "text" && content.Text != "" {
972 if allText.Len() > 0 {
973 allText.WriteString("\n\n")
974 }
975 allText.WriteString(content.Text)
976 }
977 }
978 return allText.String()
979}
980
981func (a *Agent) TotalUsage() ant.CumulativeUsage {
982 a.mu.Lock()
983 defer a.mu.Unlock()
984 return a.convo.CumulativeUsage()
985}
986
987// WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
988func (a *Agent) WaitForMessageCount(ctx context.Context, greaterThan int) {
989 for a.MessageCount() <= greaterThan {
990 a.mu.Lock()
991 ch := make(chan struct{})
992 // Deletion happens when we notify.
993 a.listeners = append(a.listeners, ch)
994 a.mu.Unlock()
995
996 select {
997 case <-ctx.Done():
998 return
999 case <-ch:
1000 continue
1001 }
1002 }
1003}
1004
1005// Diff returns a unified diff of changes made since the agent was instantiated.
1006func (a *Agent) Diff(commit *string) (string, error) {
1007 if a.initialCommit == "" {
1008 return "", fmt.Errorf("no initial commit reference available")
1009 }
1010
1011 // Find the repository root
1012 ctx := context.Background()
1013
1014 // If a specific commit hash is provided, show just that commit's changes
1015 if commit != nil && *commit != "" {
1016 // Validate that the commit looks like a valid git SHA
1017 if !isValidGitSHA(*commit) {
1018 return "", fmt.Errorf("invalid git commit SHA format: %s", *commit)
1019 }
1020
1021 // Get the diff for just this commit
1022 cmd := exec.CommandContext(ctx, "git", "show", "--unified=10", *commit)
1023 cmd.Dir = a.repoRoot
1024 output, err := cmd.CombinedOutput()
1025 if err != nil {
1026 return "", fmt.Errorf("failed to get diff for commit %s: %w - %s", *commit, err, string(output))
1027 }
1028 return string(output), nil
1029 }
1030
1031 // Otherwise, get the diff between the initial commit and the current state using exec.Command
1032 cmd := exec.CommandContext(ctx, "git", "diff", "--unified=10", a.initialCommit)
1033 cmd.Dir = a.repoRoot
1034 output, err := cmd.CombinedOutput()
1035 if err != nil {
1036 return "", fmt.Errorf("failed to get diff: %w - %s", err, string(output))
1037 }
1038
1039 return string(output), nil
1040}
1041
1042// InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
1043func (a *Agent) InitialCommit() string {
1044 return a.initialCommit
1045}
1046
1047// handleGitCommits() highlights new commits to the user. When running
1048// under docker, new HEADs are pushed to a branch according to the title.
1049func (a *Agent) handleGitCommits(ctx context.Context) ([]*GitCommit, error) {
1050 if a.repoRoot == "" {
1051 return nil, nil
1052 }
1053
1054 head, err := resolveRef(ctx, a.repoRoot, "HEAD")
1055 if err != nil {
1056 return nil, err
1057 }
1058 if head == a.lastHEAD {
1059 return nil, nil // nothing to do
1060 }
1061 defer func() {
1062 a.lastHEAD = head
1063 }()
1064
1065 // Get new commits. Because it's possible that the agent does rebases, fixups, and
1066 // so forth, we use, as our fixed point, the "initialCommit", and we limit ourselves
1067 // to the last 100 commits.
1068 var commits []*GitCommit
1069
1070 // Get commits since the initial commit
1071 // Format: <hash>\0<subject>\0<body>\0
1072 // This uses NULL bytes as separators to avoid issues with newlines in commit messages
1073 // Limit to 100 commits to avoid overwhelming the user
1074 cmd := exec.CommandContext(ctx, "git", "log", "-n", "100", "--pretty=format:%H%x00%s%x00%b%x00", "^"+a.initialCommit, head)
1075 cmd.Dir = a.repoRoot
1076 output, err := cmd.Output()
1077 if err != nil {
1078 return nil, fmt.Errorf("failed to get git log: %w", err)
1079 }
1080
1081 // Parse git log output and filter out already seen commits
1082 parsedCommits := parseGitLog(string(output))
1083
1084 var headCommit *GitCommit
1085
1086 // Filter out commits we've already seen
1087 for _, commit := range parsedCommits {
1088 if commit.Hash == head {
1089 headCommit = &commit
1090 }
1091
1092 // Skip if we've seen this commit before. If our head has changed, always include that.
1093 if a.seenCommits[commit.Hash] && commit.Hash != head {
1094 continue
1095 }
1096
1097 // Mark this commit as seen
1098 a.seenCommits[commit.Hash] = true
1099
1100 // Add to our list of new commits
1101 commits = append(commits, &commit)
1102 }
1103
1104 if a.gitRemoteAddr != "" {
1105 if headCommit == nil {
1106 // I think this can only happen if we have a bug or if there's a race.
1107 headCommit = &GitCommit{}
1108 headCommit.Hash = head
1109 headCommit.Subject = "unknown"
1110 commits = append(commits, headCommit)
1111 }
1112
1113 cleanTitle := titleToBranch(a.title)
1114 if cleanTitle == "" {
1115 cleanTitle = a.config.SessionID
1116 }
1117 branch := "sketch/" + cleanTitle
1118
1119 // TODO: I don't love the force push here. We could see if the push is a fast-forward, and,
1120 // if it's not, we could make a backup with a unique name (perhaps append a timestamp) and
1121 // then use push with lease to replace.
1122 cmd = exec.Command("git", "push", "--force", a.gitRemoteAddr, "HEAD:refs/heads/"+branch)
1123 cmd.Dir = a.workingDir
1124 if out, err := cmd.CombinedOutput(); err != nil {
1125 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("git push to host: %s: %v", out, err)))
1126 } else {
1127 headCommit.PushedBranch = branch
1128 }
1129 }
1130
1131 // If we found new commits, create a message
1132 if len(commits) > 0 {
1133 msg := AgentMessage{
1134 Type: CommitMessageType,
1135 Timestamp: time.Now(),
1136 Commits: commits,
1137 }
1138 a.pushToOutbox(ctx, msg)
1139 }
1140 return commits, nil
1141}
1142
1143func titleToBranch(s string) string {
1144 // Convert to lowercase
1145 s = strings.ToLower(s)
1146
1147 // Replace spaces with hyphens
1148 s = strings.ReplaceAll(s, " ", "-")
1149
1150 // Remove any character that isn't a-z or hyphen
1151 var result strings.Builder
1152 for _, r := range s {
1153 if (r >= 'a' && r <= 'z') || r == '-' {
1154 result.WriteRune(r)
1155 }
1156 }
1157 return result.String()
1158}
1159
1160// parseGitLog parses the output of git log with format '%H%x00%s%x00%b%x00'
1161// and returns an array of GitCommit structs.
1162func parseGitLog(output string) []GitCommit {
1163 var commits []GitCommit
1164
1165 // No output means no commits
1166 if len(output) == 0 {
1167 return commits
1168 }
1169
1170 // Split by NULL byte
1171 parts := strings.Split(output, "\x00")
1172
1173 // Process in triplets (hash, subject, body)
1174 for i := 0; i < len(parts); i++ {
1175 // Skip empty parts
1176 if parts[i] == "" {
1177 continue
1178 }
1179
1180 // This should be a hash
1181 hash := strings.TrimSpace(parts[i])
1182
1183 // Make sure we have at least a subject part available
1184 if i+1 >= len(parts) {
1185 break // No more parts available
1186 }
1187
1188 // Get the subject
1189 subject := strings.TrimSpace(parts[i+1])
1190
1191 // Get the body if available
1192 body := ""
1193 if i+2 < len(parts) {
1194 body = strings.TrimSpace(parts[i+2])
1195 }
1196
1197 // Skip to the next triplet
1198 i += 2
1199
1200 commits = append(commits, GitCommit{
1201 Hash: hash,
1202 Subject: subject,
1203 Body: body,
1204 })
1205 }
1206
1207 return commits
1208}
1209
1210func repoRoot(ctx context.Context, dir string) (string, error) {
1211 cmd := exec.CommandContext(ctx, "git", "rev-parse", "--show-toplevel")
1212 stderr := new(strings.Builder)
1213 cmd.Stderr = stderr
1214 cmd.Dir = dir
1215 out, err := cmd.Output()
1216 if err != nil {
1217 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1218 }
1219 return strings.TrimSpace(string(out)), nil
1220}
1221
1222func resolveRef(ctx context.Context, dir, refName string) (string, error) {
1223 cmd := exec.CommandContext(ctx, "git", "rev-parse", refName)
1224 stderr := new(strings.Builder)
1225 cmd.Stderr = stderr
1226 cmd.Dir = dir
1227 out, err := cmd.Output()
1228 if err != nil {
1229 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1230 }
1231 // TODO: validate that out is valid hex
1232 return strings.TrimSpace(string(out)), nil
1233}
1234
1235// isValidGitSHA validates if a string looks like a valid git SHA hash.
1236// Git SHAs are hexadecimal strings of at least 4 characters but typically 7, 8, or 40 characters.
1237func isValidGitSHA(sha string) bool {
1238 // Git SHA must be a hexadecimal string with at least 4 characters
1239 if len(sha) < 4 || len(sha) > 40 {
1240 return false
1241 }
1242
1243 // Check if the string only contains hexadecimal characters
1244 for _, char := range sha {
1245 if !(char >= '0' && char <= '9') && !(char >= 'a' && char <= 'f') && !(char >= 'A' && char <= 'F') {
1246 return false
1247 }
1248 }
1249
1250 return true
1251}
Philip Zeyligerd1402952025-04-23 03:54:37 +00001252
1253// getGitOrigin returns the URL of the git remote 'origin' if it exists
1254func getGitOrigin(ctx context.Context, dir string) string {
1255 cmd := exec.CommandContext(ctx, "git", "config", "--get", "remote.origin.url")
1256 cmd.Dir = dir
1257 stderr := new(strings.Builder)
1258 cmd.Stderr = stderr
1259 out, err := cmd.Output()
1260 if err != nil {
1261 return ""
1262 }
1263 return strings.TrimSpace(string(out))
1264}