blob: fc525322fa47570478f3ed4f3afc5d5377da9da8 [file] [log] [blame]
Philip Zeyliger25f6ff12025-05-02 04:24:10 +00001package server_test
2
3import (
4 "bufio"
5 "context"
6 "net/http"
7 "net/http/httptest"
8 "slices"
9 "strings"
10 "sync"
11 "testing"
12 "time"
13
14 "sketch.dev/llm/conversation"
15 "sketch.dev/loop"
16 "sketch.dev/loop/server"
17)
18
19// mockAgent is a mock implementation of loop.CodingAgent for testing
20type mockAgent struct {
Philip Zeyligereab12de2025-05-14 02:35:53 +000021 mu sync.RWMutex
22 messages []loop.AgentMessage
23 messageCount int
24 currentState string
25 subscribers []chan *loop.AgentMessage
26 stateTransitionListeners []chan loop.StateTransition
27 initialCommit string
28 title string
29 branchName string
Philip Zeyligerd3ac1122025-05-14 02:54:18 +000030 workingDir string
Philip Zeyliger25f6ff12025-05-02 04:24:10 +000031}
32
33func (m *mockAgent) NewIterator(ctx context.Context, nextMessageIdx int) loop.MessageIterator {
34 m.mu.RLock()
35 // Send existing messages that should be available immediately
36 ch := make(chan *loop.AgentMessage, 100)
37 iter := &mockIterator{
38 agent: m,
39 ctx: ctx,
40 nextMessageIdx: nextMessageIdx,
41 ch: ch,
42 }
43 m.mu.RUnlock()
44 return iter
45}
46
47type mockIterator struct {
48 agent *mockAgent
49 ctx context.Context
50 nextMessageIdx int
51 ch chan *loop.AgentMessage
52 subscribed bool
53}
54
55func (m *mockIterator) Next() *loop.AgentMessage {
56 if !m.subscribed {
57 m.agent.mu.Lock()
58 m.agent.subscribers = append(m.agent.subscribers, m.ch)
59 m.agent.mu.Unlock()
60 m.subscribed = true
61 }
62
63 for {
64 select {
65 case <-m.ctx.Done():
66 return nil
67 case msg := <-m.ch:
68 return msg
69 }
70 }
71}
72
73func (m *mockIterator) Close() {
74 // Remove from subscribers using slices.Delete
75 m.agent.mu.Lock()
76 for i, ch := range m.agent.subscribers {
77 if ch == m.ch {
78 m.agent.subscribers = slices.Delete(m.agent.subscribers, i, i+1)
79 break
80 }
81 }
82 m.agent.mu.Unlock()
83 close(m.ch)
84}
85
86func (m *mockAgent) Messages(start int, end int) []loop.AgentMessage {
87 m.mu.RLock()
88 defer m.mu.RUnlock()
89
90 if start >= len(m.messages) || end > len(m.messages) || start < 0 || end < 0 {
91 return []loop.AgentMessage{}
92 }
93 return slices.Clone(m.messages[start:end])
94}
95
96func (m *mockAgent) MessageCount() int {
97 m.mu.RLock()
98 defer m.mu.RUnlock()
99 return m.messageCount
100}
101
102func (m *mockAgent) AddMessage(msg loop.AgentMessage) {
103 m.mu.Lock()
104 msg.Idx = m.messageCount
105 m.messages = append(m.messages, msg)
106 m.messageCount++
107
108 // Create a copy of subscribers to avoid holding the lock while sending
109 subscribers := make([]chan *loop.AgentMessage, len(m.subscribers))
110 copy(subscribers, m.subscribers)
111 m.mu.Unlock()
112
113 // Notify subscribers
114 msgCopy := msg // Create a copy to avoid race conditions
115 for _, ch := range subscribers {
116 ch <- &msgCopy
117 }
118}
119
Philip Zeyligereab12de2025-05-14 02:35:53 +0000120func (m *mockAgent) NewStateTransitionIterator(ctx context.Context) loop.StateTransitionIterator {
121 m.mu.Lock()
122 ch := make(chan loop.StateTransition, 10)
123 m.stateTransitionListeners = append(m.stateTransitionListeners, ch)
124 m.mu.Unlock()
125
126 return &mockStateTransitionIterator{
127 agent: m,
128 ctx: ctx,
129 ch: ch,
130 }
131}
132
133type mockStateTransitionIterator struct {
134 agent *mockAgent
135 ctx context.Context
136 ch chan loop.StateTransition
137}
138
139func (m *mockStateTransitionIterator) Next() *loop.StateTransition {
140 select {
141 case <-m.ctx.Done():
142 return nil
143 case transition, ok := <-m.ch:
144 if !ok {
145 return nil
146 }
147 transitionCopy := transition
148 return &transitionCopy
149 }
150}
151
152func (m *mockStateTransitionIterator) Close() {
153 m.agent.mu.Lock()
154 for i, ch := range m.agent.stateTransitionListeners {
155 if ch == m.ch {
156 m.agent.stateTransitionListeners = slices.Delete(m.agent.stateTransitionListeners, i, i+1)
157 break
158 }
159 }
160 m.agent.mu.Unlock()
161 close(m.ch)
162}
163
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000164func (m *mockAgent) CurrentStateName() string {
165 m.mu.RLock()
166 defer m.mu.RUnlock()
167 return m.currentState
168}
169
Philip Zeyligereab12de2025-05-14 02:35:53 +0000170func (m *mockAgent) TriggerStateTransition(from, to loop.State, event loop.TransitionEvent) {
171 m.mu.Lock()
172 m.currentState = to.String()
173 transition := loop.StateTransition{
174 From: from,
175 To: to,
176 Event: event,
177 }
178
179 // Create a copy of listeners to avoid holding the lock while sending
180 listeners := make([]chan loop.StateTransition, len(m.stateTransitionListeners))
181 copy(listeners, m.stateTransitionListeners)
182 m.mu.Unlock()
183
184 // Notify listeners
185 for _, ch := range listeners {
186 ch <- transition
187 }
188}
189
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000190func (m *mockAgent) InitialCommit() string {
191 m.mu.RLock()
192 defer m.mu.RUnlock()
193 return m.initialCommit
194}
195
Philip Zeyliger49edc922025-05-14 09:45:45 -0700196func (m *mockAgent) SketchGitBase() string {
197 m.mu.RLock()
198 defer m.mu.RUnlock()
199 return m.initialCommit
200}
201
Philip Zeyligerd3ac1122025-05-14 02:54:18 +0000202func (m *mockAgent) SketchGitBaseRef() string {
203 return "sketch-base-test-session"
204}
205
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000206func (m *mockAgent) Title() string {
207 m.mu.RLock()
208 defer m.mu.RUnlock()
209 return m.title
210}
211
212func (m *mockAgent) BranchName() string {
213 m.mu.RLock()
214 defer m.mu.RUnlock()
215 return m.branchName
216}
217
218// Other required methods of loop.CodingAgent with minimal implementation
219func (m *mockAgent) Init(loop.AgentInit) error { return nil }
220func (m *mockAgent) Ready() <-chan struct{} { ch := make(chan struct{}); close(ch); return ch }
221func (m *mockAgent) URL() string { return "http://localhost:8080" }
222func (m *mockAgent) UserMessage(ctx context.Context, msg string) {}
223func (m *mockAgent) Loop(ctx context.Context) {}
224func (m *mockAgent) CancelTurn(cause error) {}
225func (m *mockAgent) CancelToolUse(id string, cause error) error { return nil }
226func (m *mockAgent) TotalUsage() conversation.CumulativeUsage { return conversation.CumulativeUsage{} }
227func (m *mockAgent) OriginalBudget() conversation.Budget { return conversation.Budget{} }
Philip Zeyligerd3ac1122025-05-14 02:54:18 +0000228func (m *mockAgent) WorkingDir() string { return m.workingDir }
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000229func (m *mockAgent) Diff(commit *string) (string, error) { return "", nil }
230func (m *mockAgent) OS() string { return "linux" }
231func (m *mockAgent) SessionID() string { return "test-session" }
232func (m *mockAgent) OutstandingLLMCallCount() int { return 0 }
233func (m *mockAgent) OutstandingToolCalls() []string { return nil }
234func (m *mockAgent) OutsideOS() string { return "linux" }
235func (m *mockAgent) OutsideHostname() string { return "test-host" }
236func (m *mockAgent) OutsideWorkingDir() string { return "/app" }
237func (m *mockAgent) GitOrigin() string { return "" }
238func (m *mockAgent) OpenBrowser(url string) {}
239func (m *mockAgent) RestartConversation(ctx context.Context, rev string, initialPrompt string) error {
240 return nil
241}
242func (m *mockAgent) SuggestReprompt(ctx context.Context) (string, error) { return "", nil }
243func (m *mockAgent) IsInContainer() bool { return false }
244func (m *mockAgent) FirstMessageIndex() int { return 0 }
245
246// TestSSEStream tests the SSE stream endpoint
247func TestSSEStream(t *testing.T) {
248 // Create a mock agent with initial messages
249 mockAgent := &mockAgent{
Philip Zeyligereab12de2025-05-14 02:35:53 +0000250 messages: []loop.AgentMessage{},
251 messageCount: 0,
252 currentState: "Ready",
253 subscribers: []chan *loop.AgentMessage{},
254 stateTransitionListeners: []chan loop.StateTransition{},
255 initialCommit: "abcd1234",
256 title: "Test Title",
257 branchName: "sketch/test-branch",
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000258 }
259
260 // Add the initial messages before creating the server
261 // to ensure they're available in the Messages slice
262 msg1 := loop.AgentMessage{
263 Type: loop.UserMessageType,
264 Content: "Hello, this is a test message",
265 Timestamp: time.Now(),
266 }
267 mockAgent.messages = append(mockAgent.messages, msg1)
268 msg1.Idx = mockAgent.messageCount
269 mockAgent.messageCount++
270
271 msg2 := loop.AgentMessage{
272 Type: loop.AgentMessageType,
273 Content: "This is a response message",
274 Timestamp: time.Now(),
275 EndOfTurn: true,
276 }
277 mockAgent.messages = append(mockAgent.messages, msg2)
278 msg2.Idx = mockAgent.messageCount
279 mockAgent.messageCount++
280
281 // Create a server with the mock agent
282 srv, err := server.New(mockAgent, nil)
283 if err != nil {
284 t.Fatalf("Failed to create server: %v", err)
285 }
286
287 // Create a test server
288 ts := httptest.NewServer(srv)
289 defer ts.Close()
290
291 // Create a context with cancellation for the client request
292 ctx, cancel := context.WithCancel(context.Background())
293
294 // Create a request to the /stream endpoint
295 req, err := http.NewRequestWithContext(ctx, "GET", ts.URL+"/stream?from=0", nil)
296 if err != nil {
297 t.Fatalf("Failed to create request: %v", err)
298 }
299
300 // Execute the request
301 res, err := http.DefaultClient.Do(req)
302 if err != nil {
303 t.Fatalf("Failed to execute request: %v", err)
304 }
305 defer res.Body.Close()
306
307 // Check response status
308 if res.StatusCode != http.StatusOK {
309 t.Fatalf("Expected status OK, got %v", res.Status)
310 }
311
312 // Check content type
313 if contentType := res.Header.Get("Content-Type"); contentType != "text/event-stream" {
314 t.Fatalf("Expected Content-Type text/event-stream, got %s", contentType)
315 }
316
317 // Read response events using a scanner
318 scanner := bufio.NewScanner(res.Body)
319
320 // Track events received
321 eventsReceived := map[string]int{
322 "state": 0,
323 "message": 0,
324 "heartbeat": 0,
325 }
326
327 // Read for a short time to capture initial state and messages
328 dataLines := []string{}
329 eventType := ""
330
331 go func() {
332 // After reading for a while, add a new message to test real-time updates
333 time.Sleep(500 * time.Millisecond)
334
335 mockAgent.AddMessage(loop.AgentMessage{
336 Type: loop.ToolUseMessageType,
337 Content: "This is a new real-time message",
338 Timestamp: time.Now(),
339 ToolName: "test_tool",
340 })
341
Philip Zeyligereab12de2025-05-14 02:35:53 +0000342 // Trigger a state transition to test state updates
343 time.Sleep(200 * time.Millisecond)
344 mockAgent.TriggerStateTransition(loop.StateReady, loop.StateSendingToLLM, loop.TransitionEvent{
345 Description: "Agent started thinking",
346 Data: "start_thinking",
347 })
348
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000349 // Let it process for longer
350 time.Sleep(1000 * time.Millisecond)
351 cancel() // Cancel to end the test
352 }()
353
354 // Read events
355 for scanner.Scan() {
356 line := scanner.Text()
357
358 if strings.HasPrefix(line, "event: ") {
359 eventType = strings.TrimPrefix(line, "event: ")
360 eventsReceived[eventType]++
361 } else if strings.HasPrefix(line, "data: ") {
362 dataLines = append(dataLines, line)
363 } else if line == "" && eventType != "" {
364 // End of event
365 eventType = ""
366 }
367
368 // Break if context is done
369 if ctx.Err() != nil {
370 break
371 }
372 }
373
374 if err := scanner.Err(); err != nil && ctx.Err() == nil {
375 t.Fatalf("Scanner error: %v", err)
376 }
377
378 // Simplified validation - just make sure we received something
379 t.Logf("Events received: %v", eventsReceived)
380 t.Logf("Data lines received: %d", len(dataLines))
381
382 // Basic validation that we received at least some events
383 if eventsReceived["state"] == 0 && eventsReceived["message"] == 0 {
384 t.Errorf("Did not receive any events")
385 }
386}
Philip Zeyligerd3ac1122025-05-14 02:54:18 +0000387
388func TestGitRawDiffHandler(t *testing.T) {
389 // Create a mock agent
390 mockAgent := &mockAgent{
391 workingDir: t.TempDir(), // Use a temp directory
392 }
393
394 // Create the server with the mock agent
395 server, err := server.New(mockAgent, nil)
396 if err != nil {
397 t.Fatalf("Failed to create server: %v", err)
398 }
399
400 // Create a test HTTP server
401 testServer := httptest.NewServer(server)
402 defer testServer.Close()
403
404 // Test missing parameters
405 resp, err := http.Get(testServer.URL + "/git/rawdiff")
406 if err != nil {
407 t.Fatalf("Failed to make HTTP request: %v", err)
408 }
409 if resp.StatusCode != http.StatusBadRequest {
410 t.Errorf("Expected status bad request, got: %d", resp.StatusCode)
411 }
412
413 // Test with commit parameter (this will fail due to no git repo, but we're testing the API, not git)
414 resp, err = http.Get(testServer.URL + "/git/rawdiff?commit=HEAD")
415 if err != nil {
416 t.Fatalf("Failed to make HTTP request: %v", err)
417 }
418 // We expect an error since there's no git repository, but the request should be processed
419 if resp.StatusCode != http.StatusInternalServerError {
420 t.Errorf("Expected status 500, got: %d", resp.StatusCode)
421 }
422
423 // Test with from/to parameters
424 resp, err = http.Get(testServer.URL + "/git/rawdiff?from=HEAD~1&to=HEAD")
425 if err != nil {
426 t.Fatalf("Failed to make HTTP request: %v", err)
427 }
428 // We expect an error since there's no git repository, but the request should be processed
429 if resp.StatusCode != http.StatusInternalServerError {
430 t.Errorf("Expected status 500, got: %d", resp.StatusCode)
431 }
432}
433
434func TestGitShowHandler(t *testing.T) {
435 // Create a mock agent
436 mockAgent := &mockAgent{
437 workingDir: t.TempDir(), // Use a temp directory
438 }
439
440 // Create the server with the mock agent
441 server, err := server.New(mockAgent, nil)
442 if err != nil {
443 t.Fatalf("Failed to create server: %v", err)
444 }
445
446 // Create a test HTTP server
447 testServer := httptest.NewServer(server)
448 defer testServer.Close()
449
450 // Test missing parameter
451 resp, err := http.Get(testServer.URL + "/git/show")
452 if err != nil {
453 t.Fatalf("Failed to make HTTP request: %v", err)
454 }
455 if resp.StatusCode != http.StatusBadRequest {
456 t.Errorf("Expected status bad request, got: %d", resp.StatusCode)
457 }
458
459 // Test with hash parameter (this will fail due to no git repo, but we're testing the API, not git)
460 resp, err = http.Get(testServer.URL + "/git/show?hash=HEAD")
461 if err != nil {
462 t.Fatalf("Failed to make HTTP request: %v", err)
463 }
464 // We expect an error since there's no git repository, but the request should be processed
465 if resp.StatusCode != http.StatusInternalServerError {
466 t.Errorf("Expected status 500, got: %d", resp.StatusCode)
467 }
468}
469
470// Removing duplicate method definition