blob: d68928e1ee5d5c6ca51477cd8d48d5bac8f8c59b [file] [log] [blame]
Philip Zeyliger25f6ff12025-05-02 04:24:10 +00001package server_test
2
3import (
4 "bufio"
5 "context"
6 "net/http"
7 "net/http/httptest"
8 "slices"
9 "strings"
10 "sync"
11 "testing"
12 "time"
13
14 "sketch.dev/llm/conversation"
15 "sketch.dev/loop"
16 "sketch.dev/loop/server"
17)
18
19// mockAgent is a mock implementation of loop.CodingAgent for testing
20type mockAgent struct {
Philip Zeyligereab12de2025-05-14 02:35:53 +000021 mu sync.RWMutex
22 messages []loop.AgentMessage
23 messageCount int
24 currentState string
25 subscribers []chan *loop.AgentMessage
26 stateTransitionListeners []chan loop.StateTransition
27 initialCommit string
28 title string
29 branchName string
Philip Zeyliger25f6ff12025-05-02 04:24:10 +000030}
31
32func (m *mockAgent) NewIterator(ctx context.Context, nextMessageIdx int) loop.MessageIterator {
33 m.mu.RLock()
34 // Send existing messages that should be available immediately
35 ch := make(chan *loop.AgentMessage, 100)
36 iter := &mockIterator{
37 agent: m,
38 ctx: ctx,
39 nextMessageIdx: nextMessageIdx,
40 ch: ch,
41 }
42 m.mu.RUnlock()
43 return iter
44}
45
46type mockIterator struct {
47 agent *mockAgent
48 ctx context.Context
49 nextMessageIdx int
50 ch chan *loop.AgentMessage
51 subscribed bool
52}
53
54func (m *mockIterator) Next() *loop.AgentMessage {
55 if !m.subscribed {
56 m.agent.mu.Lock()
57 m.agent.subscribers = append(m.agent.subscribers, m.ch)
58 m.agent.mu.Unlock()
59 m.subscribed = true
60 }
61
62 for {
63 select {
64 case <-m.ctx.Done():
65 return nil
66 case msg := <-m.ch:
67 return msg
68 }
69 }
70}
71
72func (m *mockIterator) Close() {
73 // Remove from subscribers using slices.Delete
74 m.agent.mu.Lock()
75 for i, ch := range m.agent.subscribers {
76 if ch == m.ch {
77 m.agent.subscribers = slices.Delete(m.agent.subscribers, i, i+1)
78 break
79 }
80 }
81 m.agent.mu.Unlock()
82 close(m.ch)
83}
84
85func (m *mockAgent) Messages(start int, end int) []loop.AgentMessage {
86 m.mu.RLock()
87 defer m.mu.RUnlock()
88
89 if start >= len(m.messages) || end > len(m.messages) || start < 0 || end < 0 {
90 return []loop.AgentMessage{}
91 }
92 return slices.Clone(m.messages[start:end])
93}
94
95func (m *mockAgent) MessageCount() int {
96 m.mu.RLock()
97 defer m.mu.RUnlock()
98 return m.messageCount
99}
100
101func (m *mockAgent) AddMessage(msg loop.AgentMessage) {
102 m.mu.Lock()
103 msg.Idx = m.messageCount
104 m.messages = append(m.messages, msg)
105 m.messageCount++
106
107 // Create a copy of subscribers to avoid holding the lock while sending
108 subscribers := make([]chan *loop.AgentMessage, len(m.subscribers))
109 copy(subscribers, m.subscribers)
110 m.mu.Unlock()
111
112 // Notify subscribers
113 msgCopy := msg // Create a copy to avoid race conditions
114 for _, ch := range subscribers {
115 ch <- &msgCopy
116 }
117}
118
Philip Zeyligereab12de2025-05-14 02:35:53 +0000119func (m *mockAgent) NewStateTransitionIterator(ctx context.Context) loop.StateTransitionIterator {
120 m.mu.Lock()
121 ch := make(chan loop.StateTransition, 10)
122 m.stateTransitionListeners = append(m.stateTransitionListeners, ch)
123 m.mu.Unlock()
124
125 return &mockStateTransitionIterator{
126 agent: m,
127 ctx: ctx,
128 ch: ch,
129 }
130}
131
132type mockStateTransitionIterator struct {
133 agent *mockAgent
134 ctx context.Context
135 ch chan loop.StateTransition
136}
137
138func (m *mockStateTransitionIterator) Next() *loop.StateTransition {
139 select {
140 case <-m.ctx.Done():
141 return nil
142 case transition, ok := <-m.ch:
143 if !ok {
144 return nil
145 }
146 transitionCopy := transition
147 return &transitionCopy
148 }
149}
150
151func (m *mockStateTransitionIterator) Close() {
152 m.agent.mu.Lock()
153 for i, ch := range m.agent.stateTransitionListeners {
154 if ch == m.ch {
155 m.agent.stateTransitionListeners = slices.Delete(m.agent.stateTransitionListeners, i, i+1)
156 break
157 }
158 }
159 m.agent.mu.Unlock()
160 close(m.ch)
161}
162
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000163func (m *mockAgent) CurrentStateName() string {
164 m.mu.RLock()
165 defer m.mu.RUnlock()
166 return m.currentState
167}
168
Philip Zeyligereab12de2025-05-14 02:35:53 +0000169func (m *mockAgent) TriggerStateTransition(from, to loop.State, event loop.TransitionEvent) {
170 m.mu.Lock()
171 m.currentState = to.String()
172 transition := loop.StateTransition{
173 From: from,
174 To: to,
175 Event: event,
176 }
177
178 // Create a copy of listeners to avoid holding the lock while sending
179 listeners := make([]chan loop.StateTransition, len(m.stateTransitionListeners))
180 copy(listeners, m.stateTransitionListeners)
181 m.mu.Unlock()
182
183 // Notify listeners
184 for _, ch := range listeners {
185 ch <- transition
186 }
187}
188
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000189func (m *mockAgent) InitialCommit() string {
190 m.mu.RLock()
191 defer m.mu.RUnlock()
192 return m.initialCommit
193}
194
Philip Zeyliger49edc922025-05-14 09:45:45 -0700195func (m *mockAgent) SketchGitBase() string {
196 m.mu.RLock()
197 defer m.mu.RUnlock()
198 return m.initialCommit
199}
200
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000201func (m *mockAgent) Title() string {
202 m.mu.RLock()
203 defer m.mu.RUnlock()
204 return m.title
205}
206
207func (m *mockAgent) BranchName() string {
208 m.mu.RLock()
209 defer m.mu.RUnlock()
210 return m.branchName
211}
212
213// Other required methods of loop.CodingAgent with minimal implementation
214func (m *mockAgent) Init(loop.AgentInit) error { return nil }
215func (m *mockAgent) Ready() <-chan struct{} { ch := make(chan struct{}); close(ch); return ch }
216func (m *mockAgent) URL() string { return "http://localhost:8080" }
217func (m *mockAgent) UserMessage(ctx context.Context, msg string) {}
218func (m *mockAgent) Loop(ctx context.Context) {}
219func (m *mockAgent) CancelTurn(cause error) {}
220func (m *mockAgent) CancelToolUse(id string, cause error) error { return nil }
221func (m *mockAgent) TotalUsage() conversation.CumulativeUsage { return conversation.CumulativeUsage{} }
222func (m *mockAgent) OriginalBudget() conversation.Budget { return conversation.Budget{} }
223func (m *mockAgent) WorkingDir() string { return "/app" }
224func (m *mockAgent) Diff(commit *string) (string, error) { return "", nil }
225func (m *mockAgent) OS() string { return "linux" }
226func (m *mockAgent) SessionID() string { return "test-session" }
227func (m *mockAgent) OutstandingLLMCallCount() int { return 0 }
228func (m *mockAgent) OutstandingToolCalls() []string { return nil }
229func (m *mockAgent) OutsideOS() string { return "linux" }
230func (m *mockAgent) OutsideHostname() string { return "test-host" }
231func (m *mockAgent) OutsideWorkingDir() string { return "/app" }
232func (m *mockAgent) GitOrigin() string { return "" }
233func (m *mockAgent) OpenBrowser(url string) {}
234func (m *mockAgent) RestartConversation(ctx context.Context, rev string, initialPrompt string) error {
235 return nil
236}
237func (m *mockAgent) SuggestReprompt(ctx context.Context) (string, error) { return "", nil }
238func (m *mockAgent) IsInContainer() bool { return false }
239func (m *mockAgent) FirstMessageIndex() int { return 0 }
240
241// TestSSEStream tests the SSE stream endpoint
242func TestSSEStream(t *testing.T) {
243 // Create a mock agent with initial messages
244 mockAgent := &mockAgent{
Philip Zeyligereab12de2025-05-14 02:35:53 +0000245 messages: []loop.AgentMessage{},
246 messageCount: 0,
247 currentState: "Ready",
248 subscribers: []chan *loop.AgentMessage{},
249 stateTransitionListeners: []chan loop.StateTransition{},
250 initialCommit: "abcd1234",
251 title: "Test Title",
252 branchName: "sketch/test-branch",
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000253 }
254
255 // Add the initial messages before creating the server
256 // to ensure they're available in the Messages slice
257 msg1 := loop.AgentMessage{
258 Type: loop.UserMessageType,
259 Content: "Hello, this is a test message",
260 Timestamp: time.Now(),
261 }
262 mockAgent.messages = append(mockAgent.messages, msg1)
263 msg1.Idx = mockAgent.messageCount
264 mockAgent.messageCount++
265
266 msg2 := loop.AgentMessage{
267 Type: loop.AgentMessageType,
268 Content: "This is a response message",
269 Timestamp: time.Now(),
270 EndOfTurn: true,
271 }
272 mockAgent.messages = append(mockAgent.messages, msg2)
273 msg2.Idx = mockAgent.messageCount
274 mockAgent.messageCount++
275
276 // Create a server with the mock agent
277 srv, err := server.New(mockAgent, nil)
278 if err != nil {
279 t.Fatalf("Failed to create server: %v", err)
280 }
281
282 // Create a test server
283 ts := httptest.NewServer(srv)
284 defer ts.Close()
285
286 // Create a context with cancellation for the client request
287 ctx, cancel := context.WithCancel(context.Background())
288
289 // Create a request to the /stream endpoint
290 req, err := http.NewRequestWithContext(ctx, "GET", ts.URL+"/stream?from=0", nil)
291 if err != nil {
292 t.Fatalf("Failed to create request: %v", err)
293 }
294
295 // Execute the request
296 res, err := http.DefaultClient.Do(req)
297 if err != nil {
298 t.Fatalf("Failed to execute request: %v", err)
299 }
300 defer res.Body.Close()
301
302 // Check response status
303 if res.StatusCode != http.StatusOK {
304 t.Fatalf("Expected status OK, got %v", res.Status)
305 }
306
307 // Check content type
308 if contentType := res.Header.Get("Content-Type"); contentType != "text/event-stream" {
309 t.Fatalf("Expected Content-Type text/event-stream, got %s", contentType)
310 }
311
312 // Read response events using a scanner
313 scanner := bufio.NewScanner(res.Body)
314
315 // Track events received
316 eventsReceived := map[string]int{
317 "state": 0,
318 "message": 0,
319 "heartbeat": 0,
320 }
321
322 // Read for a short time to capture initial state and messages
323 dataLines := []string{}
324 eventType := ""
325
326 go func() {
327 // After reading for a while, add a new message to test real-time updates
328 time.Sleep(500 * time.Millisecond)
329
330 mockAgent.AddMessage(loop.AgentMessage{
331 Type: loop.ToolUseMessageType,
332 Content: "This is a new real-time message",
333 Timestamp: time.Now(),
334 ToolName: "test_tool",
335 })
336
Philip Zeyligereab12de2025-05-14 02:35:53 +0000337 // Trigger a state transition to test state updates
338 time.Sleep(200 * time.Millisecond)
339 mockAgent.TriggerStateTransition(loop.StateReady, loop.StateSendingToLLM, loop.TransitionEvent{
340 Description: "Agent started thinking",
341 Data: "start_thinking",
342 })
343
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000344 // Let it process for longer
345 time.Sleep(1000 * time.Millisecond)
346 cancel() // Cancel to end the test
347 }()
348
349 // Read events
350 for scanner.Scan() {
351 line := scanner.Text()
352
353 if strings.HasPrefix(line, "event: ") {
354 eventType = strings.TrimPrefix(line, "event: ")
355 eventsReceived[eventType]++
356 } else if strings.HasPrefix(line, "data: ") {
357 dataLines = append(dataLines, line)
358 } else if line == "" && eventType != "" {
359 // End of event
360 eventType = ""
361 }
362
363 // Break if context is done
364 if ctx.Err() != nil {
365 break
366 }
367 }
368
369 if err := scanner.Err(); err != nil && ctx.Err() == nil {
370 t.Fatalf("Scanner error: %v", err)
371 }
372
373 // Simplified validation - just make sure we received something
374 t.Logf("Events received: %v", eventsReceived)
375 t.Logf("Data lines received: %d", len(dataLines))
376
377 // Basic validation that we received at least some events
378 if eventsReceived["state"] == 0 && eventsReceived["message"] == 0 {
379 t.Errorf("Did not receive any events")
380 }
381}