blob: 6ec2adc37e78ff87f4e13773230d0c3e7114cc33 [file] [log] [blame]
iomodo76f9a2d2025-07-26 12:14:40 +04001package agent
2
3import (
4 "context"
iomodo0c203b12025-07-26 19:44:57 +04005 "log/slog"
iomodo76f9a2d2025-07-26 12:14:40 +04006 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "github.com/iomodo/staff/git"
12 "github.com/iomodo/staff/llm"
13 "github.com/iomodo/staff/tm"
14 "github.com/iomodo/staff/tm/git_tm"
15 "github.com/stretchr/testify/assert"
16 "github.com/stretchr/testify/require"
17)
18
19// MockLLMProvider implements LLMProvider for testing
20type MockLLMProvider struct{}
21
22func (m *MockLLMProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
23 return &llm.ChatCompletionResponse{
24 ID: "mock-response-id",
25 Model: req.Model,
26 Choices: []llm.ChatCompletionChoice{
27 {
28 Index: 0,
29 Message: llm.Message{
30 Role: llm.RoleAssistant,
31 Content: "This is a mock response for testing purposes.",
32 },
33 FinishReason: "stop",
34 },
35 },
36 Usage: llm.Usage{
37 PromptTokens: 10,
38 CompletionTokens: 20,
39 TotalTokens: 30,
40 },
41 Provider: llm.ProviderOpenAI,
42 }, nil
43}
44
45func (m *MockLLMProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
46 return &llm.EmbeddingResponse{
47 Object: "list",
48 Data: []llm.Embedding{
49 {
50 Object: "embedding",
51 Embedding: []float64{0.1, 0.2, 0.3},
52 Index: 0,
53 },
54 },
55 Usage: llm.Usage{
56 PromptTokens: 5,
57 TotalTokens: 5,
58 },
59 Model: req.Model,
60 Provider: llm.ProviderOpenAI,
61 }, nil
62}
63
64func (m *MockLLMProvider) Close() error {
65 return nil
66}
67
68// MockLLMFactory implements ProviderFactory for testing
69type MockLLMFactory struct{}
70
71func (f *MockLLMFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
72 return &MockLLMProvider{}, nil
73}
74
75func (f *MockLLMFactory) SupportsProvider(provider llm.Provider) bool {
76 return provider == llm.ProviderOpenAI
77}
78
79func setupTestAgent(t *testing.T) (*Agent, func()) {
80 // Create temporary directories
81 tempDir, err := os.MkdirTemp("", "agent-test")
82 require.NoError(t, err)
83
84 tasksDir := filepath.Join(tempDir, "tasks")
85 workspaceDir := filepath.Join(tempDir, "workspace")
86 codeRepoDir := filepath.Join(tempDir, "code-repo")
87
88 // Create directories
89 require.NoError(t, os.MkdirAll(tasksDir, 0755))
90 require.NoError(t, os.MkdirAll(workspaceDir, 0755))
91 require.NoError(t, os.MkdirAll(codeRepoDir, 0755))
92
93 // Initialize git repositories
94 gitInterface := git.DefaultGit(tasksDir)
95 ctx := context.Background()
96
97 err = gitInterface.Init(ctx, tasksDir)
98 require.NoError(t, err)
99
100 // Set git user config
101 userConfig := git.UserConfig{
102 Name: "Test User",
103 Email: "test@example.com",
104 }
105 err = gitInterface.SetUserConfig(ctx, userConfig)
106 require.NoError(t, err)
107
iomodo0c203b12025-07-26 19:44:57 +0400108 // Create logger for testing
109 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))
110
iomodo76f9a2d2025-07-26 12:14:40 +0400111 // Create task manager
iomodo0c203b12025-07-26 19:44:57 +0400112 taskManager := git_tm.NewGitTaskManagerWithLogger(gitInterface, tasksDir, logger)
iomodo76f9a2d2025-07-26 12:14:40 +0400113
114 // Create LLM config (using a mock configuration)
115 llmConfig := llm.Config{
116 Provider: llm.ProviderOpenAI,
117 APIKey: "test-key",
118 BaseURL: "https://api.openai.com/v1",
119 Timeout: 30 * time.Second,
120 }
121
122 // Create agent config
123 config := AgentConfig{
124 Name: "test-agent",
125 Role: "Test Engineer",
126 GitUsername: "test-agent",
127 GitEmail: "test-agent@test.com",
128 WorkingDir: workspaceDir,
129 LLMProvider: llm.ProviderOpenAI,
130 LLMModel: "gpt-3.5-turbo",
131 LLMConfig: llmConfig,
132 SystemPrompt: "You are a test agent. Provide simple, clear solutions.",
133 TaskManager: taskManager,
134 GitRepoPath: codeRepoDir,
135 GitRemote: "origin",
136 GitBranch: "main",
137 }
138
iomodo0c203b12025-07-26 19:44:57 +0400139 // Create logger for testing
140 logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))
141
iomodo76f9a2d2025-07-26 12:14:40 +0400142 // Create agent with mock LLM provider
143 agent := &Agent{
144 Config: config,
145 llmProvider: &MockLLMProvider{},
146 gitInterface: git.DefaultGit(codeRepoDir),
147 ctx: context.Background(),
148 cancel: func() {},
iomodo0c203b12025-07-26 19:44:57 +0400149 logger: logger,
iomodo76f9a2d2025-07-26 12:14:40 +0400150 }
151
152 cleanup := func() {
153 agent.Stop()
154 os.RemoveAll(tempDir)
155 }
156
157 return agent, cleanup
158}
159
160func TestNewAgent(t *testing.T) {
iomodo0c203b12025-07-26 19:44:57 +0400161 // Create temporary directories
162 tempDir, err := os.MkdirTemp("", "agent-test")
163 require.NoError(t, err)
164 defer os.RemoveAll(tempDir)
165
166 tasksDir := filepath.Join(tempDir, "tasks")
167 workspaceDir := filepath.Join(tempDir, "workspace")
168 codeRepoDir := filepath.Join(tempDir, "code-repo")
169
170 // Create directories
171 require.NoError(t, os.MkdirAll(tasksDir, 0755))
172 require.NoError(t, os.MkdirAll(workspaceDir, 0755))
173 require.NoError(t, os.MkdirAll(codeRepoDir, 0755))
174
175 // Initialize git repositories
176 gitInterface := git.DefaultGit(tasksDir)
177 ctx := context.Background()
178
179 err = gitInterface.Init(ctx, tasksDir)
180 require.NoError(t, err)
181
182 // Set git user config
183 userConfig := git.UserConfig{
184 Name: "Test User",
185 Email: "test@example.com",
186 }
187 err = gitInterface.SetUserConfig(ctx, userConfig)
188 require.NoError(t, err)
189
190 // Create logger for testing
191 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))
192
193 // Create task manager
194 taskManager := git_tm.NewGitTaskManagerWithLogger(gitInterface, tasksDir, logger)
195
196 // Create LLM config (using a mock configuration)
197 llmConfig := llm.Config{
198 Provider: llm.ProviderOpenAI,
199 APIKey: "test-key",
200 BaseURL: "https://api.openai.com/v1",
201 Timeout: 30 * time.Second,
202 }
203
204 // Create agent config
205 config := AgentConfig{
206 Name: "test-agent",
207 Role: "Test Engineer",
208 GitUsername: "test-agent",
209 GitEmail: "test-agent@test.com",
210 WorkingDir: workspaceDir,
211 LLMProvider: llm.ProviderOpenAI,
212 LLMModel: "gpt-3.5-turbo",
213 LLMConfig: llmConfig,
214 SystemPrompt: "You are a test agent. Provide simple, clear solutions.",
215 TaskManager: taskManager,
216 GitRepoPath: codeRepoDir,
217 GitRemote: "origin",
218 GitBranch: "main",
219 }
220
221 // Create logger for testing
222 logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))
223
224 // Create agent using NewAgent function
225 agent, err := NewAgent(config, logger)
226 require.NoError(t, err)
227 defer agent.Stop()
iomodo76f9a2d2025-07-26 12:14:40 +0400228
229 assert.NotNil(t, agent)
230 assert.Equal(t, "test-agent", agent.Config.Name)
231 assert.Equal(t, "Test Engineer", agent.Config.Role)
232}
233
234func TestValidateConfig(t *testing.T) {
235 // Test valid config
236 validConfig := AgentConfig{
237 Name: "test",
238 Role: "test",
239 WorkingDir: "/tmp",
240 SystemPrompt: "test",
241 TaskManager: &git_tm.GitTaskManager{},
242 GitRepoPath: "/tmp",
243 }
244
245 err := validateConfig(validConfig)
246 assert.NoError(t, err)
247
248 // Test invalid configs
249 testCases := []struct {
250 name string
251 config AgentConfig
252 }{
253 {"empty name", AgentConfig{Role: "test", WorkingDir: "/tmp", SystemPrompt: "test", TaskManager: &git_tm.GitTaskManager{}, GitRepoPath: "/tmp"}},
254 {"empty role", AgentConfig{Name: "test", WorkingDir: "/tmp", SystemPrompt: "test", TaskManager: &git_tm.GitTaskManager{}, GitRepoPath: "/tmp"}},
255 {"empty working dir", AgentConfig{Name: "test", Role: "test", SystemPrompt: "test", TaskManager: &git_tm.GitTaskManager{}, GitRepoPath: "/tmp"}},
256 {"empty system prompt", AgentConfig{Name: "test", Role: "test", WorkingDir: "/tmp", SystemPrompt: "test", GitRepoPath: "/tmp"}},
257 {"nil task manager", AgentConfig{Name: "test", Role: "test", WorkingDir: "/tmp", SystemPrompt: "test", GitRepoPath: "/tmp"}},
258 {"empty git repo path", AgentConfig{Name: "test", Role: "test", WorkingDir: "/tmp", SystemPrompt: "test", TaskManager: &git_tm.GitTaskManager{}}},
259 }
260
261 for _, tc := range testCases {
262 t.Run(tc.name, func(t *testing.T) {
263 err := validateConfig(tc.config)
264 assert.Error(t, err)
265 })
266 }
267}
268
269func TestGenerateBranchName(t *testing.T) {
270 agent, cleanup := setupTestAgent(t)
271 defer cleanup()
272
273 task := &tm.Task{
274 ID: "task-123",
275 Title: "Implement User Authentication",
276 }
277
278 branchName := agent.generateBranchName(task)
279 assert.Contains(t, branchName, "task-123")
280 assert.Contains(t, branchName, "implement-user-authentication")
281}
282
283func TestBuildTaskPrompt(t *testing.T) {
284 agent, cleanup := setupTestAgent(t)
285 defer cleanup()
286
287 dueDate := time.Now().AddDate(0, 0, 7)
288 task := &tm.Task{
289 ID: "task-123",
290 Title: "Test Task",
291 Description: "This is a test task",
292 Priority: tm.PriorityHigh,
293 DueDate: &dueDate,
294 }
295
296 prompt := agent.buildTaskPrompt(task)
297 assert.Contains(t, prompt, "task-123")
298 assert.Contains(t, prompt, "Test Task")
299 assert.Contains(t, prompt, "This is a test task")
300 assert.Contains(t, prompt, "high")
301}
302
303func TestFormatSolution(t *testing.T) {
304 agent, cleanup := setupTestAgent(t)
305 defer cleanup()
306
307 task := &tm.Task{
308 ID: "task-123",
309 Title: "Test Task",
310 Description: "This is a test task description",
311 Priority: tm.PriorityMedium,
312 }
313
314 solution := "This is the solution to the task."
315 formatted := agent.formatSolution(task, solution)
316
317 assert.Contains(t, formatted, "# Task Solution: Test Task")
318 assert.Contains(t, formatted, "**Task ID:** task-123")
319 assert.Contains(t, formatted, "**Agent:** test-agent (Test Engineer)")
320 assert.Contains(t, formatted, "## Task Description")
321 assert.Contains(t, formatted, "This is a test task description")
322 assert.Contains(t, formatted, "## Solution")
323 assert.Contains(t, formatted, "This is the solution to the task.")
324 assert.Contains(t, formatted, "*This solution was generated by AI Agent*")
325}
326
327func TestAgentStop(t *testing.T) {
328 agent, cleanup := setupTestAgent(t)
329 defer cleanup()
330
331 // Test that Stop doesn't panic
332 assert.NotPanics(t, func() {
333 agent.Stop()
334 })
335}
336
337func TestGenerateBranchNameWithSpecialCharacters(t *testing.T) {
338 agent, cleanup := setupTestAgent(t)
339 defer cleanup()
340
341 testCases := []struct {
342 title string
343 expected string
344 }{
345 {
346 title: "Simple Task",
347 expected: "task/task-123-simple-task",
348 },
349 {
350 title: "Task with (parentheses) and [brackets]",
351 expected: "task/task-123-task-with-parentheses-and-brackets",
352 },
353 {
354 title: "Very Long Task Title That Should Be Truncated Because It Exceeds The Maximum Length Allowed For Branch Names",
355 expected: "task/task-123-very-long-task-title-that-should-be-truncated-beca",
356 },
357 }
358
359 for _, tc := range testCases {
360 t.Run(tc.title, func(t *testing.T) {
361 task := &tm.Task{
362 ID: "task-123",
363 Title: tc.title,
364 }
365
366 branchName := agent.generateBranchName(task)
367 assert.Equal(t, tc.expected, branchName)
368 })
369 }
370}
371
372func TestProcessTaskWithLLM(t *testing.T) {
373 agent, cleanup := setupTestAgent(t)
374 defer cleanup()
375
376 task := &tm.Task{
377 ID: "task-123",
378 Title: "Test Task",
379 Description: "This is a test task",
380 Priority: tm.PriorityHigh,
381 }
382
383 solution, err := agent.processTaskWithLLM(task)
384 assert.NoError(t, err)
385 assert.Contains(t, solution, "mock response")
386}
387
388func TestMockLLMProvider(t *testing.T) {
389 mockProvider := &MockLLMProvider{}
390
391 // Test ChatCompletion
392 req := llm.ChatCompletionRequest{
393 Model: "gpt-3.5-turbo",
394 Messages: []llm.Message{
395 {Role: llm.RoleUser, Content: "Hello"},
396 },
397 }
398
399 resp, err := mockProvider.ChatCompletion(context.Background(), req)
400 assert.NoError(t, err)
401 assert.NotNil(t, resp)
402 assert.Equal(t, "gpt-3.5-turbo", resp.Model)
403 assert.Len(t, resp.Choices, 1)
404 assert.Contains(t, resp.Choices[0].Message.Content, "mock response")
405
406 // Test CreateEmbeddings
407 embedReq := llm.EmbeddingRequest{
408 Input: "test",
409 Model: "text-embedding-ada-002",
410 }
411
412 embedResp, err := mockProvider.CreateEmbeddings(context.Background(), embedReq)
413 assert.NoError(t, err)
414 assert.NotNil(t, embedResp)
415 assert.Len(t, embedResp.Data, 1)
416 assert.Len(t, embedResp.Data[0].Embedding, 3)
417
418 // Test Close
419 err = mockProvider.Close()
420 assert.NoError(t, err)
421}