| iomodo | d60a535 | 2025-07-28 12:56:22 +0400 | [diff] [blame^] | 1 | package subtasks |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "strings" |
| 6 | "testing" |
| 7 | "time" |
| 8 | |
| 9 | "github.com/iomodo/staff/llm" |
| 10 | "github.com/iomodo/staff/tm" |
| 11 | ) |
| 12 | |
| 13 | // MockLLMProvider implements a mock LLM provider for testing |
| 14 | type MockLLMProvider struct { |
| 15 | responses []string |
| 16 | callCount int |
| 17 | } |
| 18 | |
| 19 | func NewMockLLMProvider(responses []string) *MockLLMProvider { |
| 20 | return &MockLLMProvider{ |
| 21 | responses: responses, |
| 22 | callCount: 0, |
| 23 | } |
| 24 | } |
| 25 | |
| 26 | func (m *MockLLMProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) { |
| 27 | if m.callCount >= len(m.responses) { |
| 28 | return nil, nil |
| 29 | } |
| 30 | |
| 31 | response := m.responses[m.callCount] |
| 32 | m.callCount++ |
| 33 | |
| 34 | return &llm.ChatCompletionResponse{ |
| 35 | ID: "mock-response", |
| 36 | Object: "chat.completion", |
| 37 | Created: time.Now().Unix(), |
| 38 | Model: req.Model, |
| 39 | Choices: []llm.ChatCompletionChoice{ |
| 40 | { |
| 41 | Index: 0, |
| 42 | Message: llm.Message{ |
| 43 | Role: llm.RoleAssistant, |
| 44 | Content: response, |
| 45 | }, |
| 46 | FinishReason: "stop", |
| 47 | }, |
| 48 | }, |
| 49 | Usage: llm.Usage{ |
| 50 | PromptTokens: 100, |
| 51 | CompletionTokens: 300, |
| 52 | TotalTokens: 400, |
| 53 | }, |
| 54 | }, nil |
| 55 | } |
| 56 | |
| 57 | func (m *MockLLMProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) { |
| 58 | return &llm.EmbeddingResponse{ |
| 59 | Object: "list", |
| 60 | Data: []llm.Embedding{ |
| 61 | { |
| 62 | Object: "embedding", |
| 63 | Index: 0, |
| 64 | Embedding: make([]float64, 1536), |
| 65 | }, |
| 66 | }, |
| 67 | Model: req.Model, |
| 68 | Usage: llm.Usage{ |
| 69 | PromptTokens: 50, |
| 70 | TotalTokens: 50, |
| 71 | }, |
| 72 | }, nil |
| 73 | } |
| 74 | |
| 75 | func (m *MockLLMProvider) Close() error { |
| 76 | return nil |
| 77 | } |
| 78 | |
| 79 | func TestNewSubtaskService(t *testing.T) { |
| 80 | mockProvider := NewMockLLMProvider([]string{}) |
| 81 | agentRoles := []string{"backend", "frontend", "qa"} |
| 82 | |
| 83 | service := NewSubtaskService(mockProvider, nil, agentRoles) |
| 84 | |
| 85 | if service == nil { |
| 86 | t.Fatal("NewSubtaskService returned nil") |
| 87 | } |
| 88 | |
| 89 | if service.llmProvider != mockProvider { |
| 90 | t.Error("LLM provider not set correctly") |
| 91 | } |
| 92 | |
| 93 | if len(service.agentRoles) != 3 { |
| 94 | t.Errorf("Expected 3 agent roles, got %d", len(service.agentRoles)) |
| 95 | } |
| 96 | } |
| 97 | |
| 98 | func TestAnalyzeTaskForSubtasks(t *testing.T) { |
| 99 | jsonResponse := `{ |
| 100 | "analysis_summary": "This task requires breaking down into multiple components", |
| 101 | "subtasks": [ |
| 102 | { |
| 103 | "title": "Backend Development", |
| 104 | "description": "Implement server-side logic", |
| 105 | "priority": "high", |
| 106 | "assigned_to": "backend", |
| 107 | "estimated_hours": 16, |
| 108 | "dependencies": [] |
| 109 | }, |
| 110 | { |
| 111 | "title": "Frontend Development", |
| 112 | "description": "Build user interface", |
| 113 | "priority": "medium", |
| 114 | "assigned_to": "frontend", |
| 115 | "estimated_hours": 12, |
| 116 | "dependencies": ["0"] |
| 117 | } |
| 118 | ], |
| 119 | "recommended_approach": "Start with backend then frontend", |
| 120 | "estimated_total_hours": 28, |
| 121 | "risk_assessment": "Medium complexity with API integration risks" |
| 122 | }` |
| 123 | |
| 124 | mockProvider := NewMockLLMProvider([]string{jsonResponse}) |
| 125 | agentRoles := []string{"backend", "frontend", "qa"} |
| 126 | service := NewSubtaskService(mockProvider, nil, agentRoles) |
| 127 | |
| 128 | task := &tm.Task{ |
| 129 | ID: "test-task-123", |
| 130 | Title: "Build authentication system", |
| 131 | Description: "Implement user login and registration", |
| 132 | Priority: tm.PriorityHigh, |
| 133 | Status: tm.StatusToDo, |
| 134 | CreatedAt: time.Now(), |
| 135 | UpdatedAt: time.Now(), |
| 136 | } |
| 137 | |
| 138 | analysis, err := service.AnalyzeTaskForSubtasks(context.Background(), task) |
| 139 | if err != nil { |
| 140 | t.Fatalf("AnalyzeTaskForSubtasks failed: %v", err) |
| 141 | } |
| 142 | |
| 143 | if analysis.ParentTaskID != task.ID { |
| 144 | t.Errorf("Expected parent task ID %s, got %s", task.ID, analysis.ParentTaskID) |
| 145 | } |
| 146 | |
| 147 | if analysis.AnalysisSummary == "" { |
| 148 | t.Error("Analysis summary should not be empty") |
| 149 | } |
| 150 | |
| 151 | if len(analysis.Subtasks) != 2 { |
| 152 | t.Errorf("Expected 2 subtasks, got %d", len(analysis.Subtasks)) |
| 153 | } |
| 154 | |
| 155 | // Test first subtask |
| 156 | subtask1 := analysis.Subtasks[0] |
| 157 | if subtask1.Title != "Backend Development" { |
| 158 | t.Errorf("Expected title 'Backend Development', got %s", subtask1.Title) |
| 159 | } |
| 160 | if subtask1.Priority != tm.PriorityHigh { |
| 161 | t.Errorf("Expected high priority, got %s", subtask1.Priority) |
| 162 | } |
| 163 | if subtask1.AssignedTo != "backend" { |
| 164 | t.Errorf("Expected assigned_to 'backend', got %s", subtask1.AssignedTo) |
| 165 | } |
| 166 | if subtask1.EstimatedHours != 16 { |
| 167 | t.Errorf("Expected 16 hours, got %d", subtask1.EstimatedHours) |
| 168 | } |
| 169 | |
| 170 | // Test second subtask |
| 171 | subtask2 := analysis.Subtasks[1] |
| 172 | if subtask2.Title != "Frontend Development" { |
| 173 | t.Errorf("Expected title 'Frontend Development', got %s", subtask2.Title) |
| 174 | } |
| 175 | if subtask2.Priority != tm.PriorityMedium { |
| 176 | t.Errorf("Expected medium priority, got %s", subtask2.Priority) |
| 177 | } |
| 178 | if len(subtask2.Dependencies) != 1 || subtask2.Dependencies[0] != "0" { |
| 179 | t.Errorf("Expected dependencies [0], got %v", subtask2.Dependencies) |
| 180 | } |
| 181 | |
| 182 | if analysis.EstimatedTotalHours != 28 { |
| 183 | t.Errorf("Expected 28 total hours, got %d", analysis.EstimatedTotalHours) |
| 184 | } |
| 185 | } |
| 186 | |
| 187 | func TestAnalyzeTaskForSubtasks_InvalidJSON(t *testing.T) { |
| 188 | invalidResponse := "This is not valid JSON" |
| 189 | |
| 190 | mockProvider := NewMockLLMProvider([]string{invalidResponse}) |
| 191 | agentRoles := []string{"backend", "frontend"} |
| 192 | service := NewSubtaskService(mockProvider, nil, agentRoles) |
| 193 | |
| 194 | task := &tm.Task{ |
| 195 | ID: "test-task-123", |
| 196 | Title: "Test task", |
| 197 | } |
| 198 | |
| 199 | _, err := service.AnalyzeTaskForSubtasks(context.Background(), task) |
| 200 | if err == nil { |
| 201 | t.Error("Expected error for invalid JSON, got nil") |
| 202 | } |
| 203 | |
| 204 | if !strings.Contains(err.Error(), "no JSON found") { |
| 205 | t.Errorf("Expected 'no JSON found' error, got: %v", err) |
| 206 | } |
| 207 | } |
| 208 | |
| 209 | func TestAnalyzeTaskForSubtasks_InvalidAgentRole(t *testing.T) { |
| 210 | jsonResponse := `{ |
| 211 | "analysis_summary": "Test analysis", |
| 212 | "subtasks": [ |
| 213 | { |
| 214 | "title": "Invalid Assignment", |
| 215 | "description": "Test subtask", |
| 216 | "priority": "high", |
| 217 | "assigned_to": "invalid_role", |
| 218 | "estimated_hours": 8, |
| 219 | "dependencies": [] |
| 220 | } |
| 221 | ], |
| 222 | "recommended_approach": "Test approach", |
| 223 | "estimated_total_hours": 8 |
| 224 | }` |
| 225 | |
| 226 | mockProvider := NewMockLLMProvider([]string{jsonResponse}) |
| 227 | agentRoles := []string{"backend", "frontend"} |
| 228 | service := NewSubtaskService(mockProvider, nil, agentRoles) |
| 229 | |
| 230 | task := &tm.Task{ |
| 231 | ID: "test-task-123", |
| 232 | Title: "Test task", |
| 233 | } |
| 234 | |
| 235 | analysis, err := service.AnalyzeTaskForSubtasks(context.Background(), task) |
| 236 | if err != nil { |
| 237 | t.Fatalf("AnalyzeTaskForSubtasks failed: %v", err) |
| 238 | } |
| 239 | |
| 240 | // Should fix invalid agent assignment to first available role |
| 241 | if analysis.Subtasks[0].AssignedTo != "backend" { |
| 242 | t.Errorf("Expected fixed assignment 'backend', got %s", analysis.Subtasks[0].AssignedTo) |
| 243 | } |
| 244 | } |
| 245 | |
| 246 | func TestGenerateSubtaskPR(t *testing.T) { |
| 247 | mockProvider := NewMockLLMProvider([]string{}) |
| 248 | service := NewSubtaskService(mockProvider, nil, []string{"backend"}) |
| 249 | |
| 250 | analysis := &tm.SubtaskAnalysis{ |
| 251 | ParentTaskID: "task-123", |
| 252 | AnalysisSummary: "Test analysis summary", |
| 253 | RecommendedApproach: "Test approach", |
| 254 | EstimatedTotalHours: 40, |
| 255 | RiskAssessment: "Low risk", |
| 256 | Subtasks: []tm.SubtaskProposal{ |
| 257 | { |
| 258 | Title: "Test Subtask", |
| 259 | Description: "Test description", |
| 260 | Priority: tm.PriorityHigh, |
| 261 | AssignedTo: "backend", |
| 262 | EstimatedHours: 8, |
| 263 | Dependencies: []string{}, |
| 264 | }, |
| 265 | }, |
| 266 | } |
| 267 | |
| 268 | prURL, err := service.GenerateSubtaskPR(context.Background(), analysis) |
| 269 | if err != nil { |
| 270 | t.Fatalf("GenerateSubtaskPR failed: %v", err) |
| 271 | } |
| 272 | |
| 273 | expectedURL := "https://github.com/example/repo/pull/subtasks-task-123" |
| 274 | if prURL != expectedURL { |
| 275 | t.Errorf("Expected PR URL %s, got %s", expectedURL, prURL) |
| 276 | } |
| 277 | } |
| 278 | |
| 279 | func TestBuildSubtaskAnalysisPrompt(t *testing.T) { |
| 280 | mockProvider := NewMockLLMProvider([]string{}) |
| 281 | agentRoles := []string{"backend", "frontend", "qa"} |
| 282 | service := NewSubtaskService(mockProvider, nil, agentRoles) |
| 283 | |
| 284 | task := &tm.Task{ |
| 285 | Title: "Build authentication system", |
| 286 | Description: "Implement user login and registration with OAuth", |
| 287 | Priority: tm.PriorityHigh, |
| 288 | Status: tm.StatusToDo, |
| 289 | } |
| 290 | |
| 291 | prompt := service.buildSubtaskAnalysisPrompt(task) |
| 292 | |
| 293 | if !strings.Contains(prompt, task.Title) { |
| 294 | t.Error("Prompt should contain task title") |
| 295 | } |
| 296 | |
| 297 | if !strings.Contains(prompt, task.Description) { |
| 298 | t.Error("Prompt should contain task description") |
| 299 | } |
| 300 | |
| 301 | if !strings.Contains(prompt, string(task.Priority)) { |
| 302 | t.Error("Prompt should contain task priority") |
| 303 | } |
| 304 | |
| 305 | if !strings.Contains(prompt, string(task.Status)) { |
| 306 | t.Error("Prompt should contain task status") |
| 307 | } |
| 308 | } |
| 309 | |
| 310 | func TestGetSubtaskAnalysisSystemPrompt(t *testing.T) { |
| 311 | mockProvider := NewMockLLMProvider([]string{}) |
| 312 | agentRoles := []string{"backend", "frontend", "qa", "devops"} |
| 313 | service := NewSubtaskService(mockProvider, nil, agentRoles) |
| 314 | |
| 315 | systemPrompt := service.getSubtaskAnalysisSystemPrompt() |
| 316 | |
| 317 | if !strings.Contains(systemPrompt, "backend") { |
| 318 | t.Error("System prompt should contain backend role") |
| 319 | } |
| 320 | |
| 321 | if !strings.Contains(systemPrompt, "frontend") { |
| 322 | t.Error("System prompt should contain frontend role") |
| 323 | } |
| 324 | |
| 325 | if !strings.Contains(systemPrompt, "JSON") { |
| 326 | t.Error("System prompt should mention JSON format") |
| 327 | } |
| 328 | |
| 329 | if !strings.Contains(systemPrompt, "subtasks") { |
| 330 | t.Error("System prompt should mention subtasks") |
| 331 | } |
| 332 | } |
| 333 | |
| 334 | func TestIsValidAgentRole(t *testing.T) { |
| 335 | mockProvider := NewMockLLMProvider([]string{}) |
| 336 | agentRoles := []string{"backend", "frontend", "qa"} |
| 337 | service := NewSubtaskService(mockProvider, nil, agentRoles) |
| 338 | |
| 339 | if !service.isValidAgentRole("backend") { |
| 340 | t.Error("'backend' should be a valid agent role") |
| 341 | } |
| 342 | |
| 343 | if !service.isValidAgentRole("frontend") { |
| 344 | t.Error("'frontend' should be a valid agent role") |
| 345 | } |
| 346 | |
| 347 | if service.isValidAgentRole("invalid") { |
| 348 | t.Error("'invalid' should not be a valid agent role") |
| 349 | } |
| 350 | |
| 351 | if service.isValidAgentRole("") { |
| 352 | t.Error("Empty string should not be a valid agent role") |
| 353 | } |
| 354 | } |
| 355 | |
| 356 | func TestParseSubtaskAnalysis_Priority(t *testing.T) { |
| 357 | mockProvider := NewMockLLMProvider([]string{}) |
| 358 | service := NewSubtaskService(mockProvider, nil, []string{"backend"}) |
| 359 | |
| 360 | tests := []struct { |
| 361 | input string |
| 362 | expected tm.TaskPriority |
| 363 | }{ |
| 364 | {"high", tm.PriorityHigh}, |
| 365 | {"HIGH", tm.PriorityHigh}, |
| 366 | {"High", tm.PriorityHigh}, |
| 367 | {"low", tm.PriorityLow}, |
| 368 | {"LOW", tm.PriorityLow}, |
| 369 | {"Low", tm.PriorityLow}, |
| 370 | {"medium", tm.PriorityMedium}, |
| 371 | {"MEDIUM", tm.PriorityMedium}, |
| 372 | {"Medium", tm.PriorityMedium}, |
| 373 | {"invalid", tm.PriorityMedium}, // default |
| 374 | {"", tm.PriorityMedium}, // default |
| 375 | } |
| 376 | |
| 377 | for _, test := range tests { |
| 378 | jsonResponse := `{ |
| 379 | "analysis_summary": "Test", |
| 380 | "subtasks": [{ |
| 381 | "title": "Test", |
| 382 | "description": "Test", |
| 383 | "priority": "` + test.input + `", |
| 384 | "assigned_to": "backend", |
| 385 | "estimated_hours": 8, |
| 386 | "dependencies": [] |
| 387 | }], |
| 388 | "recommended_approach": "Test", |
| 389 | "estimated_total_hours": 8 |
| 390 | }` |
| 391 | |
| 392 | analysis, err := service.parseSubtaskAnalysis(jsonResponse, "test-task") |
| 393 | if err != nil { |
| 394 | t.Fatalf("parseSubtaskAnalysis failed for priority '%s': %v", test.input, err) |
| 395 | } |
| 396 | |
| 397 | if len(analysis.Subtasks) != 1 { |
| 398 | t.Fatalf("Expected 1 subtask, got %d", len(analysis.Subtasks)) |
| 399 | } |
| 400 | |
| 401 | if analysis.Subtasks[0].Priority != test.expected { |
| 402 | t.Errorf("For priority '%s', expected %s, got %s", |
| 403 | test.input, test.expected, analysis.Subtasks[0].Priority) |
| 404 | } |
| 405 | } |
| 406 | } |
| 407 | |
| 408 | func TestClose(t *testing.T) { |
| 409 | mockProvider := NewMockLLMProvider([]string{}) |
| 410 | service := NewSubtaskService(mockProvider, nil, []string{"backend"}) |
| 411 | |
| 412 | err := service.Close() |
| 413 | if err != nil { |
| 414 | t.Errorf("Close should not return error, got: %v", err) |
| 415 | } |
| 416 | } |
| 417 | |
| 418 | // Benchmark tests |
| 419 | func BenchmarkAnalyzeTaskForSubtasks(b *testing.B) { |
| 420 | jsonResponse := `{ |
| 421 | "analysis_summary": "Benchmark test", |
| 422 | "subtasks": [ |
| 423 | { |
| 424 | "title": "Benchmark Subtask", |
| 425 | "description": "Benchmark description", |
| 426 | "priority": "high", |
| 427 | "assigned_to": "backend", |
| 428 | "estimated_hours": 8, |
| 429 | "dependencies": [] |
| 430 | } |
| 431 | ], |
| 432 | "recommended_approach": "Benchmark approach", |
| 433 | "estimated_total_hours": 8 |
| 434 | }` |
| 435 | |
| 436 | mockProvider := NewMockLLMProvider([]string{jsonResponse}) |
| 437 | service := NewSubtaskService(mockProvider, nil, []string{"backend", "frontend"}) |
| 438 | |
| 439 | task := &tm.Task{ |
| 440 | ID: "benchmark-task", |
| 441 | Title: "Benchmark Task", |
| 442 | Description: "Task for benchmarking", |
| 443 | Priority: tm.PriorityHigh, |
| 444 | } |
| 445 | |
| 446 | b.ResetTimer() |
| 447 | for i := 0; i < b.N; i++ { |
| 448 | // Reset mock provider for each iteration |
| 449 | mockProvider.callCount = 0 |
| 450 | _, err := service.AnalyzeTaskForSubtasks(context.Background(), task) |
| 451 | if err != nil { |
| 452 | b.Fatalf("AnalyzeTaskForSubtasks failed: %v", err) |
| 453 | } |
| 454 | } |
| 455 | } |