blob: d15ae8409929d7838399ac2e1c4504fd6e4fea42 [file] [log] [blame]
Philip Zeyliger194bfa82025-06-24 06:03:06 -07001package mcp
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/mark3labs/mcp-go/client"
13 "github.com/mark3labs/mcp-go/client/transport"
14 "github.com/mark3labs/mcp-go/mcp"
15 "sketch.dev/llm"
16)
17
18// ServerConfig represents the configuration for an MCP server
19type ServerConfig struct {
20 Name string `json:"name,omitempty"`
21 Type string `json:"type,omitempty"` // "stdio", "http", "sse"
22 URL string `json:"url,omitempty"` // for http/sse
23 Command string `json:"command,omitempty"` // for stdio
24 Args []string `json:"args,omitempty"` // for stdio
25 Env map[string]string `json:"env,omitempty"` // for stdio
26 Headers map[string]string `json:"headers,omitempty"` // for http/sse
27}
28
29// MCPManager manages multiple MCP server connections
30type MCPManager struct {
31 mu sync.RWMutex
32 clients map[string]*MCPClientWrapper
33}
34
35// MCPClientWrapper wraps an MCP client connection
36type MCPClientWrapper struct {
37 name string
38 config ServerConfig
39 client *client.Client
40 tools []*llm.Tool
41}
42
43// MCPServerConnection represents a successful MCP server connection with its tools
44type MCPServerConnection struct {
45 ServerName string
46 Tools []*llm.Tool
47 ToolNames []string // Original tool names without server prefix
48}
49
50// NewMCPManager creates a new MCP manager
51func NewMCPManager() *MCPManager {
52 return &MCPManager{
53 clients: make(map[string]*MCPClientWrapper),
54 }
55}
56
57// ParseServerConfigs parses JSON configuration strings into ServerConfig structs
58func ParseServerConfigs(ctx context.Context, configs []string) ([]ServerConfig, []error) {
59 if len(configs) == 0 {
60 return nil, nil
61 }
62
63 var serverConfigs []ServerConfig
64 var errors []error
65
66 for i, configStr := range configs {
67 var config ServerConfig
68 if err := json.Unmarshal([]byte(configStr), &config); err != nil {
69 slog.ErrorContext(ctx, "Failed to parse MCP server config", "config", configStr, "error", err)
70 errors = append(errors, fmt.Errorf("config %d: %w", i, err))
71 continue
72 }
73 // Require a name
74 if config.Name == "" {
75 errors = append(errors, fmt.Errorf("config %d: name is required", i))
76 continue
77 }
78 serverConfigs = append(serverConfigs, config)
79 }
80
81 return serverConfigs, errors
82}
83
Philip Zeyliger4201bde2025-06-27 17:22:43 -070084// ConnectToServerConfigs connects to multiple parsed MCP server configs in parallel
Philip Zeyliger194bfa82025-06-24 06:03:06 -070085func (m *MCPManager) ConnectToServerConfigs(ctx context.Context, serverConfigs []ServerConfig, timeout time.Duration, existingErrors []error) ([]MCPServerConnection, []error) {
86 if len(serverConfigs) == 0 {
87 return nil, existingErrors
88 }
89
90 slog.InfoContext(ctx, "Connecting to MCP servers", "count", len(serverConfigs), "timeout", timeout)
91
92 // Connect to servers in parallel using sync.WaitGroup
93 type result struct {
94 tools []*llm.Tool
95 err error
96 serverName string
97 originalTools []string // Original tool names without server prefix
98 }
99
100 results := make(chan result, len(serverConfigs))
101 ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout)
102 defer cancel()
103
104 for _, config := range serverConfigs {
105 go func(cfg ServerConfig) {
106 slog.InfoContext(ctx, "Connecting to MCP server", "server", cfg.Name, "type", cfg.Type, "url", cfg.URL, "command", cfg.Command)
107 tools, originalToolNames, err := m.connectToServerWithNames(ctxWithTimeout, cfg)
108 results <- result{
109 tools: tools,
110 err: err,
111 serverName: cfg.Name,
112 originalTools: originalToolNames,
113 }
114 }(config)
115 }
116
117 // Collect results
118 var connections []MCPServerConnection
119 errors := make([]error, 0, len(existingErrors))
120 errors = append(errors, existingErrors...)
121
Josh Bleecher Snyder44de46c2025-07-21 16:14:34 -0700122NextServer:
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700123 for range len(serverConfigs) {
124 select {
125 case res := <-results:
126 if res.err != nil {
127 slog.ErrorContext(ctx, "Failed to connect to MCP server", "server", res.serverName, "error", res.err)
128 errors = append(errors, fmt.Errorf("MCP server %q: %w", res.serverName, res.err))
129 } else {
130 connection := MCPServerConnection{
131 ServerName: res.serverName,
132 Tools: res.tools,
133 ToolNames: res.originalTools,
134 }
135 connections = append(connections, connection)
136 slog.InfoContext(ctx, "Successfully connected to MCP server", "server", res.serverName, "tools", len(res.tools), "tool_names", res.originalTools)
137 }
138 case <-ctxWithTimeout.Done():
139 errors = append(errors, fmt.Errorf("timeout connecting to MCP servers"))
Josh Bleecher Snyder44de46c2025-07-21 16:14:34 -0700140 break NextServer
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700141 }
142 }
143
144 return connections, errors
145}
146
147// connectToServerWithNames connects to a single MCP server and returns tools with original names
148func (m *MCPManager) connectToServerWithNames(ctx context.Context, config ServerConfig) ([]*llm.Tool, []string, error) {
149 tools, err := m.connectToServer(ctx, config)
150 if err != nil {
151 return nil, nil, err
152 }
153
154 // Extract original tool names (remove server prefix)
155 originalNames := make([]string, len(tools))
156 for i, tool := range tools {
157 // Tool names are in format "servername_toolname"
158 parts := strings.SplitN(tool.Name, "_", 2)
159 if len(parts) == 2 {
160 originalNames[i] = parts[1]
161 } else {
162 originalNames[i] = tool.Name // fallback if no prefix
163 }
164 }
165
166 return tools, originalNames, nil
167}
168
169// connectToServer connects to a single MCP server
170func (m *MCPManager) connectToServer(ctx context.Context, config ServerConfig) ([]*llm.Tool, error) {
171 var mcpClient *client.Client
172 var err error
173
174 // Convert environment variables to []string format
175 var envVars []string
176 for k, v := range config.Env {
177 envVars = append(envVars, k+"="+v)
178 }
179
180 switch config.Type {
181 case "stdio", "":
182 if config.Command == "" {
183 return nil, fmt.Errorf("command is required for stdio transport")
184 }
185 mcpClient, err = client.NewStdioMCPClient(config.Command, envVars, config.Args...)
Philip Zeyliger08b073b2025-07-18 10:40:00 -0700186 // TODO: Get the transport, cast it to *transport.Stdio, and start a goroutine to pipe stderr from the subprocess
187 // to our subprocess, but with each line prefixed with the server name.
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700188 case "http":
189 if config.URL == "" {
190 return nil, fmt.Errorf("URL is required for HTTP transport")
191 }
192 // Use streamable HTTP client for HTTP transport
193 var httpOptions []transport.StreamableHTTPCOption
194 if len(config.Headers) > 0 {
195 httpOptions = append(httpOptions, transport.WithHTTPHeaders(config.Headers))
196 }
197 mcpClient, err = client.NewStreamableHttpClient(config.URL, httpOptions...)
198 case "sse":
199 if config.URL == "" {
200 return nil, fmt.Errorf("URL is required for SSE transport")
201 }
202 var sseOptions []transport.ClientOption
203 if len(config.Headers) > 0 {
204 sseOptions = append(sseOptions, transport.WithHeaders(config.Headers))
205 }
206 mcpClient, err = client.NewSSEMCPClient(config.URL, sseOptions...)
207 default:
208 return nil, fmt.Errorf("unsupported MCP transport type: %s", config.Type)
209 }
210
211 if err != nil {
212 return nil, fmt.Errorf("failed to create MCP client: %w", err)
213 }
214
215 // Start the client first
216 if err := mcpClient.Start(ctx); err != nil {
217 return nil, fmt.Errorf("failed to start MCP client: %w", err)
218 }
219
220 // Initialize the client
221 initReq := mcp.InitializeRequest{
222 Params: mcp.InitializeParams{
223 ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
224 Capabilities: mcp.ClientCapabilities{},
225 ClientInfo: mcp.Implementation{
226 Name: "sketch",
227 Version: "1.0.0",
228 },
229 },
230 }
231 if _, err := mcpClient.Initialize(ctx, initReq); err != nil {
232 return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
233 }
234
235 // Get available tools
236 toolsReq := mcp.ListToolsRequest{}
237 toolsResp, err := mcpClient.ListTools(ctx, toolsReq)
238 if err != nil {
239 return nil, fmt.Errorf("failed to list tools: %w", err)
240 }
241
242 // Convert MCP tools to llm.Tool
243 llmTools, err := m.convertMCPTools(config.Name, mcpClient, toolsResp.Tools)
244 if err != nil {
245 return nil, fmt.Errorf("failed to convert tools: %w", err)
246 }
247
248 // Store the client
249 clientWrapper := &MCPClientWrapper{
250 name: config.Name,
251 config: config,
252 client: mcpClient,
253 tools: llmTools,
254 }
255
256 m.mu.Lock()
257 m.clients[config.Name] = clientWrapper
258 m.mu.Unlock()
259
260 return llmTools, nil
261}
262
263// convertMCPTools converts MCP tools to llm.Tool format
264func (m *MCPManager) convertMCPTools(serverName string, mcpClient *client.Client, mcpTools []mcp.Tool) ([]*llm.Tool, error) {
265 var llmTools []*llm.Tool
266
267 for _, mcpTool := range mcpTools {
268 // Convert the input schema
269 schemaBytes, err := json.Marshal(mcpTool.InputSchema)
270 if err != nil {
271 return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", mcpTool.Name, err)
272 }
273
274 llmTool := &llm.Tool{
275 Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name),
276 Description: mcpTool.Description,
277 InputSchema: json.RawMessage(schemaBytes),
Josh Bleecher Snyder43b60b92025-07-21 14:57:10 -0700278 Run: func(toolName string, client *client.Client) func(ctx context.Context, input json.RawMessage) llm.ToolOut {
279 return func(ctx context.Context, input json.RawMessage) llm.ToolOut {
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700280 result, err := m.executeMCPTool(ctx, client, toolName, input)
281 if err != nil {
Josh Bleecher Snyder43b60b92025-07-21 14:57:10 -0700282 return llm.ErrorToolOut(err)
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700283 }
284 // Convert result to llm.Content
Josh Bleecher Snyder43b60b92025-07-21 14:57:10 -0700285 return llm.ToolOut{LLMContent: []llm.Content{llm.StringContent(fmt.Sprintf("%v", result))}}
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700286 }
287 }(mcpTool.Name, mcpClient),
288 }
289
290 llmTools = append(llmTools, llmTool)
291 }
292
293 return llmTools, nil
294}
295
296// executeMCPTool executes an MCP tool call
297func (m *MCPManager) executeMCPTool(ctx context.Context, mcpClient *client.Client, toolName string, input json.RawMessage) (any, error) {
298 // Add timeout for tool execution
Philip Zeyliger08b073b2025-07-18 10:40:00 -0700299 // TODO: Expose the timeout as a tool call argument.
300 ctxWithTimeout, cancel := context.WithTimeout(ctx, 120*time.Second)
Philip Zeyliger194bfa82025-06-24 06:03:06 -0700301 defer cancel()
302
303 // Parse input arguments
304 var args map[string]any
305 if len(input) > 0 {
306 if err := json.Unmarshal(input, &args); err != nil {
307 return nil, fmt.Errorf("failed to parse tool arguments: %w", err)
308 }
309 }
310
311 // Call the MCP tool
312 req := mcp.CallToolRequest{
313 Params: mcp.CallToolParams{
314 Name: toolName,
315 Arguments: args,
316 },
317 }
318 resp, err := mcpClient.CallTool(ctxWithTimeout, req)
319 if err != nil {
320 return nil, fmt.Errorf("MCP tool call failed: %w", err)
321 }
322
323 // Return the content from the response
324 return resp.Content, nil
325}
326
327// Close closes all MCP client connections
328func (m *MCPManager) Close() {
329 m.mu.Lock()
330 defer m.mu.Unlock()
331
332 for _, clientWrapper := range m.clients {
333 if clientWrapper.client != nil {
334 clientWrapper.client.Close()
335 }
336 }
337 m.clients = make(map[string]*MCPClientWrapper)
338}