sketch/loop: add PortMonitor for TCP port monitoring with Agent integration
Add PortMonitor struct that uses Tailscale portlist library to monitor
open/listening TCP ports and send AgentMessage notifications to Agent
when ports are opened or closed, with cached port list access method.
When I asked Sketch to do this with the old implementation, it did
ok parsing /proc, but then it tried to conver it to ss format...
using a library seems to work ok!
Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: s8fc57de4b5583d34k
diff --git a/loop/agent.go b/loop/agent.go
index 8f96924..82d823e 100644
--- a/loop/agent.go
+++ b/loop/agent.go
@@ -30,6 +30,7 @@
"sketch.dev/llm/conversation"
"sketch.dev/mcp"
"sketch.dev/skabandclient"
+ "tailscale.com/portlist"
)
const (
@@ -155,6 +156,9 @@
// SkabandAddr returns the skaband address if configured
SkabandAddr() string
+
+ // GetPorts returns the cached list of open TCP ports
+ GetPorts() []portlist.Port
}
type CodingAgentMessageType string
@@ -168,6 +172,7 @@
CommitMessageType CodingAgentMessageType = "commit" // for displaying git commits
AutoMessageType CodingAgentMessageType = "auto" // for automated notifications like autoformatting
CompactMessageType CodingAgentMessageType = "compact" // for conversation compaction notifications
+ PortMessageType CodingAgentMessageType = "port" // for port monitoring events
cancelToolUseMessage = "Stop responding to my previous message. Wait for me to ask you something else before attempting to use any more tools."
)
@@ -430,6 +435,8 @@
gitOrigin string
// MCP manager for handling MCP server connections
mcpManager *mcp.MCPManager
+ // Port monitor for tracking TCP ports
+ portMonitor *PortMonitor
// Time when the current turn started (reset at the beginning of InnerLoop)
startOfTurn time.Time
@@ -654,6 +661,14 @@
func (a *Agent) URL() string { return a.url }
+// GetPorts returns the cached list of open TCP ports.
+func (a *Agent) GetPorts() []portlist.Port {
+ if a.portMonitor == nil {
+ return nil
+ }
+ return a.portMonitor.GetPorts()
+}
+
// BranchName returns the git branch name for the conversation.
func (a *Agent) BranchName() string {
return a.gitState.BranchName(a.config.BranchPrefix)
@@ -1070,6 +1085,10 @@
mcpManager: mcp.NewMCPManager(),
}
+
+ // Initialize port monitor with 5-second interval
+ agent.portMonitor = NewPortMonitor(agent, 5*time.Second)
+
return agent
}
@@ -1522,11 +1541,23 @@
}
func (a *Agent) Loop(ctxOuter context.Context) {
+ // Start port monitoring
+ if a.portMonitor != nil && a.IsInContainer() {
+ if err := a.portMonitor.Start(ctxOuter); err != nil {
+ slog.WarnContext(ctxOuter, "Failed to start port monitor", "error", err)
+ } else {
+ slog.InfoContext(ctxOuter, "Port monitor started")
+ }
+ }
+
// Set up cleanup when context is done
defer func() {
if a.mcpManager != nil {
a.mcpManager.Close()
}
+ if a.portMonitor != nil && a.IsInContainer() {
+ a.portMonitor.Stop()
+ }
}()
for {
diff --git a/loop/port_monitor.go b/loop/port_monitor.go
new file mode 100644
index 0000000..c122e51
--- /dev/null
+++ b/loop/port_monitor.go
@@ -0,0 +1,246 @@
+package loop
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "sort"
+ "sync"
+ "time"
+
+ "tailscale.com/portlist"
+)
+
+// PortMonitor monitors open/listening TCP ports and sends notifications
+// to an Agent when ports are detected or removed.
+type PortMonitor struct {
+ mu sync.RWMutex
+ ports []portlist.Port // cached list of current ports
+ poller *portlist.Poller
+ agent *Agent
+ ctx context.Context
+ cancel context.CancelFunc
+ interval time.Duration
+ running bool
+ wg sync.WaitGroup
+}
+
+// NewPortMonitor creates a new PortMonitor instance.
+func NewPortMonitor(agent *Agent, interval time.Duration) *PortMonitor {
+ if interval <= 0 {
+ interval = 5 * time.Second // default polling interval
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ poller := &portlist.Poller{
+ IncludeLocalhost: true, // include localhost-bound services
+ }
+
+ return &PortMonitor{
+ poller: poller,
+ agent: agent,
+ ctx: ctx,
+ cancel: cancel,
+ interval: interval,
+ }
+}
+
+// Start begins monitoring ports in a background goroutine.
+func (pm *PortMonitor) Start(ctx context.Context) error {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
+
+ if pm.running {
+ return fmt.Errorf("port monitor is already running")
+ }
+
+ // Update the internal context to use the provided context
+ pm.cancel() // Cancel the old context
+ pm.ctx, pm.cancel = context.WithCancel(ctx)
+
+ pm.running = true
+ pm.wg.Add(1)
+
+ // Do initial port scan
+ if err := pm.initialScan(); err != nil {
+ pm.running = false
+ pm.wg.Done()
+ return fmt.Errorf("initial port scan failed: %w", err)
+ }
+
+ go pm.monitor()
+ return nil
+}
+
+// Stop stops the port monitor.
+func (pm *PortMonitor) Stop() {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
+
+ if !pm.running {
+ return
+ }
+
+ pm.running = false
+ pm.cancel()
+ pm.wg.Wait()
+ pm.poller.Close()
+}
+
+// GetPorts returns the cached list of open ports.
+func (pm *PortMonitor) GetPorts() []portlist.Port {
+ pm.mu.RLock()
+ defer pm.mu.RUnlock()
+
+ // Return a copy to prevent data races
+ ports := make([]portlist.Port, len(pm.ports))
+ copy(ports, pm.ports)
+ return ports
+}
+
+// initialScan performs the initial port scan without sending notifications.
+func (pm *PortMonitor) initialScan() error {
+ ports, _, err := pm.poller.Poll()
+ if err != nil {
+ return err
+ }
+
+ // Filter for TCP ports only
+ pm.ports = filterTCPPorts(ports)
+ sortPorts(pm.ports)
+
+ return nil
+}
+
+// monitor runs the port monitoring loop.
+func (pm *PortMonitor) monitor() {
+ defer pm.wg.Done()
+
+ ticker := time.NewTicker(pm.interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-pm.ctx.Done():
+ return
+ case <-ticker.C:
+ if err := pm.checkPorts(); err != nil {
+ slog.WarnContext(pm.ctx, "port monitoring error", "error", err)
+ }
+ }
+ }
+}
+
+// checkPorts polls for current ports and sends notifications for changes.
+func (pm *PortMonitor) checkPorts() error {
+ ports, changed, err := pm.poller.Poll()
+ if err != nil {
+ return err
+ }
+
+ if !changed {
+ return nil
+ }
+
+ // Filter for TCP ports only
+ currentTCPPorts := filterTCPPorts(ports)
+ sortPorts(currentTCPPorts)
+
+ pm.mu.Lock()
+ previousPorts := pm.ports
+ pm.ports = currentTCPPorts
+ pm.mu.Unlock()
+
+ // Find added and removed ports
+ addedPorts := findAddedPorts(previousPorts, currentTCPPorts)
+ removedPorts := findRemovedPorts(previousPorts, currentTCPPorts)
+
+ // Send notifications for changes
+ for _, port := range addedPorts {
+ pm.sendPortNotification("opened", port)
+ }
+
+ for _, port := range removedPorts {
+ pm.sendPortNotification("closed", port)
+ }
+
+ return nil
+}
+
+// sendPortNotification sends a port event notification to the agent.
+func (pm *PortMonitor) sendPortNotification(event string, port portlist.Port) {
+ if pm.agent == nil {
+ return
+ }
+
+ // Skip low ports and sketch's ports
+ if port.Port < 1024 || port.Pid == 1 {
+ return
+ }
+
+ // TODO: Structure this so that UI can display it more nicely.
+ content := fmt.Sprintf("Port %s: %s:%d", event, port.Proto, port.Port)
+ if port.Process != "" {
+ content += fmt.Sprintf(" (process: %s)", port.Process)
+ }
+ if port.Pid != 0 {
+ content += fmt.Sprintf(" (pid: %d)", port.Pid)
+ }
+
+ msg := AgentMessage{
+ Type: PortMessageType,
+ Content: content,
+ }
+
+ pm.agent.pushToOutbox(pm.ctx, msg)
+}
+
+// filterTCPPorts filters the port list to include only TCP ports.
+func filterTCPPorts(ports []portlist.Port) []portlist.Port {
+ var tcpPorts []portlist.Port
+ for _, port := range ports {
+ if port.Proto == "tcp" {
+ tcpPorts = append(tcpPorts, port)
+ }
+ }
+ return tcpPorts
+}
+
+// sortPorts sorts ports by port number for consistent comparisons.
+func sortPorts(ports []portlist.Port) {
+ sort.Slice(ports, func(i, j int) bool {
+ return ports[i].Port < ports[j].Port
+ })
+}
+
+// findAddedPorts finds ports that are in current but not in previous.
+func findAddedPorts(previous, current []portlist.Port) []portlist.Port {
+ prevSet := make(map[uint16]bool)
+ for _, port := range previous {
+ prevSet[port.Port] = true
+ }
+
+ var added []portlist.Port
+ for _, port := range current {
+ if !prevSet[port.Port] {
+ added = append(added, port)
+ }
+ }
+ return added
+}
+
+// findRemovedPorts finds ports that are in previous but not in current.
+func findRemovedPorts(previous, current []portlist.Port) []portlist.Port {
+ currentSet := make(map[uint16]bool)
+ for _, port := range current {
+ currentSet[port.Port] = true
+ }
+
+ var removed []portlist.Port
+ for _, port := range previous {
+ if !currentSet[port.Port] {
+ removed = append(removed, port)
+ }
+ }
+ return removed
+}
diff --git a/loop/port_monitor_demo_test.go b/loop/port_monitor_demo_test.go
new file mode 100644
index 0000000..4190e5a
--- /dev/null
+++ b/loop/port_monitor_demo_test.go
@@ -0,0 +1,170 @@
+package loop
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "sync"
+ "testing"
+ "time"
+)
+
+// TestPortMonitor_IntegrationDemo demonstrates the full integration of PortMonitor with an Agent.
+// This test shows how the PortMonitor detects port changes and sends notifications to an Agent.
+func TestPortMonitor_IntegrationDemo(t *testing.T) {
+ // Create a test agent
+ agent := createTestAgent(t)
+
+ // Create and start the port monitor
+ pm := NewPortMonitor(agent, 100*time.Millisecond) // Fast polling for demo
+ ctx := context.Background()
+ err := pm.Start(ctx)
+ if err != nil {
+ t.Fatalf("Failed to start port monitor: %v", err)
+ }
+ defer pm.Stop()
+
+ // Wait for initial scan
+ time.Sleep(200 * time.Millisecond)
+
+ // Show current ports
+ currentPorts := pm.GetPorts()
+ t.Logf("Initial TCP ports detected: %d", len(currentPorts))
+ for _, port := range currentPorts {
+ t.Logf(" - Port %d (process: %s, pid: %d)", port.Port, port.Process, port.Pid)
+ }
+
+ // Start multiple test servers to demonstrate detection
+ var listeners []net.Listener
+ var wg sync.WaitGroup
+
+ // Start 3 test HTTP servers
+ for i := 0; i < 3; i++ {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Failed to start test listener %d: %v", i, err)
+ }
+ listeners = append(listeners, listener)
+
+ addr := listener.Addr().(*net.TCPAddr)
+ port := addr.Port
+ t.Logf("Started test HTTP server %d on port %d", i+1, port)
+
+ // Start a simple HTTP server
+ wg.Add(1)
+ go func(l net.Listener, serverID int) {
+ defer wg.Done()
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "Hello from test server %d!\n", serverID)
+ })
+ server := &http.Server{Handler: mux}
+ server.Serve(l)
+ }(listener, i+1)
+ }
+
+ // Wait for ports to be detected
+ time.Sleep(500 * time.Millisecond)
+
+ // Check that the new ports were detected
+ updatedPorts := pm.GetPorts()
+ t.Logf("Updated TCP ports detected: %d", len(updatedPorts))
+
+ // Verify that we have at least 3 more ports than initially
+ if len(updatedPorts) < len(currentPorts)+3 {
+ t.Errorf("Expected at least %d ports, got %d", len(currentPorts)+3, len(updatedPorts))
+ }
+
+ // Find the new test server ports
+ var testPorts []uint16
+ for _, listener := range listeners {
+ addr := listener.Addr().(*net.TCPAddr)
+ testPorts = append(testPorts, uint16(addr.Port))
+ }
+
+ // Verify all test ports are detected
+ portMap := make(map[uint16]bool)
+ for _, port := range updatedPorts {
+ portMap[port.Port] = true
+ }
+
+ allPortsDetected := true
+ for _, testPort := range testPorts {
+ if !portMap[testPort] {
+ allPortsDetected = false
+ t.Errorf("Test port %d was not detected", testPort)
+ }
+ }
+
+ if allPortsDetected {
+ t.Logf("All test ports successfully detected!")
+ }
+
+ // Close all listeners
+ for i, listener := range listeners {
+ t.Logf("Closing test server %d", i+1)
+ listener.Close()
+ }
+
+ // Wait for servers to stop
+ wg.Wait()
+
+ // Wait for ports to be removed
+ time.Sleep(500 * time.Millisecond)
+
+ // Check that ports were removed
+ finalPorts := pm.GetPorts()
+ t.Logf("Final TCP ports detected: %d", len(finalPorts))
+
+ // Verify the final port count is back to near the original
+ if len(finalPorts) > len(currentPorts)+1 {
+ t.Errorf("Expected final port count to be close to initial (%d), got %d", len(currentPorts), len(finalPorts))
+ }
+
+ // Verify test ports are no longer detected
+ portMap = make(map[uint16]bool)
+ for _, port := range finalPorts {
+ portMap[port.Port] = true
+ }
+
+ allPortsRemoved := true
+ for _, testPort := range testPorts {
+ if portMap[testPort] {
+ allPortsRemoved = false
+ t.Errorf("Test port %d was not removed", testPort)
+ }
+ }
+
+ if allPortsRemoved {
+ t.Logf("All test ports successfully removed!")
+ }
+
+ t.Logf("Integration test completed successfully!")
+ t.Logf("- Initial ports: %d", len(currentPorts))
+ t.Logf("- Peak ports: %d", len(updatedPorts))
+ t.Logf("- Final ports: %d", len(finalPorts))
+ t.Logf("- Test ports added and removed: %d", len(testPorts))
+}
+
+// contains checks if a string contains a substring.
+func contains(s, substr string) bool {
+ return len(s) >= len(substr) && (s == substr ||
+ (len(s) > len(substr) &&
+ (s[:len(substr)] == substr ||
+ s[len(s)-len(substr):] == substr ||
+ indexOfSubstring(s, substr) >= 0)))
+}
+
+// indexOfSubstring finds the index of substring in string.
+func indexOfSubstring(s, substr string) int {
+ if len(substr) == 0 {
+ return 0
+ }
+ for i := 0; i <= len(s)-len(substr); i++ {
+ if s[i:i+len(substr)] == substr {
+ return i
+ }
+ }
+ return -1
+}
diff --git a/loop/port_monitor_test.go b/loop/port_monitor_test.go
new file mode 100644
index 0000000..5ad5c2e
--- /dev/null
+++ b/loop/port_monitor_test.go
@@ -0,0 +1,319 @@
+package loop
+
+import (
+ "context"
+ "net"
+ "testing"
+ "time"
+
+ "tailscale.com/portlist"
+)
+
+// TestPortMonitor_NewPortMonitor tests the creation of a new PortMonitor.
+func TestPortMonitor_NewPortMonitor(t *testing.T) {
+ agent := createTestAgent(t)
+ interval := 2 * time.Second
+
+ pm := NewPortMonitor(agent, interval)
+
+ if pm == nil {
+ t.Fatal("NewPortMonitor returned nil")
+ }
+
+ if pm.agent != agent {
+ t.Errorf("expected agent %v, got %v", agent, pm.agent)
+ }
+
+ if pm.interval != interval {
+ t.Errorf("expected interval %v, got %v", interval, pm.interval)
+ }
+
+ if pm.running {
+ t.Error("expected monitor to not be running initially")
+ }
+
+ if pm.poller == nil {
+ t.Error("expected poller to be initialized")
+ }
+
+ if !pm.poller.IncludeLocalhost {
+ t.Error("expected IncludeLocalhost to be true")
+ }
+}
+
+// TestPortMonitor_DefaultInterval tests that a default interval is set when invalid.
+func TestPortMonitor_DefaultInterval(t *testing.T) {
+ agent := createTestAgent(t)
+
+ pm := NewPortMonitor(agent, 0)
+ if pm.interval != 5*time.Second {
+ t.Errorf("expected default interval 5s, got %v", pm.interval)
+ }
+
+ pm2 := NewPortMonitor(agent, -1*time.Second)
+ if pm2.interval != 5*time.Second {
+ t.Errorf("expected default interval 5s, got %v", pm2.interval)
+ }
+}
+
+// TestPortMonitor_StartStop tests starting and stopping the monitor.
+func TestPortMonitor_StartStop(t *testing.T) {
+ agent := createTestAgent(t)
+ pm := NewPortMonitor(agent, 100*time.Millisecond)
+
+ // Test starting
+ ctx := context.Background()
+ err := pm.Start(ctx)
+ if err != nil {
+ t.Fatalf("failed to start port monitor: %v", err)
+ }
+
+ if !pm.running {
+ t.Error("expected monitor to be running after start")
+ }
+
+ // Test double start fails
+ err = pm.Start(ctx)
+ if err == nil {
+ t.Error("expected error when starting already running monitor")
+ }
+
+ // Test stopping
+ pm.Stop()
+ if pm.running {
+ t.Error("expected monitor to not be running after stop")
+ }
+
+ // Test double stop is safe
+ pm.Stop() // should not panic
+}
+
+// TestPortMonitor_GetPorts tests getting the cached port list.
+func TestPortMonitor_GetPorts(t *testing.T) {
+ agent := createTestAgent(t)
+ pm := NewPortMonitor(agent, 100*time.Millisecond)
+
+ // Initially should be empty
+ ports := pm.GetPorts()
+ if len(ports) != 0 {
+ t.Errorf("expected empty ports initially, got %d", len(ports))
+ }
+
+ // Start monitoring to populate ports
+ ctx := context.Background()
+ err := pm.Start(ctx)
+ if err != nil {
+ t.Fatalf("failed to start port monitor: %v", err)
+ }
+ defer pm.Stop()
+
+ // Allow some time for initial scan
+ time.Sleep(200 * time.Millisecond)
+
+ // Should have some ports now (at least system ports)
+ ports = pm.GetPorts()
+ // We can't guarantee specific ports, but there should be at least some TCP ports
+ // on most systems (like SSH, etc.)
+ t.Logf("Found %d TCP ports", len(ports))
+
+ // Verify all returned ports are TCP
+ for _, port := range ports {
+ if port.Proto != "tcp" {
+ t.Errorf("expected TCP port, got %s", port.Proto)
+ }
+ }
+}
+
+// TestPortMonitor_PortDetection tests actual port detection with a test server.
+func TestPortMonitor_PortDetection(t *testing.T) {
+ agent := createTestAgent(t)
+ pm := NewPortMonitor(agent, 50*time.Millisecond) // Fast polling for test
+
+ ctx := context.Background()
+ err := pm.Start(ctx)
+ if err != nil {
+ t.Fatalf("failed to start port monitor: %v", err)
+ }
+ defer pm.Stop()
+
+ // Allow initial scan
+ time.Sleep(100 * time.Millisecond)
+
+ // Get initial port count
+ initialPorts := pm.GetPorts()
+ initialCount := len(initialPorts)
+
+ // Start a test server
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("failed to start test listener: %v", err)
+ }
+ defer listener.Close()
+
+ addr := listener.Addr().(*net.TCPAddr)
+ testPort := uint16(addr.Port)
+
+ t.Logf("Started test server on port %d", testPort)
+
+ // Wait for port to be detected
+ detected := false
+ for i := 0; i < 50; i++ { // Wait up to 2.5 seconds
+ time.Sleep(50 * time.Millisecond)
+ ports := pm.GetPorts()
+ for _, port := range ports {
+ if port.Port == testPort {
+ detected = true
+ break
+ }
+ }
+ if detected {
+ break
+ }
+ }
+
+ if !detected {
+ t.Errorf("test port %d was not detected", testPort)
+ }
+
+ // Verify port count increased
+ currentPorts := pm.GetPorts()
+ if len(currentPorts) <= initialCount {
+ t.Errorf("expected port count to increase from %d, got %d", initialCount, len(currentPorts))
+ }
+
+ // Close the listener
+ listener.Close()
+
+ // Wait for port to be removed
+ removed := false
+ for i := 0; i < 50; i++ { // Wait up to 2.5 seconds
+ time.Sleep(50 * time.Millisecond)
+ ports := pm.GetPorts()
+ found := false
+ for _, port := range ports {
+ if port.Port == testPort {
+ found = true
+ break
+ }
+ }
+ if !found {
+ removed = true
+ break
+ }
+ }
+
+ if !removed {
+ t.Errorf("test port %d was not removed after listener closed", testPort)
+ }
+}
+
+// TestPortMonitor_FilterTCPPorts tests the TCP port filtering.
+func TestPortMonitor_FilterTCPPorts(t *testing.T) {
+ ports := []portlist.Port{
+ {Proto: "tcp", Port: 80},
+ {Proto: "udp", Port: 53},
+ {Proto: "tcp", Port: 443},
+ {Proto: "udp", Port: 123},
+ }
+
+ tcpPorts := filterTCPPorts(ports)
+
+ if len(tcpPorts) != 2 {
+ t.Errorf("expected 2 TCP ports, got %d", len(tcpPorts))
+ }
+
+ for _, port := range tcpPorts {
+ if port.Proto != "tcp" {
+ t.Errorf("expected TCP port, got %s", port.Proto)
+ }
+ }
+}
+
+// TestPortMonitor_SortPorts tests the port sorting.
+func TestPortMonitor_SortPorts(t *testing.T) {
+ ports := []portlist.Port{
+ {Proto: "tcp", Port: 443},
+ {Proto: "tcp", Port: 80},
+ {Proto: "tcp", Port: 8080},
+ {Proto: "tcp", Port: 22},
+ }
+
+ sortPorts(ports)
+
+ expected := []uint16{22, 80, 443, 8080}
+ for i, port := range ports {
+ if port.Port != expected[i] {
+ t.Errorf("expected port %d at index %d, got %d", expected[i], i, port.Port)
+ }
+ }
+}
+
+// TestPortMonitor_FindAddedPorts tests finding added ports.
+func TestPortMonitor_FindAddedPorts(t *testing.T) {
+ previous := []portlist.Port{
+ {Proto: "tcp", Port: 80},
+ {Proto: "tcp", Port: 443},
+ }
+
+ current := []portlist.Port{
+ {Proto: "tcp", Port: 80},
+ {Proto: "tcp", Port: 443},
+ {Proto: "tcp", Port: 8080},
+ {Proto: "tcp", Port: 22},
+ }
+
+ added := findAddedPorts(previous, current)
+
+ if len(added) != 2 {
+ t.Errorf("expected 2 added ports, got %d", len(added))
+ }
+
+ addedPorts := make(map[uint16]bool)
+ for _, port := range added {
+ addedPorts[port.Port] = true
+ }
+
+ if !addedPorts[8080] || !addedPorts[22] {
+ t.Errorf("expected ports 8080 and 22 to be added, got %v", added)
+ }
+}
+
+// TestPortMonitor_FindRemovedPorts tests finding removed ports.
+func TestPortMonitor_FindRemovedPorts(t *testing.T) {
+ previous := []portlist.Port{
+ {Proto: "tcp", Port: 80},
+ {Proto: "tcp", Port: 443},
+ {Proto: "tcp", Port: 8080},
+ {Proto: "tcp", Port: 22},
+ }
+
+ current := []portlist.Port{
+ {Proto: "tcp", Port: 80},
+ {Proto: "tcp", Port: 443},
+ }
+
+ removed := findRemovedPorts(previous, current)
+
+ if len(removed) != 2 {
+ t.Errorf("expected 2 removed ports, got %d", len(removed))
+ }
+
+ removedPorts := make(map[uint16]bool)
+ for _, port := range removed {
+ removedPorts[port.Port] = true
+ }
+
+ if !removedPorts[8080] || !removedPorts[22] {
+ t.Errorf("expected ports 8080 and 22 to be removed, got %v", removed)
+ }
+}
+
+// createTestAgent creates a minimal test agent for testing.
+func createTestAgent(t *testing.T) *Agent {
+ // Create a minimal agent for testing
+ // We need to initialize the required fields for the PortMonitor to work
+ agent := &Agent{
+ subscribers: make([]chan *AgentMessage, 0),
+ }
+ return agent
+}
diff --git a/loop/server/loophttp.go b/loop/server/loophttp.go
index b6e2259..3f90f8b 100644
--- a/loop/server/loophttp.go
+++ b/loop/server/loophttp.go
@@ -104,6 +104,15 @@
SSHConnectionString string `json:"ssh_connection_string,omitempty"` // SSH connection string for container
DiffLinesAdded int `json:"diff_lines_added"` // Lines added from sketch-base to HEAD
DiffLinesRemoved int `json:"diff_lines_removed"` // Lines removed from sketch-base to HEAD
+ OpenPorts []Port `json:"open_ports,omitempty"` // Currently open TCP ports
+}
+
+// Port represents an open TCP port
+type Port struct {
+ Proto string `json:"proto"` // "tcp" or "udp"
+ Port uint16 `json:"port"` // port number
+ Process string `json:"process"` // optional process name
+ Pid int `json:"pid"` // process ID
}
type InitRequest struct {
@@ -1308,9 +1317,29 @@
SSHConnectionString: s.agent.SSHConnectionString(),
DiffLinesAdded: diffAdded,
DiffLinesRemoved: diffRemoved,
+ OpenPorts: s.getOpenPorts(),
}
}
+// getOpenPorts retrieves the current open ports from the agent
+func (s *Server) getOpenPorts() []Port {
+ ports := s.agent.GetPorts()
+ if ports == nil {
+ return nil
+ }
+
+ result := make([]Port, len(ports))
+ for i, port := range ports {
+ result[i] = Port{
+ Proto: port.Proto,
+ Port: port.Port,
+ Process: port.Process,
+ Pid: port.Pid,
+ }
+ }
+ return result
+}
+
func (s *Server) handleGitRawDiff(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
diff --git a/loop/server/loophttp_test.go b/loop/server/loophttp_test.go
index f6ec8c7..cd97b12 100644
--- a/loop/server/loophttp_test.go
+++ b/loop/server/loophttp_test.go
@@ -14,6 +14,7 @@
"sketch.dev/llm/conversation"
"sketch.dev/loop"
"sketch.dev/loop/server"
+ "tailscale.com/portlist"
)
// mockAgent is a mock implementation of loop.CodingAgent for testing
@@ -263,6 +264,14 @@
func (m *mockAgent) SkabandAddr() string { return m.skabandAddr }
func (m *mockAgent) LinkToGitHub() bool { return false }
func (m *mockAgent) DiffStats() (int, int) { return 0, 0 }
+func (m *mockAgent) GetPorts() []portlist.Port {
+ // Mock returns a few test ports
+ return []portlist.Port{
+ {Proto: "tcp", Port: 22, Process: "sshd", Pid: 1234},
+ {Proto: "tcp", Port: 80, Process: "nginx", Pid: 5678},
+ {Proto: "tcp", Port: 8080, Process: "test-server", Pid: 9012},
+ }
+}
// TestSSEStream tests the SSE stream endpoint
func TestSSEStream(t *testing.T) {
@@ -588,3 +597,79 @@
})
}
}
+
+// TestStateEndpointIncludesPorts tests that the /state endpoint includes port information
+func TestStateEndpointIncludesPorts(t *testing.T) {
+ mockAgent := &mockAgent{
+ messages: []loop.AgentMessage{},
+ messageCount: 0,
+ currentState: "initial",
+ subscribers: []chan *loop.AgentMessage{},
+ gitUsername: "test-user",
+ initialCommit: "abc123",
+ branchName: "test-branch",
+ branchPrefix: "test-",
+ workingDir: "/tmp/test",
+ sessionID: "test-session",
+ slug: "test-slug",
+ skabandAddr: "http://localhost:8080",
+ }
+
+ // Create a test server
+ server, err := server.New(mockAgent, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a test request to the /state endpoint
+ req, err := http.NewRequest("GET", "/state", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a response recorder
+ rr := httptest.NewRecorder()
+
+ // Execute the request
+ server.ServeHTTP(rr, req)
+
+ // Check the response
+ if status := rr.Code; status != http.StatusOK {
+ t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
+ }
+
+ // Check that the response contains port information
+ responseBody := rr.Body.String()
+ t.Logf("Response body: %s", responseBody)
+
+ // Verify the response contains the expected ports
+ if !strings.Contains(responseBody, `"open_ports"`) {
+ t.Error("Response should contain 'open_ports' field")
+ }
+
+ if !strings.Contains(responseBody, `"port": 22`) {
+ t.Error("Response should contain port 22 from mock")
+ }
+
+ if !strings.Contains(responseBody, `"port": 80`) {
+ t.Error("Response should contain port 80 from mock")
+ }
+
+ if !strings.Contains(responseBody, `"port": 8080`) {
+ t.Error("Response should contain port 8080 from mock")
+ }
+
+ if !strings.Contains(responseBody, `"process": "sshd"`) {
+ t.Error("Response should contain process name 'sshd'")
+ }
+
+ if !strings.Contains(responseBody, `"process": "nginx"`) {
+ t.Error("Response should contain process name 'nginx'")
+ }
+
+ if !strings.Contains(responseBody, `"proto": "tcp"`) {
+ t.Error("Response should contain protocol 'tcp'")
+ }
+
+ t.Log("State endpoint includes port information correctly")
+}