blob: e3aa7668f0778f29b513ec499967eb685a267271 [file] [log] [blame]
iomodob67a3762025-07-25 20:27:04 +04001package agent
2
iomodo76f9a2d2025-07-26 12:14:40 +04003import (
4 "context"
5 "fmt"
6 "log"
7 "os"
8 "path/filepath"
9 "strings"
10 "time"
11
12 "github.com/iomodo/staff/git"
13 "github.com/iomodo/staff/llm"
14 "github.com/iomodo/staff/tm"
15)
16
17// AgentConfig contains configuration for the agent
iomodob67a3762025-07-25 20:27:04 +040018type AgentConfig struct {
19 Name string
20 Role string
21 GitUsername string
22 GitEmail string
23 WorkingDir string
iomodo76f9a2d2025-07-26 12:14:40 +040024
25 // LLM Configuration
26 LLMProvider llm.Provider
27 LLMModel string
28 LLMConfig llm.Config
29
30 // System prompt for the agent
31 SystemPrompt string
32
33 // Task Manager Configuration
34 TaskManager tm.TaskManager
35
36 // Git Configuration
37 GitRepoPath string
38 GitRemote string
39 GitBranch string
iomodob67a3762025-07-25 20:27:04 +040040}
41
iomodo76f9a2d2025-07-26 12:14:40 +040042// Agent represents an AI agent that can process tasks
iomodob67a3762025-07-25 20:27:04 +040043type Agent struct {
iomodo76f9a2d2025-07-26 12:14:40 +040044 Config AgentConfig
45 llmProvider llm.LLMProvider
46 gitInterface git.GitInterface
47 ctx context.Context
48 cancel context.CancelFunc
iomodob67a3762025-07-25 20:27:04 +040049}
50
iomodo76f9a2d2025-07-26 12:14:40 +040051// NewAgent creates a new agent instance
52func NewAgent(config AgentConfig) (*Agent, error) {
53 // Validate configuration
54 if err := validateConfig(config); err != nil {
55 return nil, fmt.Errorf("invalid config: %w", err)
56 }
57
58 // Create LLM provider
59 llmProvider, err := llm.CreateProvider(config.LLMConfig)
60 if err != nil {
61 return nil, fmt.Errorf("failed to create LLM provider: %w", err)
62 }
63
64 // Create git interface
65 gitInterface := git.DefaultGit(config.GitRepoPath)
66
67 // Create context with cancellation
68 ctx, cancel := context.WithCancel(context.Background())
69
70 agent := &Agent{
71 Config: config,
72 llmProvider: llmProvider,
73 gitInterface: gitInterface,
74 ctx: ctx,
75 cancel: cancel,
76 }
77
78 return agent, nil
79}
80
81// validateConfig validates the agent configuration
82func validateConfig(config AgentConfig) error {
83 if config.Name == "" {
84 return fmt.Errorf("agent name is required")
85 }
86 if config.Role == "" {
87 return fmt.Errorf("agent role is required")
88 }
89 if config.WorkingDir == "" {
90 return fmt.Errorf("working directory is required")
91 }
92 if config.SystemPrompt == "" {
93 return fmt.Errorf("system prompt is required")
94 }
95 if config.TaskManager == nil {
96 return fmt.Errorf("task manager is required")
97 }
98 if config.GitRepoPath == "" {
99 return fmt.Errorf("git repository path is required")
100 }
101 return nil
102}
103
104// Run starts the agent's main loop
105func (a *Agent) Run() error {
106 log.Printf("Starting agent %s (%s)", a.Config.Name, a.Config.Role)
107 defer log.Printf("Agent %s stopped", a.Config.Name)
108
109 // Initialize git repository if needed
110 if err := a.initializeGit(); err != nil {
111 return fmt.Errorf("failed to initialize git: %w", err)
112 }
113
114 // Main agent loop
115 for {
116 select {
117 case <-a.ctx.Done():
118 return a.ctx.Err()
119 default:
120 if err := a.processNextTask(); err != nil {
121 log.Printf("Error processing task: %v", err)
122 // Continue running even if there's an error
123 time.Sleep(30 * time.Second)
124 }
125 }
iomodob67a3762025-07-25 20:27:04 +0400126 }
127}
128
iomodo76f9a2d2025-07-26 12:14:40 +0400129// Stop stops the agent
130func (a *Agent) Stop() {
131 log.Printf("Stopping agent %s", a.Config.Name)
132 a.cancel()
133 if a.llmProvider != nil {
134 a.llmProvider.Close()
135 }
136}
iomodob67a3762025-07-25 20:27:04 +0400137
iomodo76f9a2d2025-07-26 12:14:40 +0400138// initializeGit initializes the git repository
139func (a *Agent) initializeGit() error {
140 ctx := context.Background()
141
142 // Check if repository exists
143 isRepo, err := a.gitInterface.IsRepository(ctx, a.Config.GitRepoPath)
144 if err != nil {
145 return fmt.Errorf("failed to check repository: %w", err)
146 }
147
148 if !isRepo {
149 // Initialize new repository
150 if err := a.gitInterface.Init(ctx, a.Config.GitRepoPath); err != nil {
151 return fmt.Errorf("failed to initialize repository: %w", err)
152 }
153 }
154
155 // Set git user configuration
156 userConfig := git.UserConfig{
157 Name: a.Config.GitUsername,
158 Email: a.Config.GitEmail,
159 }
160 if err := a.gitInterface.SetUserConfig(ctx, userConfig); err != nil {
161 return fmt.Errorf("failed to set git user config: %w", err)
162 }
163
164 // Checkout to the specified branch
165 if a.Config.GitBranch != "" {
166 if err := a.gitInterface.Checkout(ctx, a.Config.GitBranch); err != nil {
167 // Try to create the branch if it doesn't exist
168 if err := a.gitInterface.CreateBranch(ctx, a.Config.GitBranch, ""); err != nil {
169 return fmt.Errorf("failed to create branch %s: %w", a.Config.GitBranch, err)
170 }
171 }
172 }
173
174 return nil
175}
176
177// processNextTask processes the next available task
178func (a *Agent) processNextTask() error {
179 ctx := context.Background()
180
181 // Get tasks assigned to this agent
182 taskList, err := a.Config.TaskManager.GetTasksByOwner(ctx, a.Config.Name, 0, 10)
183 if err != nil {
184 return fmt.Errorf("failed to get tasks: %w", err)
185 }
186
187 // Find a task that's ready to be worked on
188 var taskToProcess *tm.Task
189 for _, task := range taskList.Tasks {
190 if task.Status == tm.StatusToDo {
191 taskToProcess = task
192 break
193 }
194 }
195
196 if taskToProcess == nil {
197 // No tasks to process, wait a bit
198 time.Sleep(60 * time.Second)
199 return nil
200 }
201
202 log.Printf("Processing task: %s - %s", taskToProcess.ID, taskToProcess.Title)
203
204 // Start the task
205 startedTask, err := a.Config.TaskManager.StartTask(ctx, taskToProcess.ID)
206 if err != nil {
207 return fmt.Errorf("failed to start task: %w", err)
208 }
209
210 // Process the task with LLM
211 solution, err := a.processTaskWithLLM(startedTask)
212 if err != nil {
213 // Mark task as failed or retry
214 log.Printf("Failed to process task with LLM: %v", err)
215 return err
216 }
217
218 // Create PR with the solution
219 if err := a.createPullRequest(startedTask, solution); err != nil {
220 return fmt.Errorf("failed to create pull request: %w", err)
221 }
222
223 // Complete the task
224 if _, err := a.Config.TaskManager.CompleteTask(ctx, startedTask.ID); err != nil {
225 return fmt.Errorf("failed to complete task: %w", err)
226 }
227
228 log.Printf("Successfully completed task: %s", startedTask.ID)
229 return nil
230}
231
232// processTaskWithLLM sends the task to the LLM and gets a solution
233func (a *Agent) processTaskWithLLM(task *tm.Task) (string, error) {
234 ctx := context.Background()
235
236 // Prepare the prompt
237 prompt := a.buildTaskPrompt(task)
238
239 // Create chat completion request
240 req := llm.ChatCompletionRequest{
241 Model: a.Config.LLMModel,
242 Messages: []llm.Message{
243 {
244 Role: llm.RoleSystem,
245 Content: a.Config.SystemPrompt,
246 },
247 {
248 Role: llm.RoleUser,
249 Content: prompt,
250 },
251 },
252 MaxTokens: intPtr(4000),
253 Temperature: float64Ptr(0.7),
254 }
255
256 // Get response from LLM
257 resp, err := a.llmProvider.ChatCompletion(ctx, req)
258 if err != nil {
259 return "", fmt.Errorf("LLM chat completion failed: %w", err)
260 }
261
262 if len(resp.Choices) == 0 {
263 return "", fmt.Errorf("no response from LLM")
264 }
265
266 return resp.Choices[0].Message.Content, nil
267}
268
269// buildTaskPrompt builds the prompt for the LLM based on the task
270func (a *Agent) buildTaskPrompt(task *tm.Task) string {
271 var prompt strings.Builder
272
273 prompt.WriteString(fmt.Sprintf("Task ID: %s\n", task.ID))
274 prompt.WriteString(fmt.Sprintf("Title: %s\n", task.Title))
275 prompt.WriteString(fmt.Sprintf("Priority: %s\n", task.Priority))
276
277 if task.Description != "" {
278 prompt.WriteString(fmt.Sprintf("Description: %s\n", task.Description))
279 }
280
281 if task.DueDate != nil {
282 prompt.WriteString(fmt.Sprintf("Due Date: %s\n", task.DueDate.Format("2006-01-02")))
283 }
284
285 prompt.WriteString("\nPlease provide a detailed solution for this task. ")
286 prompt.WriteString("Include any code, documentation, or other deliverables as needed. ")
287 prompt.WriteString("Format your response appropriately for the type of task.")
288
289 return prompt.String()
290}
291
292// createPullRequest creates a pull request with the solution
293func (a *Agent) createPullRequest(task *tm.Task, solution string) error {
294 ctx := context.Background()
295
296 // Generate branch name
297 branchName := a.generateBranchName(task)
298
299 // Create and checkout to new branch
300 if err := a.gitInterface.CreateBranch(ctx, branchName, ""); err != nil {
301 return fmt.Errorf("failed to create branch: %w", err)
302 }
303
304 if err := a.gitInterface.Checkout(ctx, branchName); err != nil {
305 return fmt.Errorf("failed to checkout branch: %w", err)
306 }
307
308 // Create solution file
309 solutionPath := filepath.Join(a.Config.WorkingDir, fmt.Sprintf("task-%s-solution.md", task.ID))
310 solutionContent := a.formatSolution(task, solution)
311
312 if err := os.WriteFile(solutionPath, []byte(solutionContent), 0644); err != nil {
313 return fmt.Errorf("failed to write solution file: %w", err)
314 }
315
316 // Add and commit the solution
317 if err := a.gitInterface.Add(ctx, []string{solutionPath}); err != nil {
318 return fmt.Errorf("failed to add solution file: %w", err)
319 }
320
321 commitMessage := fmt.Sprintf("feat: Complete task %s - %s", task.ID, task.Title)
322 if err := a.gitInterface.Commit(ctx, commitMessage, git.CommitOptions{}); err != nil {
323 return fmt.Errorf("failed to commit solution: %w", err)
324 }
325
326 // Push the branch
327 if err := a.gitInterface.Push(ctx, "origin", branchName, git.PushOptions{SetUpstream: true}); err != nil {
328 return fmt.Errorf("failed to push branch: %w", err)
329 }
330
331 log.Printf("Created pull request for task %s on branch %s", task.ID, branchName)
332 return nil
333}
334
335// generateBranchName generates a branch name for the task
336func (a *Agent) generateBranchName(task *tm.Task) string {
337 // Clean the task title for branch name
338 cleanTitle := strings.ReplaceAll(task.Title, " ", "-")
339 cleanTitle = strings.ToLower(cleanTitle)
340
341 // Remove special characters that are not allowed in git branch names
342 // Keep only alphanumeric characters and hyphens
343 var result strings.Builder
344 for _, char := range cleanTitle {
345 if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' {
346 result.WriteRune(char)
347 }
348 }
349 cleanTitle = result.String()
350
351 // Remove consecutive hyphens
352 for strings.Contains(cleanTitle, "--") {
353 cleanTitle = strings.ReplaceAll(cleanTitle, "--", "-")
354 }
355
356 // Remove leading and trailing hyphens
357 cleanTitle = strings.Trim(cleanTitle, "-")
358
359 // Limit length
360 if len(cleanTitle) > 50 {
361 cleanTitle = cleanTitle[:50]
362 // Ensure we don't end with a hyphen after truncation
363 cleanTitle = strings.TrimSuffix(cleanTitle, "-")
364 }
365
366 return fmt.Sprintf("task/%s-%s", task.ID, cleanTitle)
367}
368
369// formatSolution formats the solution for the pull request
370func (a *Agent) formatSolution(task *tm.Task, solution string) string {
371 var content strings.Builder
372
373 content.WriteString(fmt.Sprintf("# Task Solution: %s\n\n", task.Title))
374 content.WriteString(fmt.Sprintf("**Task ID:** %s\n", task.ID))
375 content.WriteString(fmt.Sprintf("**Agent:** %s (%s)\n", a.Config.Name, a.Config.Role))
376 content.WriteString(fmt.Sprintf("**Completed:** %s\n\n", time.Now().Format("2006-01-02 15:04:05")))
377
378 content.WriteString("## Task Description\n\n")
379 content.WriteString(task.Description)
380 content.WriteString("\n\n")
381
382 content.WriteString("## Solution\n\n")
383 content.WriteString(solution)
384 content.WriteString("\n\n")
385
386 content.WriteString("---\n")
387 content.WriteString("*This solution was generated by AI Agent*\n")
388
389 return content.String()
390}
391
392// ptr helpers for cleaner code
393func intPtr(i int) *int {
394 return &i
395}
396
397func float64Ptr(f float64) *float64 {
398 return &f
iomodob67a3762025-07-25 20:27:04 +0400399}