| Philip Zeyliger | 194bfa8 | 2025-06-24 06:03:06 -0700 | [diff] [blame] | 1 | package mcp |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "encoding/json" |
| 6 | "fmt" |
| 7 | "log/slog" |
| 8 | "strings" |
| 9 | "sync" |
| 10 | "time" |
| 11 | |
| 12 | "github.com/mark3labs/mcp-go/client" |
| 13 | "github.com/mark3labs/mcp-go/client/transport" |
| 14 | "github.com/mark3labs/mcp-go/mcp" |
| 15 | "sketch.dev/llm" |
| 16 | ) |
| 17 | |
| 18 | // ServerConfig represents the configuration for an MCP server |
| 19 | type ServerConfig struct { |
| 20 | Name string `json:"name,omitempty"` |
| 21 | Type string `json:"type,omitempty"` // "stdio", "http", "sse" |
| 22 | URL string `json:"url,omitempty"` // for http/sse |
| 23 | Command string `json:"command,omitempty"` // for stdio |
| 24 | Args []string `json:"args,omitempty"` // for stdio |
| 25 | Env map[string]string `json:"env,omitempty"` // for stdio |
| 26 | Headers map[string]string `json:"headers,omitempty"` // for http/sse |
| 27 | } |
| 28 | |
| 29 | // MCPManager manages multiple MCP server connections |
| 30 | type MCPManager struct { |
| 31 | mu sync.RWMutex |
| 32 | clients map[string]*MCPClientWrapper |
| 33 | } |
| 34 | |
| 35 | // MCPClientWrapper wraps an MCP client connection |
| 36 | type MCPClientWrapper struct { |
| 37 | name string |
| 38 | config ServerConfig |
| 39 | client *client.Client |
| 40 | tools []*llm.Tool |
| 41 | } |
| 42 | |
| 43 | // MCPServerConnection represents a successful MCP server connection with its tools |
| 44 | type MCPServerConnection struct { |
| 45 | ServerName string |
| 46 | Tools []*llm.Tool |
| 47 | ToolNames []string // Original tool names without server prefix |
| 48 | } |
| 49 | |
| 50 | // NewMCPManager creates a new MCP manager |
| 51 | func NewMCPManager() *MCPManager { |
| 52 | return &MCPManager{ |
| 53 | clients: make(map[string]*MCPClientWrapper), |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | // ParseServerConfigs parses JSON configuration strings into ServerConfig structs |
| 58 | func ParseServerConfigs(ctx context.Context, configs []string) ([]ServerConfig, []error) { |
| 59 | if len(configs) == 0 { |
| 60 | return nil, nil |
| 61 | } |
| 62 | |
| 63 | var serverConfigs []ServerConfig |
| 64 | var errors []error |
| 65 | |
| 66 | for i, configStr := range configs { |
| 67 | var config ServerConfig |
| 68 | if err := json.Unmarshal([]byte(configStr), &config); err != nil { |
| 69 | slog.ErrorContext(ctx, "Failed to parse MCP server config", "config", configStr, "error", err) |
| 70 | errors = append(errors, fmt.Errorf("config %d: %w", i, err)) |
| 71 | continue |
| 72 | } |
| 73 | // Require a name |
| 74 | if config.Name == "" { |
| 75 | errors = append(errors, fmt.Errorf("config %d: name is required", i)) |
| 76 | continue |
| 77 | } |
| 78 | serverConfigs = append(serverConfigs, config) |
| 79 | } |
| 80 | |
| 81 | return serverConfigs, errors |
| 82 | } |
| 83 | |
| 84 | // ConnectToServers connects to multiple MCP servers in parallel |
| 85 | func (m *MCPManager) ConnectToServers(ctx context.Context, configs []string, timeout time.Duration) ([]MCPServerConnection, []error) { |
| 86 | serverConfigs, parseErrors := ParseServerConfigs(ctx, configs) |
| 87 | if len(serverConfigs) == 0 { |
| 88 | if len(parseErrors) > 0 { |
| 89 | return nil, parseErrors |
| 90 | } |
| 91 | return nil, nil |
| 92 | } |
| 93 | return m.ConnectToServerConfigs(ctx, serverConfigs, timeout, parseErrors) |
| 94 | } // ConnectToServerConfigs connects to multiple parsed MCP server configs in parallel |
| 95 | func (m *MCPManager) ConnectToServerConfigs(ctx context.Context, serverConfigs []ServerConfig, timeout time.Duration, existingErrors []error) ([]MCPServerConnection, []error) { |
| 96 | if len(serverConfigs) == 0 { |
| 97 | return nil, existingErrors |
| 98 | } |
| 99 | |
| 100 | slog.InfoContext(ctx, "Connecting to MCP servers", "count", len(serverConfigs), "timeout", timeout) |
| 101 | |
| 102 | // Connect to servers in parallel using sync.WaitGroup |
| 103 | type result struct { |
| 104 | tools []*llm.Tool |
| 105 | err error |
| 106 | serverName string |
| 107 | originalTools []string // Original tool names without server prefix |
| 108 | } |
| 109 | |
| 110 | results := make(chan result, len(serverConfigs)) |
| 111 | ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout) |
| 112 | defer cancel() |
| 113 | |
| 114 | for _, config := range serverConfigs { |
| 115 | go func(cfg ServerConfig) { |
| 116 | slog.InfoContext(ctx, "Connecting to MCP server", "server", cfg.Name, "type", cfg.Type, "url", cfg.URL, "command", cfg.Command) |
| 117 | tools, originalToolNames, err := m.connectToServerWithNames(ctxWithTimeout, cfg) |
| 118 | results <- result{ |
| 119 | tools: tools, |
| 120 | err: err, |
| 121 | serverName: cfg.Name, |
| 122 | originalTools: originalToolNames, |
| 123 | } |
| 124 | }(config) |
| 125 | } |
| 126 | |
| 127 | // Collect results |
| 128 | var connections []MCPServerConnection |
| 129 | errors := make([]error, 0, len(existingErrors)) |
| 130 | errors = append(errors, existingErrors...) |
| 131 | |
| 132 | for range len(serverConfigs) { |
| 133 | select { |
| 134 | case res := <-results: |
| 135 | if res.err != nil { |
| 136 | slog.ErrorContext(ctx, "Failed to connect to MCP server", "server", res.serverName, "error", res.err) |
| 137 | errors = append(errors, fmt.Errorf("MCP server %q: %w", res.serverName, res.err)) |
| 138 | } else { |
| 139 | connection := MCPServerConnection{ |
| 140 | ServerName: res.serverName, |
| 141 | Tools: res.tools, |
| 142 | ToolNames: res.originalTools, |
| 143 | } |
| 144 | connections = append(connections, connection) |
| 145 | slog.InfoContext(ctx, "Successfully connected to MCP server", "server", res.serverName, "tools", len(res.tools), "tool_names", res.originalTools) |
| 146 | } |
| 147 | case <-ctxWithTimeout.Done(): |
| 148 | errors = append(errors, fmt.Errorf("timeout connecting to MCP servers")) |
| 149 | break |
| 150 | } |
| 151 | } |
| 152 | |
| 153 | return connections, errors |
| 154 | } |
| 155 | |
| 156 | // connectToServerWithNames connects to a single MCP server and returns tools with original names |
| 157 | func (m *MCPManager) connectToServerWithNames(ctx context.Context, config ServerConfig) ([]*llm.Tool, []string, error) { |
| 158 | tools, err := m.connectToServer(ctx, config) |
| 159 | if err != nil { |
| 160 | return nil, nil, err |
| 161 | } |
| 162 | |
| 163 | // Extract original tool names (remove server prefix) |
| 164 | originalNames := make([]string, len(tools)) |
| 165 | for i, tool := range tools { |
| 166 | // Tool names are in format "servername_toolname" |
| 167 | parts := strings.SplitN(tool.Name, "_", 2) |
| 168 | if len(parts) == 2 { |
| 169 | originalNames[i] = parts[1] |
| 170 | } else { |
| 171 | originalNames[i] = tool.Name // fallback if no prefix |
| 172 | } |
| 173 | } |
| 174 | |
| 175 | return tools, originalNames, nil |
| 176 | } |
| 177 | |
| 178 | // connectToServer connects to a single MCP server |
| 179 | func (m *MCPManager) connectToServer(ctx context.Context, config ServerConfig) ([]*llm.Tool, error) { |
| 180 | var mcpClient *client.Client |
| 181 | var err error |
| 182 | |
| 183 | // Convert environment variables to []string format |
| 184 | var envVars []string |
| 185 | for k, v := range config.Env { |
| 186 | envVars = append(envVars, k+"="+v) |
| 187 | } |
| 188 | |
| 189 | switch config.Type { |
| 190 | case "stdio", "": |
| 191 | if config.Command == "" { |
| 192 | return nil, fmt.Errorf("command is required for stdio transport") |
| 193 | } |
| 194 | mcpClient, err = client.NewStdioMCPClient(config.Command, envVars, config.Args...) |
| 195 | case "http": |
| 196 | if config.URL == "" { |
| 197 | return nil, fmt.Errorf("URL is required for HTTP transport") |
| 198 | } |
| 199 | // Use streamable HTTP client for HTTP transport |
| 200 | var httpOptions []transport.StreamableHTTPCOption |
| 201 | if len(config.Headers) > 0 { |
| 202 | httpOptions = append(httpOptions, transport.WithHTTPHeaders(config.Headers)) |
| 203 | } |
| 204 | mcpClient, err = client.NewStreamableHttpClient(config.URL, httpOptions...) |
| 205 | case "sse": |
| 206 | if config.URL == "" { |
| 207 | return nil, fmt.Errorf("URL is required for SSE transport") |
| 208 | } |
| 209 | var sseOptions []transport.ClientOption |
| 210 | if len(config.Headers) > 0 { |
| 211 | sseOptions = append(sseOptions, transport.WithHeaders(config.Headers)) |
| 212 | } |
| 213 | mcpClient, err = client.NewSSEMCPClient(config.URL, sseOptions...) |
| 214 | default: |
| 215 | return nil, fmt.Errorf("unsupported MCP transport type: %s", config.Type) |
| 216 | } |
| 217 | |
| 218 | if err != nil { |
| 219 | return nil, fmt.Errorf("failed to create MCP client: %w", err) |
| 220 | } |
| 221 | |
| 222 | // Start the client first |
| 223 | if err := mcpClient.Start(ctx); err != nil { |
| 224 | return nil, fmt.Errorf("failed to start MCP client: %w", err) |
| 225 | } |
| 226 | |
| 227 | // Initialize the client |
| 228 | initReq := mcp.InitializeRequest{ |
| 229 | Params: mcp.InitializeParams{ |
| 230 | ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, |
| 231 | Capabilities: mcp.ClientCapabilities{}, |
| 232 | ClientInfo: mcp.Implementation{ |
| 233 | Name: "sketch", |
| 234 | Version: "1.0.0", |
| 235 | }, |
| 236 | }, |
| 237 | } |
| 238 | if _, err := mcpClient.Initialize(ctx, initReq); err != nil { |
| 239 | return nil, fmt.Errorf("failed to initialize MCP client: %w", err) |
| 240 | } |
| 241 | |
| 242 | // Get available tools |
| 243 | toolsReq := mcp.ListToolsRequest{} |
| 244 | toolsResp, err := mcpClient.ListTools(ctx, toolsReq) |
| 245 | if err != nil { |
| 246 | return nil, fmt.Errorf("failed to list tools: %w", err) |
| 247 | } |
| 248 | |
| 249 | // Convert MCP tools to llm.Tool |
| 250 | llmTools, err := m.convertMCPTools(config.Name, mcpClient, toolsResp.Tools) |
| 251 | if err != nil { |
| 252 | return nil, fmt.Errorf("failed to convert tools: %w", err) |
| 253 | } |
| 254 | |
| 255 | // Store the client |
| 256 | clientWrapper := &MCPClientWrapper{ |
| 257 | name: config.Name, |
| 258 | config: config, |
| 259 | client: mcpClient, |
| 260 | tools: llmTools, |
| 261 | } |
| 262 | |
| 263 | m.mu.Lock() |
| 264 | m.clients[config.Name] = clientWrapper |
| 265 | m.mu.Unlock() |
| 266 | |
| 267 | return llmTools, nil |
| 268 | } |
| 269 | |
| 270 | // convertMCPTools converts MCP tools to llm.Tool format |
| 271 | func (m *MCPManager) convertMCPTools(serverName string, mcpClient *client.Client, mcpTools []mcp.Tool) ([]*llm.Tool, error) { |
| 272 | var llmTools []*llm.Tool |
| 273 | |
| 274 | for _, mcpTool := range mcpTools { |
| 275 | // Convert the input schema |
| 276 | schemaBytes, err := json.Marshal(mcpTool.InputSchema) |
| 277 | if err != nil { |
| 278 | return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", mcpTool.Name, err) |
| 279 | } |
| 280 | |
| 281 | llmTool := &llm.Tool{ |
| 282 | Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name), |
| 283 | Description: mcpTool.Description, |
| 284 | InputSchema: json.RawMessage(schemaBytes), |
| 285 | Run: func(toolName string, client *client.Client) func(ctx context.Context, input json.RawMessage) ([]llm.Content, error) { |
| 286 | return func(ctx context.Context, input json.RawMessage) ([]llm.Content, error) { |
| 287 | result, err := m.executeMCPTool(ctx, client, toolName, input) |
| 288 | if err != nil { |
| 289 | return nil, err |
| 290 | } |
| 291 | // Convert result to llm.Content |
| 292 | return []llm.Content{llm.StringContent(fmt.Sprintf("%v", result))}, nil |
| 293 | } |
| 294 | }(mcpTool.Name, mcpClient), |
| 295 | } |
| 296 | |
| 297 | llmTools = append(llmTools, llmTool) |
| 298 | } |
| 299 | |
| 300 | return llmTools, nil |
| 301 | } |
| 302 | |
| 303 | // executeMCPTool executes an MCP tool call |
| 304 | func (m *MCPManager) executeMCPTool(ctx context.Context, mcpClient *client.Client, toolName string, input json.RawMessage) (any, error) { |
| 305 | // Add timeout for tool execution |
| 306 | ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second) |
| 307 | defer cancel() |
| 308 | |
| 309 | // Parse input arguments |
| 310 | var args map[string]any |
| 311 | if len(input) > 0 { |
| 312 | if err := json.Unmarshal(input, &args); err != nil { |
| 313 | return nil, fmt.Errorf("failed to parse tool arguments: %w", err) |
| 314 | } |
| 315 | } |
| 316 | |
| 317 | // Call the MCP tool |
| 318 | req := mcp.CallToolRequest{ |
| 319 | Params: mcp.CallToolParams{ |
| 320 | Name: toolName, |
| 321 | Arguments: args, |
| 322 | }, |
| 323 | } |
| 324 | resp, err := mcpClient.CallTool(ctxWithTimeout, req) |
| 325 | if err != nil { |
| 326 | return nil, fmt.Errorf("MCP tool call failed: %w", err) |
| 327 | } |
| 328 | |
| 329 | // Return the content from the response |
| 330 | return resp.Content, nil |
| 331 | } |
| 332 | |
| 333 | // Close closes all MCP client connections |
| 334 | func (m *MCPManager) Close() { |
| 335 | m.mu.Lock() |
| 336 | defer m.mu.Unlock() |
| 337 | |
| 338 | for _, clientWrapper := range m.clients { |
| 339 | if clientWrapper.client != nil { |
| 340 | clientWrapper.client.Close() |
| 341 | } |
| 342 | } |
| 343 | m.clients = make(map[string]*MCPClientWrapper) |
| 344 | } |