sketch/mcp: fix mcp sse stream context cancellation
The SSE stream was being started with a timeout context that got canceled
immediately after connection establishment, causing 'context canceled' errors
and breaking MCP tool execution with 'Could not find session' errors.
Separate connection establishment timeout from long-running SSE stream context:
- Use agent's main context for mcpClient.Start() (SSE stream lifecycle)
- Use separate timeout context for Initialize() and ListTools() (connection only)
Fixes SSE stream persistence and enables successful MCP tool execution.
Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: sb595ea17e6f1205ck
diff --git a/mcp/client.go b/mcp/client.go
index 8f65614..22839ca 100644
--- a/mcp/client.go
+++ b/mcp/client.go
@@ -106,13 +106,15 @@
}
results := make(chan result, len(serverConfigs))
- ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
+ // Create a timeout context only for the connection establishment goroutines
+ connectionCtx, connectionCancel := context.WithTimeout(context.Background(), timeout)
+ defer connectionCancel()
for _, config := range serverConfigs {
go func(cfg ServerConfig) {
slog.InfoContext(ctx, "Connecting to MCP server", "server", cfg.Name, "type", cfg.Type, "url", cfg.URL, "command", cfg.Command)
- tools, originalToolNames, err := m.connectToServerWithNames(ctxWithTimeout, cfg)
+ // Pass both the long-running context (ctx) and the connection timeout context
+ tools, originalToolNames, err := m.connectToServerWithNames(ctx, connectionCtx, cfg)
results <- result{
tools: tools,
err: err,
@@ -143,7 +145,7 @@
connections = append(connections, connection)
slog.InfoContext(ctx, "Successfully connected to MCP server", "server", res.serverName, "tools", len(res.tools), "tool_names", res.originalTools)
}
- case <-ctxWithTimeout.Done():
+ case <-connectionCtx.Done():
errors = append(errors, fmt.Errorf("timeout connecting to MCP servers"))
break NextServer
}
@@ -153,8 +155,8 @@
}
// connectToServerWithNames connects to a single MCP server and returns tools with original names
-func (m *MCPManager) connectToServerWithNames(ctx context.Context, config ServerConfig) ([]*llm.Tool, []string, error) {
- tools, err := m.connectToServer(ctx, config)
+func (m *MCPManager) connectToServerWithNames(longRunningCtx context.Context, connectionCtx context.Context, config ServerConfig) ([]*llm.Tool, []string, error) {
+ tools, err := m.connectToServer(longRunningCtx, connectionCtx, config)
if err != nil {
return nil, nil, err
}
@@ -175,7 +177,9 @@
}
// connectToServer connects to a single MCP server
-func (m *MCPManager) connectToServer(ctx context.Context, config ServerConfig) ([]*llm.Tool, error) {
+// longRunningCtx: context for the ongoing MCP client lifecycle (SSE streams)
+// connectionCtx: context with timeout for connection establishment only
+func (m *MCPManager) connectToServer(longRunningCtx context.Context, connectionCtx context.Context, config ServerConfig) ([]*llm.Tool, error) {
var mcpClient *client.Client
var err error
@@ -220,12 +224,12 @@
return nil, fmt.Errorf("failed to create MCP client: %w", err)
}
- // Start the client first
- if err := mcpClient.Start(ctx); err != nil {
+ // Start the client with the long-running context for SSE streams
+ if err := mcpClient.Start(longRunningCtx); err != nil {
return nil, fmt.Errorf("failed to start MCP client: %w", err)
}
- // Initialize the client
+ // Initialize the client with connection timeout context
initReq := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
@@ -236,13 +240,13 @@
},
},
}
- if _, err := mcpClient.Initialize(ctx, initReq); err != nil {
+ if _, err := mcpClient.Initialize(connectionCtx, initReq); err != nil {
return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
}
- // Get available tools
+ // Get available tools with connection timeout context
toolsReq := mcp.ListToolsRequest{}
- toolsResp, err := mcpClient.ListTools(ctx, toolsReq)
+ toolsResp, err := mcpClient.ListTools(connectionCtx, toolsReq)
if err != nil {
return nil, fmt.Errorf("failed to list tools: %w", err)
}