blob: ce362e6717615c0ddc957ef5adc8c4e6b5210482 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "os"
10 "os/exec"
11 "runtime/debug"
12 "slices"
13 "strings"
14 "sync"
15 "time"
16
17 "sketch.dev/ant"
18 "sketch.dev/claudetool"
19)
20
21const (
22 userCancelMessage = "user requested agent to stop handling responses"
23)
24
25type CodingAgent interface {
26 // Init initializes an agent inside a docker container.
27 Init(AgentInit) error
28
29 // Ready returns a channel closed after Init successfully called.
30 Ready() <-chan struct{}
31
32 // URL reports the HTTP URL of this agent.
33 URL() string
34
35 // UserMessage enqueues a message to the agent and returns immediately.
36 UserMessage(ctx context.Context, msg string)
37
38 // WaitForMessage blocks until the agent has a response to give.
39 // Use AgentMessage.EndOfTurn to help determine if you want to
40 // drain the agent.
41 WaitForMessage(ctx context.Context) AgentMessage
42
43 // Loop begins the agent loop returns only when ctx is cancelled.
44 Loop(ctx context.Context)
45
46 CancelInnerLoop(cause error)
47
48 CancelToolUse(toolUseID string, cause error) error
49
50 // Returns a subset of the agent's message history.
51 Messages(start int, end int) []AgentMessage
52
53 // Returns the current number of messages in the history
54 MessageCount() int
55
56 TotalUsage() ant.CumulativeUsage
57 OriginalBudget() ant.Budget
58
59 // WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
60 WaitForMessageCount(ctx context.Context, greaterThan int)
61
62 WorkingDir() string
63
64 // Diff returns a unified diff of changes made since the agent was instantiated.
65 // If commit is non-nil, it shows the diff for just that specific commit.
66 Diff(commit *string) (string, error)
67
68 // InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
69 InitialCommit() string
70
71 // Title returns the current title of the conversation.
72 Title() string
73
74 // OS returns the operating system of the client.
75 OS() string
76}
77
78type CodingAgentMessageType string
79
80const (
81 UserMessageType CodingAgentMessageType = "user"
82 AgentMessageType CodingAgentMessageType = "agent"
83 ErrorMessageType CodingAgentMessageType = "error"
84 BudgetMessageType CodingAgentMessageType = "budget" // dedicated for "out of budget" errors
85 ToolUseMessageType CodingAgentMessageType = "tool"
86 CommitMessageType CodingAgentMessageType = "commit" // for displaying git commits
87 AutoMessageType CodingAgentMessageType = "auto" // for automated notifications like autoformatting
88
89 cancelToolUseMessage = "Stop responding to my previous message. Wait for me to ask you something else before attempting to use any more tools."
90)
91
92type AgentMessage struct {
93 Type CodingAgentMessageType `json:"type"`
94 // EndOfTurn indicates that the AI is done working and is ready for the next user input.
95 EndOfTurn bool `json:"end_of_turn"`
96
97 Content string `json:"content"`
98 ToolName string `json:"tool_name,omitempty"`
99 ToolInput string `json:"input,omitempty"`
100 ToolResult string `json:"tool_result,omitempty"`
101 ToolError bool `json:"tool_error,omitempty"`
102 ToolCallId string `json:"tool_call_id,omitempty"`
103
104 // ToolCalls is a list of all tool calls requested in this message (name and input pairs)
105 ToolCalls []ToolCall `json:"tool_calls,omitempty"`
106
107 // Commits is a list of git commits for a commit message
108 Commits []*GitCommit `json:"commits,omitempty"`
109
110 Timestamp time.Time `json:"timestamp"`
111 ConversationID string `json:"conversation_id"`
112 ParentConversationID *string `json:"parent_conversation_id,omitempty"`
113 Usage *ant.Usage `json:"usage,omitempty"`
114
115 // Message timing information
116 StartTime *time.Time `json:"start_time,omitempty"`
117 EndTime *time.Time `json:"end_time,omitempty"`
118 Elapsed *time.Duration `json:"elapsed,omitempty"`
119
120 // Turn duration - the time taken for a complete agent turn
121 TurnDuration *time.Duration `json:"turnDuration,omitempty"`
122
123 Idx int `json:"idx"`
124}
125
126// GitCommit represents a single git commit for a commit message
127type GitCommit struct {
128 Hash string `json:"hash"` // Full commit hash
129 Subject string `json:"subject"` // Commit subject line
130 Body string `json:"body"` // Full commit message body
131 PushedBranch string `json:"pushed_branch,omitempty"` // If set, this commit was pushed to this branch
132}
133
134// ToolCall represents a single tool call within an agent message
135type ToolCall struct {
136 Name string `json:"name"`
137 Input string `json:"input"`
138 ToolCallId string `json:"tool_call_id"`
139}
140
141func (a *AgentMessage) Attr() slog.Attr {
142 var attrs []any = []any{
143 slog.String("type", string(a.Type)),
144 }
145 if a.EndOfTurn {
146 attrs = append(attrs, slog.Bool("end_of_turn", a.EndOfTurn))
147 }
148 if a.Content != "" {
149 attrs = append(attrs, slog.String("content", a.Content))
150 }
151 if a.ToolName != "" {
152 attrs = append(attrs, slog.String("tool_name", a.ToolName))
153 }
154 if a.ToolInput != "" {
155 attrs = append(attrs, slog.String("tool_input", a.ToolInput))
156 }
157 if a.Elapsed != nil {
158 attrs = append(attrs, slog.Int64("elapsed", a.Elapsed.Nanoseconds()))
159 }
160 if a.TurnDuration != nil {
161 attrs = append(attrs, slog.Int64("turnDuration", a.TurnDuration.Nanoseconds()))
162 }
163 if a.ToolResult != "" {
164 attrs = append(attrs, slog.String("tool_result", a.ToolResult))
165 }
166 if a.ToolError {
167 attrs = append(attrs, slog.Bool("tool_error", a.ToolError))
168 }
169 if len(a.ToolCalls) > 0 {
170 toolCallAttrs := make([]any, 0, len(a.ToolCalls))
171 for i, tc := range a.ToolCalls {
172 toolCallAttrs = append(toolCallAttrs, slog.Group(
173 fmt.Sprintf("tool_call_%d", i),
174 slog.String("name", tc.Name),
175 slog.String("input", tc.Input),
176 ))
177 }
178 attrs = append(attrs, slog.Group("tool_calls", toolCallAttrs...))
179 }
180 if a.ConversationID != "" {
181 attrs = append(attrs, slog.String("convo_id", a.ConversationID))
182 }
183 if a.ParentConversationID != nil {
184 attrs = append(attrs, slog.String("parent_convo_id", *a.ParentConversationID))
185 }
186 if a.Usage != nil && !a.Usage.IsZero() {
187 attrs = append(attrs, a.Usage.Attr())
188 }
189 // TODO: timestamp, convo ids, idx?
190 return slog.Group("agent_message", attrs...)
191}
192
193func errorMessage(err error) AgentMessage {
194 // It's somewhat unknowable whether error messages are "end of turn" or not, but it seems like the best approach.
195 if os.Getenv(("DEBUG")) == "1" {
196 return AgentMessage{Type: ErrorMessageType, Content: err.Error() + " Stacktrace: " + string(debug.Stack()), EndOfTurn: true}
197 }
198
199 return AgentMessage{Type: ErrorMessageType, Content: err.Error(), EndOfTurn: true}
200}
201
202func budgetMessage(err error) AgentMessage {
203 return AgentMessage{Type: BudgetMessageType, Content: err.Error(), EndOfTurn: true}
204}
205
206// ConvoInterface defines the interface for conversation interactions
207type ConvoInterface interface {
208 CumulativeUsage() ant.CumulativeUsage
209 ResetBudget(ant.Budget)
210 OverBudget() error
211 SendMessage(message ant.Message) (*ant.MessageResponse, error)
212 SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
213 ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
214 ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error)
215 CancelToolUse(toolUseID string, cause error) error
216}
217
218type Agent struct {
219 convo ConvoInterface
220 config AgentConfig // config for this agent
221 workingDir string
222 repoRoot string // workingDir may be a subdir of repoRoot
223 url string
224 lastHEAD string // hash of the last HEAD that was pushed to the host (only when under docker)
225 initialCommit string // hash of the Git HEAD when the agent was instantiated or Init()
226 gitRemoteAddr string // HTTP URL of the host git repo (only when under docker)
227 ready chan struct{} // closed when the agent is initialized (only when under docker)
228 startedAt time.Time
229 originalBudget ant.Budget
230 title string
231 codereview *claudetool.CodeReviewer
232
233 // Time when the current turn started (reset at the beginning of InnerLoop)
234 startOfTurn time.Time
235
236 // Inbox - for messages from the user to the agent.
237 // sent on by UserMessage
238 // . e.g. when user types into the chat textarea
239 // read from by GatherMessages
240 inbox chan string
241
242 // Outbox
243 // sent on by pushToOutbox
244 // via OnToolResult and OnResponse callbacks
245 // read from by WaitForMessage
246 // called by termui inside its repl loop.
247 outbox chan AgentMessage
248
249 // protects cancelInnerLoop
250 cancelInnerLoopMu sync.Mutex
251 // cancels potentially long-running tool_use calls or chains of them
252 cancelInnerLoop context.CancelCauseFunc
253
254 // protects following
255 mu sync.Mutex
256
257 // Stores all messages for this agent
258 history []AgentMessage
259
260 listeners []chan struct{}
261
262 // Track git commits we've already seen (by hash)
263 seenCommits map[string]bool
264}
265
266func (a *Agent) URL() string { return a.url }
267
268// Title returns the current title of the conversation.
269// If no title has been set, returns an empty string.
270func (a *Agent) Title() string {
271 a.mu.Lock()
272 defer a.mu.Unlock()
273 return a.title
274}
275
276// OS returns the operating system of the client.
277func (a *Agent) OS() string {
278 return a.config.ClientGOOS
279}
280
281// SetTitle sets the title of the conversation.
282func (a *Agent) SetTitle(title string) {
283 a.mu.Lock()
284 defer a.mu.Unlock()
285 a.title = title
286 // Notify all listeners that the state has changed
287 for _, ch := range a.listeners {
288 close(ch)
289 }
290 a.listeners = a.listeners[:0]
291}
292
293// OnToolResult implements ant.Listener.
294func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
295 m := AgentMessage{
296 Type: ToolUseMessageType,
297 Content: content.Text,
298 ToolResult: content.ToolResult,
299 ToolError: content.ToolError,
300 ToolName: toolName,
301 ToolInput: string(toolInput),
302 ToolCallId: content.ToolUseID,
303 StartTime: content.StartTime,
304 EndTime: content.EndTime,
305 }
306
307 // Calculate the elapsed time if both start and end times are set
308 if content.StartTime != nil && content.EndTime != nil {
309 elapsed := content.EndTime.Sub(*content.StartTime)
310 m.Elapsed = &elapsed
311 }
312
313 m.ConversationID = convo.ID
314 if convo.Parent != nil {
315 m.ParentConversationID = &convo.Parent.ID
316 }
317 a.pushToOutbox(ctx, m)
318}
319
320// OnRequest implements ant.Listener.
321func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, msg *ant.Message) {
322 // No-op.
323 // We already get tool results from the above. We send user messages to the outbox in the agent loop.
324}
325
326// OnResponse implements ant.Listener. Responses contain messages from the LLM
327// that need to be displayed (as well as tool calls that we send along when
328// they're done). (It would be reasonable to also mention tool calls when they're
329// started, but we don't do that yet.)
330func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, resp *ant.MessageResponse) {
331 endOfTurn := false
332 if resp.StopReason != ant.StopReasonToolUse {
333 endOfTurn = true
334 }
335 m := AgentMessage{
336 Type: AgentMessageType,
337 Content: collectTextContent(resp),
338 EndOfTurn: endOfTurn,
339 Usage: &resp.Usage,
340 StartTime: resp.StartTime,
341 EndTime: resp.EndTime,
342 }
343
344 // Extract any tool calls from the response
345 if resp.StopReason == ant.StopReasonToolUse {
346 var toolCalls []ToolCall
347 for _, part := range resp.Content {
348 if part.Type == "tool_use" {
349 toolCalls = append(toolCalls, ToolCall{
350 Name: part.ToolName,
351 Input: string(part.ToolInput),
352 ToolCallId: part.ID,
353 })
354 }
355 }
356 m.ToolCalls = toolCalls
357 }
358
359 // Calculate the elapsed time if both start and end times are set
360 if resp.StartTime != nil && resp.EndTime != nil {
361 elapsed := resp.EndTime.Sub(*resp.StartTime)
362 m.Elapsed = &elapsed
363 }
364
365 m.ConversationID = convo.ID
366 if convo.Parent != nil {
367 m.ParentConversationID = &convo.Parent.ID
368 }
369 a.pushToOutbox(ctx, m)
370}
371
372// WorkingDir implements CodingAgent.
373func (a *Agent) WorkingDir() string {
374 return a.workingDir
375}
376
377// MessageCount implements CodingAgent.
378func (a *Agent) MessageCount() int {
379 a.mu.Lock()
380 defer a.mu.Unlock()
381 return len(a.history)
382}
383
384// Messages implements CodingAgent.
385func (a *Agent) Messages(start int, end int) []AgentMessage {
386 a.mu.Lock()
387 defer a.mu.Unlock()
388 return slices.Clone(a.history[start:end])
389}
390
391func (a *Agent) OriginalBudget() ant.Budget {
392 return a.originalBudget
393}
394
395// AgentConfig contains configuration for creating a new Agent.
396type AgentConfig struct {
397 Context context.Context
398 AntURL string
399 APIKey string
400 HTTPC *http.Client
401 Budget ant.Budget
402 GitUsername string
403 GitEmail string
404 SessionID string
405 ClientGOOS string
406 ClientGOARCH string
407 UseAnthropicEdit bool
408}
409
410// NewAgent creates a new Agent.
411// It is not usable until Init() is called.
412func NewAgent(config AgentConfig) *Agent {
413 agent := &Agent{
414 config: config,
415 ready: make(chan struct{}),
416 inbox: make(chan string, 100),
417 outbox: make(chan AgentMessage, 100),
418 startedAt: time.Now(),
419 originalBudget: config.Budget,
420 seenCommits: make(map[string]bool),
421 }
422 return agent
423}
424
425type AgentInit struct {
426 WorkingDir string
427 NoGit bool // only for testing
428
429 InDocker bool
430 Commit string
431 GitRemoteAddr string
432 HostAddr string
433}
434
435func (a *Agent) Init(ini AgentInit) error {
436 ctx := a.config.Context
437 if ini.InDocker {
438 cmd := exec.CommandContext(ctx, "git", "stash")
439 cmd.Dir = ini.WorkingDir
440 if out, err := cmd.CombinedOutput(); err != nil {
441 return fmt.Errorf("git stash: %s: %v", out, err)
442 }
443 cmd = exec.CommandContext(ctx, "git", "fetch", ini.GitRemoteAddr)
444 cmd.Dir = ini.WorkingDir
445 if out, err := cmd.CombinedOutput(); err != nil {
446 return fmt.Errorf("git fetch: %s: %w", out, err)
447 }
448 cmd = exec.CommandContext(ctx, "git", "checkout", "-f", ini.Commit)
449 cmd.Dir = ini.WorkingDir
450 if out, err := cmd.CombinedOutput(); err != nil {
451 return fmt.Errorf("git checkout %s: %s: %w", ini.Commit, out, err)
452 }
453 a.lastHEAD = ini.Commit
454 a.gitRemoteAddr = ini.GitRemoteAddr
455 a.initialCommit = ini.Commit
456 if ini.HostAddr != "" {
457 a.url = "http://" + ini.HostAddr
458 }
459 }
460 a.workingDir = ini.WorkingDir
461
462 if !ini.NoGit {
463 repoRoot, err := repoRoot(ctx, a.workingDir)
464 if err != nil {
465 return fmt.Errorf("repoRoot: %w", err)
466 }
467 a.repoRoot = repoRoot
468
469 commitHash, err := resolveRef(ctx, a.repoRoot, "HEAD")
470 if err != nil {
471 return fmt.Errorf("resolveRef: %w", err)
472 }
473 a.initialCommit = commitHash
474
475 codereview, err := claudetool.NewCodeReviewer(ctx, a.repoRoot, a.initialCommit)
476 if err != nil {
477 return fmt.Errorf("Agent.Init: claudetool.NewCodeReviewer: %w", err)
478 }
479 a.codereview = codereview
480 }
481 a.lastHEAD = a.initialCommit
482 a.convo = a.initConvo()
483 close(a.ready)
484 return nil
485}
486
487// initConvo initializes the conversation.
488// It must not be called until all agent fields are initialized,
489// particularly workingDir and git.
490func (a *Agent) initConvo() *ant.Convo {
491 ctx := a.config.Context
492 convo := ant.NewConvo(ctx, a.config.APIKey)
493 if a.config.HTTPC != nil {
494 convo.HTTPC = a.config.HTTPC
495 }
496 if a.config.AntURL != "" {
497 convo.URL = a.config.AntURL
498 }
499 convo.PromptCaching = true
500 convo.Budget = a.config.Budget
501
502 var editPrompt string
503 if a.config.UseAnthropicEdit {
504 editPrompt = "Then use the str_replace_editor tool to make those edits. For short complete file replacements, you may use the bash tool with cat and heredoc stdin."
505 } else {
506 editPrompt = "Then use the patch tool to make those edits. Combine all edits to any given file into a single patch tool call."
507 }
508
509 convo.SystemPrompt = fmt.Sprintf(`
510You are an expert coding assistant and architect, with a specialty in Go.
511You are assisting the user to achieve their goals.
512
513Start by asking concise clarifying questions as needed.
514Once the intent is clear, work autonomously.
515
516Call the title tool early in the conversation to provide a brief summary of
517what the chat is about.
518
519Break down the overall goal into a series of smaller steps.
520(The first step is often: "Make a plan.")
521Then execute each step using tools.
522Update the plan if you have encountered problems or learned new information.
523
524When in doubt about a step, follow this broad workflow:
525
526- Think about how the current step fits into the overall plan.
527- Do research. Good tool choices: bash, think, keyword_search
528- Make edits.
529- Repeat.
530
531To make edits reliably and efficiently, first think about the intent of the edit,
532and what set of patches will achieve that intent.
533%s
534
535For renames or refactors, consider invoking gopls (via bash).
536
537The done tool provides a checklist of items you MUST verify and
538review before declaring that you are done. Before executing
539the done tool, run all the tools the done tool checklist asks
540for, including creating a git commit. Do not forget to run tests.
541
542<platform>
543%s/%s
544</platform>
545<pwd>
546%v
547</pwd>
548<git_root>
549%v
550</git_root>
551`, editPrompt, a.config.ClientGOOS, a.config.ClientGOARCH, a.workingDir, a.repoRoot)
552
553 // Register all tools with the conversation
554 // When adding, removing, or modifying tools here, double-check that the termui tool display
555 // template in termui/termui.go has pretty-printing support for all tools.
556 convo.Tools = []*ant.Tool{
557 claudetool.Bash, claudetool.Keyword,
558 claudetool.Think, a.titleTool(), makeDoneTool(a.codereview, a.config.GitUsername, a.config.GitEmail),
559 a.codereview.Tool(),
560 }
561 if a.config.UseAnthropicEdit {
562 convo.Tools = append(convo.Tools, claudetool.AnthropicEditTool)
563 } else {
564 convo.Tools = append(convo.Tools, claudetool.Patch)
565 }
566 convo.Listener = a
567 return convo
568}
569
570func (a *Agent) titleTool() *ant.Tool {
571 // titleTool creates the title tool that sets the conversation title.
572 title := &ant.Tool{
573 Name: "title",
574 Description: `Use this tool early in the conversation, BEFORE MAKING ANY GIT COMMITS, to summarize what the chat is about briefly.`,
575 InputSchema: json.RawMessage(`{
576 "type": "object",
577 "properties": {
578 "title": {
579 "type": "string",
580 "description": "A brief title summarizing what this chat is about"
581 }
582 },
583 "required": ["title"]
584}`),
585 Run: func(ctx context.Context, input json.RawMessage) (string, error) {
586 var params struct {
587 Title string `json:"title"`
588 }
589 if err := json.Unmarshal(input, &params); err != nil {
590 return "", err
591 }
592 a.SetTitle(params.Title)
593 return fmt.Sprintf("Title set to: %s", params.Title), nil
594 },
595 }
596 return title
597}
598
599func (a *Agent) Ready() <-chan struct{} {
600 return a.ready
601}
602
603func (a *Agent) UserMessage(ctx context.Context, msg string) {
604 a.pushToOutbox(ctx, AgentMessage{Type: UserMessageType, Content: msg})
605 a.inbox <- msg
606}
607
608func (a *Agent) WaitForMessage(ctx context.Context) AgentMessage {
609 // TODO: Should this drain any outbox messages in case there are multiple?
610 select {
611 case msg := <-a.outbox:
612 return msg
613 case <-ctx.Done():
614 return errorMessage(ctx.Err())
615 }
616}
617
618func (a *Agent) CancelToolUse(toolUseID string, cause error) error {
619 return a.convo.CancelToolUse(toolUseID, cause)
620}
621
622func (a *Agent) CancelInnerLoop(cause error) {
623 a.cancelInnerLoopMu.Lock()
624 defer a.cancelInnerLoopMu.Unlock()
625 if a.cancelInnerLoop != nil {
626 a.cancelInnerLoop(cause)
627 }
628}
629
630func (a *Agent) Loop(ctxOuter context.Context) {
631 for {
632 select {
633 case <-ctxOuter.Done():
634 return
635 default:
636 ctxInner, cancel := context.WithCancelCause(ctxOuter)
637 a.cancelInnerLoopMu.Lock()
638 // Set .cancelInnerLoop so the user can cancel whatever is happening
639 // inside InnerLoop(ctxInner) without canceling this outer Loop execution.
640 // This CancelInnerLoop func is intended be called from other goroutines,
641 // hence the mutex.
642 a.cancelInnerLoop = cancel
643 a.cancelInnerLoopMu.Unlock()
644 a.InnerLoop(ctxInner)
645 cancel(nil)
646 }
647 }
648}
649
650func (a *Agent) pushToOutbox(ctx context.Context, m AgentMessage) {
651 if m.Timestamp.IsZero() {
652 m.Timestamp = time.Now()
653 }
654
655 // If this is an end-of-turn message, calculate the turn duration and add it to the message
656 if m.EndOfTurn && m.Type == AgentMessageType {
657 turnDuration := time.Since(a.startOfTurn)
658 m.TurnDuration = &turnDuration
659 slog.InfoContext(ctx, "Turn completed", "turnDuration", turnDuration)
660 }
661
662 slog.InfoContext(ctx, "agent message", m.Attr())
663
664 a.mu.Lock()
665 defer a.mu.Unlock()
666 m.Idx = len(a.history)
667 a.history = append(a.history, m)
668 a.outbox <- m
669
670 // Notify all listeners:
671 for _, ch := range a.listeners {
672 close(ch)
673 }
674 a.listeners = a.listeners[:0]
675}
676
677func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
678 var m []ant.Content
679 if block {
680 select {
681 case <-ctx.Done():
682 return m, ctx.Err()
683 case msg := <-a.inbox:
684 m = append(m, ant.Content{Type: "text", Text: msg})
685 }
686 }
687 for {
688 select {
689 case msg := <-a.inbox:
690 m = append(m, ant.Content{Type: "text", Text: msg})
691 default:
692 return m, nil
693 }
694 }
695}
696
697func (a *Agent) InnerLoop(ctx context.Context) {
698 // Reset the start of turn time
699 a.startOfTurn = time.Now()
700
701 // Wait for at least one message from the user.
702 msgs, err := a.GatherMessages(ctx, true)
703 if err != nil { // e.g. the context was canceled while blocking in GatherMessages
704 return
705 }
706 // We do this as we go, but let's also do it at the end of the turn
707 defer func() {
708 if _, err := a.handleGitCommits(ctx); err != nil {
709 // Just log the error, don't stop execution
710 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
711 }
712 }()
713
714 userMessage := ant.Message{
715 Role: "user",
716 Content: msgs,
717 }
718 // convo.SendMessage does the actual network call to send this to anthropic. This blocks until the response is ready.
719 // TODO: pass ctx to SendMessage, and figure out how to square that ctx with convo's own .Ctx. Who owns the scope of this call?
720 resp, err := a.convo.SendMessage(userMessage)
721 if err != nil {
722 a.pushToOutbox(ctx, errorMessage(err))
723 return
724 }
725 for {
726 // TODO: here and below where we check the budget,
727 // we should review the UX: is it clear what happened?
728 // is it clear how to resume?
729 // should we let the user set a new budget?
730 if err := a.overBudget(ctx); err != nil {
731 return
732 }
733 if resp.StopReason != ant.StopReasonToolUse {
734 break
735 }
736 var results []ant.Content
737 cancelled := false
738 select {
739 case <-ctx.Done():
740 // Don't actually run any of the tools, but rather build a response
741 // for each tool_use message letting the LLM know that user canceled it.
742 results, err = a.convo.ToolResultCancelContents(resp)
743 if err != nil {
744 a.pushToOutbox(ctx, errorMessage(err))
745 }
746 cancelled = true
747 default:
748 ctx = claudetool.WithWorkingDir(ctx, a.workingDir)
749 // fall-through, when the user has not canceled the inner loop:
750 results, err = a.convo.ToolResultContents(ctx, resp)
751 if ctx.Err() != nil { // e.g. the user canceled the operation
752 cancelled = true
753 } else if err != nil {
754 a.pushToOutbox(ctx, errorMessage(err))
755 }
756 }
757
758 // Check for git commits. Currently we do this here, after we collect
759 // tool results, since that's when we know commits could have happened.
760 // We could instead do this when the turn ends, but I think it makes sense
761 // to do this as we go.
762 newCommits, err := a.handleGitCommits(ctx)
763 if err != nil {
764 // Just log the error, don't stop execution
765 slog.WarnContext(ctx, "Failed to check for new git commits", "error", err)
766 }
767 var autoqualityMessages []string
768 if len(newCommits) == 1 {
769 formatted := a.codereview.Autoformat(ctx)
770 if len(formatted) > 0 {
771 msg := fmt.Sprintf(`
772I ran autoformatters and they updated these files:
773
774%s
775
776Please amend your latest git commit with these changes and then continue with what you were doing.`,
777 strings.Join(formatted, "\n"),
778 )[1:]
779 a.pushToOutbox(ctx, AgentMessage{
780 Type: AutoMessageType,
781 Content: msg,
782 Timestamp: time.Now(),
783 })
784 autoqualityMessages = append(autoqualityMessages, msg)
785 }
786 }
787
788 if err := a.overBudget(ctx); err != nil {
789 return
790 }
791
792 // Include, along with the tool results (which must go first for whatever reason),
793 // any messages that the user has sent along while the tool_use was executing concurrently.
794 msgs, err = a.GatherMessages(ctx, false)
795 if err != nil {
796 return
797 }
798 // Inject any auto-generated messages from quality checks.
799 for _, msg := range autoqualityMessages {
800 msgs = append(msgs, ant.Content{Type: "text", Text: msg})
801 }
802 if cancelled {
803 msgs = append(msgs, ant.Content{Type: "text", Text: cancelToolUseMessage})
804 // EndOfTurn is false here so that the client of this agent keeps processing
805 // messages from WaitForMessage() and gets the response from the LLM (usually
806 // something like "okay, I'll wait further instructions", but the user should
807 // be made aware of it regardless).
808 a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
809 } else if err := a.convo.OverBudget(); err != nil {
810 budgetMsg := "We've exceeded our budget. Please ask the user to confirm before continuing by ending the turn."
811 msgs = append(msgs, ant.Content{Type: "text", Text: budgetMsg})
812 a.pushToOutbox(ctx, budgetMessage(fmt.Errorf("warning: %w (ask to keep trying, if you'd like)", err)))
813 }
814 results = append(results, msgs...)
815 resp, err = a.convo.SendMessage(ant.Message{
816 Role: "user",
817 Content: results,
818 })
819 if err != nil {
820 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("error: failed to continue conversation: %s", err.Error())))
821 break
822 }
823 if cancelled {
824 return
825 }
826 }
827}
828
829func (a *Agent) overBudget(ctx context.Context) error {
830 if err := a.convo.OverBudget(); err != nil {
831 m := budgetMessage(err)
832 m.Content = m.Content + "\n\nBudget reset."
833 a.pushToOutbox(ctx, budgetMessage(err))
834 a.convo.ResetBudget(a.originalBudget)
835 return err
836 }
837 return nil
838}
839
840func collectTextContent(msg *ant.MessageResponse) string {
841 // Collect all text content
842 var allText strings.Builder
843 for _, content := range msg.Content {
844 if content.Type == "text" && content.Text != "" {
845 if allText.Len() > 0 {
846 allText.WriteString("\n\n")
847 }
848 allText.WriteString(content.Text)
849 }
850 }
851 return allText.String()
852}
853
854func (a *Agent) TotalUsage() ant.CumulativeUsage {
855 a.mu.Lock()
856 defer a.mu.Unlock()
857 return a.convo.CumulativeUsage()
858}
859
860// WaitForMessageCount returns when the agent has at more than clientMessageCount messages or the context is done.
861func (a *Agent) WaitForMessageCount(ctx context.Context, greaterThan int) {
862 for a.MessageCount() <= greaterThan {
863 a.mu.Lock()
864 ch := make(chan struct{})
865 // Deletion happens when we notify.
866 a.listeners = append(a.listeners, ch)
867 a.mu.Unlock()
868
869 select {
870 case <-ctx.Done():
871 return
872 case <-ch:
873 continue
874 }
875 }
876}
877
878// Diff returns a unified diff of changes made since the agent was instantiated.
879func (a *Agent) Diff(commit *string) (string, error) {
880 if a.initialCommit == "" {
881 return "", fmt.Errorf("no initial commit reference available")
882 }
883
884 // Find the repository root
885 ctx := context.Background()
886
887 // If a specific commit hash is provided, show just that commit's changes
888 if commit != nil && *commit != "" {
889 // Validate that the commit looks like a valid git SHA
890 if !isValidGitSHA(*commit) {
891 return "", fmt.Errorf("invalid git commit SHA format: %s", *commit)
892 }
893
894 // Get the diff for just this commit
895 cmd := exec.CommandContext(ctx, "git", "show", "--unified=10", *commit)
896 cmd.Dir = a.repoRoot
897 output, err := cmd.CombinedOutput()
898 if err != nil {
899 return "", fmt.Errorf("failed to get diff for commit %s: %w - %s", *commit, err, string(output))
900 }
901 return string(output), nil
902 }
903
904 // Otherwise, get the diff between the initial commit and the current state using exec.Command
905 cmd := exec.CommandContext(ctx, "git", "diff", "--unified=10", a.initialCommit)
906 cmd.Dir = a.repoRoot
907 output, err := cmd.CombinedOutput()
908 if err != nil {
909 return "", fmt.Errorf("failed to get diff: %w - %s", err, string(output))
910 }
911
912 return string(output), nil
913}
914
915// InitialCommit returns the Git commit hash that was saved when the agent was instantiated.
916func (a *Agent) InitialCommit() string {
917 return a.initialCommit
918}
919
920// handleGitCommits() highlights new commits to the user. When running
921// under docker, new HEADs are pushed to a branch according to the title.
922func (a *Agent) handleGitCommits(ctx context.Context) ([]*GitCommit, error) {
923 if a.repoRoot == "" {
924 return nil, nil
925 }
926
927 head, err := resolveRef(ctx, a.repoRoot, "HEAD")
928 if err != nil {
929 return nil, err
930 }
931 if head == a.lastHEAD {
932 return nil, nil // nothing to do
933 }
934 defer func() {
935 a.lastHEAD = head
936 }()
937
938 // Get new commits. Because it's possible that the agent does rebases, fixups, and
939 // so forth, we use, as our fixed point, the "initialCommit", and we limit ourselves
940 // to the last 100 commits.
941 var commits []*GitCommit
942
943 // Get commits since the initial commit
944 // Format: <hash>\0<subject>\0<body>\0
945 // This uses NULL bytes as separators to avoid issues with newlines in commit messages
946 // Limit to 100 commits to avoid overwhelming the user
947 cmd := exec.CommandContext(ctx, "git", "log", "-n", "100", "--pretty=format:%H%x00%s%x00%b%x00", "^"+a.initialCommit, head)
948 cmd.Dir = a.repoRoot
949 output, err := cmd.Output()
950 if err != nil {
951 return nil, fmt.Errorf("failed to get git log: %w", err)
952 }
953
954 // Parse git log output and filter out already seen commits
955 parsedCommits := parseGitLog(string(output))
956
957 var headCommit *GitCommit
958
959 // Filter out commits we've already seen
960 for _, commit := range parsedCommits {
961 if commit.Hash == head {
962 headCommit = &commit
963 }
964
965 // Skip if we've seen this commit before. If our head has changed, always include that.
966 if a.seenCommits[commit.Hash] && commit.Hash != head {
967 continue
968 }
969
970 // Mark this commit as seen
971 a.seenCommits[commit.Hash] = true
972
973 // Add to our list of new commits
974 commits = append(commits, &commit)
975 }
976
977 if a.gitRemoteAddr != "" {
978 if headCommit == nil {
979 // I think this can only happen if we have a bug or if there's a race.
980 headCommit = &GitCommit{}
981 headCommit.Hash = head
982 headCommit.Subject = "unknown"
983 commits = append(commits, headCommit)
984 }
985
986 cleanTitle := titleToBranch(a.title)
987 if cleanTitle == "" {
988 cleanTitle = a.config.SessionID
989 }
990 branch := "sketch/" + cleanTitle
991
992 // TODO: I don't love the force push here. We could see if the push is a fast-forward, and,
993 // if it's not, we could make a backup with a unique name (perhaps append a timestamp) and
994 // then use push with lease to replace.
995 cmd = exec.Command("git", "push", "--force", a.gitRemoteAddr, "HEAD:refs/heads/"+branch)
996 cmd.Dir = a.workingDir
997 if out, err := cmd.CombinedOutput(); err != nil {
998 a.pushToOutbox(ctx, errorMessage(fmt.Errorf("git push to host: %s: %v", out, err)))
999 } else {
1000 headCommit.PushedBranch = branch
1001 }
1002 }
1003
1004 // If we found new commits, create a message
1005 if len(commits) > 0 {
1006 msg := AgentMessage{
1007 Type: CommitMessageType,
1008 Timestamp: time.Now(),
1009 Commits: commits,
1010 }
1011 a.pushToOutbox(ctx, msg)
1012 }
1013 return commits, nil
1014}
1015
1016func titleToBranch(s string) string {
1017 // Convert to lowercase
1018 s = strings.ToLower(s)
1019
1020 // Replace spaces with hyphens
1021 s = strings.ReplaceAll(s, " ", "-")
1022
1023 // Remove any character that isn't a-z or hyphen
1024 var result strings.Builder
1025 for _, r := range s {
1026 if (r >= 'a' && r <= 'z') || r == '-' {
1027 result.WriteRune(r)
1028 }
1029 }
1030 return result.String()
1031}
1032
1033// parseGitLog parses the output of git log with format '%H%x00%s%x00%b%x00'
1034// and returns an array of GitCommit structs.
1035func parseGitLog(output string) []GitCommit {
1036 var commits []GitCommit
1037
1038 // No output means no commits
1039 if len(output) == 0 {
1040 return commits
1041 }
1042
1043 // Split by NULL byte
1044 parts := strings.Split(output, "\x00")
1045
1046 // Process in triplets (hash, subject, body)
1047 for i := 0; i < len(parts); i++ {
1048 // Skip empty parts
1049 if parts[i] == "" {
1050 continue
1051 }
1052
1053 // This should be a hash
1054 hash := strings.TrimSpace(parts[i])
1055
1056 // Make sure we have at least a subject part available
1057 if i+1 >= len(parts) {
1058 break // No more parts available
1059 }
1060
1061 // Get the subject
1062 subject := strings.TrimSpace(parts[i+1])
1063
1064 // Get the body if available
1065 body := ""
1066 if i+2 < len(parts) {
1067 body = strings.TrimSpace(parts[i+2])
1068 }
1069
1070 // Skip to the next triplet
1071 i += 2
1072
1073 commits = append(commits, GitCommit{
1074 Hash: hash,
1075 Subject: subject,
1076 Body: body,
1077 })
1078 }
1079
1080 return commits
1081}
1082
1083func repoRoot(ctx context.Context, dir string) (string, error) {
1084 cmd := exec.CommandContext(ctx, "git", "rev-parse", "--show-toplevel")
1085 stderr := new(strings.Builder)
1086 cmd.Stderr = stderr
1087 cmd.Dir = dir
1088 out, err := cmd.Output()
1089 if err != nil {
1090 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1091 }
1092 return strings.TrimSpace(string(out)), nil
1093}
1094
1095func resolveRef(ctx context.Context, dir, refName string) (string, error) {
1096 cmd := exec.CommandContext(ctx, "git", "rev-parse", refName)
1097 stderr := new(strings.Builder)
1098 cmd.Stderr = stderr
1099 cmd.Dir = dir
1100 out, err := cmd.Output()
1101 if err != nil {
1102 return "", fmt.Errorf("git rev-parse failed: %w\n%s", err, stderr)
1103 }
1104 // TODO: validate that out is valid hex
1105 return strings.TrimSpace(string(out)), nil
1106}
1107
1108// isValidGitSHA validates if a string looks like a valid git SHA hash.
1109// Git SHAs are hexadecimal strings of at least 4 characters but typically 7, 8, or 40 characters.
1110func isValidGitSHA(sha string) bool {
1111 // Git SHA must be a hexadecimal string with at least 4 characters
1112 if len(sha) < 4 || len(sha) > 40 {
1113 return false
1114 }
1115
1116 // Check if the string only contains hexadecimal characters
1117 for _, char := range sha {
1118 if !(char >= '0' && char <= '9') && !(char >= 'a' && char <= 'f') && !(char >= 'A' && char <= 'F') {
1119 return false
1120 }
1121 }
1122
1123 return true
1124}