blob: 22839ca8cdaa69b712af2f1438c1c5614c42942c [file] [log] [blame]
Philip Zeyliger194bfa82025-06-24 06:03:06 -07001package mcp
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/mark3labs/mcp-go/client"
13 "github.com/mark3labs/mcp-go/client/transport"
14 "github.com/mark3labs/mcp-go/mcp"
15 "sketch.dev/llm"
16)
17
Philip Zeyligerc540df72025-07-25 09:21:56 -070018const (
19 // DefaultMCPConnectionTimeout is the default timeout for connecting to MCP servers
20 DefaultMCPConnectionTimeout = 120 * time.Second
21
22 // DefaultMCPToolTimeout is the default timeout for executing MCP tool calls
23 DefaultMCPToolTimeout = 120 * time.Second
24)
25
Philip Zeyliger194bfa82025-06-24 06:03:06 -070026// ServerConfig represents the configuration for an MCP server
27type ServerConfig struct {
28 Name string `json:"name,omitempty"`
29 Type string `json:"type,omitempty"` // "stdio", "http", "sse"
30 URL string `json:"url,omitempty"` // for http/sse
31 Command string `json:"command,omitempty"` // for stdio
32 Args []string `json:"args,omitempty"` // for stdio
33 Env map[string]string `json:"env,omitempty"` // for stdio
34 Headers map[string]string `json:"headers,omitempty"` // for http/sse
35}
36
37// MCPManager manages multiple MCP server connections
38type MCPManager struct {
39 mu sync.RWMutex
40 clients map[string]*MCPClientWrapper
41}
42
43// MCPClientWrapper wraps an MCP client connection
44type MCPClientWrapper struct {
45 name string
46 config ServerConfig
47 client *client.Client
48 tools []*llm.Tool
49}
50
51// MCPServerConnection represents a successful MCP server connection with its tools
52type MCPServerConnection struct {
53 ServerName string
54 Tools []*llm.Tool
55 ToolNames []string // Original tool names without server prefix
56}
57
58// NewMCPManager creates a new MCP manager
59func NewMCPManager() *MCPManager {
60 return &MCPManager{
61 clients: make(map[string]*MCPClientWrapper),
62 }
63}
64
65// ParseServerConfigs parses JSON configuration strings into ServerConfig structs
66func ParseServerConfigs(ctx context.Context, configs []string) ([]ServerConfig, []error) {
67 if len(configs) == 0 {
68 return nil, nil
69 }
70
71 var serverConfigs []ServerConfig
72 var errors []error
73
74 for i, configStr := range configs {
75 var config ServerConfig
76 if err := json.Unmarshal([]byte(configStr), &config); err != nil {
77 slog.ErrorContext(ctx, "Failed to parse MCP server config", "config", configStr, "error", err)
78 errors = append(errors, fmt.Errorf("config %d: %w", i, err))
79 continue
80 }
81 // Require a name
82 if config.Name == "" {
83 errors = append(errors, fmt.Errorf("config %d: name is required", i))
84 continue
85 }
86 serverConfigs = append(serverConfigs, config)
87 }
88
89 return serverConfigs, errors
90}
91
Philip Zeyliger4201bde2025-06-27 17:22:43 -070092// ConnectToServerConfigs connects to multiple parsed MCP server configs in parallel
Philip Zeyliger194bfa82025-06-24 06:03:06 -070093func (m *MCPManager) ConnectToServerConfigs(ctx context.Context, serverConfigs []ServerConfig, timeout time.Duration, existingErrors []error) ([]MCPServerConnection, []error) {
94 if len(serverConfigs) == 0 {
95 return nil, existingErrors
96 }
97
98 slog.InfoContext(ctx, "Connecting to MCP servers", "count", len(serverConfigs), "timeout", timeout)
99
100 // Connect to servers in parallel using sync.WaitGroup
101 type result struct {
102 tools []*llm.Tool
103 err error
104 serverName string
105 originalTools []string // Original tool names without server prefix
106 }
107
108 results := make(chan result, len(serverConfigs))
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700109 // Create a timeout context only for the connection establishment goroutines
110 connectionCtx, connectionCancel := context.WithTimeout(context.Background(), timeout)
111 defer connectionCancel()
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700112
113 for _, config := range serverConfigs {
114 go func(cfg ServerConfig) {
115 slog.InfoContext(ctx, "Connecting to MCP server", "server", cfg.Name, "type", cfg.Type, "url", cfg.URL, "command", cfg.Command)
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700116 // Pass both the long-running context (ctx) and the connection timeout context
117 tools, originalToolNames, err := m.connectToServerWithNames(ctx, connectionCtx, cfg)
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700118 results <- result{
119 tools: tools,
120 err: err,
121 serverName: cfg.Name,
122 originalTools: originalToolNames,
123 }
124 }(config)
125 }
126
127 // Collect results
128 var connections []MCPServerConnection
129 errors := make([]error, 0, len(existingErrors))
130 errors = append(errors, existingErrors...)
131
Josh Bleecher Snyder44de46c2025-07-21 16:14:34 -0700132NextServer:
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700133 for range len(serverConfigs) {
134 select {
135 case res := <-results:
136 if res.err != nil {
137 slog.ErrorContext(ctx, "Failed to connect to MCP server", "server", res.serverName, "error", res.err)
138 errors = append(errors, fmt.Errorf("MCP server %q: %w", res.serverName, res.err))
139 } else {
140 connection := MCPServerConnection{
141 ServerName: res.serverName,
142 Tools: res.tools,
143 ToolNames: res.originalTools,
144 }
145 connections = append(connections, connection)
146 slog.InfoContext(ctx, "Successfully connected to MCP server", "server", res.serverName, "tools", len(res.tools), "tool_names", res.originalTools)
147 }
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700148 case <-connectionCtx.Done():
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700149 errors = append(errors, fmt.Errorf("timeout connecting to MCP servers"))
Josh Bleecher Snyder44de46c2025-07-21 16:14:34 -0700150 break NextServer
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700151 }
152 }
153
154 return connections, errors
155}
156
157// connectToServerWithNames connects to a single MCP server and returns tools with original names
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700158func (m *MCPManager) connectToServerWithNames(longRunningCtx context.Context, connectionCtx context.Context, config ServerConfig) ([]*llm.Tool, []string, error) {
159 tools, err := m.connectToServer(longRunningCtx, connectionCtx, config)
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700160 if err != nil {
161 return nil, nil, err
162 }
163
164 // Extract original tool names (remove server prefix)
165 originalNames := make([]string, len(tools))
166 for i, tool := range tools {
167 // Tool names are in format "servername_toolname"
168 parts := strings.SplitN(tool.Name, "_", 2)
169 if len(parts) == 2 {
170 originalNames[i] = parts[1]
171 } else {
172 originalNames[i] = tool.Name // fallback if no prefix
173 }
174 }
175
176 return tools, originalNames, nil
177}
178
179// connectToServer connects to a single MCP server
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700180// longRunningCtx: context for the ongoing MCP client lifecycle (SSE streams)
181// connectionCtx: context with timeout for connection establishment only
182func (m *MCPManager) connectToServer(longRunningCtx context.Context, connectionCtx context.Context, config ServerConfig) ([]*llm.Tool, error) {
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700183 var mcpClient *client.Client
184 var err error
185
186 // Convert environment variables to []string format
187 var envVars []string
188 for k, v := range config.Env {
189 envVars = append(envVars, k+"="+v)
190 }
191
192 switch config.Type {
193 case "stdio", "":
194 if config.Command == "" {
195 return nil, fmt.Errorf("command is required for stdio transport")
196 }
197 mcpClient, err = client.NewStdioMCPClient(config.Command, envVars, config.Args...)
Philip Zeyliger08b073b2025-07-18 10:40:00 -0700198 // TODO: Get the transport, cast it to *transport.Stdio, and start a goroutine to pipe stderr from the subprocess
199 // to our subprocess, but with each line prefixed with the server name.
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700200 case "http":
201 if config.URL == "" {
202 return nil, fmt.Errorf("URL is required for HTTP transport")
203 }
204 // Use streamable HTTP client for HTTP transport
205 var httpOptions []transport.StreamableHTTPCOption
206 if len(config.Headers) > 0 {
207 httpOptions = append(httpOptions, transport.WithHTTPHeaders(config.Headers))
208 }
209 mcpClient, err = client.NewStreamableHttpClient(config.URL, httpOptions...)
210 case "sse":
211 if config.URL == "" {
212 return nil, fmt.Errorf("URL is required for SSE transport")
213 }
214 var sseOptions []transport.ClientOption
215 if len(config.Headers) > 0 {
216 sseOptions = append(sseOptions, transport.WithHeaders(config.Headers))
217 }
218 mcpClient, err = client.NewSSEMCPClient(config.URL, sseOptions...)
219 default:
220 return nil, fmt.Errorf("unsupported MCP transport type: %s", config.Type)
221 }
222
223 if err != nil {
224 return nil, fmt.Errorf("failed to create MCP client: %w", err)
225 }
226
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700227 // Start the client with the long-running context for SSE streams
228 if err := mcpClient.Start(longRunningCtx); err != nil {
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700229 return nil, fmt.Errorf("failed to start MCP client: %w", err)
230 }
231
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700232 // Initialize the client with connection timeout context
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700233 initReq := mcp.InitializeRequest{
234 Params: mcp.InitializeParams{
235 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
236 Capabilities: mcp.ClientCapabilities{},
237 ClientInfo: mcp.Implementation{
238 Name: "sketch",
239 Version: "1.0.0",
240 },
241 },
242 }
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700243 if _, err := mcpClient.Initialize(connectionCtx, initReq); err != nil {
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700244 return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
245 }
246
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700247 // Get available tools with connection timeout context
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700248 toolsReq := mcp.ListToolsRequest{}
Philip Zeyligera8ac1502025-07-27 21:24:42 -0700249 toolsResp, err := mcpClient.ListTools(connectionCtx, toolsReq)
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700250 if err != nil {
251 return nil, fmt.Errorf("failed to list tools: %w", err)
252 }
253
254 // Convert MCP tools to llm.Tool
255 llmTools, err := m.convertMCPTools(config.Name, mcpClient, toolsResp.Tools)
256 if err != nil {
257 return nil, fmt.Errorf("failed to convert tools: %w", err)
258 }
259
260 // Store the client
261 clientWrapper := &MCPClientWrapper{
262 name: config.Name,
263 config: config,
264 client: mcpClient,
265 tools: llmTools,
266 }
267
268 m.mu.Lock()
269 m.clients[config.Name] = clientWrapper
270 m.mu.Unlock()
271
272 return llmTools, nil
273}
274
275// convertMCPTools converts MCP tools to llm.Tool format
276func (m *MCPManager) convertMCPTools(serverName string, mcpClient *client.Client, mcpTools []mcp.Tool) ([]*llm.Tool, error) {
277 var llmTools []*llm.Tool
278
279 for _, mcpTool := range mcpTools {
280 // Convert the input schema
281 schemaBytes, err := json.Marshal(mcpTool.InputSchema)
282 if err != nil {
283 return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", mcpTool.Name, err)
284 }
285
286 llmTool := &llm.Tool{
287 Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name),
288 Description: mcpTool.Description,
289 InputSchema: json.RawMessage(schemaBytes),
Josh Bleecher Snyder43b60b92025-07-21 14:57:10 -0700290 Run: func(toolName string, client *client.Client) func(ctx context.Context, input json.RawMessage) llm.ToolOut {
291 return func(ctx context.Context, input json.RawMessage) llm.ToolOut {
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700292 result, err := m.executeMCPTool(ctx, client, toolName, input)
293 if err != nil {
Josh Bleecher Snyder43b60b92025-07-21 14:57:10 -0700294 return llm.ErrorToolOut(err)
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700295 }
296 // Convert result to llm.Content
Josh Bleecher Snyder43b60b92025-07-21 14:57:10 -0700297 return llm.ToolOut{LLMContent: []llm.Content{llm.StringContent(fmt.Sprintf("%v", result))}}
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700298 }
299 }(mcpTool.Name, mcpClient),
300 }
301
302 llmTools = append(llmTools, llmTool)
303 }
304
305 return llmTools, nil
306}
307
308// executeMCPTool executes an MCP tool call
309func (m *MCPManager) executeMCPTool(ctx context.Context, mcpClient *client.Client, toolName string, input json.RawMessage) (any, error) {
310 // Add timeout for tool execution
Philip Zeyliger08b073b2025-07-18 10:40:00 -0700311 // TODO: Expose the timeout as a tool call argument.
Philip Zeyligerc540df72025-07-25 09:21:56 -0700312 ctxWithTimeout, cancel := context.WithTimeout(ctx, DefaultMCPToolTimeout)
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700313 defer cancel()
314
315 // Parse input arguments
316 var args map[string]any
317 if len(input) > 0 {
318 if err := json.Unmarshal(input, &args); err != nil {
319 return nil, fmt.Errorf("failed to parse tool arguments: %w", err)
320 }
321 }
322
323 // Call the MCP tool
324 req := mcp.CallToolRequest{
325 Params: mcp.CallToolParams{
326 Name: toolName,
327 Arguments: args,
328 },
329 }
330 resp, err := mcpClient.CallTool(ctxWithTimeout, req)
331 if err != nil {
332 return nil, fmt.Errorf("MCP tool call failed: %w", err)
333 }
334
335 // Return the content from the response
336 return resp.Content, nil
337}
338
339// Close closes all MCP client connections
340func (m *MCPManager) Close() {
341 m.mu.Lock()
342 defer m.mu.Unlock()
343
344 for _, clientWrapper := range m.clients {
345 if clientWrapper.client != nil {
346 clientWrapper.client.Close()
347 }
348 }
349 m.clients = make(map[string]*MCPClientWrapper)
350}