| package server_test |
| |
| import ( |
| "bufio" |
| "context" |
| "encoding/json" |
| "net/http" |
| "net/http/httptest" |
| "slices" |
| "strings" |
| "sync" |
| "testing" |
| "time" |
| |
| "sketch.dev/llm/conversation" |
| "sketch.dev/loop" |
| "sketch.dev/loop/server" |
| ) |
| |
| // mockAgent is a mock implementation of loop.CodingAgent for testing |
| type mockAgent struct { |
| mu sync.RWMutex |
| messages []loop.AgentMessage |
| messageCount int |
| currentState string |
| subscribers []chan *loop.AgentMessage |
| stateTransitionListeners []chan loop.StateTransition |
| initialCommit string |
| title string |
| branchName string |
| branchPrefix string |
| workingDir string |
| sessionID string |
| } |
| |
| func (m *mockAgent) NewIterator(ctx context.Context, nextMessageIdx int) loop.MessageIterator { |
| m.mu.RLock() |
| // Send existing messages that should be available immediately |
| ch := make(chan *loop.AgentMessage, 100) |
| iter := &mockIterator{ |
| agent: m, |
| ctx: ctx, |
| nextMessageIdx: nextMessageIdx, |
| ch: ch, |
| } |
| m.mu.RUnlock() |
| return iter |
| } |
| |
| type mockIterator struct { |
| agent *mockAgent |
| ctx context.Context |
| nextMessageIdx int |
| ch chan *loop.AgentMessage |
| subscribed bool |
| } |
| |
| func (m *mockIterator) Next() *loop.AgentMessage { |
| if !m.subscribed { |
| m.agent.mu.Lock() |
| m.agent.subscribers = append(m.agent.subscribers, m.ch) |
| m.agent.mu.Unlock() |
| m.subscribed = true |
| } |
| |
| for { |
| select { |
| case <-m.ctx.Done(): |
| return nil |
| case msg := <-m.ch: |
| return msg |
| } |
| } |
| } |
| |
| func (m *mockIterator) Close() { |
| // Remove from subscribers using slices.Delete |
| m.agent.mu.Lock() |
| for i, ch := range m.agent.subscribers { |
| if ch == m.ch { |
| m.agent.subscribers = slices.Delete(m.agent.subscribers, i, i+1) |
| break |
| } |
| } |
| m.agent.mu.Unlock() |
| close(m.ch) |
| } |
| |
| func (m *mockAgent) Messages(start int, end int) []loop.AgentMessage { |
| m.mu.RLock() |
| defer m.mu.RUnlock() |
| |
| if start >= len(m.messages) || end > len(m.messages) || start < 0 || end < 0 { |
| return []loop.AgentMessage{} |
| } |
| return slices.Clone(m.messages[start:end]) |
| } |
| |
| func (m *mockAgent) MessageCount() int { |
| m.mu.RLock() |
| defer m.mu.RUnlock() |
| return m.messageCount |
| } |
| |
| func (m *mockAgent) AddMessage(msg loop.AgentMessage) { |
| m.mu.Lock() |
| msg.Idx = m.messageCount |
| m.messages = append(m.messages, msg) |
| m.messageCount++ |
| |
| // Create a copy of subscribers to avoid holding the lock while sending |
| subscribers := make([]chan *loop.AgentMessage, len(m.subscribers)) |
| copy(subscribers, m.subscribers) |
| m.mu.Unlock() |
| |
| // Notify subscribers |
| msgCopy := msg // Create a copy to avoid race conditions |
| for _, ch := range subscribers { |
| ch <- &msgCopy |
| } |
| } |
| |
| func (m *mockAgent) NewStateTransitionIterator(ctx context.Context) loop.StateTransitionIterator { |
| m.mu.Lock() |
| ch := make(chan loop.StateTransition, 10) |
| m.stateTransitionListeners = append(m.stateTransitionListeners, ch) |
| m.mu.Unlock() |
| |
| return &mockStateTransitionIterator{ |
| agent: m, |
| ctx: ctx, |
| ch: ch, |
| } |
| } |
| |
| type mockStateTransitionIterator struct { |
| agent *mockAgent |
| ctx context.Context |
| ch chan loop.StateTransition |
| } |
| |
| func (m *mockStateTransitionIterator) Next() *loop.StateTransition { |
| select { |
| case <-m.ctx.Done(): |
| return nil |
| case transition, ok := <-m.ch: |
| if !ok { |
| return nil |
| } |
| transitionCopy := transition |
| return &transitionCopy |
| } |
| } |
| |
| func (m *mockStateTransitionIterator) Close() { |
| m.agent.mu.Lock() |
| for i, ch := range m.agent.stateTransitionListeners { |
| if ch == m.ch { |
| m.agent.stateTransitionListeners = slices.Delete(m.agent.stateTransitionListeners, i, i+1) |
| break |
| } |
| } |
| m.agent.mu.Unlock() |
| close(m.ch) |
| } |
| |
| func (m *mockAgent) CurrentStateName() string { |
| m.mu.RLock() |
| defer m.mu.RUnlock() |
| return m.currentState |
| } |
| |
| func (m *mockAgent) TriggerStateTransition(from, to loop.State, event loop.TransitionEvent) { |
| m.mu.Lock() |
| m.currentState = to.String() |
| transition := loop.StateTransition{ |
| From: from, |
| To: to, |
| Event: event, |
| } |
| |
| // Create a copy of listeners to avoid holding the lock while sending |
| listeners := make([]chan loop.StateTransition, len(m.stateTransitionListeners)) |
| copy(listeners, m.stateTransitionListeners) |
| m.mu.Unlock() |
| |
| // Notify listeners |
| for _, ch := range listeners { |
| ch <- transition |
| } |
| } |
| |
| func (m *mockAgent) InitialCommit() string { |
| m.mu.RLock() |
| defer m.mu.RUnlock() |
| return m.initialCommit |
| } |
| |
| func (m *mockAgent) SketchGitBase() string { |
| m.mu.RLock() |
| defer m.mu.RUnlock() |
| return m.initialCommit |
| } |
| |
| func (m *mockAgent) SketchGitBaseRef() string { |
| return "sketch-base-test-session" |
| } |
| |
| func (m *mockAgent) Title() string { |
| m.mu.RLock() |
| defer m.mu.RUnlock() |
| return m.title |
| } |
| |
| func (m *mockAgent) BranchName() string { |
| m.mu.RLock() |
| defer m.mu.RUnlock() |
| return m.branchName |
| } |
| |
| // Other required methods of loop.CodingAgent with minimal implementation |
| func (m *mockAgent) Init(loop.AgentInit) error { return nil } |
| func (m *mockAgent) Ready() <-chan struct{} { ch := make(chan struct{}); close(ch); return ch } |
| func (m *mockAgent) URL() string { return "http://localhost:8080" } |
| func (m *mockAgent) UserMessage(ctx context.Context, msg string) {} |
| func (m *mockAgent) Loop(ctx context.Context) {} |
| func (m *mockAgent) CancelTurn(cause error) {} |
| func (m *mockAgent) CancelToolUse(id string, cause error) error { return nil } |
| func (m *mockAgent) TotalUsage() conversation.CumulativeUsage { return conversation.CumulativeUsage{} } |
| func (m *mockAgent) OriginalBudget() conversation.Budget { return conversation.Budget{} } |
| func (m *mockAgent) WorkingDir() string { return m.workingDir } |
| func (m *mockAgent) RepoRoot() string { return m.workingDir } |
| func (m *mockAgent) Diff(commit *string) (string, error) { return "", nil } |
| func (m *mockAgent) OS() string { return "linux" } |
| func (m *mockAgent) SessionID() string { return m.sessionID } |
| func (m *mockAgent) BranchPrefix() string { return m.branchPrefix } |
| func (m *mockAgent) CurrentTodoContent() string { return "" } // Mock returns empty for simplicity |
| func (m *mockAgent) OutstandingLLMCallCount() int { return 0 } |
| func (m *mockAgent) OutstandingToolCalls() []string { return nil } |
| func (m *mockAgent) OutsideOS() string { return "linux" } |
| func (m *mockAgent) OutsideHostname() string { return "test-host" } |
| func (m *mockAgent) OutsideWorkingDir() string { return "/app" } |
| func (m *mockAgent) GitOrigin() string { return "" } |
| func (m *mockAgent) OpenBrowser(url string) {} |
| func (m *mockAgent) CompactConversation(ctx context.Context) error { |
| // Mock implementation - just return nil |
| return nil |
| } |
| func (m *mockAgent) IsInContainer() bool { return false } |
| func (m *mockAgent) FirstMessageIndex() int { return 0 } |
| func (m *mockAgent) DetectGitChanges(ctx context.Context) error { return nil } |
| |
| func (m *mockAgent) GetPortMonitor() *loop.PortMonitor { return loop.NewPortMonitor() } |
| |
| // TestSSEStream tests the SSE stream endpoint |
| func TestSSEStream(t *testing.T) { |
| // Create a mock agent with initial messages |
| mockAgent := &mockAgent{ |
| messages: []loop.AgentMessage{}, |
| messageCount: 0, |
| currentState: "Ready", |
| subscribers: []chan *loop.AgentMessage{}, |
| stateTransitionListeners: []chan loop.StateTransition{}, |
| initialCommit: "abcd1234", |
| title: "Test Title", |
| branchName: "sketch/test-branch", |
| branchPrefix: "sketch/", |
| } |
| |
| // Add the initial messages before creating the server |
| // to ensure they're available in the Messages slice |
| msg1 := loop.AgentMessage{ |
| Type: loop.UserMessageType, |
| Content: "Hello, this is a test message", |
| Timestamp: time.Now(), |
| } |
| mockAgent.messages = append(mockAgent.messages, msg1) |
| msg1.Idx = mockAgent.messageCount |
| mockAgent.messageCount++ |
| |
| msg2 := loop.AgentMessage{ |
| Type: loop.AgentMessageType, |
| Content: "This is a response message", |
| Timestamp: time.Now(), |
| EndOfTurn: true, |
| } |
| mockAgent.messages = append(mockAgent.messages, msg2) |
| msg2.Idx = mockAgent.messageCount |
| mockAgent.messageCount++ |
| |
| // Create a server with the mock agent |
| srv, err := server.New(mockAgent, nil) |
| if err != nil { |
| t.Fatalf("Failed to create server: %v", err) |
| } |
| |
| // Create a test server |
| ts := httptest.NewServer(srv) |
| defer ts.Close() |
| |
| // Create a context with cancellation for the client request |
| ctx, cancel := context.WithCancel(context.Background()) |
| |
| // Create a request to the /stream endpoint |
| req, err := http.NewRequestWithContext(ctx, "GET", ts.URL+"/stream?from=0", nil) |
| if err != nil { |
| t.Fatalf("Failed to create request: %v", err) |
| } |
| |
| // Execute the request |
| res, err := http.DefaultClient.Do(req) |
| if err != nil { |
| t.Fatalf("Failed to execute request: %v", err) |
| } |
| defer res.Body.Close() |
| |
| // Check response status |
| if res.StatusCode != http.StatusOK { |
| t.Fatalf("Expected status OK, got %v", res.Status) |
| } |
| |
| // Check content type |
| if contentType := res.Header.Get("Content-Type"); contentType != "text/event-stream" { |
| t.Fatalf("Expected Content-Type text/event-stream, got %s", contentType) |
| } |
| |
| // Read response events using a scanner |
| scanner := bufio.NewScanner(res.Body) |
| |
| // Track events received |
| eventsReceived := map[string]int{ |
| "state": 0, |
| "message": 0, |
| "heartbeat": 0, |
| } |
| |
| // Read for a short time to capture initial state and messages |
| dataLines := []string{} |
| eventType := "" |
| |
| go func() { |
| // After reading for a while, add a new message to test real-time updates |
| time.Sleep(500 * time.Millisecond) |
| |
| mockAgent.AddMessage(loop.AgentMessage{ |
| Type: loop.ToolUseMessageType, |
| Content: "This is a new real-time message", |
| Timestamp: time.Now(), |
| ToolName: "test_tool", |
| }) |
| |
| // Trigger a state transition to test state updates |
| time.Sleep(200 * time.Millisecond) |
| mockAgent.TriggerStateTransition(loop.StateReady, loop.StateSendingToLLM, loop.TransitionEvent{ |
| Description: "Agent started thinking", |
| Data: "start_thinking", |
| }) |
| |
| // Let it process for longer |
| time.Sleep(1000 * time.Millisecond) |
| cancel() // Cancel to end the test |
| }() |
| |
| // Read events |
| for scanner.Scan() { |
| line := scanner.Text() |
| |
| if strings.HasPrefix(line, "event: ") { |
| eventType = strings.TrimPrefix(line, "event: ") |
| eventsReceived[eventType]++ |
| } else if strings.HasPrefix(line, "data: ") { |
| dataLines = append(dataLines, line) |
| } else if line == "" && eventType != "" { |
| // End of event |
| eventType = "" |
| } |
| |
| // Break if context is done |
| if ctx.Err() != nil { |
| break |
| } |
| } |
| |
| if err := scanner.Err(); err != nil && ctx.Err() == nil { |
| t.Fatalf("Scanner error: %v", err) |
| } |
| |
| // Simplified validation - just make sure we received something |
| t.Logf("Events received: %v", eventsReceived) |
| t.Logf("Data lines received: %d", len(dataLines)) |
| |
| // Basic validation that we received at least some events |
| if eventsReceived["state"] == 0 && eventsReceived["message"] == 0 { |
| t.Errorf("Did not receive any events") |
| } |
| } |
| |
| func TestGitRawDiffHandler(t *testing.T) { |
| // Create a mock agent |
| mockAgent := &mockAgent{ |
| workingDir: t.TempDir(), // Use a temp directory |
| branchPrefix: "sketch/", |
| } |
| |
| // Create the server with the mock agent |
| server, err := server.New(mockAgent, nil) |
| if err != nil { |
| t.Fatalf("Failed to create server: %v", err) |
| } |
| |
| // Create a test HTTP server |
| testServer := httptest.NewServer(server) |
| defer testServer.Close() |
| |
| // Test missing parameters |
| resp, err := http.Get(testServer.URL + "/git/rawdiff") |
| if err != nil { |
| t.Fatalf("Failed to make HTTP request: %v", err) |
| } |
| if resp.StatusCode != http.StatusBadRequest { |
| t.Errorf("Expected status bad request, got: %d", resp.StatusCode) |
| } |
| |
| // Test with commit parameter (this will fail due to no git repo, but we're testing the API, not git) |
| resp, err = http.Get(testServer.URL + "/git/rawdiff?commit=HEAD") |
| if err != nil { |
| t.Fatalf("Failed to make HTTP request: %v", err) |
| } |
| // We expect an error since there's no git repository, but the request should be processed |
| if resp.StatusCode != http.StatusInternalServerError { |
| t.Errorf("Expected status 500, got: %d", resp.StatusCode) |
| } |
| |
| // Test with from/to parameters |
| resp, err = http.Get(testServer.URL + "/git/rawdiff?from=HEAD~1&to=HEAD") |
| if err != nil { |
| t.Fatalf("Failed to make HTTP request: %v", err) |
| } |
| // We expect an error since there's no git repository, but the request should be processed |
| if resp.StatusCode != http.StatusInternalServerError { |
| t.Errorf("Expected status 500, got: %d", resp.StatusCode) |
| } |
| } |
| |
| func TestGitShowHandler(t *testing.T) { |
| // Create a mock agent |
| mockAgent := &mockAgent{ |
| workingDir: t.TempDir(), // Use a temp directory |
| branchPrefix: "sketch/", |
| } |
| |
| // Create the server with the mock agent |
| server, err := server.New(mockAgent, nil) |
| if err != nil { |
| t.Fatalf("Failed to create server: %v", err) |
| } |
| |
| // Create a test HTTP server |
| testServer := httptest.NewServer(server) |
| defer testServer.Close() |
| |
| // Test missing parameter |
| resp, err := http.Get(testServer.URL + "/git/show") |
| if err != nil { |
| t.Fatalf("Failed to make HTTP request: %v", err) |
| } |
| if resp.StatusCode != http.StatusBadRequest { |
| t.Errorf("Expected status bad request, got: %d", resp.StatusCode) |
| } |
| |
| // Test with hash parameter (this will fail due to no git repo, but we're testing the API, not git) |
| resp, err = http.Get(testServer.URL + "/git/show?hash=HEAD") |
| if err != nil { |
| t.Fatalf("Failed to make HTTP request: %v", err) |
| } |
| // We expect an error since there's no git repository, but the request should be processed |
| if resp.StatusCode != http.StatusInternalServerError { |
| t.Errorf("Expected status 500, got: %d", resp.StatusCode) |
| } |
| } |
| |
| func TestCompactHandler(t *testing.T) { |
| // Test that mock CompactConversation works |
| mockAgent := &mockAgent{ |
| messages: []loop.AgentMessage{}, |
| messageCount: 0, |
| sessionID: "test-session", |
| branchPrefix: "sketch/", |
| } |
| |
| ctx := context.Background() |
| err := mockAgent.CompactConversation(ctx) |
| if err != nil { |
| t.Errorf("Mock CompactConversation failed: %v", err) |
| } |
| |
| // No HTTP endpoint to test anymore - compaction is done via /compact message |
| t.Log("Mock CompactConversation works correctly") |
| } |
| |
| // TestPortEventsEndpoint tests the /port-events HTTP endpoint |
| func TestPortEventsEndpoint(t *testing.T) { |
| // Create a mock agent that implements the CodingAgent interface |
| agent := &mockAgent{ |
| branchPrefix: "sketch/", |
| } |
| |
| // Create a server with the mock agent |
| server, err := server.New(agent, nil) |
| if err != nil { |
| t.Fatalf("Failed to create server: %v", err) |
| } |
| |
| // Test GET /port-events |
| req, err := http.NewRequest("GET", "/port-events", nil) |
| if err != nil { |
| t.Fatalf("Failed to create request: %v", err) |
| } |
| |
| rr := httptest.NewRecorder() |
| server.ServeHTTP(rr, req) |
| |
| // Should return 200 OK |
| if status := rr.Code; status != http.StatusOK { |
| t.Errorf("Expected status code %d, got %d", http.StatusOK, status) |
| } |
| |
| // Should return JSON content type |
| contentType := rr.Header().Get("Content-Type") |
| if contentType != "application/json" { |
| t.Errorf("Expected Content-Type application/json, got %s", contentType) |
| } |
| |
| // Should return valid JSON (empty array since mock returns no events) |
| var events []any |
| if err := json.Unmarshal(rr.Body.Bytes(), &events); err != nil { |
| t.Errorf("Failed to parse JSON response: %v", err) |
| } |
| |
| // Should be empty array for mock agent |
| if len(events) != 0 { |
| t.Errorf("Expected empty events array, got %d events", len(events)) |
| } |
| } |
| |
| // TestPortEventsEndpointMethodNotAllowed tests that non-GET requests are rejected |
| func TestPortEventsEndpointMethodNotAllowed(t *testing.T) { |
| agent := &mockAgent{ |
| branchPrefix: "sketch/", |
| } |
| server, err := server.New(agent, nil) |
| if err != nil { |
| t.Fatalf("Failed to create server: %v", err) |
| } |
| |
| // Test POST /port-events (should be rejected) |
| req, err := http.NewRequest("POST", "/port-events", nil) |
| if err != nil { |
| t.Fatalf("Failed to create request: %v", err) |
| } |
| |
| rr := httptest.NewRecorder() |
| server.ServeHTTP(rr, req) |
| |
| // Should return 405 Method Not Allowed |
| if status := rr.Code; status != http.StatusMethodNotAllowed { |
| t.Errorf("Expected status code %d, got %d", http.StatusMethodNotAllowed, status) |
| } |
| } |