blob: 1d94211a4890a5d8eb969884f0cb246f122a4f7f [file] [log] [blame]
iomodod60a5352025-07-28 12:56:22 +04001package subtasks
2
3import (
4 "context"
5 "strings"
6 "testing"
7 "time"
8
9 "github.com/iomodo/staff/llm"
10 "github.com/iomodo/staff/tm"
11)
12
13// MockLLMProvider implements a mock LLM provider for testing
14type MockLLMProvider struct {
15 responses []string
16 callCount int
17}
18
19func NewMockLLMProvider(responses []string) *MockLLMProvider {
20 return &MockLLMProvider{
21 responses: responses,
22 callCount: 0,
23 }
24}
25
26func (m *MockLLMProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
27 if m.callCount >= len(m.responses) {
28 return nil, nil
29 }
30
31 response := m.responses[m.callCount]
32 m.callCount++
33
34 return &llm.ChatCompletionResponse{
35 ID: "mock-response",
36 Object: "chat.completion",
37 Created: time.Now().Unix(),
38 Model: req.Model,
39 Choices: []llm.ChatCompletionChoice{
40 {
41 Index: 0,
42 Message: llm.Message{
43 Role: llm.RoleAssistant,
44 Content: response,
45 },
46 FinishReason: "stop",
47 },
48 },
49 Usage: llm.Usage{
50 PromptTokens: 100,
51 CompletionTokens: 300,
52 TotalTokens: 400,
53 },
54 }, nil
55}
56
57func (m *MockLLMProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
58 return &llm.EmbeddingResponse{
59 Object: "list",
60 Data: []llm.Embedding{
61 {
62 Object: "embedding",
63 Index: 0,
64 Embedding: make([]float64, 1536),
65 },
66 },
67 Model: req.Model,
68 Usage: llm.Usage{
69 PromptTokens: 50,
70 TotalTokens: 50,
71 },
72 }, nil
73}
74
75func (m *MockLLMProvider) Close() error {
76 return nil
77}
78
79func TestNewSubtaskService(t *testing.T) {
80 mockProvider := NewMockLLMProvider([]string{})
81 agentRoles := []string{"backend", "frontend", "qa"}
82
83 service := NewSubtaskService(mockProvider, nil, agentRoles)
84
85 if service == nil {
86 t.Fatal("NewSubtaskService returned nil")
87 }
88
89 if service.llmProvider != mockProvider {
90 t.Error("LLM provider not set correctly")
91 }
92
93 if len(service.agentRoles) != 3 {
94 t.Errorf("Expected 3 agent roles, got %d", len(service.agentRoles))
95 }
96}
97
98func TestAnalyzeTaskForSubtasks(t *testing.T) {
99 jsonResponse := `{
100 "analysis_summary": "This task requires breaking down into multiple components",
101 "subtasks": [
102 {
103 "title": "Backend Development",
104 "description": "Implement server-side logic",
105 "priority": "high",
106 "assigned_to": "backend",
107 "estimated_hours": 16,
108 "dependencies": []
109 },
110 {
111 "title": "Frontend Development",
112 "description": "Build user interface",
113 "priority": "medium",
114 "assigned_to": "frontend",
115 "estimated_hours": 12,
116 "dependencies": ["0"]
117 }
118 ],
119 "recommended_approach": "Start with backend then frontend",
120 "estimated_total_hours": 28,
121 "risk_assessment": "Medium complexity with API integration risks"
122}`
123
124 mockProvider := NewMockLLMProvider([]string{jsonResponse})
125 agentRoles := []string{"backend", "frontend", "qa"}
126 service := NewSubtaskService(mockProvider, nil, agentRoles)
127
128 task := &tm.Task{
129 ID: "test-task-123",
130 Title: "Build authentication system",
131 Description: "Implement user login and registration",
132 Priority: tm.PriorityHigh,
133 Status: tm.StatusToDo,
134 CreatedAt: time.Now(),
135 UpdatedAt: time.Now(),
136 }
137
138 analysis, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
139 if err != nil {
140 t.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
141 }
142
143 if analysis.ParentTaskID != task.ID {
144 t.Errorf("Expected parent task ID %s, got %s", task.ID, analysis.ParentTaskID)
145 }
146
147 if analysis.AnalysisSummary == "" {
148 t.Error("Analysis summary should not be empty")
149 }
150
151 if len(analysis.Subtasks) != 2 {
152 t.Errorf("Expected 2 subtasks, got %d", len(analysis.Subtasks))
153 }
154
155 // Test first subtask
156 subtask1 := analysis.Subtasks[0]
157 if subtask1.Title != "Backend Development" {
158 t.Errorf("Expected title 'Backend Development', got %s", subtask1.Title)
159 }
160 if subtask1.Priority != tm.PriorityHigh {
161 t.Errorf("Expected high priority, got %s", subtask1.Priority)
162 }
163 if subtask1.AssignedTo != "backend" {
164 t.Errorf("Expected assigned_to 'backend', got %s", subtask1.AssignedTo)
165 }
166 if subtask1.EstimatedHours != 16 {
167 t.Errorf("Expected 16 hours, got %d", subtask1.EstimatedHours)
168 }
169
170 // Test second subtask
171 subtask2 := analysis.Subtasks[1]
172 if subtask2.Title != "Frontend Development" {
173 t.Errorf("Expected title 'Frontend Development', got %s", subtask2.Title)
174 }
175 if subtask2.Priority != tm.PriorityMedium {
176 t.Errorf("Expected medium priority, got %s", subtask2.Priority)
177 }
178 if len(subtask2.Dependencies) != 1 || subtask2.Dependencies[0] != "0" {
179 t.Errorf("Expected dependencies [0], got %v", subtask2.Dependencies)
180 }
181
182 if analysis.EstimatedTotalHours != 28 {
183 t.Errorf("Expected 28 total hours, got %d", analysis.EstimatedTotalHours)
184 }
185}
186
187func TestAnalyzeTaskForSubtasks_InvalidJSON(t *testing.T) {
188 invalidResponse := "This is not valid JSON"
189
190 mockProvider := NewMockLLMProvider([]string{invalidResponse})
191 agentRoles := []string{"backend", "frontend"}
192 service := NewSubtaskService(mockProvider, nil, agentRoles)
193
194 task := &tm.Task{
195 ID: "test-task-123",
196 Title: "Test task",
197 }
198
199 _, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
200 if err == nil {
201 t.Error("Expected error for invalid JSON, got nil")
202 }
203
204 if !strings.Contains(err.Error(), "no JSON found") {
205 t.Errorf("Expected 'no JSON found' error, got: %v", err)
206 }
207}
208
209func TestAnalyzeTaskForSubtasks_InvalidAgentRole(t *testing.T) {
210 jsonResponse := `{
211 "analysis_summary": "Test analysis",
212 "subtasks": [
213 {
214 "title": "Invalid Assignment",
215 "description": "Test subtask",
216 "priority": "high",
217 "assigned_to": "invalid_role",
218 "estimated_hours": 8,
219 "dependencies": []
220 }
221 ],
222 "recommended_approach": "Test approach",
223 "estimated_total_hours": 8
224}`
225
226 mockProvider := NewMockLLMProvider([]string{jsonResponse})
227 agentRoles := []string{"backend", "frontend"}
228 service := NewSubtaskService(mockProvider, nil, agentRoles)
229
230 task := &tm.Task{
231 ID: "test-task-123",
232 Title: "Test task",
233 }
234
235 analysis, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
236 if err != nil {
237 t.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
238 }
239
240 // Should fix invalid agent assignment to first available role
241 if analysis.Subtasks[0].AssignedTo != "backend" {
242 t.Errorf("Expected fixed assignment 'backend', got %s", analysis.Subtasks[0].AssignedTo)
243 }
244}
245
246func TestGenerateSubtaskPR(t *testing.T) {
247 mockProvider := NewMockLLMProvider([]string{})
248 service := NewSubtaskService(mockProvider, nil, []string{"backend"})
249
250 analysis := &tm.SubtaskAnalysis{
251 ParentTaskID: "task-123",
252 AnalysisSummary: "Test analysis summary",
253 RecommendedApproach: "Test approach",
254 EstimatedTotalHours: 40,
255 RiskAssessment: "Low risk",
256 Subtasks: []tm.SubtaskProposal{
257 {
258 Title: "Test Subtask",
259 Description: "Test description",
260 Priority: tm.PriorityHigh,
261 AssignedTo: "backend",
262 EstimatedHours: 8,
263 Dependencies: []string{},
264 },
265 },
266 }
267
268 prURL, err := service.GenerateSubtaskPR(context.Background(), analysis)
269 if err != nil {
270 t.Fatalf("GenerateSubtaskPR failed: %v", err)
271 }
272
273 expectedURL := "https://github.com/example/repo/pull/subtasks-task-123"
274 if prURL != expectedURL {
275 t.Errorf("Expected PR URL %s, got %s", expectedURL, prURL)
276 }
277}
278
279func TestBuildSubtaskAnalysisPrompt(t *testing.T) {
280 mockProvider := NewMockLLMProvider([]string{})
281 agentRoles := []string{"backend", "frontend", "qa"}
282 service := NewSubtaskService(mockProvider, nil, agentRoles)
283
284 task := &tm.Task{
285 Title: "Build authentication system",
286 Description: "Implement user login and registration with OAuth",
287 Priority: tm.PriorityHigh,
288 Status: tm.StatusToDo,
289 }
290
291 prompt := service.buildSubtaskAnalysisPrompt(task)
292
293 if !strings.Contains(prompt, task.Title) {
294 t.Error("Prompt should contain task title")
295 }
296
297 if !strings.Contains(prompt, task.Description) {
298 t.Error("Prompt should contain task description")
299 }
300
301 if !strings.Contains(prompt, string(task.Priority)) {
302 t.Error("Prompt should contain task priority")
303 }
304
305 if !strings.Contains(prompt, string(task.Status)) {
306 t.Error("Prompt should contain task status")
307 }
308}
309
310func TestGetSubtaskAnalysisSystemPrompt(t *testing.T) {
311 mockProvider := NewMockLLMProvider([]string{})
312 agentRoles := []string{"backend", "frontend", "qa", "devops"}
313 service := NewSubtaskService(mockProvider, nil, agentRoles)
314
315 systemPrompt := service.getSubtaskAnalysisSystemPrompt()
316
317 if !strings.Contains(systemPrompt, "backend") {
318 t.Error("System prompt should contain backend role")
319 }
320
321 if !strings.Contains(systemPrompt, "frontend") {
322 t.Error("System prompt should contain frontend role")
323 }
324
325 if !strings.Contains(systemPrompt, "JSON") {
326 t.Error("System prompt should mention JSON format")
327 }
328
329 if !strings.Contains(systemPrompt, "subtasks") {
330 t.Error("System prompt should mention subtasks")
331 }
332}
333
334func TestIsValidAgentRole(t *testing.T) {
335 mockProvider := NewMockLLMProvider([]string{})
336 agentRoles := []string{"backend", "frontend", "qa"}
337 service := NewSubtaskService(mockProvider, nil, agentRoles)
338
339 if !service.isValidAgentRole("backend") {
340 t.Error("'backend' should be a valid agent role")
341 }
342
343 if !service.isValidAgentRole("frontend") {
344 t.Error("'frontend' should be a valid agent role")
345 }
346
347 if service.isValidAgentRole("invalid") {
348 t.Error("'invalid' should not be a valid agent role")
349 }
350
351 if service.isValidAgentRole("") {
352 t.Error("Empty string should not be a valid agent role")
353 }
354}
355
356func TestParseSubtaskAnalysis_Priority(t *testing.T) {
357 mockProvider := NewMockLLMProvider([]string{})
358 service := NewSubtaskService(mockProvider, nil, []string{"backend"})
359
360 tests := []struct {
361 input string
362 expected tm.TaskPriority
363 }{
364 {"high", tm.PriorityHigh},
365 {"HIGH", tm.PriorityHigh},
366 {"High", tm.PriorityHigh},
367 {"low", tm.PriorityLow},
368 {"LOW", tm.PriorityLow},
369 {"Low", tm.PriorityLow},
370 {"medium", tm.PriorityMedium},
371 {"MEDIUM", tm.PriorityMedium},
372 {"Medium", tm.PriorityMedium},
373 {"invalid", tm.PriorityMedium}, // default
374 {"", tm.PriorityMedium}, // default
375 }
376
377 for _, test := range tests {
378 jsonResponse := `{
379 "analysis_summary": "Test",
380 "subtasks": [{
381 "title": "Test",
382 "description": "Test",
383 "priority": "` + test.input + `",
384 "assigned_to": "backend",
385 "estimated_hours": 8,
386 "dependencies": []
387 }],
388 "recommended_approach": "Test",
389 "estimated_total_hours": 8
390}`
391
392 analysis, err := service.parseSubtaskAnalysis(jsonResponse, "test-task")
393 if err != nil {
394 t.Fatalf("parseSubtaskAnalysis failed for priority '%s': %v", test.input, err)
395 }
396
397 if len(analysis.Subtasks) != 1 {
398 t.Fatalf("Expected 1 subtask, got %d", len(analysis.Subtasks))
399 }
400
401 if analysis.Subtasks[0].Priority != test.expected {
402 t.Errorf("For priority '%s', expected %s, got %s",
403 test.input, test.expected, analysis.Subtasks[0].Priority)
404 }
405 }
406}
407
408func TestClose(t *testing.T) {
409 mockProvider := NewMockLLMProvider([]string{})
410 service := NewSubtaskService(mockProvider, nil, []string{"backend"})
411
412 err := service.Close()
413 if err != nil {
414 t.Errorf("Close should not return error, got: %v", err)
415 }
416}
417
418// Benchmark tests
419func BenchmarkAnalyzeTaskForSubtasks(b *testing.B) {
420 jsonResponse := `{
421 "analysis_summary": "Benchmark test",
422 "subtasks": [
423 {
424 "title": "Benchmark Subtask",
425 "description": "Benchmark description",
426 "priority": "high",
427 "assigned_to": "backend",
428 "estimated_hours": 8,
429 "dependencies": []
430 }
431 ],
432 "recommended_approach": "Benchmark approach",
433 "estimated_total_hours": 8
434}`
435
436 mockProvider := NewMockLLMProvider([]string{jsonResponse})
437 service := NewSubtaskService(mockProvider, nil, []string{"backend", "frontend"})
438
439 task := &tm.Task{
440 ID: "benchmark-task",
441 Title: "Benchmark Task",
442 Description: "Task for benchmarking",
443 Priority: tm.PriorityHigh,
444 }
445
446 b.ResetTimer()
447 for i := 0; i < b.N; i++ {
448 // Reset mock provider for each iteration
449 mockProvider.callCount = 0
450 _, err := service.AnalyzeTaskForSubtasks(context.Background(), task)
451 if err != nil {
452 b.Fatalf("AnalyzeTaskForSubtasks failed: %v", err)
453 }
454 }
455}