blob: cb50642942dcca5375b99ea4bcdb348fe007e035 [file] [log] [blame]
Philip Zeyliger25f6ff12025-05-02 04:24:10 +00001package server_test
2
3import (
4 "bufio"
5 "context"
6 "net/http"
7 "net/http/httptest"
8 "slices"
9 "strings"
10 "sync"
11 "testing"
12 "time"
13
14 "sketch.dev/llm/conversation"
15 "sketch.dev/loop"
16 "sketch.dev/loop/server"
17)
18
19// mockAgent is a mock implementation of loop.CodingAgent for testing
20type mockAgent struct {
Philip Zeyligereab12de2025-05-14 02:35:53 +000021 mu sync.RWMutex
22 messages []loop.AgentMessage
23 messageCount int
24 currentState string
25 subscribers []chan *loop.AgentMessage
26 stateTransitionListeners []chan loop.StateTransition
27 initialCommit string
28 title string
29 branchName string
Philip Zeyliger25f6ff12025-05-02 04:24:10 +000030}
31
32func (m *mockAgent) NewIterator(ctx context.Context, nextMessageIdx int) loop.MessageIterator {
33 m.mu.RLock()
34 // Send existing messages that should be available immediately
35 ch := make(chan *loop.AgentMessage, 100)
36 iter := &mockIterator{
37 agent: m,
38 ctx: ctx,
39 nextMessageIdx: nextMessageIdx,
40 ch: ch,
41 }
42 m.mu.RUnlock()
43 return iter
44}
45
46type mockIterator struct {
47 agent *mockAgent
48 ctx context.Context
49 nextMessageIdx int
50 ch chan *loop.AgentMessage
51 subscribed bool
52}
53
54func (m *mockIterator) Next() *loop.AgentMessage {
55 if !m.subscribed {
56 m.agent.mu.Lock()
57 m.agent.subscribers = append(m.agent.subscribers, m.ch)
58 m.agent.mu.Unlock()
59 m.subscribed = true
60 }
61
62 for {
63 select {
64 case <-m.ctx.Done():
65 return nil
66 case msg := <-m.ch:
67 return msg
68 }
69 }
70}
71
72func (m *mockIterator) Close() {
73 // Remove from subscribers using slices.Delete
74 m.agent.mu.Lock()
75 for i, ch := range m.agent.subscribers {
76 if ch == m.ch {
77 m.agent.subscribers = slices.Delete(m.agent.subscribers, i, i+1)
78 break
79 }
80 }
81 m.agent.mu.Unlock()
82 close(m.ch)
83}
84
85func (m *mockAgent) Messages(start int, end int) []loop.AgentMessage {
86 m.mu.RLock()
87 defer m.mu.RUnlock()
88
89 if start >= len(m.messages) || end > len(m.messages) || start < 0 || end < 0 {
90 return []loop.AgentMessage{}
91 }
92 return slices.Clone(m.messages[start:end])
93}
94
95func (m *mockAgent) MessageCount() int {
96 m.mu.RLock()
97 defer m.mu.RUnlock()
98 return m.messageCount
99}
100
101func (m *mockAgent) AddMessage(msg loop.AgentMessage) {
102 m.mu.Lock()
103 msg.Idx = m.messageCount
104 m.messages = append(m.messages, msg)
105 m.messageCount++
106
107 // Create a copy of subscribers to avoid holding the lock while sending
108 subscribers := make([]chan *loop.AgentMessage, len(m.subscribers))
109 copy(subscribers, m.subscribers)
110 m.mu.Unlock()
111
112 // Notify subscribers
113 msgCopy := msg // Create a copy to avoid race conditions
114 for _, ch := range subscribers {
115 ch <- &msgCopy
116 }
117}
118
Philip Zeyligereab12de2025-05-14 02:35:53 +0000119func (m *mockAgent) NewStateTransitionIterator(ctx context.Context) loop.StateTransitionIterator {
120 m.mu.Lock()
121 ch := make(chan loop.StateTransition, 10)
122 m.stateTransitionListeners = append(m.stateTransitionListeners, ch)
123 m.mu.Unlock()
124
125 return &mockStateTransitionIterator{
126 agent: m,
127 ctx: ctx,
128 ch: ch,
129 }
130}
131
132type mockStateTransitionIterator struct {
133 agent *mockAgent
134 ctx context.Context
135 ch chan loop.StateTransition
136}
137
138func (m *mockStateTransitionIterator) Next() *loop.StateTransition {
139 select {
140 case <-m.ctx.Done():
141 return nil
142 case transition, ok := <-m.ch:
143 if !ok {
144 return nil
145 }
146 transitionCopy := transition
147 return &transitionCopy
148 }
149}
150
151func (m *mockStateTransitionIterator) Close() {
152 m.agent.mu.Lock()
153 for i, ch := range m.agent.stateTransitionListeners {
154 if ch == m.ch {
155 m.agent.stateTransitionListeners = slices.Delete(m.agent.stateTransitionListeners, i, i+1)
156 break
157 }
158 }
159 m.agent.mu.Unlock()
160 close(m.ch)
161}
162
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000163func (m *mockAgent) CurrentStateName() string {
164 m.mu.RLock()
165 defer m.mu.RUnlock()
166 return m.currentState
167}
168
Philip Zeyligereab12de2025-05-14 02:35:53 +0000169func (m *mockAgent) TriggerStateTransition(from, to loop.State, event loop.TransitionEvent) {
170 m.mu.Lock()
171 m.currentState = to.String()
172 transition := loop.StateTransition{
173 From: from,
174 To: to,
175 Event: event,
176 }
177
178 // Create a copy of listeners to avoid holding the lock while sending
179 listeners := make([]chan loop.StateTransition, len(m.stateTransitionListeners))
180 copy(listeners, m.stateTransitionListeners)
181 m.mu.Unlock()
182
183 // Notify listeners
184 for _, ch := range listeners {
185 ch <- transition
186 }
187}
188
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000189func (m *mockAgent) InitialCommit() string {
190 m.mu.RLock()
191 defer m.mu.RUnlock()
192 return m.initialCommit
193}
194
195func (m *mockAgent) Title() string {
196 m.mu.RLock()
197 defer m.mu.RUnlock()
198 return m.title
199}
200
201func (m *mockAgent) BranchName() string {
202 m.mu.RLock()
203 defer m.mu.RUnlock()
204 return m.branchName
205}
206
207// Other required methods of loop.CodingAgent with minimal implementation
208func (m *mockAgent) Init(loop.AgentInit) error { return nil }
209func (m *mockAgent) Ready() <-chan struct{} { ch := make(chan struct{}); close(ch); return ch }
210func (m *mockAgent) URL() string { return "http://localhost:8080" }
211func (m *mockAgent) UserMessage(ctx context.Context, msg string) {}
212func (m *mockAgent) Loop(ctx context.Context) {}
213func (m *mockAgent) CancelTurn(cause error) {}
214func (m *mockAgent) CancelToolUse(id string, cause error) error { return nil }
215func (m *mockAgent) TotalUsage() conversation.CumulativeUsage { return conversation.CumulativeUsage{} }
216func (m *mockAgent) OriginalBudget() conversation.Budget { return conversation.Budget{} }
217func (m *mockAgent) WorkingDir() string { return "/app" }
218func (m *mockAgent) Diff(commit *string) (string, error) { return "", nil }
219func (m *mockAgent) OS() string { return "linux" }
220func (m *mockAgent) SessionID() string { return "test-session" }
221func (m *mockAgent) OutstandingLLMCallCount() int { return 0 }
222func (m *mockAgent) OutstandingToolCalls() []string { return nil }
223func (m *mockAgent) OutsideOS() string { return "linux" }
224func (m *mockAgent) OutsideHostname() string { return "test-host" }
225func (m *mockAgent) OutsideWorkingDir() string { return "/app" }
226func (m *mockAgent) GitOrigin() string { return "" }
227func (m *mockAgent) OpenBrowser(url string) {}
228func (m *mockAgent) RestartConversation(ctx context.Context, rev string, initialPrompt string) error {
229 return nil
230}
231func (m *mockAgent) SuggestReprompt(ctx context.Context) (string, error) { return "", nil }
232func (m *mockAgent) IsInContainer() bool { return false }
233func (m *mockAgent) FirstMessageIndex() int { return 0 }
234
235// TestSSEStream tests the SSE stream endpoint
236func TestSSEStream(t *testing.T) {
237 // Create a mock agent with initial messages
238 mockAgent := &mockAgent{
Philip Zeyligereab12de2025-05-14 02:35:53 +0000239 messages: []loop.AgentMessage{},
240 messageCount: 0,
241 currentState: "Ready",
242 subscribers: []chan *loop.AgentMessage{},
243 stateTransitionListeners: []chan loop.StateTransition{},
244 initialCommit: "abcd1234",
245 title: "Test Title",
246 branchName: "sketch/test-branch",
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000247 }
248
249 // Add the initial messages before creating the server
250 // to ensure they're available in the Messages slice
251 msg1 := loop.AgentMessage{
252 Type: loop.UserMessageType,
253 Content: "Hello, this is a test message",
254 Timestamp: time.Now(),
255 }
256 mockAgent.messages = append(mockAgent.messages, msg1)
257 msg1.Idx = mockAgent.messageCount
258 mockAgent.messageCount++
259
260 msg2 := loop.AgentMessage{
261 Type: loop.AgentMessageType,
262 Content: "This is a response message",
263 Timestamp: time.Now(),
264 EndOfTurn: true,
265 }
266 mockAgent.messages = append(mockAgent.messages, msg2)
267 msg2.Idx = mockAgent.messageCount
268 mockAgent.messageCount++
269
270 // Create a server with the mock agent
271 srv, err := server.New(mockAgent, nil)
272 if err != nil {
273 t.Fatalf("Failed to create server: %v", err)
274 }
275
276 // Create a test server
277 ts := httptest.NewServer(srv)
278 defer ts.Close()
279
280 // Create a context with cancellation for the client request
281 ctx, cancel := context.WithCancel(context.Background())
282
283 // Create a request to the /stream endpoint
284 req, err := http.NewRequestWithContext(ctx, "GET", ts.URL+"/stream?from=0", nil)
285 if err != nil {
286 t.Fatalf("Failed to create request: %v", err)
287 }
288
289 // Execute the request
290 res, err := http.DefaultClient.Do(req)
291 if err != nil {
292 t.Fatalf("Failed to execute request: %v", err)
293 }
294 defer res.Body.Close()
295
296 // Check response status
297 if res.StatusCode != http.StatusOK {
298 t.Fatalf("Expected status OK, got %v", res.Status)
299 }
300
301 // Check content type
302 if contentType := res.Header.Get("Content-Type"); contentType != "text/event-stream" {
303 t.Fatalf("Expected Content-Type text/event-stream, got %s", contentType)
304 }
305
306 // Read response events using a scanner
307 scanner := bufio.NewScanner(res.Body)
308
309 // Track events received
310 eventsReceived := map[string]int{
311 "state": 0,
312 "message": 0,
313 "heartbeat": 0,
314 }
315
316 // Read for a short time to capture initial state and messages
317 dataLines := []string{}
318 eventType := ""
319
320 go func() {
321 // After reading for a while, add a new message to test real-time updates
322 time.Sleep(500 * time.Millisecond)
323
324 mockAgent.AddMessage(loop.AgentMessage{
325 Type: loop.ToolUseMessageType,
326 Content: "This is a new real-time message",
327 Timestamp: time.Now(),
328 ToolName: "test_tool",
329 })
330
Philip Zeyligereab12de2025-05-14 02:35:53 +0000331 // Trigger a state transition to test state updates
332 time.Sleep(200 * time.Millisecond)
333 mockAgent.TriggerStateTransition(loop.StateReady, loop.StateSendingToLLM, loop.TransitionEvent{
334 Description: "Agent started thinking",
335 Data: "start_thinking",
336 })
337
Philip Zeyliger25f6ff12025-05-02 04:24:10 +0000338 // Let it process for longer
339 time.Sleep(1000 * time.Millisecond)
340 cancel() // Cancel to end the test
341 }()
342
343 // Read events
344 for scanner.Scan() {
345 line := scanner.Text()
346
347 if strings.HasPrefix(line, "event: ") {
348 eventType = strings.TrimPrefix(line, "event: ")
349 eventsReceived[eventType]++
350 } else if strings.HasPrefix(line, "data: ") {
351 dataLines = append(dataLines, line)
352 } else if line == "" && eventType != "" {
353 // End of event
354 eventType = ""
355 }
356
357 // Break if context is done
358 if ctx.Err() != nil {
359 break
360 }
361 }
362
363 if err := scanner.Err(); err != nil && ctx.Err() == nil {
364 t.Fatalf("Scanner error: %v", err)
365 }
366
367 // Simplified validation - just make sure we received something
368 t.Logf("Events received: %v", eventsReceived)
369 t.Logf("Data lines received: %d", len(dataLines))
370
371 // Basic validation that we received at least some events
372 if eventsReceived["state"] == 0 && eventsReceived["message"] == 0 {
373 t.Errorf("Did not receive any events")
374 }
375}