blob: dda3aa6313749e0e1f2cdf534fad45d110727bef [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "os"
10 "os/exec"
11 "runtime/debug"
12 "slices"
13 "strings"
14 "sync"
15 "time"
16
17 "sketch.dev/ant"
18 "sketch.dev/claudetool"
19)
20
21const (
22 userCancelMessage = "user requested agent to stop handling responses"
23)
24
25type CodingAgent interface {
26 // Init initializes an agent inside a docker container.
27 Init(AgentInit) error
28
29 // Ready returns a channel closed after Init successfully called.
30 Ready() <-chan struct{}
31
32 // URL reports the HTTP URL of this agent.
33 URL() string
34
35 // UserMessage enqueues a message to the agent and returns immediately.
36 UserMessage(ctx context.Context, msg string)
37
38 // WaitForMessage blocks until the agent has a response to give.
39 // Use AgentMessage.EndOfTurn to help determine if you want to
40 // drain the agent.
41 WaitForMessage(ctx context.Context) AgentMessage
42
43 // Loop begins the agent loop returns only when ctx is cancelled.
44 Loop(ctx context.Context)
45
46 CancelInnerLoop(cause error)
47
48 CancelToolUse(toolUseID string, cause error) error
49
50 // Returns a subset of the agent's message history.
51 Messages(start int, end int) []AgentMessage
52
53 // Returns the current number of messages in the history
54 MessageCount() int
55
56 TotalUsage() ant.CumulativeUsage
57 OriginalBudget() ant.Budget
58
59 // WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
60 WaitForMessageCount(ctx context.Context, greaterThan int)
61
62 WorkingDir() string
63
64 // Diff returns a unified diff of changes made since the agent was instantiated.
65 // If commit is non-nil, it shows the diff for just that specific commit.
66 Diff(commit *string) (string, error)
67
68 // InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
69 InitialCommit() string
70
71 // Title returns the current title of the conversation.
72 Title() string
73
74 // OS returns the operating system of the client.
75 OS() string
76}
77
78type CodingAgentMessageType string
79
80const (
81 UserMessageType CodingAgentMessageType = "user"
82 AgentMessageType CodingAgentMessageType = "agent"
83 ErrorMessageType CodingAgentMessageType = "error"
84 BudgetMessageType CodingAgentMessageType = "budget" // dedicated for "out of budget" errors
85 ToolUseMessageType CodingAgentMessageType = "tool"
86 CommitMessageType CodingAgentMessageType = "commit" // for displaying git commits
87 AutoMessageType CodingAgentMessageType = "auto" // for automated notifications like autoformatting
88
89 cancelToolUseMessage = "Stop responding to my previous message. Wait for me to ask you something else before attempting to use any more tools."
90)
91
92type AgentMessage struct {
93 Type CodingAgentMessageType `json:"type"`
94 // EndOfTurn indicates that the AI is done working and is ready for the next user input.
95 EndOfTurn bool `json:"end_of_turn"`
96
97 Content string `json:"content"`
98 ToolName string `json:"tool_name,omitempty"`
99 ToolInput string `json:"input,omitempty"`
100 ToolResult string `json:"tool_result,omitempty"`
101 ToolError bool `json:"tool_error,omitempty"`
102 ToolCallId string `json:"tool_call_id,omitempty"`
103
104 // ToolCalls is a list of all tool calls requested in this message (name and input pairs)
105 ToolCalls []ToolCall `json:"tool_calls,omitempty"`
106
Sean McCulloughd9f13372025-04-21 15:08:49 -0700107 // ToolResponses is a list of all responses to tool calls requested in this message (name and input pairs)
108 ToolResponses []AgentMessage `json:"toolResponses,omitempty"`
109
Earl Lee2e463fb2025-04-17 11:22:22 -0700110 // Commits is a list of git commits for a commit message
111 Commits []*GitCommit `json:"commits,omitempty"`
112
113 Timestamp time.Time `json:"timestamp"`
114 ConversationID string `json:"conversation_id"`
115 ParentConversationID *string `json:"parent_conversation_id,omitempty"`
116 Usage *ant.Usage `json:"usage,omitempty"`
117
118 // Message timing information
119 StartTime *time.Time `json:"start_time,omitempty"`
120 EndTime *time.Time `json:"end_time,omitempty"`
121 Elapsed *time.Duration `json:"elapsed,omitempty"`
122
123 // Turn duration - the time taken for a complete agent turn
124 TurnDuration *time.Duration `json:"turnDuration,omitempty"`
125
126 Idx int `json:"idx"`
127}
128
129// GitCommit represents a single git commit for a commit message
130type GitCommit struct {
131 Hash string `json:"hash"` // Full commit hash
132 Subject string `json:"subject"` // Commit subject line
133 Body string `json:"body"` // Full commit message body
134 PushedBranch string `json:"pushed_branch,omitempty"` // If set, this commit was pushed to this branch
135}
136
137// ToolCall represents a single tool call within an agent message
138type ToolCall struct {
Sean McCulloughd9f13372025-04-21 15:08:49 -0700139 Name string `json:"name"`
140 Input string `json:"input"`
141 ToolCallId string `json:"tool_call_id"`
142 ResultMessage *AgentMessage `json:"result_message,omitempty"`
143 Args string `json:"args,omitempty"`
144 Result string `json:"result,omitempty"`
Earl Lee2e463fb2025-04-17 11:22:22 -0700145}
146
147func (a *AgentMessage) Attr() slog.Attr {
148 var attrs []any = []any{
149 slog.String("type", string(a.Type)),
150 }
151 if a.EndOfTurn {
152 attrs = append(attrs, slog.Bool("end_of_turn", a.EndOfTurn))
153 }
154 if a.Content != "" {
155 attrs = append(attrs, slog.String("content", a.Content))
156 }
157 if a.ToolName != "" {
158 attrs = append(attrs, slog.String("tool_name", a.ToolName))
159 }
160 if a.ToolInput != "" {
161 attrs = append(attrs, slog.String("tool_input", a.ToolInput))
162 }
163 if a.Elapsed != nil {
164 attrs = append(attrs, slog.Int64("elapsed", a.Elapsed.Nanoseconds()))
165 }
166 if a.TurnDuration != nil {
167 attrs = append(attrs, slog.Int64("turnDuration", a.TurnDuration.Nanoseconds()))
168 }
169 if a.ToolResult != "" {
170 attrs = append(attrs, slog.String("tool_result", a.ToolResult))
171 }
172 if a.ToolError {
173 attrs = append(attrs, slog.Bool("tool_error", a.ToolError))
174 }
175 if len(a.ToolCalls) > 0 {
176 toolCallAttrs := make([]any, 0, len(a.ToolCalls))
177 for i, tc := range a.ToolCalls {
178 toolCallAttrs = append(toolCallAttrs, slog.Group(
179 fmt.Sprintf("tool_call_%d", i),
180 slog.String("name", tc.Name),
181 slog.String("input", tc.Input),
182 ))
183 }
184 attrs = append(attrs, slog.Group("tool_calls", toolCallAttrs...))
185 }
186 if a.ConversationID != "" {
187 attrs = append(attrs, slog.String("convo_id", a.ConversationID))
188 }
189 if a.ParentConversationID != nil {
190 attrs = append(attrs, slog.String("parent_convo_id", *a.ParentConversationID))
191 }
192 if a.Usage != nil && !a.Usage.IsZero() {
193 attrs = append(attrs, a.Usage.Attr())
194 }
195 // TODO: timestamp, convo ids, idx?
196 return slog.Group("agent_message", attrs...)
197}
198
199func errorMessage(err error) AgentMessage {
200 // It's somewhat unknowable whether error messages are "end of turn" or not, but it seems like the best approach.
201 if os.Getenv(("DEBUG")) == "1" {
202 return AgentMessage{Type: ErrorMessageType, Content: err.Error() + " Stacktrace: " + string(debug.Stack()), EndOfTurn: true}
203 }
204
205 return AgentMessage{Type: ErrorMessageType, Content: err.Error(), EndOfTurn: true}
206}
207
208func budgetMessage(err error) AgentMessage {
209 return AgentMessage{Type: BudgetMessageType, Content: err.Error(), EndOfTurn: true}
210}
211
212// ConvoInterface defines the interface for conversation interactions
213type ConvoInterface interface {
214 CumulativeUsage() ant.CumulativeUsage
215 ResetBudget(ant.Budget)
216 OverBudget() error
217 SendMessage(message ant.Message) (*ant.MessageResponse, error)
218 SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
219 ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
220 ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error)
221 CancelToolUse(toolUseID string, cause error) error
222}
223
224type Agent struct {
225 convo ConvoInterface
226 config AgentConfig // config for this agent
227 workingDir string
228 repoRoot string // workingDir may be a subdir of repoRoot
229 url string
230 lastHEAD string // hash of the last HEAD that was pushed to the host (only when under docker)
231 initialCommit string // hash of the Git HEAD when the agent was instantiated or Init()
232 gitRemoteAddr string // HTTP URL of the host git repo (only when under docker)
233 ready chan struct{} // closed when the agent is initialized (only when under docker)
234 startedAt time.Time
235 originalBudget ant.Budget
236 title string
237 codereview *claudetool.CodeReviewer
238
239 // Time when the current turn started (reset at the beginning of InnerLoop)
240 startOfTurn time.Time
241
242 // Inbox - for messages from the user to the agent.
243 // sent on by UserMessage
244 // . e.g. when user types into the chat textarea
245 // read from by GatherMessages
246 inbox chan string
247
248 // Outbox
249 // sent on by pushToOutbox
250 // via OnToolResult and OnResponse callbacks
251 // read from by WaitForMessage
252 // called by termui inside its repl loop.
253 outbox chan AgentMessage
254
255 // protects cancelInnerLoop
256 cancelInnerLoopMu sync.Mutex
257 // cancels potentially long-running tool_use calls or chains of them
258 cancelInnerLoop context.CancelCauseFunc
259
260 // protects following
261 mu sync.Mutex
262
263 // Stores all messages for this agent
264 history []AgentMessage
265
266 listeners []chan struct{}
267
268 // Track git commits we've already seen (by hash)
269 seenCommits map[string]bool
270}
271
272func (a *Agent) URL() string { return a.url }
273
274// Title returns the current title of the conversation.
275// If no title has been set, returns an empty string.
276func (a *Agent) Title() string {
277 a.mu.Lock()
278 defer a.mu.Unlock()
279 return a.title
280}
281
282// OS returns the operating system of the client.
283func (a *Agent) OS() string {
284 return a.config.ClientGOOS
285}
286
287// SetTitle sets the title of the conversation.
288func (a *Agent) SetTitle(title string) {
289 a.mu.Lock()
290 defer a.mu.Unlock()
291 a.title = title
292 // Notify all listeners that the state has changed
293 for _, ch := range a.listeners {
294 close(ch)
295 }
296 a.listeners = a.listeners[:0]
297}
298
299// OnToolResult implements ant.Listener.
300func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
301 m := AgentMessage{
302 Type: ToolUseMessageType,
303 Content: content.Text,
304 ToolResult: content.ToolResult,
305 ToolError: content.ToolError,
306 ToolName: toolName,
307 ToolInput: string(toolInput),
308 ToolCallId: content.ToolUseID,
309 StartTime: content.StartTime,
310 EndTime: content.EndTime,
311 }
312
313 // Calculate the elapsed time if both start and end times are set
314 if content.StartTime != nil && content.EndTime != nil {
315 elapsed := content.EndTime.Sub(*content.StartTime)
316 m.Elapsed = &elapsed
317 }
318
319 m.ConversationID = convo.ID
320 if convo.Parent != nil {
321 m.ParentConversationID = &convo.Parent.ID
322 }
323 a.pushToOutbox(ctx, m)
324}
325
326// OnRequest implements ant.Listener.
327func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, msg *ant.Message) {
328 // No-op.
329 // We already get tool results from the above. We send user messages to the outbox in the agent loop.
330}
331
332// OnResponse implements ant.Listener. Responses contain messages from the LLM
333// that need to be displayed (as well as tool calls that we send along when
334// they're done). (It would be reasonable to also mention tool calls when they're
335// started, but we don't do that yet.)
336func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, resp *ant.MessageResponse) {
337 endOfTurn := false
338 if resp.StopReason != ant.StopReasonToolUse {
339 endOfTurn = true
340 }
341 m := AgentMessage{
342 Type: AgentMessageType,
343 Content: collectTextContent(resp),
344 EndOfTurn: endOfTurn,
345 Usage: &resp.Usage,
346 StartTime: resp.StartTime,
347 EndTime: resp.EndTime,
348 }
349
350 // Extract any tool calls from the response
351 if resp.StopReason == ant.StopReasonToolUse {
352 var toolCalls []ToolCall
353 for _, part := range resp.Content {
354 if part.Type == "tool_use" {
355 toolCalls = append(toolCalls, ToolCall{
356 Name: part.ToolName,
357 Input: string(part.ToolInput),
358 ToolCallId: part.ID,
359 })
360 }
361 }
362 m.ToolCalls = toolCalls
363 }
364
365 // Calculate the elapsed time if both start and end times are set
366 if resp.StartTime != nil && resp.EndTime != nil {
367 elapsed := resp.EndTime.Sub(*resp.StartTime)
368 m.Elapsed = &elapsed
369 }
370
371 m.ConversationID = convo.ID
372 if convo.Parent != nil {
373 m.ParentConversationID = &convo.Parent.ID
374 }
375 a.pushToOutbox(ctx, m)
376}
377
378// WorkingDir implements CodingAgent.
379func (a *Agent) WorkingDir() string {
380 return a.workingDir
381}
382
383// MessageCount implements CodingAgent.
384func (a *Agent) MessageCount() int {
385 a.mu.Lock()
386 defer a.mu.Unlock()
387 return len(a.history)
388}
389
390// Messages implements CodingAgent.
391func (a *Agent) Messages(start int, end int) []AgentMessage {
392 a.mu.Lock()
393 defer a.mu.Unlock()
394 return slices.Clone(a.history[start:end])
395}
396
397func (a *Agent) OriginalBudget() ant.Budget {
398 return a.originalBudget
399}
400
401// AgentConfig contains configuration for creating a new Agent.
402type AgentConfig struct {
403 Context context.Context
404 AntURL string
405 APIKey string
406 HTTPC *http.Client
407 Budget ant.Budget
408 GitUsername string
409 GitEmail string
410 SessionID string
411 ClientGOOS string
412 ClientGOARCH string
413 UseAnthropicEdit bool
414}
415
416// NewAgent creates a new Agent.
417// It is not usable until Init() is called.
418func NewAgent(config AgentConfig) *Agent {
419 agent := &Agent{
420 config: config,
421 ready: make(chan struct{}),
422 inbox: make(chan string, 100),
423 outbox: make(chan AgentMessage, 100),
424 startedAt: time.Now(),
425 originalBudget: config.Budget,
426 seenCommits: make(map[string]bool),
427 }
428 return agent
429}
430
431type AgentInit struct {
432 WorkingDir string
433 NoGit bool // only for testing
434
435 InDocker bool
436 Commit string
437 GitRemoteAddr string
438 HostAddr string
439}
440
441func (a *Agent) Init(ini AgentInit) error {
442 ctx := a.config.Context
443 if ini.InDocker {
444 cmd := exec.CommandContext(ctx, "git", "stash")
445 cmd.Dir = ini.WorkingDir
446 if out, err := cmd.CombinedOutput(); err != nil {
447 return fmt.Errorf("git stash: %s: %v", out, err)
448 }
449 cmd = exec.CommandContext(ctx, "git", "fetch", ini.GitRemoteAddr)
450 cmd.Dir = ini.WorkingDir
451 if out, err := cmd.CombinedOutput(); err != nil {
452 return fmt.Errorf("git fetch: %s: %w", out, err)
453 }
454 cmd = exec.CommandContext(ctx, "git", "checkout", "-f", ini.Commit)
455 cmd.Dir = ini.WorkingDir
456 if out, err := cmd.CombinedOutput(); err != nil {
457 return fmt.Errorf("git checkout %s: %s: %w", ini.Commit, out, err)
458 }
459 a.lastHEAD = ini.Commit
460 a.gitRemoteAddr = ini.GitRemoteAddr
461 a.initialCommit = ini.Commit
462 if ini.HostAddr != "" {
463 a.url = "http://" + ini.HostAddr
464 }
465 }
466 a.workingDir = ini.WorkingDir
467
468 if !ini.NoGit {
469 repoRoot, err := repoRoot(ctx, a.workingDir)
470 if err != nil {
471 return fmt.Errorf("repoRoot: %w", err)
472 }
473 a.repoRoot = repoRoot
474
475 commitHash, err := resolveRef(ctx, a.repoRoot, "HEAD")
476 if err != nil {
477 return fmt.Errorf("resolveRef: %w", err)
478 }
479 a.initialCommit = commitHash
480
481 codereview, err := claudetool.NewCodeReviewer(ctx, a.repoRoot, a.initialCommit)
482 if err != nil {
483 return fmt.Errorf("Agent.Init: claudetool.NewCodeReviewer: %w", err)
484 }
485 a.codereview = codereview
486 }
487 a.lastHEAD = a.initialCommit
488 a.convo = a.initConvo()
489 close(a.ready)
490 return nil
491}
492
493// initConvo initializes the conversation.
494// It must not be called until all agent fields are initialized,
495// particularly workingDir and git.
496func (a *Agent) initConvo() *ant.Convo {
497 ctx := a.config.Context
498 convo := ant.NewConvo(ctx, a.config.APIKey)
499 if a.config.HTTPC != nil {
500 convo.HTTPC = a.config.HTTPC
501 }
502 if a.config.AntURL != "" {
503 convo.URL = a.config.AntURL
504 }
505 convo.PromptCaching = true
506 convo.Budget = a.config.Budget
507
508 var editPrompt string
509 if a.config.UseAnthropicEdit {
510 editPrompt = "Then use the str_replace_editor tool to make those edits. For short complete file replacements, you may use the bash tool with cat and heredoc stdin."
511 } else {
512 editPrompt = "Then use the patch tool to make those edits. Combine all edits to any given file into a single patch tool call."
513 }
514
515 convo.SystemPrompt = fmt.Sprintf(`
516You are an expert coding assistant and architect, with a specialty in Go.
517You are assisting the user to achieve their goals.
518
519Start by asking concise clarifying questions as needed.
520Once the intent is clear, work autonomously.
521
522Call the title tool early in the conversation to provide a brief summary of
523what the chat is about.
524
525Break down the overall goal into a series of smaller steps.
526(The first step is often: "Make a plan.")
527Then execute each step using tools.
528Update the plan if you have encountered problems or learned new information.
529
530When in doubt about a step, follow this broad workflow:
531
532- Think about how the current step fits into the overall plan.
533- Do research. Good tool choices: bash, think, keyword_search
534- Make edits.
535- Repeat.
536
537To make edits reliably and efficiently, first think about the intent of the edit,
538and what set of patches will achieve that intent.
539%s
540
541For renames or refactors, consider invoking gopls (via bash).
542
543The done tool provides a checklist of items you MUST verify and
544review before declaring that you are done. Before executing
545the done tool, run all the tools the done tool checklist asks
546for, including creating a git commit. Do not forget to run tests.
547
548<platform>
549%s/%s
550</platform>
551<pwd>
552%v
553</pwd>
554<git_root>
555%v
556</git_root>
557`, editPrompt, a.config.ClientGOOS, a.config.ClientGOARCH, a.workingDir, a.repoRoot)
558
559 // Register all tools with the conversation
560 // When adding, removing, or modifying tools here, double-check that the termui tool display
561 // template in termui/termui.go has pretty-printing support for all tools.
562 convo.Tools = []*ant.Tool{
563 claudetool.Bash, claudetool.Keyword,
564 claudetool.Think, a.titleTool(), makeDoneTool(a.codereview, a.config.GitUsername, a.config.GitEmail),
565 a.codereview.Tool(),
566 }
567 if a.config.UseAnthropicEdit {
568 convo.Tools = append(convo.Tools, claudetool.AnthropicEditTool)
569 } else {
570 convo.Tools = append(convo.Tools, claudetool.Patch)
571 }
572 convo.Listener = a
573 return convo
574}
575
576func (a *Agent) titleTool() *ant.Tool {
577 // titleTool creates the title tool that sets the conversation title.
578 title := &ant.Tool{
579 Name: "title",
580 Description: `Use this tool early in the conversation, BEFORE MAKING ANY GIT COMMITS, to summarize what the chat is about briefly.`,
581 InputSchema: json.RawMessage(`{
582 "type": "object",
583 "properties": {
584 "title": {
585 "type": "string",
586 "description": "A brief title summarizing what this chat is about"
587 }
588 },
589 "required": ["title"]
590}`),
591 Run: func(ctx context.Context, input json.RawMessage) (string, error) {
592 var params struct {
593 Title string `json:"title"`
594 }
595 if err := json.Unmarshal(input, &params); err != nil {
596 return "", err
597 }
598 a.SetTitle(params.Title)
599 return fmt.Sprintf("Title set to: %s", params.Title), nil
600 },
601 }
602 return title
603}
604
605func (a *Agent) Ready() <-chan struct{} {
606 return a.ready
607}
608
609func (a *Agent) UserMessage(ctx context.Context, msg string) {
610 a.pushToOutbox(ctx, AgentMessage{Type: UserMessageType, Content: msg})
611 a.inbox <- msg
612}
613
614func (a *Agent) WaitForMessage(ctx context.Context) AgentMessage {
615 // TODO: Should this drain any outbox messages in case there are multiple?
616 select {
617 case msg := <-a.outbox:
618 return msg
619 case <-ctx.Done():
620 return errorMessage(ctx.Err())
621 }
622}
623
624func (a *Agent) CancelToolUse(toolUseID string, cause error) error {
625 return a.convo.CancelToolUse(toolUseID, cause)
626}
627
628func (a *Agent) CancelInnerLoop(cause error) {
629 a.cancelInnerLoopMu.Lock()
630 defer a.cancelInnerLoopMu.Unlock()
631 if a.cancelInnerLoop != nil {
632 a.cancelInnerLoop(cause)
633 }
634}
635
636func (a *Agent) Loop(ctxOuter context.Context) {
637 for {
638 select {
639 case <-ctxOuter.Done():
640 return
641 default:
642 ctxInner, cancel := context.WithCancelCause(ctxOuter)
643 a.cancelInnerLoopMu.Lock()
644 // Set .cancelInnerLoop so the user can cancel whatever is happening
645 // inside InnerLoop(ctxInner) without canceling this outer Loop execution.
646 // This CancelInnerLoop func is intended be called from other goroutines,
647 // hence the mutex.
648 a.cancelInnerLoop = cancel
649 a.cancelInnerLoopMu.Unlock()
650 a.InnerLoop(ctxInner)
651 cancel(nil)
652 }
653 }
654}
655
656func (a *Agent) pushToOutbox(ctx context.Context, m AgentMessage) {
657 if m.Timestamp.IsZero() {
658 m.Timestamp = time.Now()
659 }
660
661 // If this is an end-of-turn message, calculate the turn duration and add it to the message
662 if m.EndOfTurn && m.Type == AgentMessageType {
663 turnDuration := time.Since(a.startOfTurn)
664 m.TurnDuration = &turnDuration
665 slog.InfoContext(ctx, "Turn completed", "turnDuration", turnDuration)
666 }
667
668 slog.InfoContext(ctx, "agent message", m.Attr())
669
670 a.mu.Lock()
671 defer a.mu.Unlock()
672 m.Idx = len(a.history)
673 a.history = append(a.history, m)
674 a.outbox <- m
675
676 // Notify all listeners:
677 for _, ch := range a.listeners {
678 close(ch)
679 }
680 a.listeners = a.listeners[:0]
681}
682
683func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
684 var m []ant.Content
685 if block {
686 select {
687 case <-ctx.Done():
688 return m, ctx.Err()
689 case msg := <-a.inbox:
690 m = append(m, ant.Content{Type: "text", Text: msg})
691 }
692 }
693 for {
694 select {
695 case msg := <-a.inbox:
696 m = append(m, ant.Content{Type: "text", Text: msg})
697 default:
698 return m, nil
699 }
700 }
701}
702
703func (a *Agent) InnerLoop(ctx context.Context) {
704 // Reset the start of turn time
705 a.startOfTurn = time.Now()
706
707 // Wait for at least one message from the user.
708 msgs, err := a.GatherMessages(ctx, true)
709 if err != nil { // e.g. the context was canceled while blocking in GatherMessages
710 return
711 }
712 // We do this as we go, but let's also do it at the end of the turn
713 defer func() {
714 if _, err := a.handleGitCommits(ctx); err != nil {
715 // Just log the error, don't stop execution
716 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
717 }
718 }()
719
720 userMessage := ant.Message{
721 Role: "user",
722 Content: msgs,
723 }
724 // convo.SendMessage does the actual network call to send this to anthropic. This blocks until the response is ready.
725 // TODO: pass ctx to SendMessage, and figure out how to square that ctx with convo's own .Ctx. Who owns the scope of this call?
726 resp, err := a.convo.SendMessage(userMessage)
727 if err != nil {
728 a.pushToOutbox(ctx, errorMessage(err))
729 return
730 }
731 for {
732 // TODO: here and below where we check the budget,
733 // we should review the UX: is it clear what happened?
734 // is it clear how to resume?
735 // should we let the user set a new budget?
736 if err := a.overBudget(ctx); err != nil {
737 return
738 }
739 if resp.StopReason != ant.StopReasonToolUse {
740 break
741 }
742 var results []ant.Content
743 cancelled := false
744 select {
745 case <-ctx.Done():
746 // Don't actually run any of the tools, but rather build a response
747 // for each tool_use message letting the LLM know that user canceled it.
748 results, err = a.convo.ToolResultCancelContents(resp)
749 if err != nil {
750 a.pushToOutbox(ctx, errorMessage(err))
751 }
752 cancelled = true
753 default:
754 ctx = claudetool.WithWorkingDir(ctx, a.workingDir)
755 // fall-through, when the user has not canceled the inner loop:
756 results, err = a.convo.ToolResultContents(ctx, resp)
757 if ctx.Err() != nil { // e.g. the user canceled the operation
758 cancelled = true
759 } else if err != nil {
760 a.pushToOutbox(ctx, errorMessage(err))
761 }
762 }
763
764 // Check for git commits. Currently we do this here, after we collect
765 // tool results, since that's when we know commits could have happened.
766 // We could instead do this when the turn ends, but I think it makes sense
767 // to do this as we go.
768 newCommits, err := a.handleGitCommits(ctx)
769 if err != nil {
770 // Just log the error, don't stop execution
771 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
772 }
773 var autoqualityMessages []string
774 if len(newCommits) == 1 {
775 formatted := a.codereview.Autoformat(ctx)
776 if len(formatted) > 0 {
777 msg := fmt.Sprintf(`
778I ran autoformatters and they updated these files:
779
780%s
781
782Please amend your latest git commit with these changes and then continue with what you were doing.`,
783 strings.Join(formatted, "\n"),
784 )[1:]
785 a.pushToOutbox(ctx, AgentMessage{
786 Type: AutoMessageType,
787 Content: msg,
788 Timestamp: time.Now(),
789 })
790 autoqualityMessages = append(autoqualityMessages, msg)
791 }
792 }
793
794 if err := a.overBudget(ctx); err != nil {
795 return
796 }
797
798 // Include, along with the tool results (which must go first for whatever reason),
799 // any messages that the user has sent along while the tool_use was executing concurrently.
800 msgs, err = a.GatherMessages(ctx, false)
801 if err != nil {
802 return
803 }
804 // Inject any auto-generated messages from quality checks.
805 for _, msg := range autoqualityMessages {
806 msgs = append(msgs, ant.Content{Type: "text", Text: msg})
807 }
808 if cancelled {
809 msgs = append(msgs, ant.Content{Type: "text", Text: cancelToolUseMessage})
810 // EndOfTurn is false here so that the client of this agent keeps processing
811 // messages from WaitForMessage() and gets the response from the LLM (usually
812 // something like "okay, I'll wait further instructions", but the user should
813 // be made aware of it regardless).
814 a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
815 } else if err := a.convo.OverBudget(); err != nil {
816 budgetMsg := "We've exceeded our budget. Please ask the user to confirm before continuing by ending the turn."
817 msgs = append(msgs, ant.Content{Type: "text", Text: budgetMsg})
818 a.pushToOutbox(ctx, budgetMessage(fmt.Errorf("warning: %w (ask to keep trying, if you'd like)", err)))
819 }
820 results = append(results, msgs...)
821 resp, err = a.convo.SendMessage(ant.Message{
822 Role: "user",
823 Content: results,
824 })
825 if err != nil {
826 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("error: failed to continue conversation: %s", err.Error())))
827 break
828 }
829 if cancelled {
830 return
831 }
832 }
833}
834
835func (a *Agent) overBudget(ctx context.Context) error {
836 if err := a.convo.OverBudget(); err != nil {
837 m := budgetMessage(err)
838 m.Content = m.Content + "\n\nBudget reset."
839 a.pushToOutbox(ctx, budgetMessage(err))
840 a.convo.ResetBudget(a.originalBudget)
841 return err
842 }
843 return nil
844}
845
846func collectTextContent(msg *ant.MessageResponse) string {
847 // Collect all text content
848 var allText strings.Builder
849 for _, content := range msg.Content {
850 if content.Type == "text" && content.Text != "" {
851 if allText.Len() > 0 {
852 allText.WriteString("\n\n")
853 }
854 allText.WriteString(content.Text)
855 }
856 }
857 return allText.String()
858}
859
860func (a *Agent) TotalUsage() ant.CumulativeUsage {
861 a.mu.Lock()
862 defer a.mu.Unlock()
863 return a.convo.CumulativeUsage()
864}
865
866// WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
867func (a *Agent) WaitForMessageCount(ctx context.Context, greaterThan int) {
868 for a.MessageCount() <= greaterThan {
869 a.mu.Lock()
870 ch := make(chan struct{})
871 // Deletion happens when we notify.
872 a.listeners = append(a.listeners, ch)
873 a.mu.Unlock()
874
875 select {
876 case <-ctx.Done():
877 return
878 case <-ch:
879 continue
880 }
881 }
882}
883
884// Diff returns a unified diff of changes made since the agent was instantiated.
885func (a *Agent) Diff(commit *string) (string, error) {
886 if a.initialCommit == "" {
887 return "", fmt.Errorf("no initial commit reference available")
888 }
889
890 // Find the repository root
891 ctx := context.Background()
892
893 // If a specific commit hash is provided, show just that commit's changes
894 if commit != nil && *commit != "" {
895 // Validate that the commit looks like a valid git SHA
896 if !isValidGitSHA(*commit) {
897 return "", fmt.Errorf("invalid git commit SHA format: %s", *commit)
898 }
899
900 // Get the diff for just this commit
901 cmd := exec.CommandContext(ctx, "git", "show", "--unified=10", *commit)
902 cmd.Dir = a.repoRoot
903 output, err := cmd.CombinedOutput()
904 if err != nil {
905 return "", fmt.Errorf("failed to get diff for commit %s: %w - %s", *commit, err, string(output))
906 }
907 return string(output), nil
908 }
909
910 // Otherwise, get the diff between the initial commit and the current state using exec.Command
911 cmd := exec.CommandContext(ctx, "git", "diff", "--unified=10", a.initialCommit)
912 cmd.Dir = a.repoRoot
913 output, err := cmd.CombinedOutput()
914 if err != nil {
915 return "", fmt.Errorf("failed to get diff: %w - %s", err, string(output))
916 }
917
918 return string(output), nil
919}
920
921// InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
922func (a *Agent) InitialCommit() string {
923 return a.initialCommit
924}
925
926// handleGitCommits() highlights new commits to the user. When running
927// under docker, new HEADs are pushed to a branch according to the title.
928func (a *Agent) handleGitCommits(ctx context.Context) ([]*GitCommit, error) {
929 if a.repoRoot == "" {
930 return nil, nil
931 }
932
933 head, err := resolveRef(ctx, a.repoRoot, "HEAD")
934 if err != nil {
935 return nil, err
936 }
937 if head == a.lastHEAD {
938 return nil, nil // nothing to do
939 }
940 defer func() {
941 a.lastHEAD = head
942 }()
943
944 // Get new commits. Because it's possible that the agent does rebases, fixups, and
945 // so forth, we use, as our fixed point, the "initialCommit", and we limit ourselves
946 // to the last 100 commits.
947 var commits []*GitCommit
948
949 // Get commits since the initial commit
950 // Format: <hash>\0<subject>\0<body>\0
951 // This uses NULL bytes as separators to avoid issues with newlines in commit messages
952 // Limit to 100 commits to avoid overwhelming the user
953 cmd := exec.CommandContext(ctx, "git", "log", "-n", "100", "--pretty=format:%H%x00%s%x00%b%x00", "^"+a.initialCommit, head)
954 cmd.Dir = a.repoRoot
955 output, err := cmd.Output()
956 if err != nil {
957 return nil, fmt.Errorf("failed to get git log: %w", err)
958 }
959
960 // Parse git log output and filter out already seen commits
961 parsedCommits := parseGitLog(string(output))
962
963 var headCommit *GitCommit
964
965 // Filter out commits we've already seen
966 for _, commit := range parsedCommits {
967 if commit.Hash == head {
968 headCommit = &commit
969 }
970
971 // Skip if we've seen this commit before. If our head has changed, always include that.
972 if a.seenCommits[commit.Hash] && commit.Hash != head {
973 continue
974 }
975
976 // Mark this commit as seen
977 a.seenCommits[commit.Hash] = true
978
979 // Add to our list of new commits
980 commits = append(commits, &commit)
981 }
982
983 if a.gitRemoteAddr != "" {
984 if headCommit == nil {
985 // I think this can only happen if we have a bug or if there's a race.
986 headCommit = &GitCommit{}
987 headCommit.Hash = head
988 headCommit.Subject = "unknown"
989 commits = append(commits, headCommit)
990 }
991
992 cleanTitle := titleToBranch(a.title)
993 if cleanTitle == "" {
994 cleanTitle = a.config.SessionID
995 }
996 branch := "sketch/" + cleanTitle
997
998 // TODO: I don't love the force push here. We could see if the push is a fast-forward, and,
999 // if it's not, we could make a backup with a unique name (perhaps append a timestamp) and
1000 // then use push with lease to replace.
1001 cmd = exec.Command("git", "push", "--force", a.gitRemoteAddr, "HEAD:refs/heads/"+branch)
1002 cmd.Dir = a.workingDir
1003 if out, err := cmd.CombinedOutput(); err != nil {
1004 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("git push to host: %s: %v", out, err)))
1005 } else {
1006 headCommit.PushedBranch = branch
1007 }
1008 }
1009
1010 // If we found new commits, create a message
1011 if len(commits) > 0 {
1012 msg := AgentMessage{
1013 Type: CommitMessageType,
1014 Timestamp: time.Now(),
1015 Commits: commits,
1016 }
1017 a.pushToOutbox(ctx, msg)
1018 }
1019 return commits, nil
1020}
1021
1022func titleToBranch(s string) string {
1023 // Convert to lowercase
1024 s = strings.ToLower(s)
1025
1026 // Replace spaces with hyphens
1027 s = strings.ReplaceAll(s, " ", "-")
1028
1029 // Remove any character that isn't a-z or hyphen
1030 var result strings.Builder
1031 for _, r := range s {
1032 if (r >= 'a' && r <= 'z') || r == '-' {
1033 result.WriteRune(r)
1034 }
1035 }
1036 return result.String()
1037}
1038
1039// parseGitLog parses the output of git log with format '%H%x00%s%x00%b%x00'
1040// and returns an array of GitCommit structs.
1041func parseGitLog(output string) []GitCommit {
1042 var commits []GitCommit
1043
1044 // No output means no commits
1045 if len(output) == 0 {
1046 return commits
1047 }
1048
1049 // Split by NULL byte
1050 parts := strings.Split(output, "\x00")
1051
1052 // Process in triplets (hash, subject, body)
1053 for i := 0; i < len(parts); i++ {
1054 // Skip empty parts
1055 if parts[i] == "" {
1056 continue
1057 }
1058
1059 // This should be a hash
1060 hash := strings.TrimSpace(parts[i])
1061
1062 // Make sure we have at least a subject part available
1063 if i+1 >= len(parts) {
1064 break // No more parts available
1065 }
1066
1067 // Get the subject
1068 subject := strings.TrimSpace(parts[i+1])
1069
1070 // Get the body if available
1071 body := ""
1072 if i+2 < len(parts) {
1073 body = strings.TrimSpace(parts[i+2])
1074 }
1075
1076 // Skip to the next triplet
1077 i += 2
1078
1079 commits = append(commits, GitCommit{
1080 Hash: hash,
1081 Subject: subject,
1082 Body: body,
1083 })
1084 }
1085
1086 return commits
1087}
1088
1089func repoRoot(ctx context.Context, dir string) (string, error) {
1090 cmd := exec.CommandContext(ctx, "git", "rev-parse", "--show-toplevel")
1091 stderr := new(strings.Builder)
1092 cmd.Stderr = stderr
1093 cmd.Dir = dir
1094 out, err := cmd.Output()
1095 if err != nil {
1096 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1097 }
1098 return strings.TrimSpace(string(out)), nil
1099}
1100
1101func resolveRef(ctx context.Context, dir, refName string) (string, error) {
1102 cmd := exec.CommandContext(ctx, "git", "rev-parse", refName)
1103 stderr := new(strings.Builder)
1104 cmd.Stderr = stderr
1105 cmd.Dir = dir
1106 out, err := cmd.Output()
1107 if err != nil {
1108 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1109 }
1110 // TODO: validate that out is valid hex
1111 return strings.TrimSpace(string(out)), nil
1112}
1113
1114// isValidGitSHA validates if a string looks like a valid git SHA hash.
1115// Git SHAs are hexadecimal strings of at least 4 characters but typically 7, 8, or 40 characters.
1116func isValidGitSHA(sha string) bool {
1117 // Git SHA must be a hexadecimal string with at least 4 characters
1118 if len(sha) < 4 || len(sha) > 40 {
1119 return false
1120 }
1121
1122 // Check if the string only contains hexadecimal characters
1123 for _, char := range sha {
1124 if !(char >= '0' && char <= '9') && !(char >= 'a' && char <= 'f') && !(char >= 'A' && char <= 'F') {
1125 return false
1126 }
1127 }
1128
1129 return true
1130}