blob: ab4853303a91a8cc116a9acdd8352996daa6404a [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "os"
10 "os/exec"
11 "runtime/debug"
12 "slices"
13 "strings"
14 "sync"
15 "time"
16
17 "sketch.dev/ant"
18 "sketch.dev/claudetool"
19)
20
21const (
22 userCancelMessage = "user requested agent to stop handling responses"
23)
24
25type CodingAgent interface {
26 // Init initializes an agent inside a docker container.
27 Init(AgentInit) error
28
29 // Ready returns a channel closed after Init successfully called.
30 Ready() <-chan struct{}
31
32 // URL reports the HTTP URL of this agent.
33 URL() string
34
35 // UserMessage enqueues a message to the agent and returns immediately.
36 UserMessage(ctx context.Context, msg string)
37
38 // WaitForMessage blocks until the agent has a response to give.
39 // Use AgentMessage.EndOfTurn to help determine if you want to
40 // drain the agent.
41 WaitForMessage(ctx context.Context) AgentMessage
42
43 // Loop begins the agent loop returns only when ctx is cancelled.
44 Loop(ctx context.Context)
45
46 CancelInnerLoop(cause error)
47
48 CancelToolUse(toolUseID string, cause error) error
49
50 // Returns a subset of the agent's message history.
51 Messages(start int, end int) []AgentMessage
52
53 // Returns the current number of messages in the history
54 MessageCount() int
55
56 TotalUsage() ant.CumulativeUsage
57 OriginalBudget() ant.Budget
58
59 // WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
60 WaitForMessageCount(ctx context.Context, greaterThan int)
61
62 WorkingDir() string
63
64 // Diff returns a unified diff of changes made since the agent was instantiated.
65 // If commit is non-nil, it shows the diff for just that specific commit.
66 Diff(commit *string) (string, error)
67
68 // InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
69 InitialCommit() string
70
71 // Title returns the current title of the conversation.
72 Title() string
73
74 // OS returns the operating system of the client.
75 OS() string
Philip Zeyliger99a9a022025-04-27 15:15:25 +000076
Philip Zeyligerc72fff52025-04-29 20:17:54 +000077 // SessionID returns the unique session identifier.
78 SessionID() string
79
Philip Zeyliger99a9a022025-04-27 15:15:25 +000080 // OutstandingLLMCallCount returns the number of outstanding LLM calls.
81 OutstandingLLMCallCount() int
82
83 // OutstandingToolCalls returns the names of outstanding tool calls.
84 OutstandingToolCalls() []string
Philip Zeyliger18532b22025-04-23 21:11:46 +000085 OutsideOS() string
86 OutsideHostname() string
87 OutsideWorkingDir() string
Philip Zeyligerd1402952025-04-23 03:54:37 +000088 GitOrigin() string
Earl Lee2e463fb2025-04-17 11:22:22 -070089}
90
91type CodingAgentMessageType string
92
93const (
94 UserMessageType CodingAgentMessageType = "user"
95 AgentMessageType CodingAgentMessageType = "agent"
96 ErrorMessageType CodingAgentMessageType = "error"
97 BudgetMessageType CodingAgentMessageType = "budget" // dedicated for "out of budget" errors
98 ToolUseMessageType CodingAgentMessageType = "tool"
99 CommitMessageType CodingAgentMessageType = "commit" // for displaying git commits
100 AutoMessageType CodingAgentMessageType = "auto" // for automated notifications like autoformatting
101
102 cancelToolUseMessage = "Stop responding to my previous message. Wait for me to ask you something else before attempting to use any more tools."
103)
104
105type AgentMessage struct {
106 Type CodingAgentMessageType `json:"type"`
107 // EndOfTurn indicates that the AI is done working and is ready for the next user input.
108 EndOfTurn bool `json:"end_of_turn"`
109
110 Content string `json:"content"`
111 ToolName string `json:"tool_name,omitempty"`
112 ToolInput string `json:"input,omitempty"`
113 ToolResult string `json:"tool_result,omitempty"`
114 ToolError bool `json:"tool_error,omitempty"`
115 ToolCallId string `json:"tool_call_id,omitempty"`
116
117 // ToolCalls is a list of all tool calls requested in this message (name and input pairs)
118 ToolCalls []ToolCall `json:"tool_calls,omitempty"`
119
Sean McCulloughd9f13372025-04-21 15:08:49 -0700120 // ToolResponses is a list of all responses to tool calls requested in this message (name and input pairs)
121 ToolResponses []AgentMessage `json:"toolResponses,omitempty"`
122
Earl Lee2e463fb2025-04-17 11:22:22 -0700123 // Commits is a list of git commits for a commit message
124 Commits []*GitCommit `json:"commits,omitempty"`
125
126 Timestamp time.Time `json:"timestamp"`
127 ConversationID string `json:"conversation_id"`
128 ParentConversationID *string `json:"parent_conversation_id,omitempty"`
129 Usage *ant.Usage `json:"usage,omitempty"`
130
131 // Message timing information
132 StartTime *time.Time `json:"start_time,omitempty"`
133 EndTime *time.Time `json:"end_time,omitempty"`
134 Elapsed *time.Duration `json:"elapsed,omitempty"`
135
136 // Turn duration - the time taken for a complete agent turn
137 TurnDuration *time.Duration `json:"turnDuration,omitempty"`
138
139 Idx int `json:"idx"`
140}
141
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700142// SetConvo sets m.ConversationID and m.ParentConversationID based on convo.
143func (m *AgentMessage) SetConvo(convo *ant.Convo) {
144 if convo == nil {
145 m.ConversationID = ""
146 m.ParentConversationID = nil
147 return
148 }
149 m.ConversationID = convo.ID
150 if convo.Parent != nil {
151 m.ParentConversationID = &convo.Parent.ID
152 }
153}
154
Earl Lee2e463fb2025-04-17 11:22:22 -0700155// GitCommit represents a single git commit for a commit message
156type GitCommit struct {
157 Hash string `json:"hash"` // Full commit hash
158 Subject string `json:"subject"` // Commit subject line
159 Body string `json:"body"` // Full commit message body
160 PushedBranch string `json:"pushed_branch,omitempty"` // If set, this commit was pushed to this branch
161}
162
163// ToolCall represents a single tool call within an agent message
164type ToolCall struct {
Sean McCulloughd9f13372025-04-21 15:08:49 -0700165 Name string `json:"name"`
166 Input string `json:"input"`
167 ToolCallId string `json:"tool_call_id"`
168 ResultMessage *AgentMessage `json:"result_message,omitempty"`
169 Args string `json:"args,omitempty"`
170 Result string `json:"result,omitempty"`
Earl Lee2e463fb2025-04-17 11:22:22 -0700171}
172
173func (a *AgentMessage) Attr() slog.Attr {
174 var attrs []any = []any{
175 slog.String("type", string(a.Type)),
176 }
177 if a.EndOfTurn {
178 attrs = append(attrs, slog.Bool("end_of_turn", a.EndOfTurn))
179 }
180 if a.Content != "" {
181 attrs = append(attrs, slog.String("content", a.Content))
182 }
183 if a.ToolName != "" {
184 attrs = append(attrs, slog.String("tool_name", a.ToolName))
185 }
186 if a.ToolInput != "" {
187 attrs = append(attrs, slog.String("tool_input", a.ToolInput))
188 }
189 if a.Elapsed != nil {
190 attrs = append(attrs, slog.Int64("elapsed", a.Elapsed.Nanoseconds()))
191 }
192 if a.TurnDuration != nil {
193 attrs = append(attrs, slog.Int64("turnDuration", a.TurnDuration.Nanoseconds()))
194 }
195 if a.ToolResult != "" {
196 attrs = append(attrs, slog.String("tool_result", a.ToolResult))
197 }
198 if a.ToolError {
199 attrs = append(attrs, slog.Bool("tool_error", a.ToolError))
200 }
201 if len(a.ToolCalls) > 0 {
202 toolCallAttrs := make([]any, 0, len(a.ToolCalls))
203 for i, tc := range a.ToolCalls {
204 toolCallAttrs = append(toolCallAttrs, slog.Group(
205 fmt.Sprintf("tool_call_%d", i),
206 slog.String("name", tc.Name),
207 slog.String("input", tc.Input),
208 ))
209 }
210 attrs = append(attrs, slog.Group("tool_calls", toolCallAttrs...))
211 }
212 if a.ConversationID != "" {
213 attrs = append(attrs, slog.String("convo_id", a.ConversationID))
214 }
215 if a.ParentConversationID != nil {
216 attrs = append(attrs, slog.String("parent_convo_id", *a.ParentConversationID))
217 }
218 if a.Usage != nil && !a.Usage.IsZero() {
219 attrs = append(attrs, a.Usage.Attr())
220 }
221 // TODO: timestamp, convo ids, idx?
222 return slog.Group("agent_message", attrs...)
223}
224
225func errorMessage(err error) AgentMessage {
226 // It's somewhat unknowable whether error messages are "end of turn" or not, but it seems like the best approach.
227 if os.Getenv(("DEBUG")) == "1" {
228 return AgentMessage{Type: ErrorMessageType, Content: err.Error() + " Stacktrace: " + string(debug.Stack()), EndOfTurn: true}
229 }
230
231 return AgentMessage{Type: ErrorMessageType, Content: err.Error(), EndOfTurn: true}
232}
233
234func budgetMessage(err error) AgentMessage {
235 return AgentMessage{Type: BudgetMessageType, Content: err.Error(), EndOfTurn: true}
236}
237
238// ConvoInterface defines the interface for conversation interactions
239type ConvoInterface interface {
240 CumulativeUsage() ant.CumulativeUsage
241 ResetBudget(ant.Budget)
242 OverBudget() error
243 SendMessage(message ant.Message) (*ant.MessageResponse, error)
244 SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
245 ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
246 ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error)
247 CancelToolUse(toolUseID string, cause error) error
248}
249
250type Agent struct {
251 convo ConvoInterface
252 config AgentConfig // config for this agent
253 workingDir string
254 repoRoot string // workingDir may be a subdir of repoRoot
255 url string
256 lastHEAD string // hash of the last HEAD that was pushed to the host (only when under docker)
257 initialCommit string // hash of the Git HEAD when the agent was instantiated or Init()
258 gitRemoteAddr string // HTTP URL of the host git repo (only when under docker)
259 ready chan struct{} // closed when the agent is initialized (only when under docker)
260 startedAt time.Time
261 originalBudget ant.Budget
262 title string
263 codereview *claudetool.CodeReviewer
Philip Zeyliger18532b22025-04-23 21:11:46 +0000264 // Outside information
265 outsideHostname string
266 outsideOS string
267 outsideWorkingDir string
Philip Zeyligerd1402952025-04-23 03:54:37 +0000268 // URL of the git remote 'origin' if it exists
269 gitOrigin string
Earl Lee2e463fb2025-04-17 11:22:22 -0700270
271 // Time when the current turn started (reset at the beginning of InnerLoop)
272 startOfTurn time.Time
273
274 // Inbox - for messages from the user to the agent.
275 // sent on by UserMessage
276 // . e.g. when user types into the chat textarea
277 // read from by GatherMessages
278 inbox chan string
279
280 // Outbox
281 // sent on by pushToOutbox
282 // via OnToolResult and OnResponse callbacks
283 // read from by WaitForMessage
284 // called by termui inside its repl loop.
285 outbox chan AgentMessage
286
287 // protects cancelInnerLoop
288 cancelInnerLoopMu sync.Mutex
289 // cancels potentially long-running tool_use calls or chains of them
290 cancelInnerLoop context.CancelCauseFunc
291
292 // protects following
293 mu sync.Mutex
294
295 // Stores all messages for this agent
296 history []AgentMessage
297
298 listeners []chan struct{}
299
300 // Track git commits we've already seen (by hash)
301 seenCommits map[string]bool
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000302
303 // Track outstanding LLM call IDs
304 outstandingLLMCalls map[string]struct{}
305
306 // Track outstanding tool calls by ID with their names
307 outstandingToolCalls map[string]string
Earl Lee2e463fb2025-04-17 11:22:22 -0700308}
309
310func (a *Agent) URL() string { return a.url }
311
312// Title returns the current title of the conversation.
313// If no title has been set, returns an empty string.
314func (a *Agent) Title() string {
315 a.mu.Lock()
316 defer a.mu.Unlock()
317 return a.title
318}
319
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000320// OutstandingLLMCallCount returns the number of outstanding LLM calls.
321func (a *Agent) OutstandingLLMCallCount() int {
322 a.mu.Lock()
323 defer a.mu.Unlock()
324 return len(a.outstandingLLMCalls)
325}
326
327// OutstandingToolCalls returns the names of outstanding tool calls.
328func (a *Agent) OutstandingToolCalls() []string {
329 a.mu.Lock()
330 defer a.mu.Unlock()
331
332 tools := make([]string, 0, len(a.outstandingToolCalls))
333 for _, toolName := range a.outstandingToolCalls {
334 tools = append(tools, toolName)
335 }
336 return tools
337}
338
Earl Lee2e463fb2025-04-17 11:22:22 -0700339// OS returns the operating system of the client.
340func (a *Agent) OS() string {
341 return a.config.ClientGOOS
342}
343
Philip Zeyligerc72fff52025-04-29 20:17:54 +0000344func (a *Agent) SessionID() string {
345 return a.config.SessionID
346}
347
Philip Zeyliger18532b22025-04-23 21:11:46 +0000348// OutsideOS returns the operating system of the outside system.
349func (a *Agent) OutsideOS() string {
350 return a.outsideOS
Philip Zeyligerd1402952025-04-23 03:54:37 +0000351}
352
Philip Zeyliger18532b22025-04-23 21:11:46 +0000353// OutsideHostname returns the hostname of the outside system.
354func (a *Agent) OutsideHostname() string {
355 return a.outsideHostname
Philip Zeyligerd1402952025-04-23 03:54:37 +0000356}
357
Philip Zeyliger18532b22025-04-23 21:11:46 +0000358// OutsideWorkingDir returns the working directory on the outside system.
359func (a *Agent) OutsideWorkingDir() string {
360 return a.outsideWorkingDir
Philip Zeyligerd1402952025-04-23 03:54:37 +0000361}
362
363// GitOrigin returns the URL of the git remote 'origin' if it exists.
364func (a *Agent) GitOrigin() string {
365 return a.gitOrigin
366}
367
Earl Lee2e463fb2025-04-17 11:22:22 -0700368// SetTitle sets the title of the conversation.
369func (a *Agent) SetTitle(title string) {
370 a.mu.Lock()
371 defer a.mu.Unlock()
372 a.title = title
373 // Notify all listeners that the state has changed
374 for _, ch := range a.listeners {
375 close(ch)
376 }
377 a.listeners = a.listeners[:0]
378}
379
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000380// OnToolCall implements ant.Listener and tracks the start of a tool call.
381func (a *Agent) OnToolCall(ctx context.Context, convo *ant.Convo, id string, toolName string, toolInput json.RawMessage, content ant.Content) {
382 // Track the tool call
383 a.mu.Lock()
384 a.outstandingToolCalls[id] = toolName
385 a.mu.Unlock()
386}
387
Earl Lee2e463fb2025-04-17 11:22:22 -0700388// OnToolResult implements ant.Listener.
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000389func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolID string, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
390 // Remove the tool call from outstanding calls
391 a.mu.Lock()
392 delete(a.outstandingToolCalls, toolID)
393 a.mu.Unlock()
394
Earl Lee2e463fb2025-04-17 11:22:22 -0700395 m := AgentMessage{
396 Type: ToolUseMessageType,
397 Content: content.Text,
398 ToolResult: content.ToolResult,
399 ToolError: content.ToolError,
400 ToolName: toolName,
401 ToolInput: string(toolInput),
402 ToolCallId: content.ToolUseID,
403 StartTime: content.StartTime,
404 EndTime: content.EndTime,
405 }
406
407 // Calculate the elapsed time if both start and end times are set
408 if content.StartTime != nil && content.EndTime != nil {
409 elapsed := content.EndTime.Sub(*content.StartTime)
410 m.Elapsed = &elapsed
411 }
412
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700413 m.SetConvo(convo)
Earl Lee2e463fb2025-04-17 11:22:22 -0700414 a.pushToOutbox(ctx, m)
415}
416
417// OnRequest implements ant.Listener.
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000418func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, id string, msg *ant.Message) {
419 a.mu.Lock()
420 defer a.mu.Unlock()
421 a.outstandingLLMCalls[id] = struct{}{}
Earl Lee2e463fb2025-04-17 11:22:22 -0700422 // We already get tool results from the above. We send user messages to the outbox in the agent loop.
423}
424
425// OnResponse implements ant.Listener. Responses contain messages from the LLM
426// that need to be displayed (as well as tool calls that we send along when
427// they're done). (It would be reasonable to also mention tool calls when they're
428// started, but we don't do that yet.)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000429func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, id string, resp *ant.MessageResponse) {
430 // Remove the LLM call from outstanding calls
431 a.mu.Lock()
432 delete(a.outstandingLLMCalls, id)
433 a.mu.Unlock()
434
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700435 if resp == nil {
436 // LLM API call failed
437 m := AgentMessage{
438 Type: ErrorMessageType,
439 Content: "API call failed, type 'continue' to try again",
440 }
441 m.SetConvo(convo)
442 a.pushToOutbox(ctx, m)
443 return
444 }
445
Earl Lee2e463fb2025-04-17 11:22:22 -0700446 endOfTurn := false
447 if resp.StopReason != ant.StopReasonToolUse {
448 endOfTurn = true
449 }
450 m := AgentMessage{
451 Type: AgentMessageType,
452 Content: collectTextContent(resp),
453 EndOfTurn: endOfTurn,
454 Usage: &resp.Usage,
455 StartTime: resp.StartTime,
456 EndTime: resp.EndTime,
457 }
458
459 // Extract any tool calls from the response
460 if resp.StopReason == ant.StopReasonToolUse {
461 var toolCalls []ToolCall
462 for _, part := range resp.Content {
463 if part.Type == "tool_use" {
464 toolCalls = append(toolCalls, ToolCall{
465 Name: part.ToolName,
466 Input: string(part.ToolInput),
467 ToolCallId: part.ID,
468 })
469 }
470 }
471 m.ToolCalls = toolCalls
472 }
473
474 // Calculate the elapsed time if both start and end times are set
475 if resp.StartTime != nil && resp.EndTime != nil {
476 elapsed := resp.EndTime.Sub(*resp.StartTime)
477 m.Elapsed = &elapsed
478 }
479
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700480 m.SetConvo(convo)
Earl Lee2e463fb2025-04-17 11:22:22 -0700481 a.pushToOutbox(ctx, m)
482}
483
484// WorkingDir implements CodingAgent.
485func (a *Agent) WorkingDir() string {
486 return a.workingDir
487}
488
489// MessageCount implements CodingAgent.
490func (a *Agent) MessageCount() int {
491 a.mu.Lock()
492 defer a.mu.Unlock()
493 return len(a.history)
494}
495
496// Messages implements CodingAgent.
497func (a *Agent) Messages(start int, end int) []AgentMessage {
498 a.mu.Lock()
499 defer a.mu.Unlock()
500 return slices.Clone(a.history[start:end])
501}
502
503func (a *Agent) OriginalBudget() ant.Budget {
504 return a.originalBudget
505}
506
507// AgentConfig contains configuration for creating a new Agent.
508type AgentConfig struct {
509 Context context.Context
510 AntURL string
511 APIKey string
512 HTTPC *http.Client
513 Budget ant.Budget
514 GitUsername string
515 GitEmail string
516 SessionID string
517 ClientGOOS string
518 ClientGOARCH string
519 UseAnthropicEdit bool
Philip Zeyliger18532b22025-04-23 21:11:46 +0000520 // Outside information
521 OutsideHostname string
522 OutsideOS string
523 OutsideWorkingDir string
Earl Lee2e463fb2025-04-17 11:22:22 -0700524}
525
526// NewAgent creates a new Agent.
527// It is not usable until Init() is called.
528func NewAgent(config AgentConfig) *Agent {
529 agent := &Agent{
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000530 config: config,
531 ready: make(chan struct{}),
532 inbox: make(chan string, 100),
533 outbox: make(chan AgentMessage, 100),
534 startedAt: time.Now(),
535 originalBudget: config.Budget,
536 seenCommits: make(map[string]bool),
537 outsideHostname: config.OutsideHostname,
538 outsideOS: config.OutsideOS,
539 outsideWorkingDir: config.OutsideWorkingDir,
540 outstandingLLMCalls: make(map[string]struct{}),
541 outstandingToolCalls: make(map[string]string),
Earl Lee2e463fb2025-04-17 11:22:22 -0700542 }
543 return agent
544}
545
546type AgentInit struct {
547 WorkingDir string
548 NoGit bool // only for testing
549
550 InDocker bool
551 Commit string
552 GitRemoteAddr string
553 HostAddr string
554}
555
556func (a *Agent) Init(ini AgentInit) error {
Josh Bleecher Snyder9c07e1d2025-04-28 19:25:37 -0700557 if a.convo != nil {
558 return fmt.Errorf("Agent.Init: already initialized")
559 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700560 ctx := a.config.Context
561 if ini.InDocker {
562 cmd := exec.CommandContext(ctx, "git", "stash")
563 cmd.Dir = ini.WorkingDir
564 if out, err := cmd.CombinedOutput(); err != nil {
565 return fmt.Errorf("git stash: %s: %v", out, err)
566 }
Philip Zeyligerd0ac1ea2025-04-21 20:04:19 -0700567 cmd = exec.CommandContext(ctx, "git", "remote", "add", "sketch-host", ini.GitRemoteAddr)
568 cmd.Dir = ini.WorkingDir
569 if out, err := cmd.CombinedOutput(); err != nil {
570 return fmt.Errorf("git remote add: %s: %v", out, err)
571 }
572 cmd = exec.CommandContext(ctx, "git", "fetch", "sketch-host")
Earl Lee2e463fb2025-04-17 11:22:22 -0700573 cmd.Dir = ini.WorkingDir
574 if out, err := cmd.CombinedOutput(); err != nil {
575 return fmt.Errorf("git fetch: %s: %w", out, err)
576 }
577 cmd = exec.CommandContext(ctx, "git", "checkout", "-f", ini.Commit)
578 cmd.Dir = ini.WorkingDir
579 if out, err := cmd.CombinedOutput(); err != nil {
580 return fmt.Errorf("git checkout %s: %s: %w", ini.Commit, out, err)
581 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700582 a.lastHEAD = ini.Commit
583 a.gitRemoteAddr = ini.GitRemoteAddr
584 a.initialCommit = ini.Commit
585 if ini.HostAddr != "" {
586 a.url = "http://" + ini.HostAddr
587 }
588 }
589 a.workingDir = ini.WorkingDir
590
591 if !ini.NoGit {
592 repoRoot, err := repoRoot(ctx, a.workingDir)
593 if err != nil {
594 return fmt.Errorf("repoRoot: %w", err)
595 }
596 a.repoRoot = repoRoot
597
598 commitHash, err := resolveRef(ctx, a.repoRoot, "HEAD")
599 if err != nil {
600 return fmt.Errorf("resolveRef: %w", err)
601 }
602 a.initialCommit = commitHash
603
604 codereview, err := claudetool.NewCodeReviewer(ctx, a.repoRoot, a.initialCommit)
605 if err != nil {
606 return fmt.Errorf("Agent.Init: claudetool.NewCodeReviewer: %w", err)
607 }
608 a.codereview = codereview
Philip Zeyligerd1402952025-04-23 03:54:37 +0000609
610 a.gitOrigin = getGitOrigin(ctx, ini.WorkingDir)
Earl Lee2e463fb2025-04-17 11:22:22 -0700611 }
612 a.lastHEAD = a.initialCommit
613 a.convo = a.initConvo()
614 close(a.ready)
615 return nil
616}
617
618// initConvo initializes the conversation.
619// It must not be called until all agent fields are initialized,
620// particularly workingDir and git.
621func (a *Agent) initConvo() *ant.Convo {
622 ctx := a.config.Context
623 convo := ant.NewConvo(ctx, a.config.APIKey)
624 if a.config.HTTPC != nil {
625 convo.HTTPC = a.config.HTTPC
626 }
627 if a.config.AntURL != "" {
628 convo.URL = a.config.AntURL
629 }
630 convo.PromptCaching = true
631 convo.Budget = a.config.Budget
632
633 var editPrompt string
634 if a.config.UseAnthropicEdit {
635 editPrompt = "Then use the str_replace_editor tool to make those edits. For short complete file replacements, you may use the bash tool with cat and heredoc stdin."
636 } else {
637 editPrompt = "Then use the patch tool to make those edits. Combine all edits to any given file into a single patch tool call."
638 }
639
640 convo.SystemPrompt = fmt.Sprintf(`
641You are an expert coding assistant and architect, with a specialty in Go.
642You are assisting the user to achieve their goals.
643
644Start by asking concise clarifying questions as needed.
645Once the intent is clear, work autonomously.
646
647Call the title tool early in the conversation to provide a brief summary of
648what the chat is about.
649
650Break down the overall goal into a series of smaller steps.
651(The first step is often: "Make a plan.")
652Then execute each step using tools.
653Update the plan if you have encountered problems or learned new information.
654
655When in doubt about a step, follow this broad workflow:
656
657- Think about how the current step fits into the overall plan.
658- Do research. Good tool choices: bash, think, keyword_search
659- Make edits.
660- Repeat.
661
662To make edits reliably and efficiently, first think about the intent of the edit,
663and what set of patches will achieve that intent.
664%s
665
666For renames or refactors, consider invoking gopls (via bash).
667
668The done tool provides a checklist of items you MUST verify and
669review before declaring that you are done. Before executing
670the done tool, run all the tools the done tool checklist asks
671for, including creating a git commit. Do not forget to run tests.
672
673<platform>
674%s/%s
675</platform>
676<pwd>
677%v
678</pwd>
679<git_root>
680%v
681</git_root>
Josh Bleecher Snyder833a0f82025-04-24 18:39:36 +0000682<HEAD>
683%v
684</HEAD>
685`, editPrompt, a.config.ClientGOOS, a.config.ClientGOARCH, a.workingDir, a.repoRoot, a.initialCommit)
Earl Lee2e463fb2025-04-17 11:22:22 -0700686
687 // Register all tools with the conversation
688 // When adding, removing, or modifying tools here, double-check that the termui tool display
689 // template in termui/termui.go has pretty-printing support for all tools.
690 convo.Tools = []*ant.Tool{
691 claudetool.Bash, claudetool.Keyword,
692 claudetool.Think, a.titleTool(), makeDoneTool(a.codereview, a.config.GitUsername, a.config.GitEmail),
693 a.codereview.Tool(),
694 }
695 if a.config.UseAnthropicEdit {
696 convo.Tools = append(convo.Tools, claudetool.AnthropicEditTool)
697 } else {
698 convo.Tools = append(convo.Tools, claudetool.Patch)
699 }
700 convo.Listener = a
701 return convo
702}
703
704func (a *Agent) titleTool() *ant.Tool {
705 // titleTool creates the title tool that sets the conversation title.
706 title := &ant.Tool{
707 Name: "title",
708 Description: `Use this tool early in the conversation, BEFORE MAKING ANY GIT COMMITS, to summarize what the chat is about briefly.`,
709 InputSchema: json.RawMessage(`{
710 "type": "object",
711 "properties": {
712 "title": {
713 "type": "string",
714 "description": "A brief title summarizing what this chat is about"
715 }
716 },
717 "required": ["title"]
718}`),
719 Run: func(ctx context.Context, input json.RawMessage) (string, error) {
720 var params struct {
721 Title string `json:"title"`
722 }
723 if err := json.Unmarshal(input, &params); err != nil {
724 return "", err
725 }
726 a.SetTitle(params.Title)
727 return fmt.Sprintf("Title set to: %s", params.Title), nil
728 },
729 }
730 return title
731}
732
733func (a *Agent) Ready() <-chan struct{} {
734 return a.ready
735}
736
737func (a *Agent) UserMessage(ctx context.Context, msg string) {
738 a.pushToOutbox(ctx, AgentMessage{Type: UserMessageType, Content: msg})
739 a.inbox <- msg
740}
741
742func (a *Agent) WaitForMessage(ctx context.Context) AgentMessage {
743 // TODO: Should this drain any outbox messages in case there are multiple?
744 select {
745 case msg := <-a.outbox:
746 return msg
747 case <-ctx.Done():
748 return errorMessage(ctx.Err())
749 }
750}
751
752func (a *Agent) CancelToolUse(toolUseID string, cause error) error {
753 return a.convo.CancelToolUse(toolUseID, cause)
754}
755
756func (a *Agent) CancelInnerLoop(cause error) {
757 a.cancelInnerLoopMu.Lock()
758 defer a.cancelInnerLoopMu.Unlock()
759 if a.cancelInnerLoop != nil {
760 a.cancelInnerLoop(cause)
761 }
762}
763
764func (a *Agent) Loop(ctxOuter context.Context) {
765 for {
766 select {
767 case <-ctxOuter.Done():
768 return
769 default:
770 ctxInner, cancel := context.WithCancelCause(ctxOuter)
771 a.cancelInnerLoopMu.Lock()
772 // Set .cancelInnerLoop so the user can cancel whatever is happening
773 // inside InnerLoop(ctxInner) without canceling this outer Loop execution.
774 // This CancelInnerLoop func is intended be called from other goroutines,
775 // hence the mutex.
776 a.cancelInnerLoop = cancel
777 a.cancelInnerLoopMu.Unlock()
778 a.InnerLoop(ctxInner)
779 cancel(nil)
780 }
781 }
782}
783
784func (a *Agent) pushToOutbox(ctx context.Context, m AgentMessage) {
785 if m.Timestamp.IsZero() {
786 m.Timestamp = time.Now()
787 }
788
789 // If this is an end-of-turn message, calculate the turn duration and add it to the message
790 if m.EndOfTurn && m.Type == AgentMessageType {
791 turnDuration := time.Since(a.startOfTurn)
792 m.TurnDuration = &turnDuration
793 slog.InfoContext(ctx, "Turn completed", "turnDuration", turnDuration)
794 }
795
796 slog.InfoContext(ctx, "agent message", m.Attr())
797
798 a.mu.Lock()
799 defer a.mu.Unlock()
800 m.Idx = len(a.history)
801 a.history = append(a.history, m)
802 a.outbox <- m
803
804 // Notify all listeners:
805 for _, ch := range a.listeners {
806 close(ch)
807 }
808 a.listeners = a.listeners[:0]
809}
810
811func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
812 var m []ant.Content
813 if block {
814 select {
815 case <-ctx.Done():
816 return m, ctx.Err()
817 case msg := <-a.inbox:
818 m = append(m, ant.Content{Type: "text", Text: msg})
819 }
820 }
821 for {
822 select {
823 case msg := <-a.inbox:
824 m = append(m, ant.Content{Type: "text", Text: msg})
825 default:
826 return m, nil
827 }
828 }
829}
830
831func (a *Agent) InnerLoop(ctx context.Context) {
832 // Reset the start of turn time
833 a.startOfTurn = time.Now()
834
835 // Wait for at least one message from the user.
836 msgs, err := a.GatherMessages(ctx, true)
837 if err != nil { // e.g. the context was canceled while blocking in GatherMessages
838 return
839 }
840 // We do this as we go, but let's also do it at the end of the turn
841 defer func() {
842 if _, err := a.handleGitCommits(ctx); err != nil {
843 // Just log the error, don't stop execution
844 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
845 }
846 }()
847
848 userMessage := ant.Message{
849 Role: "user",
850 Content: msgs,
851 }
852 // convo.SendMessage does the actual network call to send this to anthropic. This blocks until the response is ready.
853 // TODO: pass ctx to SendMessage, and figure out how to square that ctx with convo's own .Ctx. Who owns the scope of this call?
854 resp, err := a.convo.SendMessage(userMessage)
855 if err != nil {
856 a.pushToOutbox(ctx, errorMessage(err))
857 return
858 }
859 for {
860 // TODO: here and below where we check the budget,
861 // we should review the UX: is it clear what happened?
862 // is it clear how to resume?
863 // should we let the user set a new budget?
864 if err := a.overBudget(ctx); err != nil {
865 return
866 }
867 if resp.StopReason != ant.StopReasonToolUse {
868 break
869 }
870 var results []ant.Content
871 cancelled := false
872 select {
873 case <-ctx.Done():
874 // Don't actually run any of the tools, but rather build a response
875 // for each tool_use message letting the LLM know that user canceled it.
876 results, err = a.convo.ToolResultCancelContents(resp)
877 if err != nil {
878 a.pushToOutbox(ctx, errorMessage(err))
879 }
880 cancelled = true
881 default:
882 ctx = claudetool.WithWorkingDir(ctx, a.workingDir)
883 // fall-through, when the user has not canceled the inner loop:
884 results, err = a.convo.ToolResultContents(ctx, resp)
885 if ctx.Err() != nil { // e.g. the user canceled the operation
886 cancelled = true
887 } else if err != nil {
888 a.pushToOutbox(ctx, errorMessage(err))
889 }
890 }
891
892 // Check for git commits. Currently we do this here, after we collect
893 // tool results, since that's when we know commits could have happened.
894 // We could instead do this when the turn ends, but I think it makes sense
895 // to do this as we go.
896 newCommits, err := a.handleGitCommits(ctx)
897 if err != nil {
898 // Just log the error, don't stop execution
899 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
900 }
901 var autoqualityMessages []string
902 if len(newCommits) == 1 {
903 formatted := a.codereview.Autoformat(ctx)
904 if len(formatted) > 0 {
905 msg := fmt.Sprintf(`
906I ran autoformatters and they updated these files:
907
908%s
909
910Please amend your latest git commit with these changes and then continue with what you were doing.`,
911 strings.Join(formatted, "\n"),
912 )[1:]
913 a.pushToOutbox(ctx, AgentMessage{
914 Type: AutoMessageType,
915 Content: msg,
916 Timestamp: time.Now(),
917 })
918 autoqualityMessages = append(autoqualityMessages, msg)
919 }
920 }
921
922 if err := a.overBudget(ctx); err != nil {
923 return
924 }
925
926 // Include, along with the tool results (which must go first for whatever reason),
927 // any messages that the user has sent along while the tool_use was executing concurrently.
928 msgs, err = a.GatherMessages(ctx, false)
929 if err != nil {
930 return
931 }
932 // Inject any auto-generated messages from quality checks.
933 for _, msg := range autoqualityMessages {
934 msgs = append(msgs, ant.Content{Type: "text", Text: msg})
935 }
936 if cancelled {
937 msgs = append(msgs, ant.Content{Type: "text", Text: cancelToolUseMessage})
938 // EndOfTurn is false here so that the client of this agent keeps processing
939 // messages from WaitForMessage() and gets the response from the LLM (usually
940 // something like "okay, I'll wait further instructions", but the user should
941 // be made aware of it regardless).
942 a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
943 } else if err := a.convo.OverBudget(); err != nil {
944 budgetMsg := "We've exceeded our budget. Please ask the user to confirm before continuing by ending the turn."
945 msgs = append(msgs, ant.Content{Type: "text", Text: budgetMsg})
946 a.pushToOutbox(ctx, budgetMessage(fmt.Errorf("warning: %w (ask to keep trying, if you'd like)", err)))
947 }
948 results = append(results, msgs...)
949 resp, err = a.convo.SendMessage(ant.Message{
950 Role: "user",
951 Content: results,
952 })
953 if err != nil {
954 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("error: failed to continue conversation: %s", err.Error())))
955 break
956 }
957 if cancelled {
958 return
959 }
960 }
961}
962
963func (a *Agent) overBudget(ctx context.Context) error {
964 if err := a.convo.OverBudget(); err != nil {
965 m := budgetMessage(err)
966 m.Content = m.Content + "\n\nBudget reset."
967 a.pushToOutbox(ctx, budgetMessage(err))
968 a.convo.ResetBudget(a.originalBudget)
969 return err
970 }
971 return nil
972}
973
974func collectTextContent(msg *ant.MessageResponse) string {
975 // Collect all text content
976 var allText strings.Builder
977 for _, content := range msg.Content {
978 if content.Type == "text" && content.Text != "" {
979 if allText.Len() > 0 {
980 allText.WriteString("\n\n")
981 }
982 allText.WriteString(content.Text)
983 }
984 }
985 return allText.String()
986}
987
988func (a *Agent) TotalUsage() ant.CumulativeUsage {
989 a.mu.Lock()
990 defer a.mu.Unlock()
991 return a.convo.CumulativeUsage()
992}
993
994// WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
995func (a *Agent) WaitForMessageCount(ctx context.Context, greaterThan int) {
996 for a.MessageCount() <= greaterThan {
997 a.mu.Lock()
998 ch := make(chan struct{})
999 // Deletion happens when we notify.
1000 a.listeners = append(a.listeners, ch)
1001 a.mu.Unlock()
1002
1003 select {
1004 case <-ctx.Done():
1005 return
1006 case <-ch:
1007 continue
1008 }
1009 }
1010}
1011
1012// Diff returns a unified diff of changes made since the agent was instantiated.
1013func (a *Agent) Diff(commit *string) (string, error) {
1014 if a.initialCommit == "" {
1015 return "", fmt.Errorf("no initial commit reference available")
1016 }
1017
1018 // Find the repository root
1019 ctx := context.Background()
1020
1021 // If a specific commit hash is provided, show just that commit's changes
1022 if commit != nil && *commit != "" {
1023 // Validate that the commit looks like a valid git SHA
1024 if !isValidGitSHA(*commit) {
1025 return "", fmt.Errorf("invalid git commit SHA format: %s", *commit)
1026 }
1027
1028 // Get the diff for just this commit
1029 cmd := exec.CommandContext(ctx, "git", "show", "--unified=10", *commit)
1030 cmd.Dir = a.repoRoot
1031 output, err := cmd.CombinedOutput()
1032 if err != nil {
1033 return "", fmt.Errorf("failed to get diff for commit %s: %w - %s", *commit, err, string(output))
1034 }
1035 return string(output), nil
1036 }
1037
1038 // Otherwise, get the diff between the initial commit and the current state using exec.Command
1039 cmd := exec.CommandContext(ctx, "git", "diff", "--unified=10", a.initialCommit)
1040 cmd.Dir = a.repoRoot
1041 output, err := cmd.CombinedOutput()
1042 if err != nil {
1043 return "", fmt.Errorf("failed to get diff: %w - %s", err, string(output))
1044 }
1045
1046 return string(output), nil
1047}
1048
1049// InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
1050func (a *Agent) InitialCommit() string {
1051 return a.initialCommit
1052}
1053
1054// handleGitCommits() highlights new commits to the user. When running
1055// under docker, new HEADs are pushed to a branch according to the title.
1056func (a *Agent) handleGitCommits(ctx context.Context) ([]*GitCommit, error) {
1057 if a.repoRoot == "" {
1058 return nil, nil
1059 }
1060
1061 head, err := resolveRef(ctx, a.repoRoot, "HEAD")
1062 if err != nil {
1063 return nil, err
1064 }
1065 if head == a.lastHEAD {
1066 return nil, nil // nothing to do
1067 }
1068 defer func() {
1069 a.lastHEAD = head
1070 }()
1071
1072 // Get new commits. Because it's possible that the agent does rebases, fixups, and
1073 // so forth, we use, as our fixed point, the "initialCommit", and we limit ourselves
1074 // to the last 100 commits.
1075 var commits []*GitCommit
1076
1077 // Get commits since the initial commit
1078 // Format: <hash>\0<subject>\0<body>\0
1079 // This uses NULL bytes as separators to avoid issues with newlines in commit messages
1080 // Limit to 100 commits to avoid overwhelming the user
1081 cmd := exec.CommandContext(ctx, "git", "log", "-n", "100", "--pretty=format:%H%x00%s%x00%b%x00", "^"+a.initialCommit, head)
1082 cmd.Dir = a.repoRoot
1083 output, err := cmd.Output()
1084 if err != nil {
1085 return nil, fmt.Errorf("failed to get git log: %w", err)
1086 }
1087
1088 // Parse git log output and filter out already seen commits
1089 parsedCommits := parseGitLog(string(output))
1090
1091 var headCommit *GitCommit
1092
1093 // Filter out commits we've already seen
1094 for _, commit := range parsedCommits {
1095 if commit.Hash == head {
1096 headCommit = &commit
1097 }
1098
1099 // Skip if we've seen this commit before. If our head has changed, always include that.
1100 if a.seenCommits[commit.Hash] && commit.Hash != head {
1101 continue
1102 }
1103
1104 // Mark this commit as seen
1105 a.seenCommits[commit.Hash] = true
1106
1107 // Add to our list of new commits
1108 commits = append(commits, &commit)
1109 }
1110
1111 if a.gitRemoteAddr != "" {
1112 if headCommit == nil {
1113 // I think this can only happen if we have a bug or if there's a race.
1114 headCommit = &GitCommit{}
1115 headCommit.Hash = head
1116 headCommit.Subject = "unknown"
1117 commits = append(commits, headCommit)
1118 }
1119
1120 cleanTitle := titleToBranch(a.title)
1121 if cleanTitle == "" {
1122 cleanTitle = a.config.SessionID
1123 }
1124 branch := "sketch/" + cleanTitle
1125
1126 // TODO: I don't love the force push here. We could see if the push is a fast-forward, and,
1127 // if it's not, we could make a backup with a unique name (perhaps append a timestamp) and
1128 // then use push with lease to replace.
1129 cmd = exec.Command("git", "push", "--force", a.gitRemoteAddr, "HEAD:refs/heads/"+branch)
1130 cmd.Dir = a.workingDir
1131 if out, err := cmd.CombinedOutput(); err != nil {
1132 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("git push to host: %s: %v", out, err)))
1133 } else {
1134 headCommit.PushedBranch = branch
1135 }
1136 }
1137
1138 // If we found new commits, create a message
1139 if len(commits) > 0 {
1140 msg := AgentMessage{
1141 Type: CommitMessageType,
1142 Timestamp: time.Now(),
1143 Commits: commits,
1144 }
1145 a.pushToOutbox(ctx, msg)
1146 }
1147 return commits, nil
1148}
1149
1150func titleToBranch(s string) string {
1151 // Convert to lowercase
1152 s = strings.ToLower(s)
1153
1154 // Replace spaces with hyphens
1155 s = strings.ReplaceAll(s, " ", "-")
1156
1157 // Remove any character that isn't a-z or hyphen
1158 var result strings.Builder
1159 for _, r := range s {
1160 if (r >= 'a' && r <= 'z') || r == '-' {
1161 result.WriteRune(r)
1162 }
1163 }
1164 return result.String()
1165}
1166
1167// parseGitLog parses the output of git log with format '%H%x00%s%x00%b%x00'
1168// and returns an array of GitCommit structs.
1169func parseGitLog(output string) []GitCommit {
1170 var commits []GitCommit
1171
1172 // No output means no commits
1173 if len(output) == 0 {
1174 return commits
1175 }
1176
1177 // Split by NULL byte
1178 parts := strings.Split(output, "\x00")
1179
1180 // Process in triplets (hash, subject, body)
1181 for i := 0; i < len(parts); i++ {
1182 // Skip empty parts
1183 if parts[i] == "" {
1184 continue
1185 }
1186
1187 // This should be a hash
1188 hash := strings.TrimSpace(parts[i])
1189
1190 // Make sure we have at least a subject part available
1191 if i+1 >= len(parts) {
1192 break // No more parts available
1193 }
1194
1195 // Get the subject
1196 subject := strings.TrimSpace(parts[i+1])
1197
1198 // Get the body if available
1199 body := ""
1200 if i+2 < len(parts) {
1201 body = strings.TrimSpace(parts[i+2])
1202 }
1203
1204 // Skip to the next triplet
1205 i += 2
1206
1207 commits = append(commits, GitCommit{
1208 Hash: hash,
1209 Subject: subject,
1210 Body: body,
1211 })
1212 }
1213
1214 return commits
1215}
1216
1217func repoRoot(ctx context.Context, dir string) (string, error) {
1218 cmd := exec.CommandContext(ctx, "git", "rev-parse", "--show-toplevel")
1219 stderr := new(strings.Builder)
1220 cmd.Stderr = stderr
1221 cmd.Dir = dir
1222 out, err := cmd.Output()
1223 if err != nil {
1224 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1225 }
1226 return strings.TrimSpace(string(out)), nil
1227}
1228
1229func resolveRef(ctx context.Context, dir, refName string) (string, error) {
1230 cmd := exec.CommandContext(ctx, "git", "rev-parse", refName)
1231 stderr := new(strings.Builder)
1232 cmd.Stderr = stderr
1233 cmd.Dir = dir
1234 out, err := cmd.Output()
1235 if err != nil {
1236 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1237 }
1238 // TODO: validate that out is valid hex
1239 return strings.TrimSpace(string(out)), nil
1240}
1241
1242// isValidGitSHA validates if a string looks like a valid git SHA hash.
1243// Git SHAs are hexadecimal strings of at least 4 characters but typically 7, 8, or 40 characters.
1244func isValidGitSHA(sha string) bool {
1245 // Git SHA must be a hexadecimal string with at least 4 characters
1246 if len(sha) < 4 || len(sha) > 40 {
1247 return false
1248 }
1249
1250 // Check if the string only contains hexadecimal characters
1251 for _, char := range sha {
1252 if !(char >= '0' && char <= '9') && !(char >= 'a' && char <= 'f') && !(char >= 'A' && char <= 'F') {
1253 return false
1254 }
1255 }
1256
1257 return true
1258}
Philip Zeyligerd1402952025-04-23 03:54:37 +00001259
1260// getGitOrigin returns the URL of the git remote 'origin' if it exists
1261func getGitOrigin(ctx context.Context, dir string) string {
1262 cmd := exec.CommandContext(ctx, "git", "config", "--get", "remote.origin.url")
1263 cmd.Dir = dir
1264 stderr := new(strings.Builder)
1265 cmd.Stderr = stderr
1266 out, err := cmd.Output()
1267 if err != nil {
1268 return ""
1269 }
1270 return strings.TrimSpace(string(out))
1271}