| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 1 | package mcp |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "encoding/json" |
| 6 | "fmt" |
| 7 | "log/slog" |
| 8 | "strings" |
| 9 | "sync" |
| 10 | "time" |
| 11 | |
| 12 | "github.com/mark3labs/mcp-go/client" |
| 13 | "github.com/mark3labs/mcp-go/client/transport" |
| 14 | "github.com/mark3labs/mcp-go/mcp" |
| 15 | "sketch.dev/llm" |
| 16 | ) |
| 17 | |
| Philip Zeyliger | c540df7 | 2025-07-25 09:21:56 -0700 | [diff] [blame] | 18 | const ( |
| 19 | // DefaultMCPConnectionTimeout is the default timeout for connecting to MCP servers |
| 20 | DefaultMCPConnectionTimeout = 120 * time.Second |
| 21 | |
| 22 | // DefaultMCPToolTimeout is the default timeout for executing MCP tool calls |
| 23 | DefaultMCPToolTimeout = 120 * time.Second |
| 24 | ) |
| 25 | |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 26 | // ServerConfig represents the configuration for an MCP server |
| 27 | type ServerConfig struct { |
| 28 | Name string `json:"name,omitempty"` |
| 29 | Type string `json:"type,omitempty"` // "stdio", "http", "sse" |
| 30 | URL string `json:"url,omitempty"` // for http/sse |
| 31 | Command string `json:"command,omitempty"` // for stdio |
| 32 | Args []string `json:"args,omitempty"` // for stdio |
| 33 | Env map[string]string `json:"env,omitempty"` // for stdio |
| 34 | Headers map[string]string `json:"headers,omitempty"` // for http/sse |
| 35 | } |
| 36 | |
| 37 | // MCPManager manages multiple MCP server connections |
| 38 | type MCPManager struct { |
| 39 | mu sync.RWMutex |
| 40 | clients map[string]*MCPClientWrapper |
| 41 | } |
| 42 | |
| 43 | // MCPClientWrapper wraps an MCP client connection |
| 44 | type MCPClientWrapper struct { |
| 45 | name string |
| 46 | config ServerConfig |
| 47 | client *client.Client |
| 48 | tools []*llm.Tool |
| 49 | } |
| 50 | |
| 51 | // MCPServerConnection represents a successful MCP server connection with its tools |
| 52 | type MCPServerConnection struct { |
| 53 | ServerName string |
| 54 | Tools []*llm.Tool |
| 55 | ToolNames []string // Original tool names without server prefix |
| 56 | } |
| 57 | |
| 58 | // NewMCPManager creates a new MCP manager |
| 59 | func NewMCPManager() *MCPManager { |
| 60 | return &MCPManager{ |
| 61 | clients: make(map[string]*MCPClientWrapper), |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | // ParseServerConfigs parses JSON configuration strings into ServerConfig structs |
| 66 | func ParseServerConfigs(ctx context.Context, configs []string) ([]ServerConfig, []error) { |
| 67 | if len(configs) == 0 { |
| 68 | return nil, nil |
| 69 | } |
| 70 | |
| 71 | var serverConfigs []ServerConfig |
| 72 | var errors []error |
| 73 | |
| 74 | for i, configStr := range configs { |
| 75 | var config ServerConfig |
| 76 | if err := json.Unmarshal([]byte(configStr), &config); err != nil { |
| 77 | slog.ErrorContext(ctx, "Failed to parse MCP server config", "config", configStr, "error", err) |
| 78 | errors = append(errors, fmt.Errorf("config %d: %w", i, err)) |
| 79 | continue |
| 80 | } |
| 81 | // Require a name |
| 82 | if config.Name == "" { |
| 83 | errors = append(errors, fmt.Errorf("config %d: name is required", i)) |
| 84 | continue |
| 85 | } |
| 86 | serverConfigs = append(serverConfigs, config) |
| 87 | } |
| 88 | |
| 89 | return serverConfigs, errors |
| 90 | } |
| 91 | |
| Philip Zeyliger | 4201bde | 2025-06-27 17:22:43 -0700 | [diff] [blame] | 92 | // ConnectToServerConfigs connects to multiple parsed MCP server configs in parallel |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 93 | func (m *MCPManager) ConnectToServerConfigs(ctx context.Context, serverConfigs []ServerConfig, timeout time.Duration, existingErrors []error) ([]MCPServerConnection, []error) { |
| 94 | if len(serverConfigs) == 0 { |
| 95 | return nil, existingErrors |
| 96 | } |
| 97 | |
| 98 | slog.InfoContext(ctx, "Connecting to MCP servers", "count", len(serverConfigs), "timeout", timeout) |
| 99 | |
| 100 | // Connect to servers in parallel using sync.WaitGroup |
| 101 | type result struct { |
| 102 | tools []*llm.Tool |
| 103 | err error |
| 104 | serverName string |
| 105 | originalTools []string // Original tool names without server prefix |
| 106 | } |
| 107 | |
| 108 | results := make(chan result, len(serverConfigs)) |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 109 | // Create a timeout context only for the connection establishment goroutines |
| 110 | connectionCtx, connectionCancel := context.WithTimeout(context.Background(), timeout) |
| 111 | defer connectionCancel() |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 112 | |
| 113 | for _, config := range serverConfigs { |
| 114 | go func(cfg ServerConfig) { |
| 115 | slog.InfoContext(ctx, "Connecting to MCP server", "server", cfg.Name, "type", cfg.Type, "url", cfg.URL, "command", cfg.Command) |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 116 | // Pass both the long-running context (ctx) and the connection timeout context |
| 117 | tools, originalToolNames, err := m.connectToServerWithNames(ctx, connectionCtx, cfg) |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 118 | results <- result{ |
| 119 | tools: tools, |
| 120 | err: err, |
| 121 | serverName: cfg.Name, |
| 122 | originalTools: originalToolNames, |
| 123 | } |
| 124 | }(config) |
| 125 | } |
| 126 | |
| 127 | // Collect results |
| 128 | var connections []MCPServerConnection |
| 129 | errors := make([]error, 0, len(existingErrors)) |
| 130 | errors = append(errors, existingErrors...) |
| 131 | |
| Josh Bleecher Snyder | 44de46c | 2025-07-21 16:14:34 -0700 | [diff] [blame] | 132 | NextServer: |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 133 | for range len(serverConfigs) { |
| 134 | select { |
| 135 | case res := <-results: |
| 136 | if res.err != nil { |
| 137 | slog.ErrorContext(ctx, "Failed to connect to MCP server", "server", res.serverName, "error", res.err) |
| 138 | errors = append(errors, fmt.Errorf("MCP server %q: %w", res.serverName, res.err)) |
| 139 | } else { |
| 140 | connection := MCPServerConnection{ |
| 141 | ServerName: res.serverName, |
| 142 | Tools: res.tools, |
| 143 | ToolNames: res.originalTools, |
| 144 | } |
| 145 | connections = append(connections, connection) |
| 146 | slog.InfoContext(ctx, "Successfully connected to MCP server", "server", res.serverName, "tools", len(res.tools), "tool_names", res.originalTools) |
| 147 | } |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 148 | case <-connectionCtx.Done(): |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 149 | errors = append(errors, fmt.Errorf("timeout connecting to MCP servers")) |
| Josh Bleecher Snyder | 44de46c | 2025-07-21 16:14:34 -0700 | [diff] [blame] | 150 | break NextServer |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 151 | } |
| 152 | } |
| 153 | |
| 154 | return connections, errors |
| 155 | } |
| 156 | |
| 157 | // connectToServerWithNames connects to a single MCP server and returns tools with original names |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 158 | func (m *MCPManager) connectToServerWithNames(longRunningCtx context.Context, connectionCtx context.Context, config ServerConfig) ([]*llm.Tool, []string, error) { |
| 159 | tools, err := m.connectToServer(longRunningCtx, connectionCtx, config) |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 160 | if err != nil { |
| 161 | return nil, nil, err |
| 162 | } |
| 163 | |
| 164 | // Extract original tool names (remove server prefix) |
| 165 | originalNames := make([]string, len(tools)) |
| 166 | for i, tool := range tools { |
| 167 | // Tool names are in format "servername_toolname" |
| 168 | parts := strings.SplitN(tool.Name, "_", 2) |
| 169 | if len(parts) == 2 { |
| 170 | originalNames[i] = parts[1] |
| 171 | } else { |
| 172 | originalNames[i] = tool.Name // fallback if no prefix |
| 173 | } |
| 174 | } |
| 175 | |
| 176 | return tools, originalNames, nil |
| 177 | } |
| 178 | |
| 179 | // connectToServer connects to a single MCP server |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 180 | // longRunningCtx: context for the ongoing MCP client lifecycle (SSE streams) |
| 181 | // connectionCtx: context with timeout for connection establishment only |
| 182 | func (m *MCPManager) connectToServer(longRunningCtx context.Context, connectionCtx context.Context, config ServerConfig) ([]*llm.Tool, error) { |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 183 | var mcpClient *client.Client |
| 184 | var err error |
| 185 | |
| 186 | // Convert environment variables to []string format |
| 187 | var envVars []string |
| 188 | for k, v := range config.Env { |
| 189 | envVars = append(envVars, k+"="+v) |
| 190 | } |
| 191 | |
| 192 | switch config.Type { |
| 193 | case "stdio", "": |
| 194 | if config.Command == "" { |
| 195 | return nil, fmt.Errorf("command is required for stdio transport") |
| 196 | } |
| 197 | mcpClient, err = client.NewStdioMCPClient(config.Command, envVars, config.Args...) |
| Philip Zeyliger | 08b073b | 2025-07-18 10:40:00 -0700 | [diff] [blame] | 198 | // TODO: Get the transport, cast it to *transport.Stdio, and start a goroutine to pipe stderr from the subprocess |
| 199 | // to our subprocess, but with each line prefixed with the server name. |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 200 | case "http": |
| 201 | if config.URL == "" { |
| 202 | return nil, fmt.Errorf("URL is required for HTTP transport") |
| 203 | } |
| 204 | // Use streamable HTTP client for HTTP transport |
| 205 | var httpOptions []transport.StreamableHTTPCOption |
| 206 | if len(config.Headers) > 0 { |
| 207 | httpOptions = append(httpOptions, transport.WithHTTPHeaders(config.Headers)) |
| 208 | } |
| 209 | mcpClient, err = client.NewStreamableHttpClient(config.URL, httpOptions...) |
| 210 | case "sse": |
| 211 | if config.URL == "" { |
| 212 | return nil, fmt.Errorf("URL is required for SSE transport") |
| 213 | } |
| 214 | var sseOptions []transport.ClientOption |
| 215 | if len(config.Headers) > 0 { |
| 216 | sseOptions = append(sseOptions, transport.WithHeaders(config.Headers)) |
| 217 | } |
| 218 | mcpClient, err = client.NewSSEMCPClient(config.URL, sseOptions...) |
| 219 | default: |
| 220 | return nil, fmt.Errorf("unsupported MCP transport type: %s", config.Type) |
| 221 | } |
| 222 | |
| 223 | if err != nil { |
| 224 | return nil, fmt.Errorf("failed to create MCP client: %w", err) |
| 225 | } |
| 226 | |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 227 | // Start the client with the long-running context for SSE streams |
| 228 | if err := mcpClient.Start(longRunningCtx); err != nil { |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 229 | return nil, fmt.Errorf("failed to start MCP client: %w", err) |
| 230 | } |
| 231 | |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 232 | // Initialize the client with connection timeout context |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 233 | initReq := mcp.InitializeRequest{ |
| 234 | Params: mcp.InitializeParams{ |
| 235 | ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, |
| 236 | Capabilities: mcp.ClientCapabilities{}, |
| 237 | ClientInfo: mcp.Implementation{ |
| 238 | Name: "sketch", |
| 239 | Version: "1.0.0", |
| 240 | }, |
| 241 | }, |
| 242 | } |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 243 | if _, err := mcpClient.Initialize(connectionCtx, initReq); err != nil { |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 244 | return nil, fmt.Errorf("failed to initialize MCP client: %w", err) |
| 245 | } |
| 246 | |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 247 | // Get available tools with connection timeout context |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 248 | toolsReq := mcp.ListToolsRequest{} |
| Philip Zeyliger | a8ac150 | 2025-07-27 21:24:42 -0700 | [diff] [blame] | 249 | toolsResp, err := mcpClient.ListTools(connectionCtx, toolsReq) |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 250 | if err != nil { |
| 251 | return nil, fmt.Errorf("failed to list tools: %w", err) |
| 252 | } |
| 253 | |
| 254 | // Convert MCP tools to llm.Tool |
| 255 | llmTools, err := m.convertMCPTools(config.Name, mcpClient, toolsResp.Tools) |
| 256 | if err != nil { |
| 257 | return nil, fmt.Errorf("failed to convert tools: %w", err) |
| 258 | } |
| 259 | |
| 260 | // Store the client |
| 261 | clientWrapper := &MCPClientWrapper{ |
| 262 | name: config.Name, |
| 263 | config: config, |
| 264 | client: mcpClient, |
| 265 | tools: llmTools, |
| 266 | } |
| 267 | |
| 268 | m.mu.Lock() |
| 269 | m.clients[config.Name] = clientWrapper |
| 270 | m.mu.Unlock() |
| 271 | |
| 272 | return llmTools, nil |
| 273 | } |
| 274 | |
| 275 | // convertMCPTools converts MCP tools to llm.Tool format |
| 276 | func (m *MCPManager) convertMCPTools(serverName string, mcpClient *client.Client, mcpTools []mcp.Tool) ([]*llm.Tool, error) { |
| 277 | var llmTools []*llm.Tool |
| 278 | |
| 279 | for _, mcpTool := range mcpTools { |
| 280 | // Convert the input schema |
| 281 | schemaBytes, err := json.Marshal(mcpTool.InputSchema) |
| 282 | if err != nil { |
| 283 | return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", mcpTool.Name, err) |
| 284 | } |
| 285 | |
| 286 | llmTool := &llm.Tool{ |
| 287 | Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name), |
| 288 | Description: mcpTool.Description, |
| 289 | InputSchema: json.RawMessage(schemaBytes), |
| Josh Bleecher Snyder | 43b60b9 | 2025-07-21 14:57:10 -0700 | [diff] [blame] | 290 | Run: func(toolName string, client *client.Client) func(ctx context.Context, input json.RawMessage) llm.ToolOut { |
| 291 | return func(ctx context.Context, input json.RawMessage) llm.ToolOut { |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 292 | result, err := m.executeMCPTool(ctx, client, toolName, input) |
| 293 | if err != nil { |
| Josh Bleecher Snyder | 43b60b9 | 2025-07-21 14:57:10 -0700 | [diff] [blame] | 294 | return llm.ErrorToolOut(err) |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 295 | } |
| 296 | // Convert result to llm.Content |
| Josh Bleecher Snyder | 43b60b9 | 2025-07-21 14:57:10 -0700 | [diff] [blame] | 297 | return llm.ToolOut{LLMContent: []llm.Content{llm.StringContent(fmt.Sprintf("%v", result))}} |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 298 | } |
| 299 | }(mcpTool.Name, mcpClient), |
| 300 | } |
| 301 | |
| 302 | llmTools = append(llmTools, llmTool) |
| 303 | } |
| 304 | |
| 305 | return llmTools, nil |
| 306 | } |
| 307 | |
| 308 | // executeMCPTool executes an MCP tool call |
| 309 | func (m *MCPManager) executeMCPTool(ctx context.Context, mcpClient *client.Client, toolName string, input json.RawMessage) (any, error) { |
| 310 | // Add timeout for tool execution |
| Philip Zeyliger | 08b073b | 2025-07-18 10:40:00 -0700 | [diff] [blame] | 311 | // TODO: Expose the timeout as a tool call argument. |
| Philip Zeyliger | c540df7 | 2025-07-25 09:21:56 -0700 | [diff] [blame] | 312 | ctxWithTimeout, cancel := context.WithTimeout(ctx, DefaultMCPToolTimeout) |
| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 313 | defer cancel() |
| 314 | |
| 315 | // Parse input arguments |
| 316 | var args map[string]any |
| 317 | if len(input) > 0 { |
| 318 | if err := json.Unmarshal(input, &args); err != nil { |
| 319 | return nil, fmt.Errorf("failed to parse tool arguments: %w", err) |
| 320 | } |
| 321 | } |
| 322 | |
| 323 | // Call the MCP tool |
| 324 | req := mcp.CallToolRequest{ |
| 325 | Params: mcp.CallToolParams{ |
| 326 | Name: toolName, |
| 327 | Arguments: args, |
| 328 | }, |
| 329 | } |
| 330 | resp, err := mcpClient.CallTool(ctxWithTimeout, req) |
| 331 | if err != nil { |
| 332 | return nil, fmt.Errorf("MCP tool call failed: %w", err) |
| 333 | } |
| 334 | |
| 335 | // Return the content from the response |
| 336 | return resp.Content, nil |
| 337 | } |
| 338 | |
| 339 | // Close closes all MCP client connections |
| 340 | func (m *MCPManager) Close() { |
| 341 | m.mu.Lock() |
| 342 | defer m.mu.Unlock() |
| 343 | |
| 344 | for _, clientWrapper := range m.clients { |
| 345 | if clientWrapper.client != nil { |
| 346 | clientWrapper.client.Close() |
| 347 | } |
| 348 | } |
| 349 | m.clients = make(map[string]*MCPClientWrapper) |
| 350 | } |