blob: ade62fc3dae37b4d7053bc4d2782bbafe9d9472f [file] [log] [blame]
package subtasks
import (
"context"
"strings"
"testing"
"time"
"github.com/iomodo/staff/llm"
"github.com/iomodo/staff/tm"
)
// MockLLMProvider implements a mock LLM provider for testing
type MockLLMProvider struct {
responses []string
callCount int
}
func NewMockLLMProvider(responses []string) *MockLLMProvider {
return &MockLLMProvider{
responses: responses,
callCount: 0,
}
}
func (m *MockLLMProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
if m.callCount >= len(m.responses) {
return nil, nil
}
response := m.responses[m.callCount]
m.callCount++
return &llm.ChatCompletionResponse{
ID: "mock-response",
Object: "chat.completion",
Created: time.Now().Unix(),
Model: req.Model,
Choices: []llm.ChatCompletionChoice{
{
Index: 0,
Message: llm.Message{
Role: llm.RoleAssistant,
Content: response,
},
FinishReason: "stop",
},
},
Usage: llm.Usage{
PromptTokens: 100,
CompletionTokens: 300,
TotalTokens: 400,
},
}, nil
}
func (m *MockLLMProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
return &llm.EmbeddingResponse{
Object: "list",
Data: []llm.Embedding{
{
Object: "embedding",
Index: 0,
Embedding: make([]float64, 1536),
},
},
Model: req.Model,
Usage: llm.Usage{
PromptTokens: 50,
TotalTokens: 50,
},
}, nil
}
func (m *MockLLMProvider) Close() error {
return nil
}
func TestNewSubtaskService(t *testing.T) {
mockProvider := NewMockLLMProvider([]string{})
agentRoles := []string{"backend", "frontend", "qa"}
service := NewSubtaskService(mockProvider, nil, agentRoles)
if service == nil {
t.Fatal("NewSubtaskService returned nil")
}
if service.llmProvider != mockProvider {
t.Error("LLM provider not set correctly")
}
if len(service.agentRoles) != 3 {
t.Errorf("Expected 3 agent roles, got %d", len(service.agentRoles))
}
}
func TestShouldGenerateSubtasks(t *testing.T) {
// Test decision to generate subtasks
decisionResponse := `{
"needs_subtasks": true,
"reasoning": "Complex task requiring multiple skills",
"complexity_score": 8,
"required_skills": ["backend", "frontend", "database"]
}`
mockProvider := NewMockLLMProvider([]string{decisionResponse})
agentRoles := []string{"backend", "frontend", "qa"}
service := NewSubtaskService(mockProvider, nil, agentRoles)
// Test the parseSubtaskDecision method directly since ShouldGenerateSubtasks is used by manager
decision, err := service.parseSubtaskDecision(decisionResponse)
if err != nil {
t.Fatalf("parseSubtaskDecision failed: %v", err)
}
if !decision.NeedsSubtasks {
t.Error("Expected decision to need subtasks")
}
if decision.ComplexityScore != 8 {
t.Errorf("Expected complexity score 8, got %d", decision.ComplexityScore)
}
if len(decision.RequiredSkills) != 3 {
t.Errorf("Expected 3 required skills, got %d", len(decision.RequiredSkills))
}
}
func TestAnalyzeTaskForSubtasks(t *testing.T) {
jsonResponse := `{
"analysis_summary": "This task requires breaking down into multiple components",
"subtasks": [
{
"title": "Backend Development",
"description": "Implement server-side logic",
"priority": "high",
"assigned_to": "backend",
"estimated_hours": 16,
"dependencies": [],
"required_skills": ["go", "api_development"]
},
{
"title": "Frontend Development",
"description": "Build user interface",
"priority": "medium",
"assigned_to": "frontend",
"estimated_hours": 12,
"dependencies": ["0"],
"required_skills": ["react", "typescript"]
}
],
"agent_creations": [
{
"role": "security_specialist",
"skills": ["security_audit", "penetration_testing"],
"description": "Specialized agent for security tasks",
"justification": "Authentication requires security expertise"
}
],
"recommended_approach": "Start with backend then frontend",
"estimated_total_hours": 28,
"risk_assessment": "Medium complexity with API integration risks"
}`
mockProvider := NewMockLLMProvider([]string{jsonResponse})
agentRoles := []string{"backend", "frontend", "qa", "ceo"} // Include CEO for agent creation
service := NewSubtaskService(mockProvider, nil, agentRoles)
task := &tm.Task{
ID: "test-task-123",
Title: "Build authentication system",
Description: "Implement user login and registration",
Priority: tm.PriorityHigh,
Status: tm.StatusToDo,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
analysis, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
if err != nil {
t.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
}
if analysis.ParentTaskID != task.ID {
t.Errorf("Expected parent task ID %s, got %s", task.ID, analysis.ParentTaskID)
}
if analysis.AnalysisSummary == "" {
t.Error("Analysis summary should not be empty")
}
// Should have 3 subtasks (1 for agent creation + 2 original)
if len(analysis.Subtasks) != 3 {
t.Errorf("Expected 3 subtasks (including agent creation), got %d", len(analysis.Subtasks))
t.Logf("Subtasks: %+v", analysis.Subtasks)
return // Exit early if count is wrong to avoid index errors
}
// Test agent creation was processed
if len(analysis.AgentCreations) != 1 {
t.Errorf("Expected 1 agent creation, got %d", len(analysis.AgentCreations))
} else {
agentCreation := analysis.AgentCreations[0]
if agentCreation.Role != "security_specialist" {
t.Errorf("Expected role 'security_specialist', got %s", agentCreation.Role)
}
if len(agentCreation.Skills) != 2 {
t.Errorf("Expected 2 skills, got %d", len(agentCreation.Skills))
}
}
// We already checked the count above
// Test first subtask (agent creation)
subtask0 := analysis.Subtasks[0]
if !strings.Contains(subtask0.Title, "Security_specialist") {
t.Errorf("Expected agent creation subtask for security_specialist, got %s", subtask0.Title)
}
if subtask0.AssignedTo != "ceo" {
t.Errorf("Expected agent creation assigned to 'ceo', got %s", subtask0.AssignedTo)
}
// Test second subtask (original backend task, now at index 1)
subtask1 := analysis.Subtasks[1]
if subtask1.Title != "Backend Development" {
t.Errorf("Expected title 'Backend Development', got %s", subtask1.Title)
}
if subtask1.Priority != tm.PriorityHigh {
t.Errorf("Expected high priority, got %s", subtask1.Priority)
}
if subtask1.AssignedTo != "backend" {
t.Errorf("Expected assigned_to 'backend', got %s", subtask1.AssignedTo)
}
if subtask1.EstimatedHours != 16 {
t.Errorf("Expected 16 hours, got %d", subtask1.EstimatedHours)
}
if len(subtask1.RequiredSkills) != 2 {
t.Errorf("Expected 2 required skills, got %d", len(subtask1.RequiredSkills))
}
// Test third subtask (original frontend task, now at index 2 with updated dependencies)
subtask2 := analysis.Subtasks[2]
if subtask2.Title != "Frontend Development" {
t.Errorf("Expected title 'Frontend Development', got %s", subtask2.Title)
}
if subtask2.Priority != tm.PriorityMedium {
t.Errorf("Expected medium priority, got %s", subtask2.Priority)
}
// Dependencies should be updated to account for the new agent creation subtask
if len(subtask2.Dependencies) != 1 || subtask2.Dependencies[0] != "1" {
t.Errorf("Expected dependencies [1] (updated for agent creation), got %v", subtask2.Dependencies)
}
if len(subtask2.RequiredSkills) != 2 {
t.Errorf("Expected 2 required skills, got %d", len(subtask2.RequiredSkills))
}
// Total hours should include agent creation time (4 hours)
if analysis.EstimatedTotalHours != 28 {
t.Errorf("Expected 28 total hours, got %d", analysis.EstimatedTotalHours)
}
}
func TestAnalyzeTaskForSubtasks_InvalidJSON(t *testing.T) {
invalidResponse := "This is not valid JSON"
mockProvider := NewMockLLMProvider([]string{invalidResponse})
agentRoles := []string{"backend", "frontend"}
service := NewSubtaskService(mockProvider, nil, agentRoles)
task := &tm.Task{
ID: "test-task-123",
Title: "Test task",
}
_, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
if err == nil {
t.Error("Expected error for invalid JSON, got nil")
}
if !strings.Contains(err.Error(), "no JSON found") {
t.Errorf("Expected 'no JSON found' error, got: %v", err)
}
}
func TestAnalyzeTaskForSubtasks_InvalidAgentRole(t *testing.T) {
jsonResponse := `{
"analysis_summary": "Test analysis",
"subtasks": [
{
"title": "Invalid Assignment",
"description": "Test subtask",
"priority": "high",
"assigned_to": "invalid_role",
"estimated_hours": 8,
"dependencies": []
}
],
"recommended_approach": "Test approach",
"estimated_total_hours": 8
}`
mockProvider := NewMockLLMProvider([]string{jsonResponse})
agentRoles := []string{"backend", "frontend"}
service := NewSubtaskService(mockProvider, nil, agentRoles)
task := &tm.Task{
ID: "test-task-123",
Title: "Test task",
}
analysis, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
if err != nil {
t.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
}
// Should fix invalid agent assignment to first available role
if analysis.Subtasks[0].AssignedTo != "backend" {
t.Errorf("Expected fixed assignment 'backend', got %s", analysis.Subtasks[0].AssignedTo)
}
}
func TestGenerateSubtaskPR(t *testing.T) {
mockProvider := NewMockLLMProvider([]string{})
service := NewSubtaskService(mockProvider, nil, []string{"backend"})
analysis := &tm.SubtaskAnalysis{
ParentTaskID: "task-123",
AnalysisSummary: "Test analysis summary",
RecommendedApproach: "Test approach",
EstimatedTotalHours: 40,
RiskAssessment: "Low risk",
Subtasks: []tm.SubtaskProposal{
{
Title: "Test Subtask",
Description: "Test description",
Priority: tm.PriorityHigh,
AssignedTo: "backend",
EstimatedHours: 8,
Dependencies: []string{},
},
},
}
prURL, err := service.GenerateSubtaskPR(context.Background(), analysis)
if err != nil {
t.Fatalf("GenerateSubtaskPR failed: %v", err)
}
expectedURL := "https://github.com/example/repo/pull/subtasks-task-123"
if prURL != expectedURL {
t.Errorf("Expected PR URL %s, got %s", expectedURL, prURL)
}
}
func TestBuildSubtaskAnalysisPrompt(t *testing.T) {
mockProvider := NewMockLLMProvider([]string{})
agentRoles := []string{"backend", "frontend", "qa"}
service := NewSubtaskService(mockProvider, nil, agentRoles)
task := &tm.Task{
Title: "Build authentication system",
Description: "Implement user login and registration with OAuth",
Priority: tm.PriorityHigh,
Status: tm.StatusToDo,
}
prompt := service.buildSubtaskAnalysisPrompt(task)
if !strings.Contains(prompt, task.Title) {
t.Error("Prompt should contain task title")
}
if !strings.Contains(prompt, task.Description) {
t.Error("Prompt should contain task description")
}
if !strings.Contains(prompt, string(task.Priority)) {
t.Error("Prompt should contain task priority")
}
if !strings.Contains(prompt, string(task.Status)) {
t.Error("Prompt should contain task status")
}
}
func TestGetSubtaskAnalysisSystemPrompt(t *testing.T) {
mockProvider := NewMockLLMProvider([]string{})
agentRoles := []string{"backend", "frontend", "qa", "devops"}
service := NewSubtaskService(mockProvider, nil, agentRoles)
systemPrompt := service.getSubtaskAnalysisSystemPrompt()
if !strings.Contains(systemPrompt, "backend") {
t.Error("System prompt should contain backend role")
}
if !strings.Contains(systemPrompt, "frontend") {
t.Error("System prompt should contain frontend role")
}
if !strings.Contains(systemPrompt, "JSON") {
t.Error("System prompt should mention JSON format")
}
if !strings.Contains(systemPrompt, "subtasks") {
t.Error("System prompt should mention subtasks")
}
}
func TestIsValidAgentRole(t *testing.T) {
mockProvider := NewMockLLMProvider([]string{})
agentRoles := []string{"backend", "frontend", "qa"}
service := NewSubtaskService(mockProvider, nil, agentRoles)
if !service.isValidAgentRole("backend") {
t.Error("'backend' should be a valid agent role")
}
if !service.isValidAgentRole("frontend") {
t.Error("'frontend' should be a valid agent role")
}
if service.isValidAgentRole("invalid") {
t.Error("'invalid' should not be a valid agent role")
}
if service.isValidAgentRole("") {
t.Error("Empty string should not be a valid agent role")
}
}
func TestParseSubtaskAnalysis_Priority(t *testing.T) {
mockProvider := NewMockLLMProvider([]string{})
service := NewSubtaskService(mockProvider, nil, []string{"backend"})
tests := []struct {
input string
expected tm.TaskPriority
}{
{"high", tm.PriorityHigh},
{"HIGH", tm.PriorityHigh},
{"High", tm.PriorityHigh},
{"low", tm.PriorityLow},
{"LOW", tm.PriorityLow},
{"Low", tm.PriorityLow},
{"medium", tm.PriorityMedium},
{"MEDIUM", tm.PriorityMedium},
{"Medium", tm.PriorityMedium},
{"invalid", tm.PriorityMedium}, // default
{"", tm.PriorityMedium}, // default
}
for _, test := range tests {
jsonResponse := `{
"analysis_summary": "Test",
"subtasks": [{
"title": "Test",
"description": "Test",
"priority": "` + test.input + `",
"assigned_to": "backend",
"estimated_hours": 8,
"dependencies": []
}],
"recommended_approach": "Test",
"estimated_total_hours": 8
}`
analysis, err := service.parseSubtaskAnalysis(jsonResponse, "test-task")
if err != nil {
t.Fatalf("parseSubtaskAnalysis failed for priority '%s': %v", test.input, err)
}
if len(analysis.Subtasks) != 1 {
t.Fatalf("Expected 1 subtask, got %d", len(analysis.Subtasks))
}
if analysis.Subtasks[0].Priority != test.expected {
t.Errorf("For priority '%s', expected %s, got %s",
test.input, test.expected, analysis.Subtasks[0].Priority)
}
}
}
func TestClose(t *testing.T) {
mockProvider := NewMockLLMProvider([]string{})
service := NewSubtaskService(mockProvider, nil, []string{"backend"})
err := service.Close()
if err != nil {
t.Errorf("Close should not return error, got: %v", err)
}
}
// Benchmark tests
func BenchmarkAnalyzeTaskForSubtasks(b *testing.B) {
jsonResponse := `{
"analysis_summary": "Benchmark test",
"subtasks": [
{
"title": "Benchmark Subtask",
"description": "Benchmark description",
"priority": "high",
"assigned_to": "backend",
"estimated_hours": 8,
"dependencies": []
}
],
"recommended_approach": "Benchmark approach",
"estimated_total_hours": 8
}`
mockProvider := NewMockLLMProvider([]string{jsonResponse})
service := NewSubtaskService(mockProvider, nil, []string{"backend", "frontend"})
task := &tm.Task{
ID: "benchmark-task",
Title: "Benchmark Task",
Description: "Task for benchmarking",
Priority: tm.PriorityHigh,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Reset mock provider for each iteration
mockProvider.callCount = 0
_, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
if err != nil {
b.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
}
}
}