blob: 949a95adce479c272785bb533b5f707264ea15fe [file] [log] [blame]
Philip Zeyliger25f6ff12025-05-02 04:24:10 +00001package server_test
2
3import (
4 "bufio"
5 "context"
6 "net/http"
7 "net/http/httptest"
8 "slices"
9 "strings"
10 "sync"
11 "testing"
12 "time"
13
14 "sketch.dev/llm/conversation"
15 "sketch.dev/loop"
16 "sketch.dev/loop/server"
17)
18
19// mockAgent is a mock implementation of loop.CodingAgent for testing
20type mockAgent struct {
Philip Zeyligereab12de2025-05-14 02:35:53 +000021 mu sync.RWMutex
22 messages []loop.AgentMessage
23 messageCount int
24 currentState string
25 subscribers []chan *loop.AgentMessage
26 stateTransitionListeners []chan loop.StateTransition
27 initialCommit string
28 title string
29 branchName string
Philip Zeyligerd3ac1122025-05-14 02:54:18 +000030 workingDir string
Philip Zeyliger25f6ff12025-05-02 04:24:10 +000031}
32
33func (m *mockAgent) NewIterator(ctx context.Context, nextMessageIdx int) loop.MessageIterator {
34 m.mu.RLock()
35 // Send existing messages that should be available immediately
36 ch := make(chan *loop.AgentMessage, 100)
37 iter := &mockIterator{
38 agent: m,
39 ctx: ctx,
40 nextMessageIdx: nextMessageIdx,
41 ch: ch,
42 }
43 m.mu.RUnlock()
44 return iter
45}
46
47type mockIterator struct {
48 agent *mockAgent
49 ctx context.Context
50 nextMessageIdx int
51 ch chan *loop.AgentMessage
52 subscribed bool
53}
54
55func (m *mockIterator) Next() *loop.AgentMessage {
56 if !m.subscribed {
57 m.agent.mu.Lock()
58 m.agent.subscribers = append(m.agent.subscribers, m.ch)
59 m.agent.mu.Unlock()
60 m.subscribed = true
61 }
62
63 for {
64 select {
65 case <-m.ctx.Done():
66 return nil
67 case msg := <-m.ch:
68 return msg
69 }
70 }
71}
72
73func (m *mockIterator) Close() {
74 // Remove from subscribers using slices.Delete
75 m.agent.mu.Lock()
76 for i, ch := range m.agent.subscribers {
77 if ch == m.ch {
78 m.agent.subscribers = slices.Delete(m.agent.subscribers, i, i+1)
79 break
80 }
81 }
82 m.agent.mu.Unlock()
83 close(m.ch)
84}
85
86func (m *mockAgent) Messages(start int, end int) []loop.AgentMessage {
87 m.mu.RLock()
88 defer m.mu.RUnlock()
89
90 if start >= len(m.messages) || end > len(m.messages) || start < 0 || end < 0 {
91 return []loop.AgentMessage{}
92 }
93 return slices.Clone(m.messages[start:end])
94}
95
96func (m *mockAgent) MessageCount() int {
97 m.mu.RLock()
98 defer m.mu.RUnlock()
99 return m.messageCount
100}
101
102func (m *mockAgent) AddMessage(msg loop.AgentMessage) {
103 m.mu.Lock()
104 msg.Idx = m.messageCount
105 m.messages = append(m.messages, msg)
106 m.messageCount++
107
108 // Create a copy of subscribers to avoid holding the lock while sending
109 subscribers := make([]chan *loop.AgentMessage, len(m.subscribers))
110 copy(subscribers, m.subscribers)
111 m.mu.Unlock()
112
113 // Notify subscribers
114 msgCopy := msg // Create a copy to avoid race conditions
115 for _, ch := range subscribers {
116 ch <- &msgCopy
117 }
118}
119
Philip Zeyligereab12de2025-05-14 02:35:53 +0000120func (m *mockAgent) NewStateTransitionIterator(ctx context.Context) loop.StateTransitionIterator {
121 m.mu.Lock()
122 ch := make(chan loop.StateTransition, 10)
123 m.stateTransitionListeners = append(m.stateTransitionListeners, ch)
124 m.mu.Unlock()
125
126 return &mockStateTransitionIterator{
127 agent: m,
128 ctx: ctx,
129 ch: ch,
130 }
131}
132
133type mockStateTransitionIterator struct {
134 agent *mockAgent
135 ctx context.Context
136 ch chan loop.StateTransition
137}
138
139func (m *mockStateTransitionIterator) Next() *loop.StateTransition {
140 select {
141 case <-m.ctx.Done():
142 return nil
143 case transition, ok := <-m.ch:
144 if !ok {
145 return nil
146 }
147 transitionCopy := transition
148 return &transitionCopy
149 }
150}
151
152func (m *mockStateTransitionIterator) Close() {
153 m.agent.mu.Lock()
154 for i, ch := range m.agent.stateTransitionListeners {
155 if ch == m.ch {
156 m.agent.stateTransitionListeners = slices.Delete(m.agent.stateTransitionListeners, i, i+1)
157 break
158 }
159 }
160 m.agent.mu.Unlock()
161 close(m.ch)
162}
163
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000164func (m *mockAgent) CurrentStateName() string {
165 m.mu.RLock()
166 defer m.mu.RUnlock()
167 return m.currentState
168}
169
Philip Zeyligereab12de2025-05-14 02:35:53 +0000170func (m *mockAgent) TriggerStateTransition(from, to loop.State, event loop.TransitionEvent) {
171 m.mu.Lock()
172 m.currentState = to.String()
173 transition := loop.StateTransition{
174 From: from,
175 To: to,
176 Event: event,
177 }
178
179 // Create a copy of listeners to avoid holding the lock while sending
180 listeners := make([]chan loop.StateTransition, len(m.stateTransitionListeners))
181 copy(listeners, m.stateTransitionListeners)
182 m.mu.Unlock()
183
184 // Notify listeners
185 for _, ch := range listeners {
186 ch <- transition
187 }
188}
189
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000190func (m *mockAgent) InitialCommit() string {
191 m.mu.RLock()
192 defer m.mu.RUnlock()
193 return m.initialCommit
194}
195
Philip Zeyliger49edc922025-05-14 09:45:45 -0700196func (m *mockAgent) SketchGitBase() string {
197 m.mu.RLock()
198 defer m.mu.RUnlock()
199 return m.initialCommit
200}
201
Philip Zeyligerd3ac1122025-05-14 02:54:18 +0000202func (m *mockAgent) SketchGitBaseRef() string {
203 return "sketch-base-test-session"
204}
205
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000206func (m *mockAgent) Title() string {
207 m.mu.RLock()
208 defer m.mu.RUnlock()
209 return m.title
210}
211
212func (m *mockAgent) BranchName() string {
213 m.mu.RLock()
214 defer m.mu.RUnlock()
215 return m.branchName
216}
217
218// Other required methods of loop.CodingAgent with minimal implementation
219func (m *mockAgent) Init(loop.AgentInit) error { return nil }
220func (m *mockAgent) Ready() <-chan struct{} { ch := make(chan struct{}); close(ch); return ch }
221func (m *mockAgent) URL() string { return "http://localhost:8080" }
222func (m *mockAgent) UserMessage(ctx context.Context, msg string) {}
223func (m *mockAgent) Loop(ctx context.Context) {}
224func (m *mockAgent) CancelTurn(cause error) {}
225func (m *mockAgent) CancelToolUse(id string, cause error) error { return nil }
226func (m *mockAgent) TotalUsage() conversation.CumulativeUsage { return conversation.CumulativeUsage{} }
227func (m *mockAgent) OriginalBudget() conversation.Budget { return conversation.Budget{} }
Philip Zeyligerd3ac1122025-05-14 02:54:18 +0000228func (m *mockAgent) WorkingDir() string { return m.workingDir }
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000229func (m *mockAgent) Diff(commit *string) (string, error) { return "", nil }
230func (m *mockAgent) OS() string { return "linux" }
231func (m *mockAgent) SessionID() string { return "test-session" }
232func (m *mockAgent) OutstandingLLMCallCount() int { return 0 }
233func (m *mockAgent) OutstandingToolCalls() []string { return nil }
234func (m *mockAgent) OutsideOS() string { return "linux" }
235func (m *mockAgent) OutsideHostname() string { return "test-host" }
236func (m *mockAgent) OutsideWorkingDir() string { return "/app" }
237func (m *mockAgent) GitOrigin() string { return "" }
238func (m *mockAgent) OpenBrowser(url string) {}
239func (m *mockAgent) RestartConversation(ctx context.Context, rev string, initialPrompt string) error {
240 return nil
241}
242func (m *mockAgent) SuggestReprompt(ctx context.Context) (string, error) { return "", nil }
243func (m *mockAgent) IsInContainer() bool { return false }
244func (m *mockAgent) FirstMessageIndex() int { return 0 }
Philip Zeyliger9bca61e2025-05-22 12:40:06 -0700245func (m *mockAgent) DetectGitChanges(ctx context.Context) error { return nil } // TestSSEStream tests the SSE stream endpoint
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000246func TestSSEStream(t *testing.T) {
247 // Create a mock agent with initial messages
248 mockAgent := &mockAgent{
Philip Zeyligereab12de2025-05-14 02:35:53 +0000249 messages: []loop.AgentMessage{},
250 messageCount: 0,
251 currentState: "Ready",
252 subscribers: []chan *loop.AgentMessage{},
253 stateTransitionListeners: []chan loop.StateTransition{},
254 initialCommit: "abcd1234",
255 title: "Test Title",
256 branchName: "sketch/test-branch",
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000257 }
258
259 // Add the initial messages before creating the server
260 // to ensure they're available in the Messages slice
261 msg1 := loop.AgentMessage{
262 Type: loop.UserMessageType,
263 Content: "Hello, this is a test message",
264 Timestamp: time.Now(),
265 }
266 mockAgent.messages = append(mockAgent.messages, msg1)
267 msg1.Idx = mockAgent.messageCount
268 mockAgent.messageCount++
269
270 msg2 := loop.AgentMessage{
271 Type: loop.AgentMessageType,
272 Content: "This is a response message",
273 Timestamp: time.Now(),
274 EndOfTurn: true,
275 }
276 mockAgent.messages = append(mockAgent.messages, msg2)
277 msg2.Idx = mockAgent.messageCount
278 mockAgent.messageCount++
279
280 // Create a server with the mock agent
281 srv, err := server.New(mockAgent, nil)
282 if err != nil {
283 t.Fatalf("Failed to create server: %v", err)
284 }
285
286 // Create a test server
287 ts := httptest.NewServer(srv)
288 defer ts.Close()
289
290 // Create a context with cancellation for the client request
291 ctx, cancel := context.WithCancel(context.Background())
292
293 // Create a request to the /stream endpoint
294 req, err := http.NewRequestWithContext(ctx, "GET", ts.URL+"/stream?from=0", nil)
295 if err != nil {
296 t.Fatalf("Failed to create request: %v", err)
297 }
298
299 // Execute the request
300 res, err := http.DefaultClient.Do(req)
301 if err != nil {
302 t.Fatalf("Failed to execute request: %v", err)
303 }
304 defer res.Body.Close()
305
306 // Check response status
307 if res.StatusCode != http.StatusOK {
308 t.Fatalf("Expected status OK, got %v", res.Status)
309 }
310
311 // Check content type
312 if contentType := res.Header.Get("Content-Type"); contentType != "text/event-stream" {
313 t.Fatalf("Expected Content-Type text/event-stream, got %s", contentType)
314 }
315
316 // Read response events using a scanner
317 scanner := bufio.NewScanner(res.Body)
318
319 // Track events received
320 eventsReceived := map[string]int{
321 "state": 0,
322 "message": 0,
323 "heartbeat": 0,
324 }
325
326 // Read for a short time to capture initial state and messages
327 dataLines := []string{}
328 eventType := ""
329
330 go func() {
331 // After reading for a while, add a new message to test real-time updates
332 time.Sleep(500 * time.Millisecond)
333
334 mockAgent.AddMessage(loop.AgentMessage{
335 Type: loop.ToolUseMessageType,
336 Content: "This is a new real-time message",
337 Timestamp: time.Now(),
338 ToolName: "test_tool",
339 })
340
Philip Zeyligereab12de2025-05-14 02:35:53 +0000341 // Trigger a state transition to test state updates
342 time.Sleep(200 * time.Millisecond)
343 mockAgent.TriggerStateTransition(loop.StateReady, loop.StateSendingToLLM, loop.TransitionEvent{
344 Description: "Agent started thinking",
345 Data: "start_thinking",
346 })
347
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000348 // Let it process for longer
349 time.Sleep(1000 * time.Millisecond)
350 cancel() // Cancel to end the test
351 }()
352
353 // Read events
354 for scanner.Scan() {
355 line := scanner.Text()
356
357 if strings.HasPrefix(line, "event: ") {
358 eventType = strings.TrimPrefix(line, "event: ")
359 eventsReceived[eventType]++
360 } else if strings.HasPrefix(line, "data: ") {
361 dataLines = append(dataLines, line)
362 } else if line == "" && eventType != "" {
363 // End of event
364 eventType = ""
365 }
366
367 // Break if context is done
368 if ctx.Err() != nil {
369 break
370 }
371 }
372
373 if err := scanner.Err(); err != nil && ctx.Err() == nil {
374 t.Fatalf("Scanner error: %v", err)
375 }
376
377 // Simplified validation - just make sure we received something
378 t.Logf("Events received: %v", eventsReceived)
379 t.Logf("Data lines received: %d", len(dataLines))
380
381 // Basic validation that we received at least some events
382 if eventsReceived["state"] == 0 && eventsReceived["message"] == 0 {
383 t.Errorf("Did not receive any events")
384 }
385}
Philip Zeyligerd3ac1122025-05-14 02:54:18 +0000386
387func TestGitRawDiffHandler(t *testing.T) {
388 // Create a mock agent
389 mockAgent := &mockAgent{
390 workingDir: t.TempDir(), // Use a temp directory
391 }
392
393 // Create the server with the mock agent
394 server, err := server.New(mockAgent, nil)
395 if err != nil {
396 t.Fatalf("Failed to create server: %v", err)
397 }
398
399 // Create a test HTTP server
400 testServer := httptest.NewServer(server)
401 defer testServer.Close()
402
403 // Test missing parameters
404 resp, err := http.Get(testServer.URL + "/git/rawdiff")
405 if err != nil {
406 t.Fatalf("Failed to make HTTP request: %v", err)
407 }
408 if resp.StatusCode != http.StatusBadRequest {
409 t.Errorf("Expected status bad request, got: %d", resp.StatusCode)
410 }
411
412 // Test with commit parameter (this will fail due to no git repo, but we're testing the API, not git)
413 resp, err = http.Get(testServer.URL + "/git/rawdiff?commit=HEAD")
414 if err != nil {
415 t.Fatalf("Failed to make HTTP request: %v", err)
416 }
417 // We expect an error since there's no git repository, but the request should be processed
418 if resp.StatusCode != http.StatusInternalServerError {
419 t.Errorf("Expected status 500, got: %d", resp.StatusCode)
420 }
421
422 // Test with from/to parameters
423 resp, err = http.Get(testServer.URL + "/git/rawdiff?from=HEAD~1&to=HEAD")
424 if err != nil {
425 t.Fatalf("Failed to make HTTP request: %v", err)
426 }
427 // We expect an error since there's no git repository, but the request should be processed
428 if resp.StatusCode != http.StatusInternalServerError {
429 t.Errorf("Expected status 500, got: %d", resp.StatusCode)
430 }
431}
432
433func TestGitShowHandler(t *testing.T) {
434 // Create a mock agent
435 mockAgent := &mockAgent{
436 workingDir: t.TempDir(), // Use a temp directory
437 }
438
439 // Create the server with the mock agent
440 server, err := server.New(mockAgent, nil)
441 if err != nil {
442 t.Fatalf("Failed to create server: %v", err)
443 }
444
445 // Create a test HTTP server
446 testServer := httptest.NewServer(server)
447 defer testServer.Close()
448
449 // Test missing parameter
450 resp, err := http.Get(testServer.URL + "/git/show")
451 if err != nil {
452 t.Fatalf("Failed to make HTTP request: %v", err)
453 }
454 if resp.StatusCode != http.StatusBadRequest {
455 t.Errorf("Expected status bad request, got: %d", resp.StatusCode)
456 }
457
458 // Test with hash parameter (this will fail due to no git repo, but we're testing the API, not git)
459 resp, err = http.Get(testServer.URL + "/git/show?hash=HEAD")
460 if err != nil {
461 t.Fatalf("Failed to make HTTP request: %v", err)
462 }
463 // We expect an error since there's no git repository, but the request should be processed
464 if resp.StatusCode != http.StatusInternalServerError {
465 t.Errorf("Expected status 500, got: %d", resp.StatusCode)
466 }
467}
468
469// Removing duplicate method definition