blob: 6de4784d3ba1c68b32fff8a5f8f53abb2562ec71 [file] [log] [blame]
package agent
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/iomodo/staff/git"
"github.com/iomodo/staff/llm"
"github.com/iomodo/staff/tm"
)
// AgentConfig contains configuration for the agent
type AgentConfig struct {
Name string
Role string
GitUsername string
GitEmail string
WorkingDir string
// LLM Configuration
LLMProvider llm.Provider
LLMModel string
LLMConfig llm.Config
// System prompt for the agent
SystemPrompt string
// Task Manager Configuration
TaskManager tm.TaskManager
// Git Configuration
GitRepoPath string
GitRemote string
GitBranch string
// Gerrit Configuration
GerritEnabled bool
GerritConfig GerritConfig
}
// GerritConfig holds configuration for Gerrit operations
type GerritConfig struct {
Username string
Password string // Can be HTTP password or API token
BaseURL string
Project string
}
// Agent represents an AI agent that can process tasks
type Agent struct {
Config AgentConfig
llmProvider llm.LLMProvider
gitInterface git.GitInterface
ctx context.Context
cancel context.CancelFunc
}
// NewAgent creates a new agent instance
func NewAgent(config AgentConfig) (*Agent, error) {
// Validate configuration
if err := validateConfig(config); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
// Create LLM provider
llmProvider, err := llm.CreateProvider(config.LLMConfig)
if err != nil {
return nil, fmt.Errorf("failed to create LLM provider: %w", err)
}
// Create git interface
var gitInterface git.GitInterface
if config.GerritEnabled {
// Create Gerrit pull request provider
gerritPRProvider := git.NewGerritPullRequestProvider(config.GerritConfig.Project, git.GerritConfig{
Username: config.GerritConfig.Username,
Password: config.GerritConfig.Password,
BaseURL: config.GerritConfig.BaseURL,
HTTPClient: nil, // Will use default client
})
// Create git interface with Gerrit pull request provider
gitConfig := git.GitConfig{
Timeout: 30 * time.Second,
PullRequestProvider: gerritPRProvider,
}
gitInterface = git.NewGitWithPullRequests(config.GitRepoPath, gitConfig, gerritPRProvider)
} else {
// Use default git interface (GitHub)
gitInterface = git.DefaultGit(config.GitRepoPath)
}
// Create context with cancellation
ctx, cancel := context.WithCancel(context.Background())
agent := &Agent{
Config: config,
llmProvider: llmProvider,
gitInterface: gitInterface,
ctx: ctx,
cancel: cancel,
}
return agent, nil
}
// validateConfig validates the agent configuration
func validateConfig(config AgentConfig) error {
if config.Name == "" {
return fmt.Errorf("agent name is required")
}
if config.Role == "" {
return fmt.Errorf("agent role is required")
}
if config.WorkingDir == "" {
return fmt.Errorf("working directory is required")
}
if config.SystemPrompt == "" {
return fmt.Errorf("system prompt is required")
}
if config.TaskManager == nil {
return fmt.Errorf("task manager is required")
}
if config.GitRepoPath == "" {
return fmt.Errorf("git repository path is required")
}
return nil
}
// Run starts the agent's main loop
func (a *Agent) Run() error {
log.Printf("Starting agent %s (%s)", a.Config.Name, a.Config.Role)
defer log.Printf("Agent %s stopped", a.Config.Name)
// Initialize git repository if needed
if err := a.initializeGit(); err != nil {
return fmt.Errorf("failed to initialize git: %w", err)
}
// Main agent loop
for {
select {
case <-a.ctx.Done():
return a.ctx.Err()
default:
if err := a.processNextTask(); err != nil {
log.Printf("Error processing task: %v", err)
// Continue running even if there's an error
time.Sleep(30 * time.Second)
}
}
}
}
// Stop stops the agent
func (a *Agent) Stop() {
log.Printf("Stopping agent %s", a.Config.Name)
a.cancel()
if a.llmProvider != nil {
a.llmProvider.Close()
}
}
// initializeGit initializes the git repository
func (a *Agent) initializeGit() error {
ctx := context.Background()
// Check if repository exists
isRepo, err := a.gitInterface.IsRepository(ctx, a.Config.GitRepoPath)
if err != nil {
return fmt.Errorf("failed to check repository: %w", err)
}
if !isRepo {
// Initialize new repository
if err := a.gitInterface.Init(ctx, a.Config.GitRepoPath); err != nil {
return fmt.Errorf("failed to initialize repository: %w", err)
}
}
// Check if remote origin exists, if not add it
remotes, err := a.gitInterface.ListRemotes(ctx)
if err != nil {
return fmt.Errorf("failed to list remotes: %w", err)
}
originExists := false
for _, remote := range remotes {
if remote.Name == "origin" {
originExists = true
break
}
}
if !originExists {
// Add remote origin - use Gerrit URL if enabled, otherwise use the configured remote
remoteURL := a.Config.GitRemote
if a.Config.GerritEnabled {
// For Gerrit, the remote URL should be the Gerrit SSH or HTTP URL
// Format: ssh://username@gerrit-host:29418/project-name.git
// or: https://gerrit-host/project-name.git
if strings.HasPrefix(a.Config.GerritConfig.BaseURL, "https://") {
remoteURL = fmt.Sprintf("%s/%s.git", a.Config.GerritConfig.BaseURL, a.Config.GerritConfig.Project)
} else {
// Assume SSH format
remoteURL = fmt.Sprintf("ssh://%s@%s:29418/%s.git",
a.Config.GerritConfig.Username,
strings.TrimPrefix(a.Config.GerritConfig.BaseURL, "https://"),
a.Config.GerritConfig.Project)
}
}
if err := a.gitInterface.AddRemote(ctx, "origin", remoteURL); err != nil {
return fmt.Errorf("failed to add remote origin: %w", err)
}
}
// Checkout to the specified branch
if a.Config.GitBranch != "" {
// First, check if we're already on the target branch
currentBranch, err := a.gitInterface.GetCurrentBranch(ctx)
if err != nil {
return fmt.Errorf("failed to get current branch: %w", err)
}
// Only checkout if we're not already on the target branch
if currentBranch != a.Config.GitBranch {
if err := a.gitInterface.Checkout(ctx, a.Config.GitBranch); err != nil {
errMsg := err.Error()
// Only create the branch if the error indicates it doesn't exist
if strings.Contains(errMsg, "did not match any file(s) known to git") ||
strings.Contains(errMsg, "not found") ||
strings.Contains(errMsg, "unknown revision") ||
strings.Contains(errMsg, "reference is not a tree") ||
strings.Contains(errMsg, "pathspec") ||
strings.Contains(errMsg, "fatal: invalid reference") {
if err := a.gitInterface.CreateBranch(ctx, a.Config.GitBranch, ""); err != nil {
return fmt.Errorf("failed to create branch %s: %w", a.Config.GitBranch, err)
}
} else {
return fmt.Errorf("failed to checkout branch %s: %w", a.Config.GitBranch, err)
}
}
} else {
log.Printf("Already on target branch: %s", a.Config.GitBranch)
}
}
return nil
}
// processNextTask processes the next available task
func (a *Agent) processNextTask() error {
ctx := context.Background()
// Get tasks assigned to this agent
taskList, err := a.Config.TaskManager.GetTasksByOwner(ctx, a.Config.Name, 0, 10)
if err != nil {
return fmt.Errorf("failed to get tasks: %w", err)
}
// Find a task that's ready to be worked on
var taskToProcess *tm.Task
for _, task := range taskList.Tasks {
if task.Status == tm.StatusToDo {
taskToProcess = task
break
}
}
if taskToProcess == nil {
// No tasks to process, wait a bit
time.Sleep(60 * time.Second)
return nil
}
log.Printf("Processing task: %s - %s", taskToProcess.ID, taskToProcess.Title)
// Start the task
startedTask, err := a.Config.TaskManager.StartTask(ctx, taskToProcess.ID)
if err != nil {
return fmt.Errorf("failed to start task: %w", err)
}
// Process the task with LLM
solution, err := a.processTaskWithLLM(startedTask)
if err != nil {
// Mark task as failed or retry
log.Printf("Failed to process task with LLM: %v", err)
return err
}
// Create PR with the solution
if err := a.createPullRequest(startedTask, solution); err != nil {
return fmt.Errorf("failed to create pull request: %w", err)
}
// Complete the task
if _, err := a.Config.TaskManager.CompleteTask(ctx, startedTask.ID); err != nil {
return fmt.Errorf("failed to complete task: %w", err)
}
log.Printf("Successfully completed task: %s", startedTask.ID)
return nil
}
// processTaskWithLLM sends the task to the LLM and gets a solution
func (a *Agent) processTaskWithLLM(task *tm.Task) (string, error) {
ctx := context.Background()
// Prepare the prompt
prompt := a.buildTaskPrompt(task)
// Create chat completion request
req := llm.ChatCompletionRequest{
Model: a.Config.LLMModel,
Messages: []llm.Message{
{
Role: llm.RoleSystem,
Content: a.Config.SystemPrompt,
},
{
Role: llm.RoleUser,
Content: prompt,
},
},
MaxTokens: intPtr(4000),
Temperature: float64Ptr(0.7),
}
// Get response from LLM
resp, err := a.llmProvider.ChatCompletion(ctx, req)
if err != nil {
return "", fmt.Errorf("LLM chat completion failed: %w", err)
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no response from LLM")
}
return resp.Choices[0].Message.Content, nil
}
// buildTaskPrompt builds the prompt for the LLM based on the task
func (a *Agent) buildTaskPrompt(task *tm.Task) string {
var prompt strings.Builder
prompt.WriteString(fmt.Sprintf("Task ID: %s\n", task.ID))
prompt.WriteString(fmt.Sprintf("Title: %s\n", task.Title))
prompt.WriteString(fmt.Sprintf("Priority: %s\n", task.Priority))
if task.Description != "" {
prompt.WriteString(fmt.Sprintf("Description: %s\n", task.Description))
}
if task.DueDate != nil {
prompt.WriteString(fmt.Sprintf("Due Date: %s\n", task.DueDate.Format("2006-01-02")))
}
prompt.WriteString("\nPlease provide a detailed solution for this task. ")
prompt.WriteString("Include any code, documentation, or other deliverables as needed. ")
prompt.WriteString("Format your response appropriately for the type of task.")
return prompt.String()
}
// createPullRequest creates a pull request with the solution
func (a *Agent) createPullRequest(task *tm.Task, solution string) error {
ctx := context.Background()
// Generate branch name
branchName := a.generateBranchName(task)
// Create and checkout to new branch
if err := a.gitInterface.CreateBranch(ctx, branchName, ""); err != nil {
return fmt.Errorf("failed to create branch: %w", err)
}
if err := a.gitInterface.Checkout(ctx, branchName); err != nil {
return fmt.Errorf("failed to checkout branch: %w", err)
}
// Create solution file
solutionPath := filepath.Join(a.Config.WorkingDir, fmt.Sprintf("task-%s-solution.md", task.ID))
solutionContent := a.formatSolution(task, solution)
if err := os.WriteFile(solutionPath, []byte(solutionContent), 0644); err != nil {
return fmt.Errorf("failed to write solution file: %w", err)
}
// Add and commit the solution
if err := a.gitInterface.Add(ctx, []string{solutionPath}); err != nil {
return fmt.Errorf("failed to add solution file: %w", err)
}
commitMessage := fmt.Sprintf("feat: Complete task %s - %s\n\n%s", task.ID, task.Title, a.formatPullRequestDescription(task, solution))
if err := a.gitInterface.Commit(ctx, commitMessage, git.CommitOptions{
Author: &git.Author{
Name: a.Config.GitUsername,
Email: a.Config.GitEmail,
Time: time.Now(),
},
}); err != nil {
return fmt.Errorf("failed to commit solution: %w", err)
}
if a.Config.GerritEnabled {
// For Gerrit: Push to refs/for/BRANCH to create a change
gerritRef := fmt.Sprintf("refs/for/%s", a.Config.GitBranch)
if err := a.gitInterface.Push(ctx, "origin", gerritRef, git.PushOptions{}); err != nil {
return fmt.Errorf("failed to push to Gerrit: %w", err)
}
log.Printf("Created Gerrit change for task %s by pushing to %s", task.ID, gerritRef)
} else {
// For GitHub: Push branch and create PR
if err := a.gitInterface.Push(ctx, "origin", branchName, git.PushOptions{SetUpstream: true}); err != nil {
return fmt.Errorf("failed to push branch: %w", err)
}
// Create pull request using the git interface
prOptions := git.PullRequestOptions{
Title: fmt.Sprintf("Complete task %s: %s", task.ID, task.Title),
Description: a.formatPullRequestDescription(task, solution),
BaseBranch: a.Config.GitBranch,
HeadBranch: branchName,
BaseRepo: a.Config.GerritConfig.Project,
HeadRepo: a.Config.GerritConfig.Project,
}
pr, err := a.gitInterface.CreatePullRequest(ctx, prOptions)
if err != nil {
return fmt.Errorf("failed to create pull request: %w", err)
}
log.Printf("Created pull request for task %s: %s (ID: %s)", task.ID, pr.Title, pr.ID)
}
return nil
}
// generateBranchName generates a branch name for the task
func (a *Agent) generateBranchName(task *tm.Task) string {
// Clean the task title for branch name
cleanTitle := strings.ReplaceAll(task.Title, " ", "-")
cleanTitle = strings.ToLower(cleanTitle)
// Remove special characters that are not allowed in git branch names
// Keep only alphanumeric characters and hyphens
var result strings.Builder
for _, char := range cleanTitle {
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' {
result.WriteRune(char)
}
}
cleanTitle = result.String()
// Remove consecutive hyphens
for strings.Contains(cleanTitle, "--") {
cleanTitle = strings.ReplaceAll(cleanTitle, "--", "-")
}
// Remove leading and trailing hyphens
cleanTitle = strings.Trim(cleanTitle, "-")
// Limit length
if len(cleanTitle) > 50 {
cleanTitle = cleanTitle[:50]
// Ensure we don't end with a hyphen after truncation
cleanTitle = strings.TrimSuffix(cleanTitle, "-")
}
return fmt.Sprintf("task/%s-%s", task.ID, cleanTitle)
}
// formatSolution formats the solution for the pull request
func (a *Agent) formatSolution(task *tm.Task, solution string) string {
var content strings.Builder
content.WriteString(fmt.Sprintf("# Task Solution: %s\n\n", task.Title))
content.WriteString(fmt.Sprintf("**Task ID:** %s\n", task.ID))
content.WriteString(fmt.Sprintf("**Agent:** %s (%s)\n", a.Config.Name, a.Config.Role))
content.WriteString(fmt.Sprintf("**Completed:** %s\n\n", time.Now().Format("2006-01-02 15:04:05")))
content.WriteString("## Task Description\n\n")
content.WriteString(task.Description)
content.WriteString("\n\n")
content.WriteString("## Solution\n\n")
content.WriteString(solution)
content.WriteString("\n\n")
content.WriteString("---\n")
content.WriteString("*This solution was generated by AI Agent*\n")
return content.String()
}
// formatPullRequestDescription formats the description for the pull request
func (a *Agent) formatPullRequestDescription(task *tm.Task, solution string) string {
var content strings.Builder
content.WriteString(fmt.Sprintf("**Task ID:** %s\n", task.ID))
content.WriteString(fmt.Sprintf("**Title:** %s\n", task.Title))
content.WriteString(fmt.Sprintf("**Priority:** %s\n", task.Priority))
if task.Description != "" {
content.WriteString(fmt.Sprintf("**Description:** %s\n", task.Description))
}
if task.DueDate != nil {
content.WriteString(fmt.Sprintf("**Due Date:** %s\n", task.DueDate.Format("2006-01-02")))
}
content.WriteString("\n**Solution:**\n\n")
content.WriteString(solution)
return content.String()
}
// ptr helpers for cleaner code
func intPtr(i int) *int {
return &i
}
func float64Ptr(f float64) *float64 {
return &f
}