blob: ade62fc3dae37b4d7053bc4d2782bbafe9d9472f [file] [log] [blame]
iomodod60a5352025-07-28 12:56:22 +04001package subtasks
2
3import (
4 "context"
5 "strings"
6 "testing"
7 "time"
8
9 "github.com/iomodo/staff/llm"
10 "github.com/iomodo/staff/tm"
11)
12
13// MockLLMProvider implements a mock LLM provider for testing
14type MockLLMProvider struct {
15 responses []string
16 callCount int
17}
18
19func NewMockLLMProvider(responses []string) *MockLLMProvider {
20 return &MockLLMProvider{
21 responses: responses,
22 callCount: 0,
23 }
24}
25
26func (m *MockLLMProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
27 if m.callCount >= len(m.responses) {
28 return nil, nil
29 }
30
31 response := m.responses[m.callCount]
32 m.callCount++
33
34 return &llm.ChatCompletionResponse{
35 ID: "mock-response",
36 Object: "chat.completion",
37 Created: time.Now().Unix(),
38 Model: req.Model,
39 Choices: []llm.ChatCompletionChoice{
40 {
41 Index: 0,
42 Message: llm.Message{
43 Role: llm.RoleAssistant,
44 Content: response,
45 },
46 FinishReason: "stop",
47 },
48 },
49 Usage: llm.Usage{
50 PromptTokens: 100,
51 CompletionTokens: 300,
52 TotalTokens: 400,
53 },
54 }, nil
55}
56
57func (m *MockLLMProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
58 return &llm.EmbeddingResponse{
59 Object: "list",
60 Data: []llm.Embedding{
61 {
62 Object: "embedding",
63 Index: 0,
64 Embedding: make([]float64, 1536),
65 },
66 },
67 Model: req.Model,
68 Usage: llm.Usage{
69 PromptTokens: 50,
70 TotalTokens: 50,
71 },
72 }, nil
73}
74
75func (m *MockLLMProvider) Close() error {
76 return nil
77}
78
79func TestNewSubtaskService(t *testing.T) {
80 mockProvider := NewMockLLMProvider([]string{})
81 agentRoles := []string{"backend", "frontend", "qa"}
82
83 service := NewSubtaskService(mockProvider, nil, agentRoles)
84
85 if service == nil {
86 t.Fatal("NewSubtaskService returned nil")
87 }
88
89 if service.llmProvider != mockProvider {
90 t.Error("LLM provider not set correctly")
91 }
92
93 if len(service.agentRoles) != 3 {
94 t.Errorf("Expected 3 agent roles, got %d", len(service.agentRoles))
95 }
96}
97
iomodo5c99a442025-07-28 14:23:52 +040098func TestShouldGenerateSubtasks(t *testing.T) {
99 // Test decision to generate subtasks
100 decisionResponse := `{
101 "needs_subtasks": true,
102 "reasoning": "Complex task requiring multiple skills",
103 "complexity_score": 8,
104 "required_skills": ["backend", "frontend", "database"]
105}`
106
107 mockProvider := NewMockLLMProvider([]string{decisionResponse})
108 agentRoles := []string{"backend", "frontend", "qa"}
109 service := NewSubtaskService(mockProvider, nil, agentRoles)
110
111 // Test the parseSubtaskDecision method directly since ShouldGenerateSubtasks is used by manager
112 decision, err := service.parseSubtaskDecision(decisionResponse)
113 if err != nil {
114 t.Fatalf("parseSubtaskDecision failed: %v", err)
115 }
116
117 if !decision.NeedsSubtasks {
118 t.Error("Expected decision to need subtasks")
119 }
120
121 if decision.ComplexityScore != 8 {
122 t.Errorf("Expected complexity score 8, got %d", decision.ComplexityScore)
123 }
124
125 if len(decision.RequiredSkills) != 3 {
126 t.Errorf("Expected 3 required skills, got %d", len(decision.RequiredSkills))
127 }
128}
129
iomodod60a5352025-07-28 12:56:22 +0400130func TestAnalyzeTaskForSubtasks(t *testing.T) {
131 jsonResponse := `{
132 "analysis_summary": "This task requires breaking down into multiple components",
133 "subtasks": [
134 {
135 "title": "Backend Development",
136 "description": "Implement server-side logic",
137 "priority": "high",
138 "assigned_to": "backend",
139 "estimated_hours": 16,
iomodo5c99a442025-07-28 14:23:52 +0400140 "dependencies": [],
141 "required_skills": ["go", "api_development"]
iomodod60a5352025-07-28 12:56:22 +0400142 },
143 {
144 "title": "Frontend Development",
145 "description": "Build user interface",
146 "priority": "medium",
147 "assigned_to": "frontend",
148 "estimated_hours": 12,
iomodo5c99a442025-07-28 14:23:52 +0400149 "dependencies": ["0"],
150 "required_skills": ["react", "typescript"]
151 }
152 ],
153 "agent_creations": [
154 {
155 "role": "security_specialist",
156 "skills": ["security_audit", "penetration_testing"],
157 "description": "Specialized agent for security tasks",
158 "justification": "Authentication requires security expertise"
iomodod60a5352025-07-28 12:56:22 +0400159 }
160 ],
161 "recommended_approach": "Start with backend then frontend",
162 "estimated_total_hours": 28,
163 "risk_assessment": "Medium complexity with API integration risks"
164}`
165
166 mockProvider := NewMockLLMProvider([]string{jsonResponse})
iomodo5c99a442025-07-28 14:23:52 +0400167 agentRoles := []string{"backend", "frontend", "qa", "ceo"} // Include CEO for agent creation
iomodod60a5352025-07-28 12:56:22 +0400168 service := NewSubtaskService(mockProvider, nil, agentRoles)
169
170 task := &tm.Task{
171 ID: "test-task-123",
172 Title: "Build authentication system",
173 Description: "Implement user login and registration",
174 Priority: tm.PriorityHigh,
175 Status: tm.StatusToDo,
176 CreatedAt: time.Now(),
177 UpdatedAt: time.Now(),
178 }
179
180 analysis, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
181 if err != nil {
182 t.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
183 }
184
185 if analysis.ParentTaskID != task.ID {
186 t.Errorf("Expected parent task ID %s, got %s", task.ID, analysis.ParentTaskID)
187 }
188
189 if analysis.AnalysisSummary == "" {
190 t.Error("Analysis summary should not be empty")
191 }
192
iomodo5c99a442025-07-28 14:23:52 +0400193 // Should have 3 subtasks (1 for agent creation + 2 original)
194 if len(analysis.Subtasks) != 3 {
195 t.Errorf("Expected 3 subtasks (including agent creation), got %d", len(analysis.Subtasks))
196 t.Logf("Subtasks: %+v", analysis.Subtasks)
197 return // Exit early if count is wrong to avoid index errors
iomodod60a5352025-07-28 12:56:22 +0400198 }
199
iomodo5c99a442025-07-28 14:23:52 +0400200 // Test agent creation was processed
201 if len(analysis.AgentCreations) != 1 {
202 t.Errorf("Expected 1 agent creation, got %d", len(analysis.AgentCreations))
203 } else {
204 agentCreation := analysis.AgentCreations[0]
205 if agentCreation.Role != "security_specialist" {
206 t.Errorf("Expected role 'security_specialist', got %s", agentCreation.Role)
207 }
208 if len(agentCreation.Skills) != 2 {
209 t.Errorf("Expected 2 skills, got %d", len(agentCreation.Skills))
210 }
211 }
212
213 // We already checked the count above
214
215 // Test first subtask (agent creation)
216 subtask0 := analysis.Subtasks[0]
217 if !strings.Contains(subtask0.Title, "Security_specialist") {
218 t.Errorf("Expected agent creation subtask for security_specialist, got %s", subtask0.Title)
219 }
220 if subtask0.AssignedTo != "ceo" {
221 t.Errorf("Expected agent creation assigned to 'ceo', got %s", subtask0.AssignedTo)
222 }
223
224 // Test second subtask (original backend task, now at index 1)
225 subtask1 := analysis.Subtasks[1]
iomodod60a5352025-07-28 12:56:22 +0400226 if subtask1.Title != "Backend Development" {
227 t.Errorf("Expected title 'Backend Development', got %s", subtask1.Title)
228 }
229 if subtask1.Priority != tm.PriorityHigh {
230 t.Errorf("Expected high priority, got %s", subtask1.Priority)
231 }
232 if subtask1.AssignedTo != "backend" {
233 t.Errorf("Expected assigned_to 'backend', got %s", subtask1.AssignedTo)
234 }
235 if subtask1.EstimatedHours != 16 {
236 t.Errorf("Expected 16 hours, got %d", subtask1.EstimatedHours)
237 }
iomodo5c99a442025-07-28 14:23:52 +0400238 if len(subtask1.RequiredSkills) != 2 {
239 t.Errorf("Expected 2 required skills, got %d", len(subtask1.RequiredSkills))
240 }
iomodod60a5352025-07-28 12:56:22 +0400241
iomodo5c99a442025-07-28 14:23:52 +0400242 // Test third subtask (original frontend task, now at index 2 with updated dependencies)
243 subtask2 := analysis.Subtasks[2]
iomodod60a5352025-07-28 12:56:22 +0400244 if subtask2.Title != "Frontend Development" {
245 t.Errorf("Expected title 'Frontend Development', got %s", subtask2.Title)
246 }
247 if subtask2.Priority != tm.PriorityMedium {
248 t.Errorf("Expected medium priority, got %s", subtask2.Priority)
249 }
iomodo5c99a442025-07-28 14:23:52 +0400250 // Dependencies should be updated to account for the new agent creation subtask
251 if len(subtask2.Dependencies) != 1 || subtask2.Dependencies[0] != "1" {
252 t.Errorf("Expected dependencies [1] (updated for agent creation), got %v", subtask2.Dependencies)
253 }
254 if len(subtask2.RequiredSkills) != 2 {
255 t.Errorf("Expected 2 required skills, got %d", len(subtask2.RequiredSkills))
iomodod60a5352025-07-28 12:56:22 +0400256 }
257
iomodo5c99a442025-07-28 14:23:52 +0400258 // Total hours should include agent creation time (4 hours)
iomodod60a5352025-07-28 12:56:22 +0400259 if analysis.EstimatedTotalHours != 28 {
260 t.Errorf("Expected 28 total hours, got %d", analysis.EstimatedTotalHours)
261 }
262}
263
264func TestAnalyzeTaskForSubtasks_InvalidJSON(t *testing.T) {
265 invalidResponse := "This is not valid JSON"
266
267 mockProvider := NewMockLLMProvider([]string{invalidResponse})
268 agentRoles := []string{"backend", "frontend"}
269 service := NewSubtaskService(mockProvider, nil, agentRoles)
270
271 task := &tm.Task{
272 ID: "test-task-123",
273 Title: "Test task",
274 }
275
276 _, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
277 if err == nil {
278 t.Error("Expected error for invalid JSON, got nil")
279 }
280
281 if !strings.Contains(err.Error(), "no JSON found") {
282 t.Errorf("Expected 'no JSON found' error, got: %v", err)
283 }
284}
285
286func TestAnalyzeTaskForSubtasks_InvalidAgentRole(t *testing.T) {
287 jsonResponse := `{
288 "analysis_summary": "Test analysis",
289 "subtasks": [
290 {
291 "title": "Invalid Assignment",
292 "description": "Test subtask",
293 "priority": "high",
294 "assigned_to": "invalid_role",
295 "estimated_hours": 8,
296 "dependencies": []
297 }
298 ],
299 "recommended_approach": "Test approach",
300 "estimated_total_hours": 8
301}`
302
303 mockProvider := NewMockLLMProvider([]string{jsonResponse})
304 agentRoles := []string{"backend", "frontend"}
305 service := NewSubtaskService(mockProvider, nil, agentRoles)
306
307 task := &tm.Task{
308 ID: "test-task-123",
309 Title: "Test task",
310 }
311
312 analysis, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
313 if err != nil {
314 t.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
315 }
316
317 // Should fix invalid agent assignment to first available role
318 if analysis.Subtasks[0].AssignedTo != "backend" {
319 t.Errorf("Expected fixed assignment 'backend', got %s", analysis.Subtasks[0].AssignedTo)
320 }
321}
322
323func TestGenerateSubtaskPR(t *testing.T) {
324 mockProvider := NewMockLLMProvider([]string{})
325 service := NewSubtaskService(mockProvider, nil, []string{"backend"})
326
327 analysis := &tm.SubtaskAnalysis{
328 ParentTaskID: "task-123",
329 AnalysisSummary: "Test analysis summary",
330 RecommendedApproach: "Test approach",
331 EstimatedTotalHours: 40,
332 RiskAssessment: "Low risk",
333 Subtasks: []tm.SubtaskProposal{
334 {
335 Title: "Test Subtask",
336 Description: "Test description",
337 Priority: tm.PriorityHigh,
338 AssignedTo: "backend",
339 EstimatedHours: 8,
340 Dependencies: []string{},
341 },
342 },
343 }
344
345 prURL, err := service.GenerateSubtaskPR(context.Background(), analysis)
346 if err != nil {
347 t.Fatalf("GenerateSubtaskPR failed: %v", err)
348 }
349
350 expectedURL := "https://github.com/example/repo/pull/subtasks-task-123"
351 if prURL != expectedURL {
352 t.Errorf("Expected PR URL %s, got %s", expectedURL, prURL)
353 }
354}
355
356func TestBuildSubtaskAnalysisPrompt(t *testing.T) {
357 mockProvider := NewMockLLMProvider([]string{})
358 agentRoles := []string{"backend", "frontend", "qa"}
359 service := NewSubtaskService(mockProvider, nil, agentRoles)
360
361 task := &tm.Task{
362 Title: "Build authentication system",
363 Description: "Implement user login and registration with OAuth",
364 Priority: tm.PriorityHigh,
365 Status: tm.StatusToDo,
366 }
367
368 prompt := service.buildSubtaskAnalysisPrompt(task)
369
370 if !strings.Contains(prompt, task.Title) {
371 t.Error("Prompt should contain task title")
372 }
373
374 if !strings.Contains(prompt, task.Description) {
375 t.Error("Prompt should contain task description")
376 }
377
378 if !strings.Contains(prompt, string(task.Priority)) {
379 t.Error("Prompt should contain task priority")
380 }
381
382 if !strings.Contains(prompt, string(task.Status)) {
383 t.Error("Prompt should contain task status")
384 }
385}
386
387func TestGetSubtaskAnalysisSystemPrompt(t *testing.T) {
388 mockProvider := NewMockLLMProvider([]string{})
389 agentRoles := []string{"backend", "frontend", "qa", "devops"}
390 service := NewSubtaskService(mockProvider, nil, agentRoles)
391
392 systemPrompt := service.getSubtaskAnalysisSystemPrompt()
393
394 if !strings.Contains(systemPrompt, "backend") {
395 t.Error("System prompt should contain backend role")
396 }
397
398 if !strings.Contains(systemPrompt, "frontend") {
399 t.Error("System prompt should contain frontend role")
400 }
401
402 if !strings.Contains(systemPrompt, "JSON") {
403 t.Error("System prompt should mention JSON format")
404 }
405
406 if !strings.Contains(systemPrompt, "subtasks") {
407 t.Error("System prompt should mention subtasks")
408 }
409}
410
411func TestIsValidAgentRole(t *testing.T) {
412 mockProvider := NewMockLLMProvider([]string{})
413 agentRoles := []string{"backend", "frontend", "qa"}
414 service := NewSubtaskService(mockProvider, nil, agentRoles)
415
416 if !service.isValidAgentRole("backend") {
417 t.Error("'backend' should be a valid agent role")
418 }
419
420 if !service.isValidAgentRole("frontend") {
421 t.Error("'frontend' should be a valid agent role")
422 }
423
424 if service.isValidAgentRole("invalid") {
425 t.Error("'invalid' should not be a valid agent role")
426 }
427
428 if service.isValidAgentRole("") {
429 t.Error("Empty string should not be a valid agent role")
430 }
431}
432
433func TestParseSubtaskAnalysis_Priority(t *testing.T) {
434 mockProvider := NewMockLLMProvider([]string{})
435 service := NewSubtaskService(mockProvider, nil, []string{"backend"})
436
437 tests := []struct {
438 input string
439 expected tm.TaskPriority
440 }{
441 {"high", tm.PriorityHigh},
442 {"HIGH", tm.PriorityHigh},
443 {"High", tm.PriorityHigh},
444 {"low", tm.PriorityLow},
445 {"LOW", tm.PriorityLow},
446 {"Low", tm.PriorityLow},
447 {"medium", tm.PriorityMedium},
448 {"MEDIUM", tm.PriorityMedium},
449 {"Medium", tm.PriorityMedium},
450 {"invalid", tm.PriorityMedium}, // default
451 {"", tm.PriorityMedium}, // default
452 }
453
454 for _, test := range tests {
455 jsonResponse := `{
456 "analysis_summary": "Test",
457 "subtasks": [{
458 "title": "Test",
459 "description": "Test",
460 "priority": "` + test.input + `",
461 "assigned_to": "backend",
462 "estimated_hours": 8,
463 "dependencies": []
464 }],
465 "recommended_approach": "Test",
466 "estimated_total_hours": 8
467}`
468
469 analysis, err := service.parseSubtaskAnalysis(jsonResponse, "test-task")
470 if err != nil {
471 t.Fatalf("parseSubtaskAnalysis failed for priority '%s': %v", test.input, err)
472 }
473
474 if len(analysis.Subtasks) != 1 {
475 t.Fatalf("Expected 1 subtask, got %d", len(analysis.Subtasks))
476 }
477
478 if analysis.Subtasks[0].Priority != test.expected {
479 t.Errorf("For priority '%s', expected %s, got %s",
480 test.input, test.expected, analysis.Subtasks[0].Priority)
481 }
482 }
483}
484
485func TestClose(t *testing.T) {
486 mockProvider := NewMockLLMProvider([]string{})
487 service := NewSubtaskService(mockProvider, nil, []string{"backend"})
488
489 err := service.Close()
490 if err != nil {
491 t.Errorf("Close should not return error, got: %v", err)
492 }
493}
494
495// Benchmark tests
496func BenchmarkAnalyzeTaskForSubtasks(b *testing.B) {
497 jsonResponse := `{
498 "analysis_summary": "Benchmark test",
499 "subtasks": [
500 {
501 "title": "Benchmark Subtask",
502 "description": "Benchmark description",
503 "priority": "high",
504 "assigned_to": "backend",
505 "estimated_hours": 8,
506 "dependencies": []
507 }
508 ],
509 "recommended_approach": "Benchmark approach",
510 "estimated_total_hours": 8
511}`
512
513 mockProvider := NewMockLLMProvider([]string{jsonResponse})
514 service := NewSubtaskService(mockProvider, nil, []string{"backend", "frontend"})
515
516 task := &tm.Task{
517 ID: "benchmark-task",
518 Title: "Benchmark Task",
519 Description: "Task for benchmarking",
520 Priority: tm.PriorityHigh,
521 }
522
523 b.ResetTimer()
524 for i := 0; i < b.N; i++ {
525 // Reset mock provider for each iteration
526 mockProvider.callCount = 0
527 _, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
528 if err != nil {
529 b.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
530 }
531 }
532}