blob: 405e5b717980a7b72bdc1f2921656cf10b8b9e4f [file] [log] [blame]
package agent
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/iomodo/staff/git"
"github.com/iomodo/staff/llm"
"github.com/iomodo/staff/tm"
"github.com/iomodo/staff/tm/git_tm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// MockLLMProvider implements LLMProvider for testing
type MockLLMProvider struct{}
func (m *MockLLMProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
return &llm.ChatCompletionResponse{
ID: "mock-response-id",
Model: req.Model,
Choices: []llm.ChatCompletionChoice{
{
Index: 0,
Message: llm.Message{
Role: llm.RoleAssistant,
Content: "This is a mock response for testing purposes.",
},
FinishReason: "stop",
},
},
Usage: llm.Usage{
PromptTokens: 10,
CompletionTokens: 20,
TotalTokens: 30,
},
Provider: llm.ProviderOpenAI,
}, nil
}
func (m *MockLLMProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
return &llm.EmbeddingResponse{
Object: "list",
Data: []llm.Embedding{
{
Object: "embedding",
Embedding: []float64{0.1, 0.2, 0.3},
Index: 0,
},
},
Usage: llm.Usage{
PromptTokens: 5,
TotalTokens: 5,
},
Model: req.Model,
Provider: llm.ProviderOpenAI,
}, nil
}
func (m *MockLLMProvider) Close() error {
return nil
}
// MockLLMFactory implements ProviderFactory for testing
type MockLLMFactory struct{}
func (f *MockLLMFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
return &MockLLMProvider{}, nil
}
func (f *MockLLMFactory) SupportsProvider(provider llm.Provider) bool {
return provider == llm.ProviderOpenAI
}
func setupTestAgent(t *testing.T) (*Agent, func()) {
// Create temporary directories
tempDir, err := os.MkdirTemp("", "agent-test")
require.NoError(t, err)
tasksDir := filepath.Join(tempDir, "tasks")
workspaceDir := filepath.Join(tempDir, "workspace")
codeRepoDir := filepath.Join(tempDir, "code-repo")
// Create directories
require.NoError(t, os.MkdirAll(tasksDir, 0755))
require.NoError(t, os.MkdirAll(workspaceDir, 0755))
require.NoError(t, os.MkdirAll(codeRepoDir, 0755))
// Initialize git repositories
gitInterface := git.DefaultGit(tasksDir)
ctx := context.Background()
err = gitInterface.Init(ctx, tasksDir)
require.NoError(t, err)
// Set git user config
userConfig := git.UserConfig{
Name: "Test User",
Email: "test@example.com",
}
err = gitInterface.SetUserConfig(ctx, userConfig)
require.NoError(t, err)
// Create task manager
taskManager := git_tm.NewGitTaskManager(gitInterface, tasksDir)
// Create LLM config (using a mock configuration)
llmConfig := llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "test-key",
BaseURL: "https://api.openai.com/v1",
Timeout: 30 * time.Second,
}
// Create agent config
config := AgentConfig{
Name: "test-agent",
Role: "Test Engineer",
GitUsername: "test-agent",
GitEmail: "test-agent@test.com",
WorkingDir: workspaceDir,
LLMProvider: llm.ProviderOpenAI,
LLMModel: "gpt-3.5-turbo",
LLMConfig: llmConfig,
SystemPrompt: "You are a test agent. Provide simple, clear solutions.",
TaskManager: taskManager,
GitRepoPath: codeRepoDir,
GitRemote: "origin",
GitBranch: "main",
}
// Create agent with mock LLM provider
agent := &Agent{
Config: config,
llmProvider: &MockLLMProvider{},
gitInterface: git.DefaultGit(codeRepoDir),
ctx: context.Background(),
cancel: func() {},
}
cleanup := func() {
agent.Stop()
os.RemoveAll(tempDir)
}
return agent, cleanup
}
func TestNewAgent(t *testing.T) {
agent, cleanup := setupTestAgent(t)
defer cleanup()
assert.NotNil(t, agent)
assert.Equal(t, "test-agent", agent.Config.Name)
assert.Equal(t, "Test Engineer", agent.Config.Role)
}
func TestValidateConfig(t *testing.T) {
// Test valid config
validConfig := AgentConfig{
Name: "test",
Role: "test",
WorkingDir: "/tmp",
SystemPrompt: "test",
TaskManager: &git_tm.GitTaskManager{},
GitRepoPath: "/tmp",
}
err := validateConfig(validConfig)
assert.NoError(t, err)
// Test invalid configs
testCases := []struct {
name string
config AgentConfig
}{
{"empty name", AgentConfig{Role: "test", WorkingDir: "/tmp", SystemPrompt: "test", TaskManager: &git_tm.GitTaskManager{}, GitRepoPath: "/tmp"}},
{"empty role", AgentConfig{Name: "test", WorkingDir: "/tmp", SystemPrompt: "test", TaskManager: &git_tm.GitTaskManager{}, GitRepoPath: "/tmp"}},
{"empty working dir", AgentConfig{Name: "test", Role: "test", SystemPrompt: "test", TaskManager: &git_tm.GitTaskManager{}, GitRepoPath: "/tmp"}},
{"empty system prompt", AgentConfig{Name: "test", Role: "test", WorkingDir: "/tmp", SystemPrompt: "test", GitRepoPath: "/tmp"}},
{"nil task manager", AgentConfig{Name: "test", Role: "test", WorkingDir: "/tmp", SystemPrompt: "test", GitRepoPath: "/tmp"}},
{"empty git repo path", AgentConfig{Name: "test", Role: "test", WorkingDir: "/tmp", SystemPrompt: "test", TaskManager: &git_tm.GitTaskManager{}}},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := validateConfig(tc.config)
assert.Error(t, err)
})
}
}
func TestGenerateBranchName(t *testing.T) {
agent, cleanup := setupTestAgent(t)
defer cleanup()
task := &tm.Task{
ID: "task-123",
Title: "Implement User Authentication",
}
branchName := agent.generateBranchName(task)
assert.Contains(t, branchName, "task-123")
assert.Contains(t, branchName, "implement-user-authentication")
}
func TestBuildTaskPrompt(t *testing.T) {
agent, cleanup := setupTestAgent(t)
defer cleanup()
dueDate := time.Now().AddDate(0, 0, 7)
task := &tm.Task{
ID: "task-123",
Title: "Test Task",
Description: "This is a test task",
Priority: tm.PriorityHigh,
DueDate: &dueDate,
}
prompt := agent.buildTaskPrompt(task)
assert.Contains(t, prompt, "task-123")
assert.Contains(t, prompt, "Test Task")
assert.Contains(t, prompt, "This is a test task")
assert.Contains(t, prompt, "high")
}
func TestFormatSolution(t *testing.T) {
agent, cleanup := setupTestAgent(t)
defer cleanup()
task := &tm.Task{
ID: "task-123",
Title: "Test Task",
Description: "This is a test task description",
Priority: tm.PriorityMedium,
}
solution := "This is the solution to the task."
formatted := agent.formatSolution(task, solution)
assert.Contains(t, formatted, "# Task Solution: Test Task")
assert.Contains(t, formatted, "**Task ID:** task-123")
assert.Contains(t, formatted, "**Agent:** test-agent (Test Engineer)")
assert.Contains(t, formatted, "## Task Description")
assert.Contains(t, formatted, "This is a test task description")
assert.Contains(t, formatted, "## Solution")
assert.Contains(t, formatted, "This is the solution to the task.")
assert.Contains(t, formatted, "*This solution was generated by AI Agent*")
}
func TestAgentStop(t *testing.T) {
agent, cleanup := setupTestAgent(t)
defer cleanup()
// Test that Stop doesn't panic
assert.NotPanics(t, func() {
agent.Stop()
})
}
func TestGenerateBranchNameWithSpecialCharacters(t *testing.T) {
agent, cleanup := setupTestAgent(t)
defer cleanup()
testCases := []struct {
title string
expected string
}{
{
title: "Simple Task",
expected: "task/task-123-simple-task",
},
{
title: "Task with (parentheses) and [brackets]",
expected: "task/task-123-task-with-parentheses-and-brackets",
},
{
title: "Very Long Task Title That Should Be Truncated Because It Exceeds The Maximum Length Allowed For Branch Names",
expected: "task/task-123-very-long-task-title-that-should-be-truncated-beca",
},
}
for _, tc := range testCases {
t.Run(tc.title, func(t *testing.T) {
task := &tm.Task{
ID: "task-123",
Title: tc.title,
}
branchName := agent.generateBranchName(task)
assert.Equal(t, tc.expected, branchName)
})
}
}
func TestProcessTaskWithLLM(t *testing.T) {
agent, cleanup := setupTestAgent(t)
defer cleanup()
task := &tm.Task{
ID: "task-123",
Title: "Test Task",
Description: "This is a test task",
Priority: tm.PriorityHigh,
}
solution, err := agent.processTaskWithLLM(task)
assert.NoError(t, err)
assert.Contains(t, solution, "mock response")
}
func TestMockLLMProvider(t *testing.T) {
mockProvider := &MockLLMProvider{}
// Test ChatCompletion
req := llm.ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []llm.Message{
{Role: llm.RoleUser, Content: "Hello"},
},
}
resp, err := mockProvider.ChatCompletion(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "gpt-3.5-turbo", resp.Model)
assert.Len(t, resp.Choices, 1)
assert.Contains(t, resp.Choices[0].Message.Content, "mock response")
// Test CreateEmbeddings
embedReq := llm.EmbeddingRequest{
Input: "test",
Model: "text-embedding-ada-002",
}
embedResp, err := mockProvider.CreateEmbeddings(context.Background(), embedReq)
assert.NoError(t, err)
assert.NotNil(t, embedResp)
assert.Len(t, embedResp.Data, 1)
assert.Len(t, embedResp.Data[0].Embedding, 3)
// Test Close
err = mockProvider.Close()
assert.NoError(t, err)
}