blob: d15ae8409929d7838399ac2e1c4504fd6e4fea42 [file] [log] [blame]
package mcp
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"sketch.dev/llm"
)
// ServerConfig represents the configuration for an MCP server
type ServerConfig struct {
Name string `json:"name,omitempty"`
Type string `json:"type,omitempty"` // "stdio", "http", "sse"
URL string `json:"url,omitempty"` // for http/sse
Command string `json:"command,omitempty"` // for stdio
Args []string `json:"args,omitempty"` // for stdio
Env map[string]string `json:"env,omitempty"` // for stdio
Headers map[string]string `json:"headers,omitempty"` // for http/sse
}
// MCPManager manages multiple MCP server connections
type MCPManager struct {
mu sync.RWMutex
clients map[string]*MCPClientWrapper
}
// MCPClientWrapper wraps an MCP client connection
type MCPClientWrapper struct {
name string
config ServerConfig
client *client.Client
tools []*llm.Tool
}
// MCPServerConnection represents a successful MCP server connection with its tools
type MCPServerConnection struct {
ServerName string
Tools []*llm.Tool
ToolNames []string // Original tool names without server prefix
}
// NewMCPManager creates a new MCP manager
func NewMCPManager() *MCPManager {
return &MCPManager{
clients: make(map[string]*MCPClientWrapper),
}
}
// ParseServerConfigs parses JSON configuration strings into ServerConfig structs
func ParseServerConfigs(ctx context.Context, configs []string) ([]ServerConfig, []error) {
if len(configs) == 0 {
return nil, nil
}
var serverConfigs []ServerConfig
var errors []error
for i, configStr := range configs {
var config ServerConfig
if err := json.Unmarshal([]byte(configStr), &config); err != nil {
slog.ErrorContext(ctx, "Failed to parse MCP server config", "config", configStr, "error", err)
errors = append(errors, fmt.Errorf("config %d: %w", i, err))
continue
}
// Require a name
if config.Name == "" {
errors = append(errors, fmt.Errorf("config %d: name is required", i))
continue
}
serverConfigs = append(serverConfigs, config)
}
return serverConfigs, errors
}
// ConnectToServerConfigs connects to multiple parsed MCP server configs in parallel
func (m *MCPManager) ConnectToServerConfigs(ctx context.Context, serverConfigs []ServerConfig, timeout time.Duration, existingErrors []error) ([]MCPServerConnection, []error) {
if len(serverConfigs) == 0 {
return nil, existingErrors
}
slog.InfoContext(ctx, "Connecting to MCP servers", "count", len(serverConfigs), "timeout", timeout)
// Connect to servers in parallel using sync.WaitGroup
type result struct {
tools []*llm.Tool
err error
serverName string
originalTools []string // Original tool names without server prefix
}
results := make(chan result, len(serverConfigs))
ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for _, config := range serverConfigs {
go func(cfg ServerConfig) {
slog.InfoContext(ctx, "Connecting to MCP server", "server", cfg.Name, "type", cfg.Type, "url", cfg.URL, "command", cfg.Command)
tools, originalToolNames, err := m.connectToServerWithNames(ctxWithTimeout, cfg)
results <- result{
tools: tools,
err: err,
serverName: cfg.Name,
originalTools: originalToolNames,
}
}(config)
}
// Collect results
var connections []MCPServerConnection
errors := make([]error, 0, len(existingErrors))
errors = append(errors, existingErrors...)
NextServer:
for range len(serverConfigs) {
select {
case res := <-results:
if res.err != nil {
slog.ErrorContext(ctx, "Failed to connect to MCP server", "server", res.serverName, "error", res.err)
errors = append(errors, fmt.Errorf("MCP server %q: %w", res.serverName, res.err))
} else {
connection := MCPServerConnection{
ServerName: res.serverName,
Tools: res.tools,
ToolNames: res.originalTools,
}
connections = append(connections, connection)
slog.InfoContext(ctx, "Successfully connected to MCP server", "server", res.serverName, "tools", len(res.tools), "tool_names", res.originalTools)
}
case <-ctxWithTimeout.Done():
errors = append(errors, fmt.Errorf("timeout connecting to MCP servers"))
break NextServer
}
}
return connections, errors
}
// connectToServerWithNames connects to a single MCP server and returns tools with original names
func (m *MCPManager) connectToServerWithNames(ctx context.Context, config ServerConfig) ([]*llm.Tool, []string, error) {
tools, err := m.connectToServer(ctx, config)
if err != nil {
return nil, nil, err
}
// Extract original tool names (remove server prefix)
originalNames := make([]string, len(tools))
for i, tool := range tools {
// Tool names are in format "servername_toolname"
parts := strings.SplitN(tool.Name, "_", 2)
if len(parts) == 2 {
originalNames[i] = parts[1]
} else {
originalNames[i] = tool.Name // fallback if no prefix
}
}
return tools, originalNames, nil
}
// connectToServer connects to a single MCP server
func (m *MCPManager) connectToServer(ctx context.Context, config ServerConfig) ([]*llm.Tool, error) {
var mcpClient *client.Client
var err error
// Convert environment variables to []string format
var envVars []string
for k, v := range config.Env {
envVars = append(envVars, k+"="+v)
}
switch config.Type {
case "stdio", "":
if config.Command == "" {
return nil, fmt.Errorf("command is required for stdio transport")
}
mcpClient, err = client.NewStdioMCPClient(config.Command, envVars, config.Args...)
// TODO: Get the transport, cast it to *transport.Stdio, and start a goroutine to pipe stderr from the subprocess
// to our subprocess, but with each line prefixed with the server name.
case "http":
if config.URL == "" {
return nil, fmt.Errorf("URL is required for HTTP transport")
}
// Use streamable HTTP client for HTTP transport
var httpOptions []transport.StreamableHTTPCOption
if len(config.Headers) > 0 {
httpOptions = append(httpOptions, transport.WithHTTPHeaders(config.Headers))
}
mcpClient, err = client.NewStreamableHttpClient(config.URL, httpOptions...)
case "sse":
if config.URL == "" {
return nil, fmt.Errorf("URL is required for SSE transport")
}
var sseOptions []transport.ClientOption
if len(config.Headers) > 0 {
sseOptions = append(sseOptions, transport.WithHeaders(config.Headers))
}
mcpClient, err = client.NewSSEMCPClient(config.URL, sseOptions...)
default:
return nil, fmt.Errorf("unsupported MCP transport type: %s", config.Type)
}
if err != nil {
return nil, fmt.Errorf("failed to create MCP client: %w", err)
}
// Start the client first
if err := mcpClient.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start MCP client: %w", err)
}
// Initialize the client
initReq := mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
Capabilities: mcp.ClientCapabilities{},
ClientInfo: mcp.Implementation{
Name: "sketch",
Version: "1.0.0",
},
},
}
if _, err := mcpClient.Initialize(ctx, initReq); err != nil {
return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
}
// Get available tools
toolsReq := mcp.ListToolsRequest{}
toolsResp, err := mcpClient.ListTools(ctx, toolsReq)
if err != nil {
return nil, fmt.Errorf("failed to list tools: %w", err)
}
// Convert MCP tools to llm.Tool
llmTools, err := m.convertMCPTools(config.Name, mcpClient, toolsResp.Tools)
if err != nil {
return nil, fmt.Errorf("failed to convert tools: %w", err)
}
// Store the client
clientWrapper := &MCPClientWrapper{
name: config.Name,
config: config,
client: mcpClient,
tools: llmTools,
}
m.mu.Lock()
m.clients[config.Name] = clientWrapper
m.mu.Unlock()
return llmTools, nil
}
// convertMCPTools converts MCP tools to llm.Tool format
func (m *MCPManager) convertMCPTools(serverName string, mcpClient *client.Client, mcpTools []mcp.Tool) ([]*llm.Tool, error) {
var llmTools []*llm.Tool
for _, mcpTool := range mcpTools {
// Convert the input schema
schemaBytes, err := json.Marshal(mcpTool.InputSchema)
if err != nil {
return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", mcpTool.Name, err)
}
llmTool := &llm.Tool{
Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name),
Description: mcpTool.Description,
InputSchema: json.RawMessage(schemaBytes),
Run: func(toolName string, client *client.Client) func(ctx context.Context, input json.RawMessage) llm.ToolOut {
return func(ctx context.Context, input json.RawMessage) llm.ToolOut {
result, err := m.executeMCPTool(ctx, client, toolName, input)
if err != nil {
return llm.ErrorToolOut(err)
}
// Convert result to llm.Content
return llm.ToolOut{LLMContent: []llm.Content{llm.StringContent(fmt.Sprintf("%v", result))}}
}
}(mcpTool.Name, mcpClient),
}
llmTools = append(llmTools, llmTool)
}
return llmTools, nil
}
// executeMCPTool executes an MCP tool call
func (m *MCPManager) executeMCPTool(ctx context.Context, mcpClient *client.Client, toolName string, input json.RawMessage) (any, error) {
// Add timeout for tool execution
// TODO: Expose the timeout as a tool call argument.
ctxWithTimeout, cancel := context.WithTimeout(ctx, 120*time.Second)
defer cancel()
// Parse input arguments
var args map[string]any
if len(input) > 0 {
if err := json.Unmarshal(input, &args); err != nil {
return nil, fmt.Errorf("failed to parse tool arguments: %w", err)
}
}
// Call the MCP tool
req := mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: toolName,
Arguments: args,
},
}
resp, err := mcpClient.CallTool(ctxWithTimeout, req)
if err != nil {
return nil, fmt.Errorf("MCP tool call failed: %w", err)
}
// Return the content from the response
return resp.Content, nil
}
// Close closes all MCP client connections
func (m *MCPManager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
for _, clientWrapper := range m.clients {
if clientWrapper.client != nil {
clientWrapper.client.Close()
}
}
m.clients = make(map[string]*MCPClientWrapper)
}