blob: 2f4efe99732bb9de12ee97d4ef300b7681a848f6 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "os"
10 "os/exec"
11 "runtime/debug"
12 "slices"
13 "strings"
14 "sync"
15 "time"
16
17 "sketch.dev/ant"
18 "sketch.dev/claudetool"
19)
20
21const (
22 userCancelMessage = "user requested agent to stop handling responses"
23)
24
25type CodingAgent interface {
26 // Init initializes an agent inside a docker container.
27 Init(AgentInit) error
28
29 // Ready returns a channel closed after Init successfully called.
30 Ready() <-chan struct{}
31
32 // URL reports the HTTP URL of this agent.
33 URL() string
34
35 // UserMessage enqueues a message to the agent and returns immediately.
36 UserMessage(ctx context.Context, msg string)
37
38 // WaitForMessage blocks until the agent has a response to give.
39 // Use AgentMessage.EndOfTurn to help determine if you want to
40 // drain the agent.
41 WaitForMessage(ctx context.Context) AgentMessage
42
43 // Loop begins the agent loop returns only when ctx is cancelled.
44 Loop(ctx context.Context)
45
46 CancelInnerLoop(cause error)
47
48 CancelToolUse(toolUseID string, cause error) error
49
50 // Returns a subset of the agent's message history.
51 Messages(start int, end int) []AgentMessage
52
53 // Returns the current number of messages in the history
54 MessageCount() int
55
56 TotalUsage() ant.CumulativeUsage
57 OriginalBudget() ant.Budget
58
59 // WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
60 WaitForMessageCount(ctx context.Context, greaterThan int)
61
62 WorkingDir() string
63
64 // Diff returns a unified diff of changes made since the agent was instantiated.
65 // If commit is non-nil, it shows the diff for just that specific commit.
66 Diff(commit *string) (string, error)
67
68 // InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
69 InitialCommit() string
70
71 // Title returns the current title of the conversation.
72 Title() string
73
74 // OS returns the operating system of the client.
75 OS() string
Philip Zeyliger18532b22025-04-23 21:11:46 +000076 OutsideOS() string
77 OutsideHostname() string
78 OutsideWorkingDir() string
Philip Zeyligerd1402952025-04-23 03:54:37 +000079 GitOrigin() string
Earl Lee2e463fb2025-04-17 11:22:22 -070080}
81
82type CodingAgentMessageType string
83
84const (
85 UserMessageType CodingAgentMessageType = "user"
86 AgentMessageType CodingAgentMessageType = "agent"
87 ErrorMessageType CodingAgentMessageType = "error"
88 BudgetMessageType CodingAgentMessageType = "budget" // dedicated for "out of budget" errors
89 ToolUseMessageType CodingAgentMessageType = "tool"
90 CommitMessageType CodingAgentMessageType = "commit" // for displaying git commits
91 AutoMessageType CodingAgentMessageType = "auto" // for automated notifications like autoformatting
92
93 cancelToolUseMessage = "Stop responding to my previous message. Wait for me to ask you something else before attempting to use any more tools."
94)
95
96type AgentMessage struct {
97 Type CodingAgentMessageType `json:"type"`
98 // EndOfTurn indicates that the AI is done working and is ready for the next user input.
99 EndOfTurn bool `json:"end_of_turn"`
100
101 Content string `json:"content"`
102 ToolName string `json:"tool_name,omitempty"`
103 ToolInput string `json:"input,omitempty"`
104 ToolResult string `json:"tool_result,omitempty"`
105 ToolError bool `json:"tool_error,omitempty"`
106 ToolCallId string `json:"tool_call_id,omitempty"`
107
108 // ToolCalls is a list of all tool calls requested in this message (name and input pairs)
109 ToolCalls []ToolCall `json:"tool_calls,omitempty"`
110
Sean McCulloughd9f13372025-04-21 15:08:49 -0700111 // ToolResponses is a list of all responses to tool calls requested in this message (name and input pairs)
112 ToolResponses []AgentMessage `json:"toolResponses,omitempty"`
113
Earl Lee2e463fb2025-04-17 11:22:22 -0700114 // Commits is a list of git commits for a commit message
115 Commits []*GitCommit `json:"commits,omitempty"`
116
117 Timestamp time.Time `json:"timestamp"`
118 ConversationID string `json:"conversation_id"`
119 ParentConversationID *string `json:"parent_conversation_id,omitempty"`
120 Usage *ant.Usage `json:"usage,omitempty"`
121
122 // Message timing information
123 StartTime *time.Time `json:"start_time,omitempty"`
124 EndTime *time.Time `json:"end_time,omitempty"`
125 Elapsed *time.Duration `json:"elapsed,omitempty"`
126
127 // Turn duration - the time taken for a complete agent turn
128 TurnDuration *time.Duration `json:"turnDuration,omitempty"`
129
130 Idx int `json:"idx"`
131}
132
133// GitCommit represents a single git commit for a commit message
134type GitCommit struct {
135 Hash string `json:"hash"` // Full commit hash
136 Subject string `json:"subject"` // Commit subject line
137 Body string `json:"body"` // Full commit message body
138 PushedBranch string `json:"pushed_branch,omitempty"` // If set, this commit was pushed to this branch
139}
140
141// ToolCall represents a single tool call within an agent message
142type ToolCall struct {
Sean McCulloughd9f13372025-04-21 15:08:49 -0700143 Name string `json:"name"`
144 Input string `json:"input"`
145 ToolCallId string `json:"tool_call_id"`
146 ResultMessage *AgentMessage `json:"result_message,omitempty"`
147 Args string `json:"args,omitempty"`
148 Result string `json:"result,omitempty"`
Earl Lee2e463fb2025-04-17 11:22:22 -0700149}
150
151func (a *AgentMessage) Attr() slog.Attr {
152 var attrs []any = []any{
153 slog.String("type", string(a.Type)),
154 }
155 if a.EndOfTurn {
156 attrs = append(attrs, slog.Bool("end_of_turn", a.EndOfTurn))
157 }
158 if a.Content != "" {
159 attrs = append(attrs, slog.String("content", a.Content))
160 }
161 if a.ToolName != "" {
162 attrs = append(attrs, slog.String("tool_name", a.ToolName))
163 }
164 if a.ToolInput != "" {
165 attrs = append(attrs, slog.String("tool_input", a.ToolInput))
166 }
167 if a.Elapsed != nil {
168 attrs = append(attrs, slog.Int64("elapsed", a.Elapsed.Nanoseconds()))
169 }
170 if a.TurnDuration != nil {
171 attrs = append(attrs, slog.Int64("turnDuration", a.TurnDuration.Nanoseconds()))
172 }
173 if a.ToolResult != "" {
174 attrs = append(attrs, slog.String("tool_result", a.ToolResult))
175 }
176 if a.ToolError {
177 attrs = append(attrs, slog.Bool("tool_error", a.ToolError))
178 }
179 if len(a.ToolCalls) > 0 {
180 toolCallAttrs := make([]any, 0, len(a.ToolCalls))
181 for i, tc := range a.ToolCalls {
182 toolCallAttrs = append(toolCallAttrs, slog.Group(
183 fmt.Sprintf("tool_call_%d", i),
184 slog.String("name", tc.Name),
185 slog.String("input", tc.Input),
186 ))
187 }
188 attrs = append(attrs, slog.Group("tool_calls", toolCallAttrs...))
189 }
190 if a.ConversationID != "" {
191 attrs = append(attrs, slog.String("convo_id", a.ConversationID))
192 }
193 if a.ParentConversationID != nil {
194 attrs = append(attrs, slog.String("parent_convo_id", *a.ParentConversationID))
195 }
196 if a.Usage != nil && !a.Usage.IsZero() {
197 attrs = append(attrs, a.Usage.Attr())
198 }
199 // TODO: timestamp, convo ids, idx?
200 return slog.Group("agent_message", attrs...)
201}
202
203func errorMessage(err error) AgentMessage {
204 // It's somewhat unknowable whether error messages are "end of turn" or not, but it seems like the best approach.
205 if os.Getenv(("DEBUG")) == "1" {
206 return AgentMessage{Type: ErrorMessageType, Content: err.Error() + " Stacktrace: " + string(debug.Stack()), EndOfTurn: true}
207 }
208
209 return AgentMessage{Type: ErrorMessageType, Content: err.Error(), EndOfTurn: true}
210}
211
212func budgetMessage(err error) AgentMessage {
213 return AgentMessage{Type: BudgetMessageType, Content: err.Error(), EndOfTurn: true}
214}
215
216// ConvoInterface defines the interface for conversation interactions
217type ConvoInterface interface {
218 CumulativeUsage() ant.CumulativeUsage
219 ResetBudget(ant.Budget)
220 OverBudget() error
221 SendMessage(message ant.Message) (*ant.MessageResponse, error)
222 SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
223 ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
224 ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error)
225 CancelToolUse(toolUseID string, cause error) error
226}
227
228type Agent struct {
229 convo ConvoInterface
230 config AgentConfig // config for this agent
231 workingDir string
232 repoRoot string // workingDir may be a subdir of repoRoot
233 url string
234 lastHEAD string // hash of the last HEAD that was pushed to the host (only when under docker)
235 initialCommit string // hash of the Git HEAD when the agent was instantiated or Init()
236 gitRemoteAddr string // HTTP URL of the host git repo (only when under docker)
237 ready chan struct{} // closed when the agent is initialized (only when under docker)
238 startedAt time.Time
239 originalBudget ant.Budget
240 title string
241 codereview *claudetool.CodeReviewer
Philip Zeyliger18532b22025-04-23 21:11:46 +0000242 // Outside information
243 outsideHostname string
244 outsideOS string
245 outsideWorkingDir string
Philip Zeyligerd1402952025-04-23 03:54:37 +0000246 // URL of the git remote 'origin' if it exists
247 gitOrigin string
Earl Lee2e463fb2025-04-17 11:22:22 -0700248
249 // Time when the current turn started (reset at the beginning of InnerLoop)
250 startOfTurn time.Time
251
252 // Inbox - for messages from the user to the agent.
253 // sent on by UserMessage
254 // . e.g. when user types into the chat textarea
255 // read from by GatherMessages
256 inbox chan string
257
258 // Outbox
259 // sent on by pushToOutbox
260 // via OnToolResult and OnResponse callbacks
261 // read from by WaitForMessage
262 // called by termui inside its repl loop.
263 outbox chan AgentMessage
264
265 // protects cancelInnerLoop
266 cancelInnerLoopMu sync.Mutex
267 // cancels potentially long-running tool_use calls or chains of them
268 cancelInnerLoop context.CancelCauseFunc
269
270 // protects following
271 mu sync.Mutex
272
273 // Stores all messages for this agent
274 history []AgentMessage
275
276 listeners []chan struct{}
277
278 // Track git commits we've already seen (by hash)
279 seenCommits map[string]bool
280}
281
282func (a *Agent) URL() string { return a.url }
283
284// Title returns the current title of the conversation.
285// If no title has been set, returns an empty string.
286func (a *Agent) Title() string {
287 a.mu.Lock()
288 defer a.mu.Unlock()
289 return a.title
290}
291
292// OS returns the operating system of the client.
293func (a *Agent) OS() string {
294 return a.config.ClientGOOS
295}
296
Philip Zeyliger18532b22025-04-23 21:11:46 +0000297// OutsideOS returns the operating system of the outside system.
298func (a *Agent) OutsideOS() string {
299 return a.outsideOS
Philip Zeyligerd1402952025-04-23 03:54:37 +0000300}
301
Philip Zeyliger18532b22025-04-23 21:11:46 +0000302// OutsideHostname returns the hostname of the outside system.
303func (a *Agent) OutsideHostname() string {
304 return a.outsideHostname
Philip Zeyligerd1402952025-04-23 03:54:37 +0000305}
306
Philip Zeyliger18532b22025-04-23 21:11:46 +0000307// OutsideWorkingDir returns the working directory on the outside system.
308func (a *Agent) OutsideWorkingDir() string {
309 return a.outsideWorkingDir
Philip Zeyligerd1402952025-04-23 03:54:37 +0000310}
311
312// GitOrigin returns the URL of the git remote 'origin' if it exists.
313func (a *Agent) GitOrigin() string {
314 return a.gitOrigin
315}
316
Earl Lee2e463fb2025-04-17 11:22:22 -0700317// SetTitle sets the title of the conversation.
318func (a *Agent) SetTitle(title string) {
319 a.mu.Lock()
320 defer a.mu.Unlock()
321 a.title = title
322 // Notify all listeners that the state has changed
323 for _, ch := range a.listeners {
324 close(ch)
325 }
326 a.listeners = a.listeners[:0]
327}
328
329// OnToolResult implements ant.Listener.
330func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
331 m := AgentMessage{
332 Type: ToolUseMessageType,
333 Content: content.Text,
334 ToolResult: content.ToolResult,
335 ToolError: content.ToolError,
336 ToolName: toolName,
337 ToolInput: string(toolInput),
338 ToolCallId: content.ToolUseID,
339 StartTime: content.StartTime,
340 EndTime: content.EndTime,
341 }
342
343 // Calculate the elapsed time if both start and end times are set
344 if content.StartTime != nil && content.EndTime != nil {
345 elapsed := content.EndTime.Sub(*content.StartTime)
346 m.Elapsed = &elapsed
347 }
348
349 m.ConversationID = convo.ID
350 if convo.Parent != nil {
351 m.ParentConversationID = &convo.Parent.ID
352 }
353 a.pushToOutbox(ctx, m)
354}
355
356// OnRequest implements ant.Listener.
357func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, msg *ant.Message) {
358 // No-op.
359 // We already get tool results from the above. We send user messages to the outbox in the agent loop.
360}
361
362// OnResponse implements ant.Listener. Responses contain messages from the LLM
363// that need to be displayed (as well as tool calls that we send along when
364// they're done). (It would be reasonable to also mention tool calls when they're
365// started, but we don't do that yet.)
366func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, resp *ant.MessageResponse) {
367 endOfTurn := false
368 if resp.StopReason != ant.StopReasonToolUse {
369 endOfTurn = true
370 }
371 m := AgentMessage{
372 Type: AgentMessageType,
373 Content: collectTextContent(resp),
374 EndOfTurn: endOfTurn,
375 Usage: &resp.Usage,
376 StartTime: resp.StartTime,
377 EndTime: resp.EndTime,
378 }
379
380 // Extract any tool calls from the response
381 if resp.StopReason == ant.StopReasonToolUse {
382 var toolCalls []ToolCall
383 for _, part := range resp.Content {
384 if part.Type == "tool_use" {
385 toolCalls = append(toolCalls, ToolCall{
386 Name: part.ToolName,
387 Input: string(part.ToolInput),
388 ToolCallId: part.ID,
389 })
390 }
391 }
392 m.ToolCalls = toolCalls
393 }
394
395 // Calculate the elapsed time if both start and end times are set
396 if resp.StartTime != nil && resp.EndTime != nil {
397 elapsed := resp.EndTime.Sub(*resp.StartTime)
398 m.Elapsed = &elapsed
399 }
400
401 m.ConversationID = convo.ID
402 if convo.Parent != nil {
403 m.ParentConversationID = &convo.Parent.ID
404 }
405 a.pushToOutbox(ctx, m)
406}
407
408// WorkingDir implements CodingAgent.
409func (a *Agent) WorkingDir() string {
410 return a.workingDir
411}
412
413// MessageCount implements CodingAgent.
414func (a *Agent) MessageCount() int {
415 a.mu.Lock()
416 defer a.mu.Unlock()
417 return len(a.history)
418}
419
420// Messages implements CodingAgent.
421func (a *Agent) Messages(start int, end int) []AgentMessage {
422 a.mu.Lock()
423 defer a.mu.Unlock()
424 return slices.Clone(a.history[start:end])
425}
426
427func (a *Agent) OriginalBudget() ant.Budget {
428 return a.originalBudget
429}
430
431// AgentConfig contains configuration for creating a new Agent.
432type AgentConfig struct {
433 Context context.Context
434 AntURL string
435 APIKey string
436 HTTPC *http.Client
437 Budget ant.Budget
438 GitUsername string
439 GitEmail string
440 SessionID string
441 ClientGOOS string
442 ClientGOARCH string
443 UseAnthropicEdit bool
Philip Zeyliger18532b22025-04-23 21:11:46 +0000444 // Outside information
445 OutsideHostname string
446 OutsideOS string
447 OutsideWorkingDir string
Earl Lee2e463fb2025-04-17 11:22:22 -0700448}
449
450// NewAgent creates a new Agent.
451// It is not usable until Init() is called.
452func NewAgent(config AgentConfig) *Agent {
453 agent := &Agent{
Philip Zeyliger18532b22025-04-23 21:11:46 +0000454 config: config,
455 ready: make(chan struct{}),
456 inbox: make(chan string, 100),
457 outbox: make(chan AgentMessage, 100),
458 startedAt: time.Now(),
459 originalBudget: config.Budget,
460 seenCommits: make(map[string]bool),
461 outsideHostname: config.OutsideHostname,
462 outsideOS: config.OutsideOS,
463 outsideWorkingDir: config.OutsideWorkingDir,
Earl Lee2e463fb2025-04-17 11:22:22 -0700464 }
465 return agent
466}
467
468type AgentInit struct {
469 WorkingDir string
470 NoGit bool // only for testing
471
472 InDocker bool
473 Commit string
474 GitRemoteAddr string
475 HostAddr string
476}
477
478func (a *Agent) Init(ini AgentInit) error {
479 ctx := a.config.Context
480 if ini.InDocker {
481 cmd := exec.CommandContext(ctx, "git", "stash")
482 cmd.Dir = ini.WorkingDir
483 if out, err := cmd.CombinedOutput(); err != nil {
484 return fmt.Errorf("git stash: %s: %v", out, err)
485 }
Philip Zeyligerd0ac1ea2025-04-21 20:04:19 -0700486 cmd = exec.CommandContext(ctx, "git", "remote", "add", "sketch-host", ini.GitRemoteAddr)
487 cmd.Dir = ini.WorkingDir
488 if out, err := cmd.CombinedOutput(); err != nil {
489 return fmt.Errorf("git remote add: %s: %v", out, err)
490 }
491 cmd = exec.CommandContext(ctx, "git", "fetch", "sketch-host")
Earl Lee2e463fb2025-04-17 11:22:22 -0700492 cmd.Dir = ini.WorkingDir
493 if out, err := cmd.CombinedOutput(); err != nil {
494 return fmt.Errorf("git fetch: %s: %w", out, err)
495 }
496 cmd = exec.CommandContext(ctx, "git", "checkout", "-f", ini.Commit)
497 cmd.Dir = ini.WorkingDir
498 if out, err := cmd.CombinedOutput(); err != nil {
499 return fmt.Errorf("git checkout %s: %s: %w", ini.Commit, out, err)
500 }
Earl Lee2e463fb2025-04-17 11:22:22 -0700501 a.lastHEAD = ini.Commit
502 a.gitRemoteAddr = ini.GitRemoteAddr
503 a.initialCommit = ini.Commit
504 if ini.HostAddr != "" {
505 a.url = "http://" + ini.HostAddr
506 }
507 }
508 a.workingDir = ini.WorkingDir
509
510 if !ini.NoGit {
511 repoRoot, err := repoRoot(ctx, a.workingDir)
512 if err != nil {
513 return fmt.Errorf("repoRoot: %w", err)
514 }
515 a.repoRoot = repoRoot
516
517 commitHash, err := resolveRef(ctx, a.repoRoot, "HEAD")
518 if err != nil {
519 return fmt.Errorf("resolveRef: %w", err)
520 }
521 a.initialCommit = commitHash
522
523 codereview, err := claudetool.NewCodeReviewer(ctx, a.repoRoot, a.initialCommit)
524 if err != nil {
525 return fmt.Errorf("Agent.Init: claudetool.NewCodeReviewer: %w", err)
526 }
527 a.codereview = codereview
Philip Zeyligerd1402952025-04-23 03:54:37 +0000528
529 a.gitOrigin = getGitOrigin(ctx, ini.WorkingDir)
Earl Lee2e463fb2025-04-17 11:22:22 -0700530 }
531 a.lastHEAD = a.initialCommit
532 a.convo = a.initConvo()
533 close(a.ready)
534 return nil
535}
536
537// initConvo initializes the conversation.
538// It must not be called until all agent fields are initialized,
539// particularly workingDir and git.
540func (a *Agent) initConvo() *ant.Convo {
541 ctx := a.config.Context
542 convo := ant.NewConvo(ctx, a.config.APIKey)
543 if a.config.HTTPC != nil {
544 convo.HTTPC = a.config.HTTPC
545 }
546 if a.config.AntURL != "" {
547 convo.URL = a.config.AntURL
548 }
549 convo.PromptCaching = true
550 convo.Budget = a.config.Budget
551
552 var editPrompt string
553 if a.config.UseAnthropicEdit {
554 editPrompt = "Then use the str_replace_editor tool to make those edits. For short complete file replacements, you may use the bash tool with cat and heredoc stdin."
555 } else {
556 editPrompt = "Then use the patch tool to make those edits. Combine all edits to any given file into a single patch tool call."
557 }
558
559 convo.SystemPrompt = fmt.Sprintf(`
560You are an expert coding assistant and architect, with a specialty in Go.
561You are assisting the user to achieve their goals.
562
563Start by asking concise clarifying questions as needed.
564Once the intent is clear, work autonomously.
565
566Call the title tool early in the conversation to provide a brief summary of
567what the chat is about.
568
569Break down the overall goal into a series of smaller steps.
570(The first step is often: "Make a plan.")
571Then execute each step using tools.
572Update the plan if you have encountered problems or learned new information.
573
574When in doubt about a step, follow this broad workflow:
575
576- Think about how the current step fits into the overall plan.
577- Do research. Good tool choices: bash, think, keyword_search
578- Make edits.
579- Repeat.
580
581To make edits reliably and efficiently, first think about the intent of the edit,
582and what set of patches will achieve that intent.
583%s
584
585For renames or refactors, consider invoking gopls (via bash).
586
587The done tool provides a checklist of items you MUST verify and
588review before declaring that you are done. Before executing
589the done tool, run all the tools the done tool checklist asks
590for, including creating a git commit. Do not forget to run tests.
591
592<platform>
593%s/%s
594</platform>
595<pwd>
596%v
597</pwd>
598<git_root>
599%v
600</git_root>
Josh Bleecher Snyder833a0f82025-04-24 18:39:36 +0000601<HEAD>
602%v
603</HEAD>
604`, editPrompt, a.config.ClientGOOS, a.config.ClientGOARCH, a.workingDir, a.repoRoot, a.initialCommit)
Earl Lee2e463fb2025-04-17 11:22:22 -0700605
606 // Register all tools with the conversation
607 // When adding, removing, or modifying tools here, double-check that the termui tool display
608 // template in termui/termui.go has pretty-printing support for all tools.
609 convo.Tools = []*ant.Tool{
610 claudetool.Bash, claudetool.Keyword,
611 claudetool.Think, a.titleTool(), makeDoneTool(a.codereview, a.config.GitUsername, a.config.GitEmail),
612 a.codereview.Tool(),
613 }
614 if a.config.UseAnthropicEdit {
615 convo.Tools = append(convo.Tools, claudetool.AnthropicEditTool)
616 } else {
617 convo.Tools = append(convo.Tools, claudetool.Patch)
618 }
619 convo.Listener = a
620 return convo
621}
622
623func (a *Agent) titleTool() *ant.Tool {
624 // titleTool creates the title tool that sets the conversation title.
625 title := &ant.Tool{
626 Name: "title",
627 Description: `Use this tool early in the conversation, BEFORE MAKING ANY GIT COMMITS, to summarize what the chat is about briefly.`,
628 InputSchema: json.RawMessage(`{
629 "type": "object",
630 "properties": {
631 "title": {
632 "type": "string",
633 "description": "A brief title summarizing what this chat is about"
634 }
635 },
636 "required": ["title"]
637}`),
638 Run: func(ctx context.Context, input json.RawMessage) (string, error) {
639 var params struct {
640 Title string `json:"title"`
641 }
642 if err := json.Unmarshal(input, &params); err != nil {
643 return "", err
644 }
645 a.SetTitle(params.Title)
646 return fmt.Sprintf("Title set to: %s", params.Title), nil
647 },
648 }
649 return title
650}
651
652func (a *Agent) Ready() <-chan struct{} {
653 return a.ready
654}
655
656func (a *Agent) UserMessage(ctx context.Context, msg string) {
657 a.pushToOutbox(ctx, AgentMessage{Type: UserMessageType, Content: msg})
658 a.inbox <- msg
659}
660
661func (a *Agent) WaitForMessage(ctx context.Context) AgentMessage {
662 // TODO: Should this drain any outbox messages in case there are multiple?
663 select {
664 case msg := <-a.outbox:
665 return msg
666 case <-ctx.Done():
667 return errorMessage(ctx.Err())
668 }
669}
670
671func (a *Agent) CancelToolUse(toolUseID string, cause error) error {
672 return a.convo.CancelToolUse(toolUseID, cause)
673}
674
675func (a *Agent) CancelInnerLoop(cause error) {
676 a.cancelInnerLoopMu.Lock()
677 defer a.cancelInnerLoopMu.Unlock()
678 if a.cancelInnerLoop != nil {
679 a.cancelInnerLoop(cause)
680 }
681}
682
683func (a *Agent) Loop(ctxOuter context.Context) {
684 for {
685 select {
686 case <-ctxOuter.Done():
687 return
688 default:
689 ctxInner, cancel := context.WithCancelCause(ctxOuter)
690 a.cancelInnerLoopMu.Lock()
691 // Set .cancelInnerLoop so the user can cancel whatever is happening
692 // inside InnerLoop(ctxInner) without canceling this outer Loop execution.
693 // This CancelInnerLoop func is intended be called from other goroutines,
694 // hence the mutex.
695 a.cancelInnerLoop = cancel
696 a.cancelInnerLoopMu.Unlock()
697 a.InnerLoop(ctxInner)
698 cancel(nil)
699 }
700 }
701}
702
703func (a *Agent) pushToOutbox(ctx context.Context, m AgentMessage) {
704 if m.Timestamp.IsZero() {
705 m.Timestamp = time.Now()
706 }
707
708 // If this is an end-of-turn message, calculate the turn duration and add it to the message
709 if m.EndOfTurn && m.Type == AgentMessageType {
710 turnDuration := time.Since(a.startOfTurn)
711 m.TurnDuration = &turnDuration
712 slog.InfoContext(ctx, "Turn completed", "turnDuration", turnDuration)
713 }
714
715 slog.InfoContext(ctx, "agent message", m.Attr())
716
717 a.mu.Lock()
718 defer a.mu.Unlock()
719 m.Idx = len(a.history)
720 a.history = append(a.history, m)
721 a.outbox <- m
722
723 // Notify all listeners:
724 for _, ch := range a.listeners {
725 close(ch)
726 }
727 a.listeners = a.listeners[:0]
728}
729
730func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
731 var m []ant.Content
732 if block {
733 select {
734 case <-ctx.Done():
735 return m, ctx.Err()
736 case msg := <-a.inbox:
737 m = append(m, ant.Content{Type: "text", Text: msg})
738 }
739 }
740 for {
741 select {
742 case msg := <-a.inbox:
743 m = append(m, ant.Content{Type: "text", Text: msg})
744 default:
745 return m, nil
746 }
747 }
748}
749
750func (a *Agent) InnerLoop(ctx context.Context) {
751 // Reset the start of turn time
752 a.startOfTurn = time.Now()
753
754 // Wait for at least one message from the user.
755 msgs, err := a.GatherMessages(ctx, true)
756 if err != nil { // e.g. the context was canceled while blocking in GatherMessages
757 return
758 }
759 // We do this as we go, but let's also do it at the end of the turn
760 defer func() {
761 if _, err := a.handleGitCommits(ctx); err != nil {
762 // Just log the error, don't stop execution
763 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
764 }
765 }()
766
767 userMessage := ant.Message{
768 Role: "user",
769 Content: msgs,
770 }
771 // convo.SendMessage does the actual network call to send this to anthropic. This blocks until the response is ready.
772 // TODO: pass ctx to SendMessage, and figure out how to square that ctx with convo's own .Ctx. Who owns the scope of this call?
773 resp, err := a.convo.SendMessage(userMessage)
774 if err != nil {
775 a.pushToOutbox(ctx, errorMessage(err))
776 return
777 }
778 for {
779 // TODO: here and below where we check the budget,
780 // we should review the UX: is it clear what happened?
781 // is it clear how to resume?
782 // should we let the user set a new budget?
783 if err := a.overBudget(ctx); err != nil {
784 return
785 }
786 if resp.StopReason != ant.StopReasonToolUse {
787 break
788 }
789 var results []ant.Content
790 cancelled := false
791 select {
792 case <-ctx.Done():
793 // Don't actually run any of the tools, but rather build a response
794 // for each tool_use message letting the LLM know that user canceled it.
795 results, err = a.convo.ToolResultCancelContents(resp)
796 if err != nil {
797 a.pushToOutbox(ctx, errorMessage(err))
798 }
799 cancelled = true
800 default:
801 ctx = claudetool.WithWorkingDir(ctx, a.workingDir)
802 // fall-through, when the user has not canceled the inner loop:
803 results, err = a.convo.ToolResultContents(ctx, resp)
804 if ctx.Err() != nil { // e.g. the user canceled the operation
805 cancelled = true
806 } else if err != nil {
807 a.pushToOutbox(ctx, errorMessage(err))
808 }
809 }
810
811 // Check for git commits. Currently we do this here, after we collect
812 // tool results, since that's when we know commits could have happened.
813 // We could instead do this when the turn ends, but I think it makes sense
814 // to do this as we go.
815 newCommits, err := a.handleGitCommits(ctx)
816 if err != nil {
817 // Just log the error, don't stop execution
818 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
819 }
820 var autoqualityMessages []string
821 if len(newCommits) == 1 {
822 formatted := a.codereview.Autoformat(ctx)
823 if len(formatted) > 0 {
824 msg := fmt.Sprintf(`
825I ran autoformatters and they updated these files:
826
827%s
828
829Please amend your latest git commit with these changes and then continue with what you were doing.`,
830 strings.Join(formatted, "\n"),
831 )[1:]
832 a.pushToOutbox(ctx, AgentMessage{
833 Type: AutoMessageType,
834 Content: msg,
835 Timestamp: time.Now(),
836 })
837 autoqualityMessages = append(autoqualityMessages, msg)
838 }
839 }
840
841 if err := a.overBudget(ctx); err != nil {
842 return
843 }
844
845 // Include, along with the tool results (which must go first for whatever reason),
846 // any messages that the user has sent along while the tool_use was executing concurrently.
847 msgs, err = a.GatherMessages(ctx, false)
848 if err != nil {
849 return
850 }
851 // Inject any auto-generated messages from quality checks.
852 for _, msg := range autoqualityMessages {
853 msgs = append(msgs, ant.Content{Type: "text", Text: msg})
854 }
855 if cancelled {
856 msgs = append(msgs, ant.Content{Type: "text", Text: cancelToolUseMessage})
857 // EndOfTurn is false here so that the client of this agent keeps processing
858 // messages from WaitForMessage() and gets the response from the LLM (usually
859 // something like "okay, I'll wait further instructions", but the user should
860 // be made aware of it regardless).
861 a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
862 } else if err := a.convo.OverBudget(); err != nil {
863 budgetMsg := "We've exceeded our budget. Please ask the user to confirm before continuing by ending the turn."
864 msgs = append(msgs, ant.Content{Type: "text", Text: budgetMsg})
865 a.pushToOutbox(ctx, budgetMessage(fmt.Errorf("warning: %w (ask to keep trying, if you'd like)", err)))
866 }
867 results = append(results, msgs...)
868 resp, err = a.convo.SendMessage(ant.Message{
869 Role: "user",
870 Content: results,
871 })
872 if err != nil {
873 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("error: failed to continue conversation: %s", err.Error())))
874 break
875 }
876 if cancelled {
877 return
878 }
879 }
880}
881
882func (a *Agent) overBudget(ctx context.Context) error {
883 if err := a.convo.OverBudget(); err != nil {
884 m := budgetMessage(err)
885 m.Content = m.Content + "\n\nBudget reset."
886 a.pushToOutbox(ctx, budgetMessage(err))
887 a.convo.ResetBudget(a.originalBudget)
888 return err
889 }
890 return nil
891}
892
893func collectTextContent(msg *ant.MessageResponse) string {
894 // Collect all text content
895 var allText strings.Builder
896 for _, content := range msg.Content {
897 if content.Type == "text" && content.Text != "" {
898 if allText.Len() > 0 {
899 allText.WriteString("\n\n")
900 }
901 allText.WriteString(content.Text)
902 }
903 }
904 return allText.String()
905}
906
907func (a *Agent) TotalUsage() ant.CumulativeUsage {
908 a.mu.Lock()
909 defer a.mu.Unlock()
910 return a.convo.CumulativeUsage()
911}
912
913// WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
914func (a *Agent) WaitForMessageCount(ctx context.Context, greaterThan int) {
915 for a.MessageCount() <= greaterThan {
916 a.mu.Lock()
917 ch := make(chan struct{})
918 // Deletion happens when we notify.
919 a.listeners = append(a.listeners, ch)
920 a.mu.Unlock()
921
922 select {
923 case <-ctx.Done():
924 return
925 case <-ch:
926 continue
927 }
928 }
929}
930
931// Diff returns a unified diff of changes made since the agent was instantiated.
932func (a *Agent) Diff(commit *string) (string, error) {
933 if a.initialCommit == "" {
934 return "", fmt.Errorf("no initial commit reference available")
935 }
936
937 // Find the repository root
938 ctx := context.Background()
939
940 // If a specific commit hash is provided, show just that commit's changes
941 if commit != nil && *commit != "" {
942 // Validate that the commit looks like a valid git SHA
943 if !isValidGitSHA(*commit) {
944 return "", fmt.Errorf("invalid git commit SHA format: %s", *commit)
945 }
946
947 // Get the diff for just this commit
948 cmd := exec.CommandContext(ctx, "git", "show", "--unified=10", *commit)
949 cmd.Dir = a.repoRoot
950 output, err := cmd.CombinedOutput()
951 if err != nil {
952 return "", fmt.Errorf("failed to get diff for commit %s: %w - %s", *commit, err, string(output))
953 }
954 return string(output), nil
955 }
956
957 // Otherwise, get the diff between the initial commit and the current state using exec.Command
958 cmd := exec.CommandContext(ctx, "git", "diff", "--unified=10", a.initialCommit)
959 cmd.Dir = a.repoRoot
960 output, err := cmd.CombinedOutput()
961 if err != nil {
962 return "", fmt.Errorf("failed to get diff: %w - %s", err, string(output))
963 }
964
965 return string(output), nil
966}
967
968// InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
969func (a *Agent) InitialCommit() string {
970 return a.initialCommit
971}
972
973// handleGitCommits() highlights new commits to the user. When running
974// under docker, new HEADs are pushed to a branch according to the title.
975func (a *Agent) handleGitCommits(ctx context.Context) ([]*GitCommit, error) {
976 if a.repoRoot == "" {
977 return nil, nil
978 }
979
980 head, err := resolveRef(ctx, a.repoRoot, "HEAD")
981 if err != nil {
982 return nil, err
983 }
984 if head == a.lastHEAD {
985 return nil, nil // nothing to do
986 }
987 defer func() {
988 a.lastHEAD = head
989 }()
990
991 // Get new commits. Because it's possible that the agent does rebases, fixups, and
992 // so forth, we use, as our fixed point, the "initialCommit", and we limit ourselves
993 // to the last 100 commits.
994 var commits []*GitCommit
995
996 // Get commits since the initial commit
997 // Format: <hash>\0<subject>\0<body>\0
998 // This uses NULL bytes as separators to avoid issues with newlines in commit messages
999 // Limit to 100 commits to avoid overwhelming the user
1000 cmd := exec.CommandContext(ctx, "git", "log", "-n", "100", "--pretty=format:%H%x00%s%x00%b%x00", "^"+a.initialCommit, head)
1001 cmd.Dir = a.repoRoot
1002 output, err := cmd.Output()
1003 if err != nil {
1004 return nil, fmt.Errorf("failed to get git log: %w", err)
1005 }
1006
1007 // Parse git log output and filter out already seen commits
1008 parsedCommits := parseGitLog(string(output))
1009
1010 var headCommit *GitCommit
1011
1012 // Filter out commits we've already seen
1013 for _, commit := range parsedCommits {
1014 if commit.Hash == head {
1015 headCommit = &commit
1016 }
1017
1018 // Skip if we've seen this commit before. If our head has changed, always include that.
1019 if a.seenCommits[commit.Hash] && commit.Hash != head {
1020 continue
1021 }
1022
1023 // Mark this commit as seen
1024 a.seenCommits[commit.Hash] = true
1025
1026 // Add to our list of new commits
1027 commits = append(commits, &commit)
1028 }
1029
1030 if a.gitRemoteAddr != "" {
1031 if headCommit == nil {
1032 // I think this can only happen if we have a bug or if there's a race.
1033 headCommit = &GitCommit{}
1034 headCommit.Hash = head
1035 headCommit.Subject = "unknown"
1036 commits = append(commits, headCommit)
1037 }
1038
1039 cleanTitle := titleToBranch(a.title)
1040 if cleanTitle == "" {
1041 cleanTitle = a.config.SessionID
1042 }
1043 branch := "sketch/" + cleanTitle
1044
1045 // TODO: I don't love the force push here. We could see if the push is a fast-forward, and,
1046 // if it's not, we could make a backup with a unique name (perhaps append a timestamp) and
1047 // then use push with lease to replace.
1048 cmd = exec.Command("git", "push", "--force", a.gitRemoteAddr, "HEAD:refs/heads/"+branch)
1049 cmd.Dir = a.workingDir
1050 if out, err := cmd.CombinedOutput(); err != nil {
1051 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("git push to host: %s: %v", out, err)))
1052 } else {
1053 headCommit.PushedBranch = branch
1054 }
1055 }
1056
1057 // If we found new commits, create a message
1058 if len(commits) > 0 {
1059 msg := AgentMessage{
1060 Type: CommitMessageType,
1061 Timestamp: time.Now(),
1062 Commits: commits,
1063 }
1064 a.pushToOutbox(ctx, msg)
1065 }
1066 return commits, nil
1067}
1068
1069func titleToBranch(s string) string {
1070 // Convert to lowercase
1071 s = strings.ToLower(s)
1072
1073 // Replace spaces with hyphens
1074 s = strings.ReplaceAll(s, " ", "-")
1075
1076 // Remove any character that isn't a-z or hyphen
1077 var result strings.Builder
1078 for _, r := range s {
1079 if (r >= 'a' && r <= 'z') || r == '-' {
1080 result.WriteRune(r)
1081 }
1082 }
1083 return result.String()
1084}
1085
1086// parseGitLog parses the output of git log with format '%H%x00%s%x00%b%x00'
1087// and returns an array of GitCommit structs.
1088func parseGitLog(output string) []GitCommit {
1089 var commits []GitCommit
1090
1091 // No output means no commits
1092 if len(output) == 0 {
1093 return commits
1094 }
1095
1096 // Split by NULL byte
1097 parts := strings.Split(output, "\x00")
1098
1099 // Process in triplets (hash, subject, body)
1100 for i := 0; i < len(parts); i++ {
1101 // Skip empty parts
1102 if parts[i] == "" {
1103 continue
1104 }
1105
1106 // This should be a hash
1107 hash := strings.TrimSpace(parts[i])
1108
1109 // Make sure we have at least a subject part available
1110 if i+1 >= len(parts) {
1111 break // No more parts available
1112 }
1113
1114 // Get the subject
1115 subject := strings.TrimSpace(parts[i+1])
1116
1117 // Get the body if available
1118 body := ""
1119 if i+2 < len(parts) {
1120 body = strings.TrimSpace(parts[i+2])
1121 }
1122
1123 // Skip to the next triplet
1124 i += 2
1125
1126 commits = append(commits, GitCommit{
1127 Hash: hash,
1128 Subject: subject,
1129 Body: body,
1130 })
1131 }
1132
1133 return commits
1134}
1135
1136func repoRoot(ctx context.Context, dir string) (string, error) {
1137 cmd := exec.CommandContext(ctx, "git", "rev-parse", "--show-toplevel")
1138 stderr := new(strings.Builder)
1139 cmd.Stderr = stderr
1140 cmd.Dir = dir
1141 out, err := cmd.Output()
1142 if err != nil {
1143 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1144 }
1145 return strings.TrimSpace(string(out)), nil
1146}
1147
1148func resolveRef(ctx context.Context, dir, refName string) (string, error) {
1149 cmd := exec.CommandContext(ctx, "git", "rev-parse", refName)
1150 stderr := new(strings.Builder)
1151 cmd.Stderr = stderr
1152 cmd.Dir = dir
1153 out, err := cmd.Output()
1154 if err != nil {
1155 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1156 }
1157 // TODO: validate that out is valid hex
1158 return strings.TrimSpace(string(out)), nil
1159}
1160
1161// isValidGitSHA validates if a string looks like a valid git SHA hash.
1162// Git SHAs are hexadecimal strings of at least 4 characters but typically 7, 8, or 40 characters.
1163func isValidGitSHA(sha string) bool {
1164 // Git SHA must be a hexadecimal string with at least 4 characters
1165 if len(sha) < 4 || len(sha) > 40 {
1166 return false
1167 }
1168
1169 // Check if the string only contains hexadecimal characters
1170 for _, char := range sha {
1171 if !(char >= '0' && char <= '9') && !(char >= 'a' && char <= 'f') && !(char >= 'A' && char <= 'F') {
1172 return false
1173 }
1174 }
1175
1176 return true
1177}
Philip Zeyligerd1402952025-04-23 03:54:37 +00001178
1179// getGitOrigin returns the URL of the git remote 'origin' if it exists
1180func getGitOrigin(ctx context.Context, dir string) string {
1181 cmd := exec.CommandContext(ctx, "git", "config", "--get", "remote.origin.url")
1182 cmd.Dir = dir
1183 stderr := new(strings.Builder)
1184 cmd.Stderr = stderr
1185 out, err := cmd.Output()
1186 if err != nil {
1187 return ""
1188 }
1189 return strings.TrimSpace(string(out))
1190}