blob: d1c64a46b846c98701338cc55cbb1b0125c0e10e [file] [log] [blame]
iomodob67a3762025-07-25 20:27:04 +04001package agent
2
iomodo76f9a2d2025-07-26 12:14:40 +04003import (
4 "context"
5 "fmt"
6 "log"
7 "os"
8 "path/filepath"
9 "strings"
10 "time"
11
12 "github.com/iomodo/staff/git"
13 "github.com/iomodo/staff/llm"
14 "github.com/iomodo/staff/tm"
15)
16
17// AgentConfig contains configuration for the agent
iomodob67a3762025-07-25 20:27:04 +040018type AgentConfig struct {
19 Name string
20 Role string
21 GitUsername string
22 GitEmail string
23 WorkingDir string
iomodo76f9a2d2025-07-26 12:14:40 +040024
25 // LLM Configuration
26 LLMProvider llm.Provider
27 LLMModel string
28 LLMConfig llm.Config
29
30 // System prompt for the agent
31 SystemPrompt string
32
33 // Task Manager Configuration
34 TaskManager tm.TaskManager
35
36 // Git Configuration
37 GitRepoPath string
38 GitRemote string
39 GitBranch string
iomodob67a3762025-07-25 20:27:04 +040040}
41
iomodo76f9a2d2025-07-26 12:14:40 +040042// Agent represents an AI agent that can process tasks
iomodob67a3762025-07-25 20:27:04 +040043type Agent struct {
iomodo76f9a2d2025-07-26 12:14:40 +040044 Config AgentConfig
45 llmProvider llm.LLMProvider
46 gitInterface git.GitInterface
47 ctx context.Context
48 cancel context.CancelFunc
iomodob67a3762025-07-25 20:27:04 +040049}
50
iomodo76f9a2d2025-07-26 12:14:40 +040051// NewAgent creates a new agent instance
52func NewAgent(config AgentConfig) (*Agent, error) {
53 // Validate configuration
54 if err := validateConfig(config); err != nil {
55 return nil, fmt.Errorf("invalid config: %w", err)
56 }
57
58 // Create LLM provider
59 llmProvider, err := llm.CreateProvider(config.LLMConfig)
60 if err != nil {
61 return nil, fmt.Errorf("failed to create LLM provider: %w", err)
62 }
63
64 // Create git interface
65 gitInterface := git.DefaultGit(config.GitRepoPath)
66
67 // Create context with cancellation
68 ctx, cancel := context.WithCancel(context.Background())
69
70 agent := &Agent{
71 Config: config,
72 llmProvider: llmProvider,
73 gitInterface: gitInterface,
74 ctx: ctx,
75 cancel: cancel,
76 }
77
78 return agent, nil
79}
80
81// validateConfig validates the agent configuration
82func validateConfig(config AgentConfig) error {
83 if config.Name == "" {
84 return fmt.Errorf("agent name is required")
85 }
86 if config.Role == "" {
87 return fmt.Errorf("agent role is required")
88 }
89 if config.WorkingDir == "" {
90 return fmt.Errorf("working directory is required")
91 }
92 if config.SystemPrompt == "" {
93 return fmt.Errorf("system prompt is required")
94 }
95 if config.TaskManager == nil {
96 return fmt.Errorf("task manager is required")
97 }
98 if config.GitRepoPath == "" {
99 return fmt.Errorf("git repository path is required")
100 }
101 return nil
102}
103
104// Run starts the agent's main loop
105func (a *Agent) Run() error {
106 log.Printf("Starting agent %s (%s)", a.Config.Name, a.Config.Role)
107 defer log.Printf("Agent %s stopped", a.Config.Name)
108
109 // Initialize git repository if needed
110 if err := a.initializeGit(); err != nil {
111 return fmt.Errorf("failed to initialize git: %w", err)
112 }
113
114 // Main agent loop
115 for {
116 select {
117 case <-a.ctx.Done():
118 return a.ctx.Err()
119 default:
120 if err := a.processNextTask(); err != nil {
121 log.Printf("Error processing task: %v", err)
122 // Continue running even if there's an error
123 time.Sleep(30 * time.Second)
124 }
125 }
iomodob67a3762025-07-25 20:27:04 +0400126 }
127}
128
iomodo76f9a2d2025-07-26 12:14:40 +0400129// Stop stops the agent
130func (a *Agent) Stop() {
131 log.Printf("Stopping agent %s", a.Config.Name)
132 a.cancel()
133 if a.llmProvider != nil {
134 a.llmProvider.Close()
135 }
136}
iomodob67a3762025-07-25 20:27:04 +0400137
iomodo76f9a2d2025-07-26 12:14:40 +0400138// initializeGit initializes the git repository
139func (a *Agent) initializeGit() error {
140 ctx := context.Background()
141
142 // Check if repository exists
143 isRepo, err := a.gitInterface.IsRepository(ctx, a.Config.GitRepoPath)
144 if err != nil {
145 return fmt.Errorf("failed to check repository: %w", err)
146 }
147
148 if !isRepo {
149 // Initialize new repository
150 if err := a.gitInterface.Init(ctx, a.Config.GitRepoPath); err != nil {
151 return fmt.Errorf("failed to initialize repository: %w", err)
152 }
153 }
154
iomodo76f9a2d2025-07-26 12:14:40 +0400155 // Checkout to the specified branch
156 if a.Config.GitBranch != "" {
157 if err := a.gitInterface.Checkout(ctx, a.Config.GitBranch); err != nil {
158 // Try to create the branch if it doesn't exist
159 if err := a.gitInterface.CreateBranch(ctx, a.Config.GitBranch, ""); err != nil {
160 return fmt.Errorf("failed to create branch %s: %w", a.Config.GitBranch, err)
161 }
162 }
163 }
164
165 return nil
166}
167
168// processNextTask processes the next available task
169func (a *Agent) processNextTask() error {
170 ctx := context.Background()
171
172 // Get tasks assigned to this agent
173 taskList, err := a.Config.TaskManager.GetTasksByOwner(ctx, a.Config.Name, 0, 10)
174 if err != nil {
175 return fmt.Errorf("failed to get tasks: %w", err)
176 }
177
178 // Find a task that's ready to be worked on
179 var taskToProcess *tm.Task
180 for _, task := range taskList.Tasks {
181 if task.Status == tm.StatusToDo {
182 taskToProcess = task
183 break
184 }
185 }
186
187 if taskToProcess == nil {
188 // No tasks to process, wait a bit
189 time.Sleep(60 * time.Second)
190 return nil
191 }
192
193 log.Printf("Processing task: %s - %s", taskToProcess.ID, taskToProcess.Title)
194
195 // Start the task
196 startedTask, err := a.Config.TaskManager.StartTask(ctx, taskToProcess.ID)
197 if err != nil {
198 return fmt.Errorf("failed to start task: %w", err)
199 }
200
201 // Process the task with LLM
202 solution, err := a.processTaskWithLLM(startedTask)
203 if err != nil {
204 // Mark task as failed or retry
205 log.Printf("Failed to process task with LLM: %v", err)
206 return err
207 }
208
209 // Create PR with the solution
210 if err := a.createPullRequest(startedTask, solution); err != nil {
211 return fmt.Errorf("failed to create pull request: %w", err)
212 }
213
214 // Complete the task
215 if _, err := a.Config.TaskManager.CompleteTask(ctx, startedTask.ID); err != nil {
216 return fmt.Errorf("failed to complete task: %w", err)
217 }
218
219 log.Printf("Successfully completed task: %s", startedTask.ID)
220 return nil
221}
222
223// processTaskWithLLM sends the task to the LLM and gets a solution
224func (a *Agent) processTaskWithLLM(task *tm.Task) (string, error) {
225 ctx := context.Background()
226
227 // Prepare the prompt
228 prompt := a.buildTaskPrompt(task)
229
230 // Create chat completion request
231 req := llm.ChatCompletionRequest{
232 Model: a.Config.LLMModel,
233 Messages: []llm.Message{
234 {
235 Role: llm.RoleSystem,
236 Content: a.Config.SystemPrompt,
237 },
238 {
239 Role: llm.RoleUser,
240 Content: prompt,
241 },
242 },
243 MaxTokens: intPtr(4000),
244 Temperature: float64Ptr(0.7),
245 }
246
247 // Get response from LLM
248 resp, err := a.llmProvider.ChatCompletion(ctx, req)
249 if err != nil {
250 return "", fmt.Errorf("LLM chat completion failed: %w", err)
251 }
252
253 if len(resp.Choices) == 0 {
254 return "", fmt.Errorf("no response from LLM")
255 }
256
257 return resp.Choices[0].Message.Content, nil
258}
259
260// buildTaskPrompt builds the prompt for the LLM based on the task
261func (a *Agent) buildTaskPrompt(task *tm.Task) string {
262 var prompt strings.Builder
263
264 prompt.WriteString(fmt.Sprintf("Task ID: %s\n", task.ID))
265 prompt.WriteString(fmt.Sprintf("Title: %s\n", task.Title))
266 prompt.WriteString(fmt.Sprintf("Priority: %s\n", task.Priority))
267
268 if task.Description != "" {
269 prompt.WriteString(fmt.Sprintf("Description: %s\n", task.Description))
270 }
271
272 if task.DueDate != nil {
273 prompt.WriteString(fmt.Sprintf("Due Date: %s\n", task.DueDate.Format("2006-01-02")))
274 }
275
276 prompt.WriteString("\nPlease provide a detailed solution for this task. ")
277 prompt.WriteString("Include any code, documentation, or other deliverables as needed. ")
278 prompt.WriteString("Format your response appropriately for the type of task.")
279
280 return prompt.String()
281}
282
283// createPullRequest creates a pull request with the solution
284func (a *Agent) createPullRequest(task *tm.Task, solution string) error {
285 ctx := context.Background()
286
287 // Generate branch name
288 branchName := a.generateBranchName(task)
289
290 // Create and checkout to new branch
291 if err := a.gitInterface.CreateBranch(ctx, branchName, ""); err != nil {
292 return fmt.Errorf("failed to create branch: %w", err)
293 }
294
295 if err := a.gitInterface.Checkout(ctx, branchName); err != nil {
296 return fmt.Errorf("failed to checkout branch: %w", err)
297 }
298
299 // Create solution file
300 solutionPath := filepath.Join(a.Config.WorkingDir, fmt.Sprintf("task-%s-solution.md", task.ID))
301 solutionContent := a.formatSolution(task, solution)
302
303 if err := os.WriteFile(solutionPath, []byte(solutionContent), 0644); err != nil {
304 return fmt.Errorf("failed to write solution file: %w", err)
305 }
306
307 // Add and commit the solution
308 if err := a.gitInterface.Add(ctx, []string{solutionPath}); err != nil {
309 return fmt.Errorf("failed to add solution file: %w", err)
310 }
311
312 commitMessage := fmt.Sprintf("feat: Complete task %s - %s", task.ID, task.Title)
iomodo7d08e8e2025-07-26 15:24:42 +0400313 if err := a.gitInterface.Commit(ctx, commitMessage, git.CommitOptions{
314 Author: &git.Author{
315 Name: a.Config.GitUsername,
316 Email: a.Config.GitEmail,
317 Time: time.Now(),
318 },
319 }); err != nil {
iomodo76f9a2d2025-07-26 12:14:40 +0400320 return fmt.Errorf("failed to commit solution: %w", err)
321 }
322
323 // Push the branch
324 if err := a.gitInterface.Push(ctx, "origin", branchName, git.PushOptions{SetUpstream: true}); err != nil {
325 return fmt.Errorf("failed to push branch: %w", err)
326 }
327
328 log.Printf("Created pull request for task %s on branch %s", task.ID, branchName)
329 return nil
330}
331
332// generateBranchName generates a branch name for the task
333func (a *Agent) generateBranchName(task *tm.Task) string {
334 // Clean the task title for branch name
335 cleanTitle := strings.ReplaceAll(task.Title, " ", "-")
336 cleanTitle = strings.ToLower(cleanTitle)
337
338 // Remove special characters that are not allowed in git branch names
339 // Keep only alphanumeric characters and hyphens
340 var result strings.Builder
341 for _, char := range cleanTitle {
342 if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' {
343 result.WriteRune(char)
344 }
345 }
346 cleanTitle = result.String()
347
348 // Remove consecutive hyphens
349 for strings.Contains(cleanTitle, "--") {
350 cleanTitle = strings.ReplaceAll(cleanTitle, "--", "-")
351 }
352
353 // Remove leading and trailing hyphens
354 cleanTitle = strings.Trim(cleanTitle, "-")
355
356 // Limit length
357 if len(cleanTitle) > 50 {
358 cleanTitle = cleanTitle[:50]
359 // Ensure we don't end with a hyphen after truncation
360 cleanTitle = strings.TrimSuffix(cleanTitle, "-")
361 }
362
363 return fmt.Sprintf("task/%s-%s", task.ID, cleanTitle)
364}
365
366// formatSolution formats the solution for the pull request
367func (a *Agent) formatSolution(task *tm.Task, solution string) string {
368 var content strings.Builder
369
370 content.WriteString(fmt.Sprintf("# Task Solution: %s\n\n", task.Title))
371 content.WriteString(fmt.Sprintf("**Task ID:** %s\n", task.ID))
372 content.WriteString(fmt.Sprintf("**Agent:** %s (%s)\n", a.Config.Name, a.Config.Role))
373 content.WriteString(fmt.Sprintf("**Completed:** %s\n\n", time.Now().Format("2006-01-02 15:04:05")))
374
375 content.WriteString("## Task Description\n\n")
376 content.WriteString(task.Description)
377 content.WriteString("\n\n")
378
379 content.WriteString("## Solution\n\n")
380 content.WriteString(solution)
381 content.WriteString("\n\n")
382
383 content.WriteString("---\n")
384 content.WriteString("*This solution was generated by AI Agent*\n")
385
386 return content.String()
387}
388
389// ptr helpers for cleaner code
390func intPtr(i int) *int {
391 return &i
392}
393
394func float64Ptr(f float64) *float64 {
395 return &f
iomodob67a3762025-07-25 20:27:04 +0400396}