blob: fd9c63b26d6b17109309dd673c46c97440f96697 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
Josh Bleecher Snyderdbe02302025-04-29 16:44:23 -07005 _ "embed"
Earl Lee2e463fb2025-04-17 11:22:22 -07006 "encoding/json"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "os"
11 "os/exec"
12 "runtime/debug"
13 "slices"
14 "strings"
15 "sync"
16 "time"
17
18 "sketch.dev/ant"
19 "sketch.dev/claudetool"
20)
21
22const (
23 userCancelMessage = "user requested agent to stop handling responses"
24)
25
26type CodingAgent interface {
27 // Init initializes an agent inside a docker container.
28 Init(AgentInit) error
29
30 // Ready returns a channel closed after Init successfully called.
31 Ready() <-chan struct{}
32
33 // URL reports the HTTP URL of this agent.
34 URL() string
35
36 // UserMessage enqueues a message to the agent and returns immediately.
37 UserMessage(ctx context.Context, msg string)
38
39 // WaitForMessage blocks until the agent has a response to give.
40 // Use AgentMessage.EndOfTurn to help determine if you want to
41 // drain the agent.
42 WaitForMessage(ctx context.Context) AgentMessage
43
44 // Loop begins the agent loop returns only when ctx is cancelled.
45 Loop(ctx context.Context)
46
47 CancelInnerLoop(cause error)
48
49 CancelToolUse(toolUseID string, cause error) error
50
51 // Returns a subset of the agent's message history.
52 Messages(start int, end int) []AgentMessage
53
54 // Returns the current number of messages in the history
55 MessageCount() int
56
57 TotalUsage() ant.CumulativeUsage
58 OriginalBudget() ant.Budget
59
60 // WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
61 WaitForMessageCount(ctx context.Context, greaterThan int)
62
63 WorkingDir() string
64
65 // Diff returns a unified diff of changes made since the agent was instantiated.
66 // If commit is non-nil, it shows the diff for just that specific commit.
67 Diff(commit *string) (string, error)
68
69 // InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
70 InitialCommit() string
71
72 // Title returns the current title of the conversation.
73 Title() string
74
75 // OS returns the operating system of the client.
76 OS() string
Philip Zeyliger99a9a022025-04-27 15:15:25 +000077
Philip Zeyligerc72fff52025-04-29 20:17:54 +000078 // SessionID returns the unique session identifier.
79 SessionID() string
80
Philip Zeyliger99a9a022025-04-27 15:15:25 +000081 // OutstandingLLMCallCount returns the number of outstanding LLM calls.
82 OutstandingLLMCallCount() int
83
84 // OutstandingToolCalls returns the names of outstanding tool calls.
85 OutstandingToolCalls() []string
Philip Zeyliger18532b22025-04-23 21:11:46 +000086 OutsideOS() string
87 OutsideHostname() string
88 OutsideWorkingDir() string
Philip Zeyligerd1402952025-04-23 03:54:37 +000089 GitOrigin() string
Earl Lee2e463fb2025-04-17 11:22:22 -070090}
91
92type CodingAgentMessageType string
93
94const (
95 UserMessageType CodingAgentMessageType = "user"
96 AgentMessageType CodingAgentMessageType = "agent"
97 ErrorMessageType CodingAgentMessageType = "error"
98 BudgetMessageType CodingAgentMessageType = "budget" // dedicated for "out of budget" errors
99 ToolUseMessageType CodingAgentMessageType = "tool"
100 CommitMessageType CodingAgentMessageType = "commit" // for displaying git commits
101 AutoMessageType CodingAgentMessageType = "auto" // for automated notifications like autoformatting
102
103 cancelToolUseMessage = "Stop responding to my previous message. Wait for me to ask you something else before attempting to use any more tools."
104)
105
106type AgentMessage struct {
107 Type CodingAgentMessageType `json:"type"`
108 // EndOfTurn indicates that the AI is done working and is ready for the next user input.
109 EndOfTurn bool `json:"end_of_turn"`
110
111 Content string `json:"content"`
112 ToolName string `json:"tool_name,omitempty"`
113 ToolInput string `json:"input,omitempty"`
114 ToolResult string `json:"tool_result,omitempty"`
115 ToolError bool `json:"tool_error,omitempty"`
116 ToolCallId string `json:"tool_call_id,omitempty"`
117
118 // ToolCalls is a list of all tool calls requested in this message (name and input pairs)
119 ToolCalls []ToolCall `json:"tool_calls,omitempty"`
120
Sean McCulloughd9f13372025-04-21 15:08:49 -0700121 // ToolResponses is a list of all responses to tool calls requested in this message (name and input pairs)
122 ToolResponses []AgentMessage `json:"toolResponses,omitempty"`
123
Earl Lee2e463fb2025-04-17 11:22:22 -0700124 // Commits is a list of git commits for a commit message
125 Commits []*GitCommit `json:"commits,omitempty"`
126
127 Timestamp time.Time `json:"timestamp"`
128 ConversationID string `json:"conversation_id"`
129 ParentConversationID *string `json:"parent_conversation_id,omitempty"`
130 Usage *ant.Usage `json:"usage,omitempty"`
131
132 // Message timing information
133 StartTime *time.Time `json:"start_time,omitempty"`
134 EndTime *time.Time `json:"end_time,omitempty"`
135 Elapsed *time.Duration `json:"elapsed,omitempty"`
136
137 // Turn duration - the time taken for a complete agent turn
138 TurnDuration *time.Duration `json:"turnDuration,omitempty"`
139
140 Idx int `json:"idx"`
141}
142
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700143// SetConvo sets m.ConversationID and m.ParentConversationID based on convo.
144func (m *AgentMessage) SetConvo(convo *ant.Convo) {
145 if convo == nil {
146 m.ConversationID = ""
147 m.ParentConversationID = nil
148 return
149 }
150 m.ConversationID = convo.ID
151 if convo.Parent != nil {
152 m.ParentConversationID = &convo.Parent.ID
153 }
154}
155
Earl Lee2e463fb2025-04-17 11:22:22 -0700156// GitCommit represents a single git commit for a commit message
157type GitCommit struct {
158 Hash string `json:"hash"` // Full commit hash
159 Subject string `json:"subject"` // Commit subject line
160 Body string `json:"body"` // Full commit message body
161 PushedBranch string `json:"pushed_branch,omitempty"` // If set, this commit was pushed to this branch
162}
163
164// ToolCall represents a single tool call within an agent message
165type ToolCall struct {
Sean McCulloughd9f13372025-04-21 15:08:49 -0700166 Name string `json:"name"`
167 Input string `json:"input"`
168 ToolCallId string `json:"tool_call_id"`
169 ResultMessage *AgentMessage `json:"result_message,omitempty"`
170 Args string `json:"args,omitempty"`
171 Result string `json:"result,omitempty"`
Earl Lee2e463fb2025-04-17 11:22:22 -0700172}
173
174func (a *AgentMessage) Attr() slog.Attr {
175 var attrs []any = []any{
176 slog.String("type", string(a.Type)),
177 }
178 if a.EndOfTurn {
179 attrs = append(attrs, slog.Bool("end_of_turn", a.EndOfTurn))
180 }
181 if a.Content != "" {
182 attrs = append(attrs, slog.String("content", a.Content))
183 }
184 if a.ToolName != "" {
185 attrs = append(attrs, slog.String("tool_name", a.ToolName))
186 }
187 if a.ToolInput != "" {
188 attrs = append(attrs, slog.String("tool_input", a.ToolInput))
189 }
190 if a.Elapsed != nil {
191 attrs = append(attrs, slog.Int64("elapsed", a.Elapsed.Nanoseconds()))
192 }
193 if a.TurnDuration != nil {
194 attrs = append(attrs, slog.Int64("turnDuration", a.TurnDuration.Nanoseconds()))
195 }
196 if a.ToolResult != "" {
197 attrs = append(attrs, slog.String("tool_result", a.ToolResult))
198 }
199 if a.ToolError {
200 attrs = append(attrs, slog.Bool("tool_error", a.ToolError))
201 }
202 if len(a.ToolCalls) > 0 {
203 toolCallAttrs := make([]any, 0, len(a.ToolCalls))
204 for i, tc := range a.ToolCalls {
205 toolCallAttrs = append(toolCallAttrs, slog.Group(
206 fmt.Sprintf("tool_call_%d", i),
207 slog.String("name", tc.Name),
208 slog.String("input", tc.Input),
209 ))
210 }
211 attrs = append(attrs, slog.Group("tool_calls", toolCallAttrs...))
212 }
213 if a.ConversationID != "" {
214 attrs = append(attrs, slog.String("convo_id", a.ConversationID))
215 }
216 if a.ParentConversationID != nil {
217 attrs = append(attrs, slog.String("parent_convo_id", *a.ParentConversationID))
218 }
219 if a.Usage != nil && !a.Usage.IsZero() {
220 attrs = append(attrs, a.Usage.Attr())
221 }
222 // TODO: timestamp, convo ids, idx?
223 return slog.Group("agent_message", attrs...)
224}
225
226func errorMessage(err error) AgentMessage {
227 // It's somewhat unknowable whether error messages are "end of turn" or not, but it seems like the best approach.
228 if os.Getenv(("DEBUG")) == "1" {
229 return AgentMessage{Type: ErrorMessageType, Content: err.Error() + " Stacktrace: " + string(debug.Stack()), EndOfTurn: true}
230 }
231
232 return AgentMessage{Type: ErrorMessageType, Content: err.Error(), EndOfTurn: true}
233}
234
235func budgetMessage(err error) AgentMessage {
236 return AgentMessage{Type: BudgetMessageType, Content: err.Error(), EndOfTurn: true}
237}
238
239// ConvoInterface defines the interface for conversation interactions
240type ConvoInterface interface {
241 CumulativeUsage() ant.CumulativeUsage
242 ResetBudget(ant.Budget)
243 OverBudget() error
244 SendMessage(message ant.Message) (*ant.MessageResponse, error)
245 SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
246 ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
247 ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error)
248 CancelToolUse(toolUseID string, cause error) error
249}
250
251type Agent struct {
252 convo ConvoInterface
253 config AgentConfig // config for this agent
254 workingDir string
255 repoRoot string // workingDir may be a subdir of repoRoot
256 url string
257 lastHEAD string // hash of the last HEAD that was pushed to the host (only when under docker)
258 initialCommit string // hash of the Git HEAD when the agent was instantiated or Init()
259 gitRemoteAddr string // HTTP URL of the host git repo (only when under docker)
260 ready chan struct{} // closed when the agent is initialized (only when under docker)
261 startedAt time.Time
262 originalBudget ant.Budget
263 title string
264 codereview *claudetool.CodeReviewer
Philip Zeyliger18532b22025-04-23 21:11:46 +0000265 // Outside information
266 outsideHostname string
267 outsideOS string
268 outsideWorkingDir string
Philip Zeyligerd1402952025-04-23 03:54:37 +0000269 // URL of the git remote 'origin' if it exists
270 gitOrigin string
Earl Lee2e463fb2025-04-17 11:22:22 -0700271
272 // Time when the current turn started (reset at the beginning of InnerLoop)
273 startOfTurn time.Time
274
275 // Inbox - for messages from the user to the agent.
276 // sent on by UserMessage
277 // . e.g. when user types into the chat textarea
278 // read from by GatherMessages
279 inbox chan string
280
281 // Outbox
282 // sent on by pushToOutbox
283 // via OnToolResult and OnResponse callbacks
284 // read from by WaitForMessage
285 // called by termui inside its repl loop.
286 outbox chan AgentMessage
287
288 // protects cancelInnerLoop
289 cancelInnerLoopMu sync.Mutex
290 // cancels potentially long-running tool_use calls or chains of them
291 cancelInnerLoop context.CancelCauseFunc
292
293 // protects following
294 mu sync.Mutex
295
296 // Stores all messages for this agent
297 history []AgentMessage
298
299 listeners []chan struct{}
300
301 // Track git commits we've already seen (by hash)
302 seenCommits map[string]bool
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000303
304 // Track outstanding LLM call IDs
305 outstandingLLMCalls map[string]struct{}
306
307 // Track outstanding tool calls by ID with their names
308 outstandingToolCalls map[string]string
Earl Lee2e463fb2025-04-17 11:22:22 -0700309}
310
311func (a *Agent) URL() string { return a.url }
312
313// Title returns the current title of the conversation.
314// If no title has been set, returns an empty string.
315func (a *Agent) Title() string {
316 a.mu.Lock()
317 defer a.mu.Unlock()
318 return a.title
319}
320
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000321// OutstandingLLMCallCount returns the number of outstanding LLM calls.
322func (a *Agent) OutstandingLLMCallCount() int {
323 a.mu.Lock()
324 defer a.mu.Unlock()
325 return len(a.outstandingLLMCalls)
326}
327
328// OutstandingToolCalls returns the names of outstanding tool calls.
329func (a *Agent) OutstandingToolCalls() []string {
330 a.mu.Lock()
331 defer a.mu.Unlock()
332
333 tools := make([]string, 0, len(a.outstandingToolCalls))
334 for _, toolName := range a.outstandingToolCalls {
335 tools = append(tools, toolName)
336 }
337 return tools
338}
339
Earl Lee2e463fb2025-04-17 11:22:22 -0700340// OS returns the operating system of the client.
341func (a *Agent) OS() string {
342 return a.config.ClientGOOS
343}
344
Philip Zeyligerc72fff52025-04-29 20:17:54 +0000345func (a *Agent) SessionID() string {
346 return a.config.SessionID
347}
348
Philip Zeyliger18532b22025-04-23 21:11:46 +0000349// OutsideOS returns the operating system of the outside system.
350func (a *Agent) OutsideOS() string {
351 return a.outsideOS
Philip Zeyligerd1402952025-04-23 03:54:37 +0000352}
353
Philip Zeyliger18532b22025-04-23 21:11:46 +0000354// OutsideHostname returns the hostname of the outside system.
355func (a *Agent) OutsideHostname() string {
356 return a.outsideHostname
Philip Zeyligerd1402952025-04-23 03:54:37 +0000357}
358
Philip Zeyliger18532b22025-04-23 21:11:46 +0000359// OutsideWorkingDir returns the working directory on the outside system.
360func (a *Agent) OutsideWorkingDir() string {
361 return a.outsideWorkingDir
Philip Zeyligerd1402952025-04-23 03:54:37 +0000362}
363
364// GitOrigin returns the URL of the git remote 'origin' if it exists.
365func (a *Agent) GitOrigin() string {
366 return a.gitOrigin
367}
368
Earl Lee2e463fb2025-04-17 11:22:22 -0700369// SetTitle sets the title of the conversation.
370func (a *Agent) SetTitle(title string) {
371 a.mu.Lock()
372 defer a.mu.Unlock()
373 a.title = title
374 // Notify all listeners that the state has changed
375 for _, ch := range a.listeners {
376 close(ch)
377 }
378 a.listeners = a.listeners[:0]
379}
380
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000381// OnToolCall implements ant.Listener and tracks the start of a tool call.
382func (a *Agent) OnToolCall(ctx context.Context, convo *ant.Convo, id string, toolName string, toolInput json.RawMessage, content ant.Content) {
383 // Track the tool call
384 a.mu.Lock()
385 a.outstandingToolCalls[id] = toolName
386 a.mu.Unlock()
387}
388
Earl Lee2e463fb2025-04-17 11:22:22 -0700389// OnToolResult implements ant.Listener.
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000390func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolID string, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
391 // Remove the tool call from outstanding calls
392 a.mu.Lock()
393 delete(a.outstandingToolCalls, toolID)
394 a.mu.Unlock()
395
Earl Lee2e463fb2025-04-17 11:22:22 -0700396 m := AgentMessage{
397 Type: ToolUseMessageType,
398 Content: content.Text,
399 ToolResult: content.ToolResult,
400 ToolError: content.ToolError,
401 ToolName: toolName,
402 ToolInput: string(toolInput),
403 ToolCallId: content.ToolUseID,
404 StartTime: content.StartTime,
405 EndTime: content.EndTime,
406 }
407
408 // Calculate the elapsed time if both start and end times are set
409 if content.StartTime != nil && content.EndTime != nil {
410 elapsed := content.EndTime.Sub(*content.StartTime)
411 m.Elapsed = &elapsed
412 }
413
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700414 m.SetConvo(convo)
Earl Lee2e463fb2025-04-17 11:22:22 -0700415 a.pushToOutbox(ctx, m)
416}
417
418// OnRequest implements ant.Listener.
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000419func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, id string, msg *ant.Message) {
420 a.mu.Lock()
421 defer a.mu.Unlock()
422 a.outstandingLLMCalls[id] = struct{}{}
Earl Lee2e463fb2025-04-17 11:22:22 -0700423 // We already get tool results from the above. We send user messages to the outbox in the agent loop.
424}
425
426// OnResponse implements ant.Listener. Responses contain messages from the LLM
427// that need to be displayed (as well as tool calls that we send along when
428// they're done). (It would be reasonable to also mention tool calls when they're
429// started, but we don't do that yet.)
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000430func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, id string, resp *ant.MessageResponse) {
431 // Remove the LLM call from outstanding calls
432 a.mu.Lock()
433 delete(a.outstandingLLMCalls, id)
434 a.mu.Unlock()
435
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700436 if resp == nil {
437 // LLM API call failed
438 m := AgentMessage{
439 Type: ErrorMessageType,
440 Content: "API call failed, type 'continue' to try again",
441 }
442 m.SetConvo(convo)
443 a.pushToOutbox(ctx, m)
444 return
445 }
446
Earl Lee2e463fb2025-04-17 11:22:22 -0700447 endOfTurn := false
448 if resp.StopReason != ant.StopReasonToolUse {
449 endOfTurn = true
450 }
451 m := AgentMessage{
452 Type: AgentMessageType,
453 Content: collectTextContent(resp),
454 EndOfTurn: endOfTurn,
455 Usage: &resp.Usage,
456 StartTime: resp.StartTime,
457 EndTime: resp.EndTime,
458 }
459
460 // Extract any tool calls from the response
461 if resp.StopReason == ant.StopReasonToolUse {
462 var toolCalls []ToolCall
463 for _, part := range resp.Content {
464 if part.Type == "tool_use" {
465 toolCalls = append(toolCalls, ToolCall{
466 Name: part.ToolName,
467 Input: string(part.ToolInput),
468 ToolCallId: part.ID,
469 })
470 }
471 }
472 m.ToolCalls = toolCalls
473 }
474
475 // Calculate the elapsed time if both start and end times are set
476 if resp.StartTime != nil && resp.EndTime != nil {
477 elapsed := resp.EndTime.Sub(*resp.StartTime)
478 m.Elapsed = &elapsed
479 }
480
Josh Bleecher Snyder50a1d622025-04-29 09:59:03 -0700481 m.SetConvo(convo)
Earl Lee2e463fb2025-04-17 11:22:22 -0700482 a.pushToOutbox(ctx, m)
483}
484
485// WorkingDir implements CodingAgent.
486func (a *Agent) WorkingDir() string {
487 return a.workingDir
488}
489
490// MessageCount implements CodingAgent.
491func (a *Agent) MessageCount() int {
492 a.mu.Lock()
493 defer a.mu.Unlock()
494 return len(a.history)
495}
496
497// Messages implements CodingAgent.
498func (a *Agent) Messages(start int, end int) []AgentMessage {
499 a.mu.Lock()
500 defer a.mu.Unlock()
501 return slices.Clone(a.history[start:end])
502}
503
504func (a *Agent) OriginalBudget() ant.Budget {
505 return a.originalBudget
506}
507
508// AgentConfig contains configuration for creating a new Agent.
509type AgentConfig struct {
510 Context context.Context
511 AntURL string
512 APIKey string
513 HTTPC *http.Client
514 Budget ant.Budget
515 GitUsername string
516 GitEmail string
517 SessionID string
518 ClientGOOS string
519 ClientGOARCH string
520 UseAnthropicEdit bool
Philip Zeyliger18532b22025-04-23 21:11:46 +0000521 // Outside information
522 OutsideHostname string
523 OutsideOS string
524 OutsideWorkingDir string
Earl Lee2e463fb2025-04-17 11:22:22 -0700525}
526
527// NewAgent creates a new Agent.
528// It is not usable until Init() is called.
529func NewAgent(config AgentConfig) *Agent {
530 agent := &Agent{
Philip Zeyliger99a9a022025-04-27 15:15:25 +0000531 config: config,
532 ready: make(chan struct{}),
533 inbox: make(chan string, 100),
534 outbox: make(chan AgentMessage, 100),
535 startedAt: time.Now(),
536 originalBudget: config.Budget,
537 seenCommits: make(map[string]bool),
538 outsideHostname: config.OutsideHostname,
539 outsideOS: config.OutsideOS,
540 outsideWorkingDir: config.OutsideWorkingDir,
541 outstandingLLMCalls: make(map[string]struct{}),
542 outstandingToolCalls: make(map[string]string),
Earl Lee2e463fb2025-04-17 11:22:22 -0700543 }
544 return agent
545}
546
547type AgentInit struct {
548 WorkingDir string
549 NoGit bool // only for testing
550
551 InDocker bool
552 Commit string
553 GitRemoteAddr string
554 HostAddr string
555}
556
557func (a *Agent) Init(ini AgentInit) error {
Josh Bleecher Snyder9c07e1d2025-04-28 19:25:37 -0700558 if a.convo != nil {
559 return fmt.Errorf("Agent.Init: already initialized")
560 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700561 ctx := a.config.Context
562 if ini.InDocker {
563 cmd := exec.CommandContext(ctx, "git", "stash")
564 cmd.Dir = ini.WorkingDir
565 if out, err := cmd.CombinedOutput(); err != nil {
566 return fmt.Errorf("git stash: %s: %v", out, err)
567 }
Philip Zeyligerd0ac1ea2025-04-21 20:04:19 -0700568 cmd = exec.CommandContext(ctx, "git", "remote", "add", "sketch-host", ini.GitRemoteAddr)
569 cmd.Dir = ini.WorkingDir
570 if out, err := cmd.CombinedOutput(); err != nil {
571 return fmt.Errorf("git remote add: %s: %v", out, err)
572 }
573 cmd = exec.CommandContext(ctx, "git", "fetch", "sketch-host")
Earl Lee2e463fb2025-04-17 11:22:22 -0700574 cmd.Dir = ini.WorkingDir
575 if out, err := cmd.CombinedOutput(); err != nil {
576 return fmt.Errorf("git fetch: %s: %w", out, err)
577 }
578 cmd = exec.CommandContext(ctx, "git", "checkout", "-f", ini.Commit)
579 cmd.Dir = ini.WorkingDir
580 if out, err := cmd.CombinedOutput(); err != nil {
581 return fmt.Errorf("git checkout %s: %s: %w", ini.Commit, out, err)
582 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700583 a.lastHEAD = ini.Commit
584 a.gitRemoteAddr = ini.GitRemoteAddr
585 a.initialCommit = ini.Commit
586 if ini.HostAddr != "" {
587 a.url = "http://" + ini.HostAddr
588 }
589 }
590 a.workingDir = ini.WorkingDir
591
592 if !ini.NoGit {
593 repoRoot, err := repoRoot(ctx, a.workingDir)
594 if err != nil {
595 return fmt.Errorf("repoRoot: %w", err)
596 }
597 a.repoRoot = repoRoot
598
599 commitHash, err := resolveRef(ctx, a.repoRoot, "HEAD")
600 if err != nil {
601 return fmt.Errorf("resolveRef: %w", err)
602 }
603 a.initialCommit = commitHash
604
605 codereview, err := claudetool.NewCodeReviewer(ctx, a.repoRoot, a.initialCommit)
606 if err != nil {
607 return fmt.Errorf("Agent.Init: claudetool.NewCodeReviewer: %w", err)
608 }
609 a.codereview = codereview
Philip Zeyligerd1402952025-04-23 03:54:37 +0000610
611 a.gitOrigin = getGitOrigin(ctx, ini.WorkingDir)
Earl Lee2e463fb2025-04-17 11:22:22 -0700612 }
613 a.lastHEAD = a.initialCommit
614 a.convo = a.initConvo()
615 close(a.ready)
616 return nil
617}
618
Josh Bleecher Snyderdbe02302025-04-29 16:44:23 -0700619//go:embed agent_system_prompt.txt
620var agentSystemPrompt string
621
Earl Lee2e463fb2025-04-17 11:22:22 -0700622// initConvo initializes the conversation.
623// It must not be called until all agent fields are initialized,
624// particularly workingDir and git.
625func (a *Agent) initConvo() *ant.Convo {
626 ctx := a.config.Context
627 convo := ant.NewConvo(ctx, a.config.APIKey)
628 if a.config.HTTPC != nil {
629 convo.HTTPC = a.config.HTTPC
630 }
631 if a.config.AntURL != "" {
632 convo.URL = a.config.AntURL
633 }
634 convo.PromptCaching = true
635 convo.Budget = a.config.Budget
636
637 var editPrompt string
638 if a.config.UseAnthropicEdit {
639 editPrompt = "Then use the str_replace_editor tool to make those edits. For short complete file replacements, you may use the bash tool with cat and heredoc stdin."
640 } else {
641 editPrompt = "Then use the patch tool to make those edits. Combine all edits to any given file into a single patch tool call."
642 }
643
Josh Bleecher Snyderdbe02302025-04-29 16:44:23 -0700644 convo.SystemPrompt = fmt.Sprintf(agentSystemPrompt, editPrompt, a.config.ClientGOOS, a.config.ClientGOARCH, a.workingDir, a.repoRoot, a.initialCommit)
Earl Lee2e463fb2025-04-17 11:22:22 -0700645
646 // Register all tools with the conversation
647 // When adding, removing, or modifying tools here, double-check that the termui tool display
648 // template in termui/termui.go has pretty-printing support for all tools.
649 convo.Tools = []*ant.Tool{
650 claudetool.Bash, claudetool.Keyword,
651 claudetool.Think, a.titleTool(), makeDoneTool(a.codereview, a.config.GitUsername, a.config.GitEmail),
652 a.codereview.Tool(),
653 }
654 if a.config.UseAnthropicEdit {
655 convo.Tools = append(convo.Tools, claudetool.AnthropicEditTool)
656 } else {
657 convo.Tools = append(convo.Tools, claudetool.Patch)
658 }
659 convo.Listener = a
660 return convo
661}
662
663func (a *Agent) titleTool() *ant.Tool {
664 // titleTool creates the title tool that sets the conversation title.
665 title := &ant.Tool{
666 Name: "title",
667 Description: `Use this tool early in the conversation, BEFORE MAKING ANY GIT COMMITS, to summarize what the chat is about briefly.`,
668 InputSchema: json.RawMessage(`{
669 "type": "object",
670 "properties": {
671 "title": {
672 "type": "string",
673 "description": "A brief title summarizing what this chat is about"
674 }
675 },
676 "required": ["title"]
677}`),
678 Run: func(ctx context.Context, input json.RawMessage) (string, error) {
679 var params struct {
680 Title string `json:"title"`
681 }
682 if err := json.Unmarshal(input, &params); err != nil {
683 return "", err
684 }
685 a.SetTitle(params.Title)
686 return fmt.Sprintf("Title set to: %s", params.Title), nil
687 },
688 }
689 return title
690}
691
692func (a *Agent) Ready() <-chan struct{} {
693 return a.ready
694}
695
696func (a *Agent) UserMessage(ctx context.Context, msg string) {
697 a.pushToOutbox(ctx, AgentMessage{Type: UserMessageType, Content: msg})
698 a.inbox <- msg
699}
700
701func (a *Agent) WaitForMessage(ctx context.Context) AgentMessage {
702 // TODO: Should this drain any outbox messages in case there are multiple?
703 select {
704 case msg := <-a.outbox:
705 return msg
706 case <-ctx.Done():
707 return errorMessage(ctx.Err())
708 }
709}
710
711func (a *Agent) CancelToolUse(toolUseID string, cause error) error {
712 return a.convo.CancelToolUse(toolUseID, cause)
713}
714
715func (a *Agent) CancelInnerLoop(cause error) {
716 a.cancelInnerLoopMu.Lock()
717 defer a.cancelInnerLoopMu.Unlock()
718 if a.cancelInnerLoop != nil {
719 a.cancelInnerLoop(cause)
720 }
721}
722
723func (a *Agent) Loop(ctxOuter context.Context) {
724 for {
725 select {
726 case <-ctxOuter.Done():
727 return
728 default:
729 ctxInner, cancel := context.WithCancelCause(ctxOuter)
730 a.cancelInnerLoopMu.Lock()
731 // Set .cancelInnerLoop so the user can cancel whatever is happening
732 // inside InnerLoop(ctxInner) without canceling this outer Loop execution.
733 // This CancelInnerLoop func is intended be called from other goroutines,
734 // hence the mutex.
735 a.cancelInnerLoop = cancel
736 a.cancelInnerLoopMu.Unlock()
737 a.InnerLoop(ctxInner)
738 cancel(nil)
739 }
740 }
741}
742
743func (a *Agent) pushToOutbox(ctx context.Context, m AgentMessage) {
744 if m.Timestamp.IsZero() {
745 m.Timestamp = time.Now()
746 }
747
748 // If this is an end-of-turn message, calculate the turn duration and add it to the message
749 if m.EndOfTurn && m.Type == AgentMessageType {
750 turnDuration := time.Since(a.startOfTurn)
751 m.TurnDuration = &turnDuration
752 slog.InfoContext(ctx, "Turn completed", "turnDuration", turnDuration)
753 }
754
755 slog.InfoContext(ctx, "agent message", m.Attr())
756
757 a.mu.Lock()
758 defer a.mu.Unlock()
759 m.Idx = len(a.history)
760 a.history = append(a.history, m)
761 a.outbox <- m
762
763 // Notify all listeners:
764 for _, ch := range a.listeners {
765 close(ch)
766 }
767 a.listeners = a.listeners[:0]
768}
769
770func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
771 var m []ant.Content
772 if block {
773 select {
774 case <-ctx.Done():
775 return m, ctx.Err()
776 case msg := <-a.inbox:
777 m = append(m, ant.Content{Type: "text", Text: msg})
778 }
779 }
780 for {
781 select {
782 case msg := <-a.inbox:
783 m = append(m, ant.Content{Type: "text", Text: msg})
784 default:
785 return m, nil
786 }
787 }
788}
789
790func (a *Agent) InnerLoop(ctx context.Context) {
791 // Reset the start of turn time
792 a.startOfTurn = time.Now()
793
794 // Wait for at least one message from the user.
795 msgs, err := a.GatherMessages(ctx, true)
796 if err != nil { // e.g. the context was canceled while blocking in GatherMessages
797 return
798 }
799 // We do this as we go, but let's also do it at the end of the turn
800 defer func() {
801 if _, err := a.handleGitCommits(ctx); err != nil {
802 // Just log the error, don't stop execution
803 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
804 }
805 }()
806
807 userMessage := ant.Message{
808 Role: "user",
809 Content: msgs,
810 }
811 // convo.SendMessage does the actual network call to send this to anthropic. This blocks until the response is ready.
812 // TODO: pass ctx to SendMessage, and figure out how to square that ctx with convo's own .Ctx. Who owns the scope of this call?
813 resp, err := a.convo.SendMessage(userMessage)
814 if err != nil {
815 a.pushToOutbox(ctx, errorMessage(err))
816 return
817 }
818 for {
819 // TODO: here and below where we check the budget,
820 // we should review the UX: is it clear what happened?
821 // is it clear how to resume?
822 // should we let the user set a new budget?
823 if err := a.overBudget(ctx); err != nil {
824 return
825 }
826 if resp.StopReason != ant.StopReasonToolUse {
827 break
828 }
829 var results []ant.Content
830 cancelled := false
831 select {
832 case <-ctx.Done():
833 // Don't actually run any of the tools, but rather build a response
834 // for each tool_use message letting the LLM know that user canceled it.
835 results, err = a.convo.ToolResultCancelContents(resp)
836 if err != nil {
837 a.pushToOutbox(ctx, errorMessage(err))
838 }
839 cancelled = true
840 default:
841 ctx = claudetool.WithWorkingDir(ctx, a.workingDir)
842 // fall-through, when the user has not canceled the inner loop:
843 results, err = a.convo.ToolResultContents(ctx, resp)
844 if ctx.Err() != nil { // e.g. the user canceled the operation
845 cancelled = true
846 } else if err != nil {
847 a.pushToOutbox(ctx, errorMessage(err))
848 }
849 }
850
851 // Check for git commits. Currently we do this here, after we collect
852 // tool results, since that's when we know commits could have happened.
853 // We could instead do this when the turn ends, but I think it makes sense
854 // to do this as we go.
855 newCommits, err := a.handleGitCommits(ctx)
856 if err != nil {
857 // Just log the error, don't stop execution
858 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
859 }
860 var autoqualityMessages []string
861 if len(newCommits) == 1 {
862 formatted := a.codereview.Autoformat(ctx)
863 if len(formatted) > 0 {
864 msg := fmt.Sprintf(`
865I ran autoformatters and they updated these files:
866
867%s
868
869Please amend your latest git commit with these changes and then continue with what you were doing.`,
870 strings.Join(formatted, "\n"),
871 )[1:]
872 a.pushToOutbox(ctx, AgentMessage{
873 Type: AutoMessageType,
874 Content: msg,
875 Timestamp: time.Now(),
876 })
877 autoqualityMessages = append(autoqualityMessages, msg)
878 }
879 }
880
881 if err := a.overBudget(ctx); err != nil {
882 return
883 }
884
885 // Include, along with the tool results (which must go first for whatever reason),
886 // any messages that the user has sent along while the tool_use was executing concurrently.
887 msgs, err = a.GatherMessages(ctx, false)
888 if err != nil {
889 return
890 }
891 // Inject any auto-generated messages from quality checks.
892 for _, msg := range autoqualityMessages {
893 msgs = append(msgs, ant.Content{Type: "text", Text: msg})
894 }
895 if cancelled {
896 msgs = append(msgs, ant.Content{Type: "text", Text: cancelToolUseMessage})
897 // EndOfTurn is false here so that the client of this agent keeps processing
898 // messages from WaitForMessage() and gets the response from the LLM (usually
899 // something like "okay, I'll wait further instructions", but the user should
900 // be made aware of it regardless).
901 a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
902 } else if err := a.convo.OverBudget(); err != nil {
903 budgetMsg := "We've exceeded our budget. Please ask the user to confirm before continuing by ending the turn."
904 msgs = append(msgs, ant.Content{Type: "text", Text: budgetMsg})
905 a.pushToOutbox(ctx, budgetMessage(fmt.Errorf("warning: %w (ask to keep trying, if you'd like)", err)))
906 }
907 results = append(results, msgs...)
908 resp, err = a.convo.SendMessage(ant.Message{
909 Role: "user",
910 Content: results,
911 })
912 if err != nil {
913 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("error: failed to continue conversation: %s", err.Error())))
914 break
915 }
916 if cancelled {
917 return
918 }
919 }
920}
921
922func (a *Agent) overBudget(ctx context.Context) error {
923 if err := a.convo.OverBudget(); err != nil {
924 m := budgetMessage(err)
925 m.Content = m.Content + "\n\nBudget reset."
926 a.pushToOutbox(ctx, budgetMessage(err))
927 a.convo.ResetBudget(a.originalBudget)
928 return err
929 }
930 return nil
931}
932
933func collectTextContent(msg *ant.MessageResponse) string {
934 // Collect all text content
935 var allText strings.Builder
936 for _, content := range msg.Content {
937 if content.Type == "text" && content.Text != "" {
938 if allText.Len() > 0 {
939 allText.WriteString("\n\n")
940 }
941 allText.WriteString(content.Text)
942 }
943 }
944 return allText.String()
945}
946
947func (a *Agent) TotalUsage() ant.CumulativeUsage {
948 a.mu.Lock()
949 defer a.mu.Unlock()
950 return a.convo.CumulativeUsage()
951}
952
953// WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
954func (a *Agent) WaitForMessageCount(ctx context.Context, greaterThan int) {
955 for a.MessageCount() <= greaterThan {
956 a.mu.Lock()
957 ch := make(chan struct{})
958 // Deletion happens when we notify.
959 a.listeners = append(a.listeners, ch)
960 a.mu.Unlock()
961
962 select {
963 case <-ctx.Done():
964 return
965 case <-ch:
966 continue
967 }
968 }
969}
970
971// Diff returns a unified diff of changes made since the agent was instantiated.
972func (a *Agent) Diff(commit *string) (string, error) {
973 if a.initialCommit == "" {
974 return "", fmt.Errorf("no initial commit reference available")
975 }
976
977 // Find the repository root
978 ctx := context.Background()
979
980 // If a specific commit hash is provided, show just that commit's changes
981 if commit != nil && *commit != "" {
982 // Validate that the commit looks like a valid git SHA
983 if !isValidGitSHA(*commit) {
984 return "", fmt.Errorf("invalid git commit SHA format: %s", *commit)
985 }
986
987 // Get the diff for just this commit
988 cmd := exec.CommandContext(ctx, "git", "show", "--unified=10", *commit)
989 cmd.Dir = a.repoRoot
990 output, err := cmd.CombinedOutput()
991 if err != nil {
992 return "", fmt.Errorf("failed to get diff for commit %s: %w - %s", *commit, err, string(output))
993 }
994 return string(output), nil
995 }
996
997 // Otherwise, get the diff between the initial commit and the current state using exec.Command
998 cmd := exec.CommandContext(ctx, "git", "diff", "--unified=10", a.initialCommit)
999 cmd.Dir = a.repoRoot
1000 output, err := cmd.CombinedOutput()
1001 if err != nil {
1002 return "", fmt.Errorf("failed to get diff: %w - %s", err, string(output))
1003 }
1004
1005 return string(output), nil
1006}
1007
1008// InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
1009func (a *Agent) InitialCommit() string {
1010 return a.initialCommit
1011}
1012
1013// handleGitCommits() highlights new commits to the user. When running
1014// under docker, new HEADs are pushed to a branch according to the title.
1015func (a *Agent) handleGitCommits(ctx context.Context) ([]*GitCommit, error) {
1016 if a.repoRoot == "" {
1017 return nil, nil
1018 }
1019
1020 head, err := resolveRef(ctx, a.repoRoot, "HEAD")
1021 if err != nil {
1022 return nil, err
1023 }
1024 if head == a.lastHEAD {
1025 return nil, nil // nothing to do
1026 }
1027 defer func() {
1028 a.lastHEAD = head
1029 }()
1030
1031 // Get new commits. Because it's possible that the agent does rebases, fixups, and
1032 // so forth, we use, as our fixed point, the "initialCommit", and we limit ourselves
1033 // to the last 100 commits.
1034 var commits []*GitCommit
1035
1036 // Get commits since the initial commit
1037 // Format: <hash>\0<subject>\0<body>\0
1038 // This uses NULL bytes as separators to avoid issues with newlines in commit messages
1039 // Limit to 100 commits to avoid overwhelming the user
1040 cmd := exec.CommandContext(ctx, "git", "log", "-n", "100", "--pretty=format:%H%x00%s%x00%b%x00", "^"+a.initialCommit, head)
1041 cmd.Dir = a.repoRoot
1042 output, err := cmd.Output()
1043 if err != nil {
1044 return nil, fmt.Errorf("failed to get git log: %w", err)
1045 }
1046
1047 // Parse git log output and filter out already seen commits
1048 parsedCommits := parseGitLog(string(output))
1049
1050 var headCommit *GitCommit
1051
1052 // Filter out commits we've already seen
1053 for _, commit := range parsedCommits {
1054 if commit.Hash == head {
1055 headCommit = &commit
1056 }
1057
1058 // Skip if we've seen this commit before. If our head has changed, always include that.
1059 if a.seenCommits[commit.Hash] && commit.Hash != head {
1060 continue
1061 }
1062
1063 // Mark this commit as seen
1064 a.seenCommits[commit.Hash] = true
1065
1066 // Add to our list of new commits
1067 commits = append(commits, &commit)
1068 }
1069
1070 if a.gitRemoteAddr != "" {
1071 if headCommit == nil {
1072 // I think this can only happen if we have a bug or if there's a race.
1073 headCommit = &GitCommit{}
1074 headCommit.Hash = head
1075 headCommit.Subject = "unknown"
1076 commits = append(commits, headCommit)
1077 }
1078
1079 cleanTitle := titleToBranch(a.title)
1080 if cleanTitle == "" {
1081 cleanTitle = a.config.SessionID
1082 }
1083 branch := "sketch/" + cleanTitle
1084
1085 // TODO: I don't love the force push here. We could see if the push is a fast-forward, and,
1086 // if it's not, we could make a backup with a unique name (perhaps append a timestamp) and
1087 // then use push with lease to replace.
1088 cmd = exec.Command("git", "push", "--force", a.gitRemoteAddr, "HEAD:refs/heads/"+branch)
1089 cmd.Dir = a.workingDir
1090 if out, err := cmd.CombinedOutput(); err != nil {
1091 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("git push to host: %s: %v", out, err)))
1092 } else {
1093 headCommit.PushedBranch = branch
1094 }
1095 }
1096
1097 // If we found new commits, create a message
1098 if len(commits) > 0 {
1099 msg := AgentMessage{
1100 Type: CommitMessageType,
1101 Timestamp: time.Now(),
1102 Commits: commits,
1103 }
1104 a.pushToOutbox(ctx, msg)
1105 }
1106 return commits, nil
1107}
1108
1109func titleToBranch(s string) string {
1110 // Convert to lowercase
1111 s = strings.ToLower(s)
1112
1113 // Replace spaces with hyphens
1114 s = strings.ReplaceAll(s, " ", "-")
1115
1116 // Remove any character that isn't a-z or hyphen
1117 var result strings.Builder
1118 for _, r := range s {
1119 if (r >= 'a' && r <= 'z') || r == '-' {
1120 result.WriteRune(r)
1121 }
1122 }
1123 return result.String()
1124}
1125
1126// parseGitLog parses the output of git log with format '%H%x00%s%x00%b%x00'
1127// and returns an array of GitCommit structs.
1128func parseGitLog(output string) []GitCommit {
1129 var commits []GitCommit
1130
1131 // No output means no commits
1132 if len(output) == 0 {
1133 return commits
1134 }
1135
1136 // Split by NULL byte
1137 parts := strings.Split(output, "\x00")
1138
1139 // Process in triplets (hash, subject, body)
1140 for i := 0; i < len(parts); i++ {
1141 // Skip empty parts
1142 if parts[i] == "" {
1143 continue
1144 }
1145
1146 // This should be a hash
1147 hash := strings.TrimSpace(parts[i])
1148
1149 // Make sure we have at least a subject part available
1150 if i+1 >= len(parts) {
1151 break // No more parts available
1152 }
1153
1154 // Get the subject
1155 subject := strings.TrimSpace(parts[i+1])
1156
1157 // Get the body if available
1158 body := ""
1159 if i+2 < len(parts) {
1160 body = strings.TrimSpace(parts[i+2])
1161 }
1162
1163 // Skip to the next triplet
1164 i += 2
1165
1166 commits = append(commits, GitCommit{
1167 Hash: hash,
1168 Subject: subject,
1169 Body: body,
1170 })
1171 }
1172
1173 return commits
1174}
1175
1176func repoRoot(ctx context.Context, dir string) (string, error) {
1177 cmd := exec.CommandContext(ctx, "git", "rev-parse", "--show-toplevel")
1178 stderr := new(strings.Builder)
1179 cmd.Stderr = stderr
1180 cmd.Dir = dir
1181 out, err := cmd.Output()
1182 if err != nil {
1183 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1184 }
1185 return strings.TrimSpace(string(out)), nil
1186}
1187
1188func resolveRef(ctx context.Context, dir, refName string) (string, error) {
1189 cmd := exec.CommandContext(ctx, "git", "rev-parse", refName)
1190 stderr := new(strings.Builder)
1191 cmd.Stderr = stderr
1192 cmd.Dir = dir
1193 out, err := cmd.Output()
1194 if err != nil {
1195 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1196 }
1197 // TODO: validate that out is valid hex
1198 return strings.TrimSpace(string(out)), nil
1199}
1200
1201// isValidGitSHA validates if a string looks like a valid git SHA hash.
1202// Git SHAs are hexadecimal strings of at least 4 characters but typically 7, 8, or 40 characters.
1203func isValidGitSHA(sha string) bool {
1204 // Git SHA must be a hexadecimal string with at least 4 characters
1205 if len(sha) < 4 || len(sha) > 40 {
1206 return false
1207 }
1208
1209 // Check if the string only contains hexadecimal characters
1210 for _, char := range sha {
1211 if !(char >= '0' && char <= '9') && !(char >= 'a' && char <= 'f') && !(char >= 'A' && char <= 'F') {
1212 return false
1213 }
1214 }
1215
1216 return true
1217}
Philip Zeyligerd1402952025-04-23 03:54:37 +00001218
1219// getGitOrigin returns the URL of the git remote 'origin' if it exists
1220func getGitOrigin(ctx context.Context, dir string) string {
1221 cmd := exec.CommandContext(ctx, "git", "config", "--get", "remote.origin.url")
1222 cmd.Dir = dir
1223 stderr := new(strings.Builder)
1224 cmd.Stderr = stderr
1225 out, err := cmd.Output()
1226 if err != nil {
1227 return ""
1228 }
1229 return strings.TrimSpace(string(out))
1230}