blob: 22ad237a6dc2a169467c1cabc039dfea6c71e574 [file] [log] [blame]
Philip Zeyliger25f6ff12025-05-02 04:24:10 +00001package server_test
2
3import (
4 "bufio"
5 "context"
6 "net/http"
7 "net/http/httptest"
8 "slices"
9 "strings"
10 "sync"
11 "testing"
12 "time"
13
14 "sketch.dev/llm/conversation"
15 "sketch.dev/loop"
16 "sketch.dev/loop/server"
17)
18
19// mockAgent is a mock implementation of loop.CodingAgent for testing
20type mockAgent struct {
21 mu sync.RWMutex
22 messages []loop.AgentMessage
23 messageCount int
24 currentState string
25 subscribers []chan *loop.AgentMessage
26 initialCommit string
27 title string
28 branchName string
29}
30
31func (m *mockAgent) NewIterator(ctx context.Context, nextMessageIdx int) loop.MessageIterator {
32 m.mu.RLock()
33 // Send existing messages that should be available immediately
34 ch := make(chan *loop.AgentMessage, 100)
35 iter := &mockIterator{
36 agent: m,
37 ctx: ctx,
38 nextMessageIdx: nextMessageIdx,
39 ch: ch,
40 }
41 m.mu.RUnlock()
42 return iter
43}
44
45type mockIterator struct {
46 agent *mockAgent
47 ctx context.Context
48 nextMessageIdx int
49 ch chan *loop.AgentMessage
50 subscribed bool
51}
52
53func (m *mockIterator) Next() *loop.AgentMessage {
54 if !m.subscribed {
55 m.agent.mu.Lock()
56 m.agent.subscribers = append(m.agent.subscribers, m.ch)
57 m.agent.mu.Unlock()
58 m.subscribed = true
59 }
60
61 for {
62 select {
63 case <-m.ctx.Done():
64 return nil
65 case msg := <-m.ch:
66 return msg
67 }
68 }
69}
70
71func (m *mockIterator) Close() {
72 // Remove from subscribers using slices.Delete
73 m.agent.mu.Lock()
74 for i, ch := range m.agent.subscribers {
75 if ch == m.ch {
76 m.agent.subscribers = slices.Delete(m.agent.subscribers, i, i+1)
77 break
78 }
79 }
80 m.agent.mu.Unlock()
81 close(m.ch)
82}
83
84func (m *mockAgent) Messages(start int, end int) []loop.AgentMessage {
85 m.mu.RLock()
86 defer m.mu.RUnlock()
87
88 if start >= len(m.messages) || end > len(m.messages) || start < 0 || end < 0 {
89 return []loop.AgentMessage{}
90 }
91 return slices.Clone(m.messages[start:end])
92}
93
94func (m *mockAgent) MessageCount() int {
95 m.mu.RLock()
96 defer m.mu.RUnlock()
97 return m.messageCount
98}
99
100func (m *mockAgent) AddMessage(msg loop.AgentMessage) {
101 m.mu.Lock()
102 msg.Idx = m.messageCount
103 m.messages = append(m.messages, msg)
104 m.messageCount++
105
106 // Create a copy of subscribers to avoid holding the lock while sending
107 subscribers := make([]chan *loop.AgentMessage, len(m.subscribers))
108 copy(subscribers, m.subscribers)
109 m.mu.Unlock()
110
111 // Notify subscribers
112 msgCopy := msg // Create a copy to avoid race conditions
113 for _, ch := range subscribers {
114 ch <- &msgCopy
115 }
116}
117
118func (m *mockAgent) CurrentStateName() string {
119 m.mu.RLock()
120 defer m.mu.RUnlock()
121 return m.currentState
122}
123
124func (m *mockAgent) InitialCommit() string {
125 m.mu.RLock()
126 defer m.mu.RUnlock()
127 return m.initialCommit
128}
129
130func (m *mockAgent) Title() string {
131 m.mu.RLock()
132 defer m.mu.RUnlock()
133 return m.title
134}
135
136func (m *mockAgent) BranchName() string {
137 m.mu.RLock()
138 defer m.mu.RUnlock()
139 return m.branchName
140}
141
142// Other required methods of loop.CodingAgent with minimal implementation
143func (m *mockAgent) Init(loop.AgentInit) error { return nil }
144func (m *mockAgent) Ready() <-chan struct{} { ch := make(chan struct{}); close(ch); return ch }
145func (m *mockAgent) URL() string { return "http://localhost:8080" }
146func (m *mockAgent) UserMessage(ctx context.Context, msg string) {}
147func (m *mockAgent) Loop(ctx context.Context) {}
148func (m *mockAgent) CancelTurn(cause error) {}
149func (m *mockAgent) CancelToolUse(id string, cause error) error { return nil }
150func (m *mockAgent) TotalUsage() conversation.CumulativeUsage { return conversation.CumulativeUsage{} }
151func (m *mockAgent) OriginalBudget() conversation.Budget { return conversation.Budget{} }
152func (m *mockAgent) WorkingDir() string { return "/app" }
153func (m *mockAgent) Diff(commit *string) (string, error) { return "", nil }
154func (m *mockAgent) OS() string { return "linux" }
155func (m *mockAgent) SessionID() string { return "test-session" }
156func (m *mockAgent) OutstandingLLMCallCount() int { return 0 }
157func (m *mockAgent) OutstandingToolCalls() []string { return nil }
158func (m *mockAgent) OutsideOS() string { return "linux" }
159func (m *mockAgent) OutsideHostname() string { return "test-host" }
160func (m *mockAgent) OutsideWorkingDir() string { return "/app" }
161func (m *mockAgent) GitOrigin() string { return "" }
162func (m *mockAgent) OpenBrowser(url string) {}
163func (m *mockAgent) RestartConversation(ctx context.Context, rev string, initialPrompt string) error {
164 return nil
165}
166func (m *mockAgent) SuggestReprompt(ctx context.Context) (string, error) { return "", nil }
167func (m *mockAgent) IsInContainer() bool { return false }
168func (m *mockAgent) FirstMessageIndex() int { return 0 }
169
170// TestSSEStream tests the SSE stream endpoint
171func TestSSEStream(t *testing.T) {
172 // Create a mock agent with initial messages
173 mockAgent := &mockAgent{
174 messages: []loop.AgentMessage{},
175 messageCount: 0,
176 currentState: "Ready",
177 subscribers: []chan *loop.AgentMessage{},
178 initialCommit: "abcd1234",
179 title: "Test Title",
180 branchName: "sketch/test-branch",
181 }
182
183 // Add the initial messages before creating the server
184 // to ensure they're available in the Messages slice
185 msg1 := loop.AgentMessage{
186 Type: loop.UserMessageType,
187 Content: "Hello, this is a test message",
188 Timestamp: time.Now(),
189 }
190 mockAgent.messages = append(mockAgent.messages, msg1)
191 msg1.Idx = mockAgent.messageCount
192 mockAgent.messageCount++
193
194 msg2 := loop.AgentMessage{
195 Type: loop.AgentMessageType,
196 Content: "This is a response message",
197 Timestamp: time.Now(),
198 EndOfTurn: true,
199 }
200 mockAgent.messages = append(mockAgent.messages, msg2)
201 msg2.Idx = mockAgent.messageCount
202 mockAgent.messageCount++
203
204 // Create a server with the mock agent
205 srv, err := server.New(mockAgent, nil)
206 if err != nil {
207 t.Fatalf("Failed to create server: %v", err)
208 }
209
210 // Create a test server
211 ts := httptest.NewServer(srv)
212 defer ts.Close()
213
214 // Create a context with cancellation for the client request
215 ctx, cancel := context.WithCancel(context.Background())
216
217 // Create a request to the /stream endpoint
218 req, err := http.NewRequestWithContext(ctx, "GET", ts.URL+"/stream?from=0", nil)
219 if err != nil {
220 t.Fatalf("Failed to create request: %v", err)
221 }
222
223 // Execute the request
224 res, err := http.DefaultClient.Do(req)
225 if err != nil {
226 t.Fatalf("Failed to execute request: %v", err)
227 }
228 defer res.Body.Close()
229
230 // Check response status
231 if res.StatusCode != http.StatusOK {
232 t.Fatalf("Expected status OK, got %v", res.Status)
233 }
234
235 // Check content type
236 if contentType := res.Header.Get("Content-Type"); contentType != "text/event-stream" {
237 t.Fatalf("Expected Content-Type text/event-stream, got %s", contentType)
238 }
239
240 // Read response events using a scanner
241 scanner := bufio.NewScanner(res.Body)
242
243 // Track events received
244 eventsReceived := map[string]int{
245 "state": 0,
246 "message": 0,
247 "heartbeat": 0,
248 }
249
250 // Read for a short time to capture initial state and messages
251 dataLines := []string{}
252 eventType := ""
253
254 go func() {
255 // After reading for a while, add a new message to test real-time updates
256 time.Sleep(500 * time.Millisecond)
257
258 mockAgent.AddMessage(loop.AgentMessage{
259 Type: loop.ToolUseMessageType,
260 Content: "This is a new real-time message",
261 Timestamp: time.Now(),
262 ToolName: "test_tool",
263 })
264
265 // Let it process for longer
266 time.Sleep(1000 * time.Millisecond)
267 cancel() // Cancel to end the test
268 }()
269
270 // Read events
271 for scanner.Scan() {
272 line := scanner.Text()
273
274 if strings.HasPrefix(line, "event: ") {
275 eventType = strings.TrimPrefix(line, "event: ")
276 eventsReceived[eventType]++
277 } else if strings.HasPrefix(line, "data: ") {
278 dataLines = append(dataLines, line)
279 } else if line == "" && eventType != "" {
280 // End of event
281 eventType = ""
282 }
283
284 // Break if context is done
285 if ctx.Err() != nil {
286 break
287 }
288 }
289
290 if err := scanner.Err(); err != nil && ctx.Err() == nil {
291 t.Fatalf("Scanner error: %v", err)
292 }
293
294 // Simplified validation - just make sure we received something
295 t.Logf("Events received: %v", eventsReceived)
296 t.Logf("Data lines received: %d", len(dataLines))
297
298 // Basic validation that we received at least some events
299 if eventsReceived["state"] == 0 && eventsReceived["message"] == 0 {
300 t.Errorf("Did not receive any events")
301 }
302}