sketch: add MCP support
Lets you initialize extra tools via MCP. This is additive, so it's
harmless enough.
Here are some examples of the kind of things you can pass to the -mcp
flag.
{"name": "context7", "type": "http", "url": "https://mcp.context7.com/mcp"}
{"name": "context7-http", "type": "http", "url": "https://mcp.context7.com/mcp"}
{"name": "context7-stdio", "type": "stdio", "command": "npx", "args": ["-y", "@upstash/context7-mcp"]}
{"name": "context7-sse", "type": "sse", "url": "https://mcp.context7.com/sse"}
{"name": "local-tool", "type": "stdio", "command": "my_tool", "args": ["--option", "value"], "env": {"TOKEN": "secret"}}
{ "name": "playwright", "command": "npx", "args": [ "@playwright/mcp@latest" ]}
Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: s259a35d11e7bd660k
diff --git a/mcp/client.go b/mcp/client.go
new file mode 100644
index 0000000..e87a465
--- /dev/null
+++ b/mcp/client.go
@@ -0,0 +1,344 @@
+package mcp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mark3labs/mcp-go/client"
+ "github.com/mark3labs/mcp-go/client/transport"
+ "github.com/mark3labs/mcp-go/mcp"
+ "sketch.dev/llm"
+)
+
+// ServerConfig represents the configuration for an MCP server
+type ServerConfig struct {
+ Name string `json:"name,omitempty"`
+ Type string `json:"type,omitempty"` // "stdio", "http", "sse"
+ URL string `json:"url,omitempty"` // for http/sse
+ Command string `json:"command,omitempty"` // for stdio
+ Args []string `json:"args,omitempty"` // for stdio
+ Env map[string]string `json:"env,omitempty"` // for stdio
+ Headers map[string]string `json:"headers,omitempty"` // for http/sse
+}
+
+// MCPManager manages multiple MCP server connections
+type MCPManager struct {
+ mu sync.RWMutex
+ clients map[string]*MCPClientWrapper
+}
+
+// MCPClientWrapper wraps an MCP client connection
+type MCPClientWrapper struct {
+ name string
+ config ServerConfig
+ client *client.Client
+ tools []*llm.Tool
+}
+
+// MCPServerConnection represents a successful MCP server connection with its tools
+type MCPServerConnection struct {
+ ServerName string
+ Tools []*llm.Tool
+ ToolNames []string // Original tool names without server prefix
+}
+
+// NewMCPManager creates a new MCP manager
+func NewMCPManager() *MCPManager {
+ return &MCPManager{
+ clients: make(map[string]*MCPClientWrapper),
+ }
+}
+
+// ParseServerConfigs parses JSON configuration strings into ServerConfig structs
+func ParseServerConfigs(ctx context.Context, configs []string) ([]ServerConfig, []error) {
+ if len(configs) == 0 {
+ return nil, nil
+ }
+
+ var serverConfigs []ServerConfig
+ var errors []error
+
+ for i, configStr := range configs {
+ var config ServerConfig
+ if err := json.Unmarshal([]byte(configStr), &config); err != nil {
+ slog.ErrorContext(ctx, "Failed to parse MCP server config", "config", configStr, "error", err)
+ errors = append(errors, fmt.Errorf("config %d: %w", i, err))
+ continue
+ }
+ // Require a name
+ if config.Name == "" {
+ errors = append(errors, fmt.Errorf("config %d: name is required", i))
+ continue
+ }
+ serverConfigs = append(serverConfigs, config)
+ }
+
+ return serverConfigs, errors
+}
+
+// ConnectToServers connects to multiple MCP servers in parallel
+func (m *MCPManager) ConnectToServers(ctx context.Context, configs []string, timeout time.Duration) ([]MCPServerConnection, []error) {
+ serverConfigs, parseErrors := ParseServerConfigs(ctx, configs)
+ if len(serverConfigs) == 0 {
+ if len(parseErrors) > 0 {
+ return nil, parseErrors
+ }
+ return nil, nil
+ }
+ return m.ConnectToServerConfigs(ctx, serverConfigs, timeout, parseErrors)
+} // ConnectToServerConfigs connects to multiple parsed MCP server configs in parallel
+func (m *MCPManager) ConnectToServerConfigs(ctx context.Context, serverConfigs []ServerConfig, timeout time.Duration, existingErrors []error) ([]MCPServerConnection, []error) {
+ if len(serverConfigs) == 0 {
+ return nil, existingErrors
+ }
+
+ slog.InfoContext(ctx, "Connecting to MCP servers", "count", len(serverConfigs), "timeout", timeout)
+
+ // Connect to servers in parallel using sync.WaitGroup
+ type result struct {
+ tools []*llm.Tool
+ err error
+ serverName string
+ originalTools []string // Original tool names without server prefix
+ }
+
+ results := make(chan result, len(serverConfigs))
+ ctxWithTimeout, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+
+ for _, config := range serverConfigs {
+ go func(cfg ServerConfig) {
+ slog.InfoContext(ctx, "Connecting to MCP server", "server", cfg.Name, "type", cfg.Type, "url", cfg.URL, "command", cfg.Command)
+ tools, originalToolNames, err := m.connectToServerWithNames(ctxWithTimeout, cfg)
+ results <- result{
+ tools: tools,
+ err: err,
+ serverName: cfg.Name,
+ originalTools: originalToolNames,
+ }
+ }(config)
+ }
+
+ // Collect results
+ var connections []MCPServerConnection
+ errors := make([]error, 0, len(existingErrors))
+ errors = append(errors, existingErrors...)
+
+ for range len(serverConfigs) {
+ select {
+ case res := <-results:
+ if res.err != nil {
+ slog.ErrorContext(ctx, "Failed to connect to MCP server", "server", res.serverName, "error", res.err)
+ errors = append(errors, fmt.Errorf("MCP server %q: %w", res.serverName, res.err))
+ } else {
+ connection := MCPServerConnection{
+ ServerName: res.serverName,
+ Tools: res.tools,
+ ToolNames: res.originalTools,
+ }
+ connections = append(connections, connection)
+ slog.InfoContext(ctx, "Successfully connected to MCP server", "server", res.serverName, "tools", len(res.tools), "tool_names", res.originalTools)
+ }
+ case <-ctxWithTimeout.Done():
+ errors = append(errors, fmt.Errorf("timeout connecting to MCP servers"))
+ break
+ }
+ }
+
+ return connections, errors
+}
+
+// connectToServerWithNames connects to a single MCP server and returns tools with original names
+func (m *MCPManager) connectToServerWithNames(ctx context.Context, config ServerConfig) ([]*llm.Tool, []string, error) {
+ tools, err := m.connectToServer(ctx, config)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Extract original tool names (remove server prefix)
+ originalNames := make([]string, len(tools))
+ for i, tool := range tools {
+ // Tool names are in format "servername_toolname"
+ parts := strings.SplitN(tool.Name, "_", 2)
+ if len(parts) == 2 {
+ originalNames[i] = parts[1]
+ } else {
+ originalNames[i] = tool.Name // fallback if no prefix
+ }
+ }
+
+ return tools, originalNames, nil
+}
+
+// connectToServer connects to a single MCP server
+func (m *MCPManager) connectToServer(ctx context.Context, config ServerConfig) ([]*llm.Tool, error) {
+ var mcpClient *client.Client
+ var err error
+
+ // Convert environment variables to []string format
+ var envVars []string
+ for k, v := range config.Env {
+ envVars = append(envVars, k+"="+v)
+ }
+
+ switch config.Type {
+ case "stdio", "":
+ if config.Command == "" {
+ return nil, fmt.Errorf("command is required for stdio transport")
+ }
+ mcpClient, err = client.NewStdioMCPClient(config.Command, envVars, config.Args...)
+ case "http":
+ if config.URL == "" {
+ return nil, fmt.Errorf("URL is required for HTTP transport")
+ }
+ // Use streamable HTTP client for HTTP transport
+ var httpOptions []transport.StreamableHTTPCOption
+ if len(config.Headers) > 0 {
+ httpOptions = append(httpOptions, transport.WithHTTPHeaders(config.Headers))
+ }
+ mcpClient, err = client.NewStreamableHttpClient(config.URL, httpOptions...)
+ case "sse":
+ if config.URL == "" {
+ return nil, fmt.Errorf("URL is required for SSE transport")
+ }
+ var sseOptions []transport.ClientOption
+ if len(config.Headers) > 0 {
+ sseOptions = append(sseOptions, transport.WithHeaders(config.Headers))
+ }
+ mcpClient, err = client.NewSSEMCPClient(config.URL, sseOptions...)
+ default:
+ return nil, fmt.Errorf("unsupported MCP transport type: %s", config.Type)
+ }
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to create MCP client: %w", err)
+ }
+
+ // Start the client first
+ if err := mcpClient.Start(ctx); err != nil {
+ return nil, fmt.Errorf("failed to start MCP client: %w", err)
+ }
+
+ // Initialize the client
+ initReq := mcp.InitializeRequest{
+ Params: mcp.InitializeParams{
+ ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
+ Capabilities: mcp.ClientCapabilities{},
+ ClientInfo: mcp.Implementation{
+ Name: "sketch",
+ Version: "1.0.0",
+ },
+ },
+ }
+ if _, err := mcpClient.Initialize(ctx, initReq); err != nil {
+ return nil, fmt.Errorf("failed to initialize MCP client: %w", err)
+ }
+
+ // Get available tools
+ toolsReq := mcp.ListToolsRequest{}
+ toolsResp, err := mcpClient.ListTools(ctx, toolsReq)
+ if err != nil {
+ return nil, fmt.Errorf("failed to list tools: %w", err)
+ }
+
+ // Convert MCP tools to llm.Tool
+ llmTools, err := m.convertMCPTools(config.Name, mcpClient, toolsResp.Tools)
+ if err != nil {
+ return nil, fmt.Errorf("failed to convert tools: %w", err)
+ }
+
+ // Store the client
+ clientWrapper := &MCPClientWrapper{
+ name: config.Name,
+ config: config,
+ client: mcpClient,
+ tools: llmTools,
+ }
+
+ m.mu.Lock()
+ m.clients[config.Name] = clientWrapper
+ m.mu.Unlock()
+
+ return llmTools, nil
+}
+
+// convertMCPTools converts MCP tools to llm.Tool format
+func (m *MCPManager) convertMCPTools(serverName string, mcpClient *client.Client, mcpTools []mcp.Tool) ([]*llm.Tool, error) {
+ var llmTools []*llm.Tool
+
+ for _, mcpTool := range mcpTools {
+ // Convert the input schema
+ schemaBytes, err := json.Marshal(mcpTool.InputSchema)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", mcpTool.Name, err)
+ }
+
+ llmTool := &llm.Tool{
+ Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name),
+ Description: mcpTool.Description,
+ InputSchema: json.RawMessage(schemaBytes),
+ Run: func(toolName string, client *client.Client) func(ctx context.Context, input json.RawMessage) ([]llm.Content, error) {
+ return func(ctx context.Context, input json.RawMessage) ([]llm.Content, error) {
+ result, err := m.executeMCPTool(ctx, client, toolName, input)
+ if err != nil {
+ return nil, err
+ }
+ // Convert result to llm.Content
+ return []llm.Content{llm.StringContent(fmt.Sprintf("%v", result))}, nil
+ }
+ }(mcpTool.Name, mcpClient),
+ }
+
+ llmTools = append(llmTools, llmTool)
+ }
+
+ return llmTools, nil
+}
+
+// executeMCPTool executes an MCP tool call
+func (m *MCPManager) executeMCPTool(ctx context.Context, mcpClient *client.Client, toolName string, input json.RawMessage) (any, error) {
+ // Add timeout for tool execution
+ ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+
+ // Parse input arguments
+ var args map[string]any
+ if len(input) > 0 {
+ if err := json.Unmarshal(input, &args); err != nil {
+ return nil, fmt.Errorf("failed to parse tool arguments: %w", err)
+ }
+ }
+
+ // Call the MCP tool
+ req := mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: toolName,
+ Arguments: args,
+ },
+ }
+ resp, err := mcpClient.CallTool(ctxWithTimeout, req)
+ if err != nil {
+ return nil, fmt.Errorf("MCP tool call failed: %w", err)
+ }
+
+ // Return the content from the response
+ return resp.Content, nil
+}
+
+// Close closes all MCP client connections
+func (m *MCPManager) Close() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ for _, clientWrapper := range m.clients {
+ if clientWrapper.client != nil {
+ clientWrapper.client.Close()
+ }
+ }
+ m.clients = make(map[string]*MCPClientWrapper)
+}