Add llm interface
Change-Id: Idf599500fc131fb9509102e38736a6baeff6d6d8
diff --git a/server/llm/README.md b/server/llm/README.md
new file mode 100644
index 0000000..92df86b
--- /dev/null
+++ b/server/llm/README.md
@@ -0,0 +1,335 @@
+# LLM Interface Package
+
+This package provides a generic interface for different Large Language Model (LLM) providers, with OpenAI's API structure as the primary reference. It supports multiple providers including OpenAI, xAI, Claude, Gemini, and local models.
+
+## Features
+
+- **Unified Interface**: Single interface for all LLM providers
+- **Multiple Providers**: Support for OpenAI, xAI, Claude, Gemini, and local models
+- **Tool/Function Calling**: Support for function calling and tool usage
+- **Embeddings**: Generate embeddings for text
+- **Configurable**: Flexible configuration options for each provider
+- **Thread-Safe**: Thread-safe factory and registry implementations
+
+## Quick Start
+
+```go
+package main
+
+import (
+ "context"
+ "fmt"
+ "log"
+
+ "your-project/server/llm"
+)
+
+func main() {
+ // Create a configuration for OpenAI
+ config := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "your-openai-api-key",
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ }
+
+ // Create a provider (you'll need to register the implementation first)
+ provider, err := llm.CreateProvider(config)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer provider.Close()
+
+ // Create a chat completion request
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello, how are you?"},
+ },
+ MaxTokens: &[]int{100}[0],
+ Temperature: &[]float64{0.7}[0],
+ }
+
+ // Get the response
+ resp, err := provider.ChatCompletion(context.Background(), req)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Println("Response:", resp.Choices[0].Message.Content)
+}
+```
+
+## Core Types
+
+### Provider
+
+Represents different LLM service providers:
+
+```go
+const (
+ ProviderOpenAI Provider = "openai"
+ ProviderXAI Provider = "xai"
+ ProviderClaude Provider = "claude"
+ ProviderGemini Provider = "gemini"
+ ProviderLocal Provider = "local"
+)
+```
+
+### Message
+
+Represents a single message in a conversation:
+
+```go
+type Message struct {
+ Role Role `json:"role"`
+ Content string `json:"content"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+ Name string `json:"name,omitempty"`
+}
+```
+
+### ChatCompletionRequest
+
+Represents a request to complete a chat conversation:
+
+```go
+type ChatCompletionRequest struct {
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ MaxTokens *int `json:"max_tokens,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ Stream *bool `json:"stream,omitempty"`
+ Tools []Tool `json:"tools,omitempty"`
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ ExtraParams map[string]interface{} `json:"-"` // Provider-specific parameters
+}
+```
+
+## Main Interface
+
+### LLMProvider
+
+The main interface that all LLM providers must implement:
+
+```go
+type LLMProvider interface {
+ // ChatCompletion creates a chat completion
+ ChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error)
+
+ // CreateEmbeddings generates embeddings for the given input
+ CreateEmbeddings(ctx context.Context, req EmbeddingRequest) (*EmbeddingResponse, error)
+
+ // Close performs any necessary cleanup
+ Close() error
+}
+```
+
+## Provider Factory
+
+The package includes a factory system for creating and managing LLM providers:
+
+```go
+// Register a provider factory
+err := llm.RegisterProvider(llm.ProviderOpenAI, openaiFactory)
+
+// Create a provider
+provider, err := llm.CreateProvider(config)
+
+// Check if a provider is supported
+if llm.SupportsProvider(llm.ProviderOpenAI) {
+ // Provider is available
+}
+
+// List all registered providers
+providers := llm.ListProviders()
+```
+
+## Configuration
+
+Each provider can be configured with specific settings:
+
+```go
+config := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "your-api-key",
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ MaxRetries: 3,
+ ExtraConfig: map[string]interface{}{
+ "organization": "your-org-id",
+ },
+}
+```
+
+## Tool/Function Calling
+
+Support for function calling and tool usage:
+
+```go
+tools := []llm.Tool{
+ {
+ Type: "function",
+ Function: llm.Function{
+ Name: "get_weather",
+ Description: "Get the current weather for a location",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ },
+ "required": []string{"location"},
+ },
+ },
+ },
+}
+
+req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: messages,
+ Tools: tools,
+}
+```
+
+## Embeddings
+
+Generate embeddings for text:
+
+```go
+req := llm.EmbeddingRequest{
+ Input: "Hello, world!",
+ Model: "text-embedding-ada-002",
+}
+
+resp, err := provider.CreateEmbeddings(context.Background(), req)
+if err != nil {
+ log.Fatal(err)
+}
+
+fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding))
+```
+
+## Implementing a New Provider
+
+To implement a new LLM provider:
+
+1. **Implement the LLMProvider interface**:
+
+```go
+type MyProvider struct {
+ config llm.Config
+ client *http.Client
+}
+
+func (p *MyProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
+ // Implementation here
+}
+
+func (p *MyProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
+ // Implementation here
+}
+
+func (p *MyProvider) Close() error {
+ // Cleanup implementation
+ return nil
+}
+```
+
+2. **Create a factory**:
+
+```go
+type MyProviderFactory struct{}
+
+func (f *MyProviderFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
+ return &MyProvider{config: config}, nil
+}
+
+func (f *MyProviderFactory) SupportsProvider(provider llm.Provider) bool {
+ return provider == llm.ProviderMyProvider
+}
+```
+
+3. **Register the provider**:
+
+```go
+func init() {
+ llm.RegisterProvider(llm.ProviderMyProvider, &MyProviderFactory{})
+}
+```
+
+## Error Handling
+
+The package defines common error types:
+
+```go
+var (
+ ErrInvalidConfig = fmt.Errorf("invalid configuration")
+ ErrUnsupportedProvider = fmt.Errorf("unsupported provider")
+ ErrAPIKeyRequired = fmt.Errorf("API key is required")
+ ErrModelNotFound = fmt.Errorf("model not found")
+ ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded")
+ ErrContextCancelled = fmt.Errorf("context cancelled")
+ ErrTimeout = fmt.Errorf("request timeout")
+)
+```
+
+## Utilities
+
+The package includes utility functions:
+
+```go
+// Validate configuration
+err := llm.ValidateConfig(config)
+
+// Check if provider is valid
+if llm.IsValidProvider(llm.ProviderOpenAI) {
+ // Provider is valid
+}
+
+// Get default configuration
+config, err := llm.GetDefaultConfig(llm.ProviderOpenAI)
+
+// Merge custom config with defaults
+config = llm.MergeConfig(customConfig)
+
+// Validate requests
+err := llm.ValidateChatCompletionRequest(req)
+err := llm.ValidateEmbeddingRequest(req)
+
+// Estimate tokens
+tokens := llm.EstimateTokens(req)
+```
+
+## Thread Safety
+
+The factory and registry implementations are thread-safe and can be used concurrently from multiple goroutines.
+
+## Default Configurations
+
+The package provides default configurations for each provider:
+
+```go
+var DefaultConfigs = map[llm.Provider]llm.Config{
+ llm.ProviderOpenAI: {
+ Provider: llm.ProviderOpenAI,
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ MaxRetries: 3,
+ },
+ // ... other providers
+}
+```
+
+## Next Steps
+
+1. Implement the actual provider implementations (OpenAI, xAI, Claude, etc.)
+2. Add tests for the interface and implementations
+3. Add more utility functions as needed
+4. Consider adding caching and retry mechanisms
+5. Add support for more provider-specific features
\ No newline at end of file
diff --git a/server/llm/factory.go b/server/llm/factory.go
new file mode 100644
index 0000000..e425eee
--- /dev/null
+++ b/server/llm/factory.go
@@ -0,0 +1,206 @@
+package llm
+
+import (
+ "fmt"
+ "sync"
+)
+
+// GlobalProviderFactory is the main factory for creating LLM providers
+type GlobalProviderFactory struct {
+ providers map[Provider]ProviderFactory
+ mu sync.RWMutex
+}
+
+// NewGlobalProviderFactory creates a new global provider factory
+func NewGlobalProviderFactory() *GlobalProviderFactory {
+ return &GlobalProviderFactory{
+ providers: make(map[Provider]ProviderFactory),
+ }
+}
+
+// RegisterProvider registers a provider factory for a specific provider type
+func (f *GlobalProviderFactory) RegisterProvider(provider Provider, factory ProviderFactory) error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if !IsValidProvider(provider) {
+ return fmt.Errorf("unsupported provider: %s", provider)
+ }
+
+ f.providers[provider] = factory
+ return nil
+}
+
+// CreateProvider creates a new LLM provider instance
+func (f *GlobalProviderFactory) CreateProvider(config Config) (LLMProvider, error) {
+ f.mu.RLock()
+ factory, exists := f.providers[config.Provider]
+ f.mu.RUnlock()
+
+ if !exists {
+ return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
+ }
+
+ // Validate and merge config
+ if err := ValidateConfig(config); err != nil {
+ return nil, fmt.Errorf("invalid config: %w", err)
+ }
+
+ config = MergeConfig(config)
+
+ return factory.CreateProvider(config)
+}
+
+// SupportsProvider checks if the factory supports the given provider
+func (f *GlobalProviderFactory) SupportsProvider(provider Provider) bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ _, exists := f.providers[provider]
+ return exists
+}
+
+// ListSupportedProviders returns a list of supported providers
+func (f *GlobalProviderFactory) ListSupportedProviders() []Provider {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ providers := make([]Provider, 0, len(f.providers))
+ for provider := range f.providers {
+ providers = append(providers, provider)
+ }
+
+ return providers
+}
+
+// UnregisterProvider removes a provider factory
+func (f *GlobalProviderFactory) UnregisterProvider(provider Provider) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ delete(f.providers, provider)
+}
+
+// DefaultFactory is the default global factory instance
+var DefaultFactory = NewGlobalProviderFactory()
+
+// RegisterDefaultProvider registers a provider with the default factory
+func RegisterDefaultProvider(provider Provider, factory ProviderFactory) error {
+ return DefaultFactory.RegisterProvider(provider, factory)
+}
+
+// CreateDefaultProvider creates a provider using the default factory
+func CreateDefaultProvider(config Config) (LLMProvider, error) {
+ return DefaultFactory.CreateProvider(config)
+}
+
+// SupportsDefaultProvider checks if the default factory supports a provider
+func SupportsDefaultProvider(provider Provider) bool {
+ return DefaultFactory.SupportsProvider(provider)
+}
+
+// ListDefaultSupportedProviders returns providers supported by the default factory
+func ListDefaultSupportedProviders() []Provider {
+ return DefaultFactory.ListSupportedProviders()
+}
+
+// ProviderRegistry provides a simple way to register and manage providers
+type ProviderRegistry struct {
+ factories map[Provider]ProviderFactory
+ mu sync.RWMutex
+}
+
+// NewProviderRegistry creates a new provider registry
+func NewProviderRegistry() *ProviderRegistry {
+ return &ProviderRegistry{
+ factories: make(map[Provider]ProviderFactory),
+ }
+}
+
+// Register registers a provider factory
+func (r *ProviderRegistry) Register(provider Provider, factory ProviderFactory) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !IsValidProvider(provider) {
+ return fmt.Errorf("unsupported provider: %s", provider)
+ }
+
+ r.factories[provider] = factory
+ return nil
+}
+
+// Get retrieves a provider factory
+func (r *ProviderRegistry) Get(provider Provider) (ProviderFactory, bool) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ factory, exists := r.factories[provider]
+ return factory, exists
+}
+
+// Create creates a new LLM provider instance
+func (r *ProviderRegistry) Create(config Config) (LLMProvider, error) {
+ factory, exists := r.Get(config.Provider)
+ if !exists {
+ return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
+ }
+
+ // Validate and merge config
+ if err := ValidateConfig(config); err != nil {
+ return nil, fmt.Errorf("invalid config: %w", err)
+ }
+
+ config = MergeConfig(config)
+
+ return factory.CreateProvider(config)
+}
+
+// List returns all registered providers
+func (r *ProviderRegistry) List() []Provider {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ providers := make([]Provider, 0, len(r.factories))
+ for provider := range r.factories {
+ providers = append(providers, provider)
+ }
+
+ return providers
+}
+
+// Unregister removes a provider factory
+func (r *ProviderRegistry) Unregister(provider Provider) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ delete(r.factories, provider)
+}
+
+// DefaultRegistry is the default provider registry
+var DefaultRegistry = NewProviderRegistry()
+
+// RegisterProvider registers a provider with the default registry
+func RegisterProvider(provider Provider, factory ProviderFactory) error {
+ return DefaultRegistry.Register(provider, factory)
+}
+
+// CreateProvider creates a provider using the default registry
+func CreateProvider(config Config) (LLMProvider, error) {
+ return DefaultRegistry.Create(config)
+}
+
+// GetProviderFactory gets a provider factory from the default registry
+func GetProviderFactory(provider Provider) (ProviderFactory, bool) {
+ return DefaultRegistry.Get(provider)
+}
+
+// ListProviders returns all providers registered with the default registry
+func ListProviders() []Provider {
+ return DefaultRegistry.List()
+}
+
+// UnregisterProvider removes a provider from the default registry
+func UnregisterProvider(provider Provider) {
+ DefaultRegistry.Unregister(provider)
+}
diff --git a/server/llm/llm.go b/server/llm/llm.go
new file mode 100644
index 0000000..2fd3bf3
--- /dev/null
+++ b/server/llm/llm.go
@@ -0,0 +1,273 @@
+package llm
+
+import (
+ "context"
+ "time"
+)
+
+// LLMProvider defines the interface that all LLM providers must implement
+type LLMProvider interface {
+ // ChatCompletion creates a chat completion
+ ChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error)
+
+ // CreateEmbeddings generates embeddings for the given input
+ CreateEmbeddings(ctx context.Context, req EmbeddingRequest) (*EmbeddingResponse, error)
+
+ // Close performs any necessary cleanup
+ Close() error
+}
+
+// ProviderFactory creates LLM provider instances
+type ProviderFactory interface {
+ // CreateProvider creates a new LLM provider instance
+ CreateProvider(config Config) (LLMProvider, error)
+
+ // SupportsProvider checks if the factory supports the given provider
+ SupportsProvider(provider Provider) bool
+}
+
+// Provider represents different LLM service providers
+type Provider string
+
+const (
+ ProviderOpenAI Provider = "openai"
+ ProviderXAI Provider = "xai"
+ ProviderClaude Provider = "claude"
+ ProviderGemini Provider = "gemini"
+ ProviderLocal Provider = "local"
+)
+
+// Role represents the role of a message participant
+type Role string
+
+const (
+ RoleSystem Role = "system"
+ RoleUser Role = "user"
+ RoleAssistant Role = "assistant"
+ RoleTool Role = "tool"
+)
+
+// Message represents a single message in a conversation
+type Message struct {
+ Role Role `json:"role"`
+ Content string `json:"content"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+ Name string `json:"name,omitempty"`
+}
+
+// ToolCall represents a function/tool call request
+type ToolCall struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Function Function `json:"function"`
+}
+
+// Function represents a function definition
+type Function struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Parameters map[string]interface{} `json:"parameters,omitempty"`
+}
+
+// Tool represents a tool that can be called by the model
+type Tool struct {
+ Type string `json:"type"`
+ Function Function `json:"function"`
+}
+
+// ChatCompletionRequest represents a request to complete a chat conversation
+type ChatCompletionRequest struct {
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ MaxTokens *int `json:"max_tokens,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ N *int `json:"n,omitempty"`
+ Stream *bool `json:"stream,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+ PresencePenalty *float64 `json:"presence_penalty,omitempty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
+ LogitBias map[string]int `json:"logit_bias,omitempty"`
+ User string `json:"user,omitempty"`
+ Tools []Tool `json:"tools,omitempty"`
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ Seed *int64 `json:"seed,omitempty"`
+ ExtraParams map[string]interface{} `json:"-"` // For provider-specific parameters
+}
+
+// ResponseFormat specifies the format of the response
+type ResponseFormat struct {
+ Type string `json:"type"` // "text" or "json_object"
+}
+
+// ChatCompletionResponse represents a response from a chat completion request
+type ChatCompletionResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint,omitempty"`
+ Choices []ChatCompletionChoice `json:"choices"`
+ Usage Usage `json:"usage"`
+ Provider Provider `json:"provider"`
+}
+
+// ChatCompletionChoice represents a single choice in a chat completion response
+type ChatCompletionChoice struct {
+ Index int `json:"index"`
+ Message Message `json:"message"`
+ Logprobs *Logprobs `json:"logprobs,omitempty"`
+ FinishReason string `json:"finish_reason"`
+ Delta *Message `json:"delta,omitempty"` // For streaming
+ ExtraData map[string]interface{} `json:"-"` // For provider-specific data
+}
+
+// Logprobs represents log probability information
+type Logprobs struct {
+ Content []LogprobContent `json:"content,omitempty"`
+}
+
+// LogprobContent represents content with log probabilities
+type LogprobContent struct {
+ Token string `json:"token"`
+ Logprob float64 `json:"logprob"`
+ Bytes []int `json:"bytes,omitempty"`
+ TopLogprobs []TopLogprob `json:"top_logprobs,omitempty"`
+}
+
+// TopLogprob represents a top log probability
+type TopLogprob struct {
+ Token string `json:"token"`
+ Logprob float64 `json:"logprob"`
+ Bytes []int `json:"bytes,omitempty"`
+}
+
+// Usage represents token usage information
+type Usage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+// EmbeddingRequest represents a request to generate embeddings
+type EmbeddingRequest struct {
+ Input interface{} `json:"input"` // string, []string, or []int
+ Model string `json:"model"`
+ EncodingFormat string `json:"encoding_format,omitempty"`
+ Dimensions *int `json:"dimensions,omitempty"`
+ User string `json:"user,omitempty"`
+ ExtraParams map[string]interface{} `json:"-"` // For provider-specific parameters
+}
+
+// EmbeddingResponse represents a response from an embedding request
+type EmbeddingResponse struct {
+ Object string `json:"object"`
+ Data []Embedding `json:"data"`
+ Usage Usage `json:"usage"`
+ Model string `json:"model"`
+ Provider Provider `json:"provider"`
+}
+
+// Embedding represents a single embedding
+type Embedding struct {
+ Object string `json:"object"`
+ Embedding []float64 `json:"embedding"`
+ Index int `json:"index"`
+}
+
+// ModelInfo represents information about an available model
+type ModelInfo struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ OwnedBy string `json:"owned_by"`
+ Permission []ModelPermission `json:"permission"`
+ Root string `json:"root"`
+ Parent string `json:"parent"`
+ Provider Provider `json:"provider"`
+ ExtraData map[string]interface{} `json:"-"` // For provider-specific data
+}
+
+// ModelPermission represents permissions for a model
+type ModelPermission struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ AllowCreateEngine bool `json:"allow_create_engine"`
+ AllowSampling bool `json:"allow_sampling"`
+ AllowLogprobs bool `json:"allow_logprobs"`
+ AllowSearchIndices bool `json:"allow_search_indices"`
+ AllowView bool `json:"allow_view"`
+ AllowFineTuning bool `json:"allow_fine_tuning"`
+ Organization string `json:"organization"`
+ Group string `json:"group"`
+ IsBlocking bool `json:"is_blocking"`
+}
+
+// Error represents an error response from an LLM provider
+type Error struct {
+ Error struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Code string `json:"code,omitempty"`
+ Param string `json:"param,omitempty"`
+ } `json:"error"`
+}
+
+// Config represents configuration for an LLM provider
+type Config struct {
+ Provider Provider `json:"provider"`
+ APIKey string `json:"api_key"`
+ BaseURL string `json:"base_url,omitempty"`
+ Timeout time.Duration `json:"timeout,omitempty"`
+ MaxRetries int `json:"max_retries,omitempty"`
+ ExtraConfig map[string]interface{} `json:"extra_config,omitempty"`
+}
+
+// StreamResponse represents a streaming response chunk
+type StreamResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint,omitempty"`
+ Choices []ChatCompletionChoice `json:"choices"`
+ Usage *Usage `json:"usage,omitempty"`
+ Provider Provider `json:"provider"`
+}
+
+// DefaultConfigs provides default configurations for different providers
+var DefaultConfigs = map[Provider]Config{
+ ProviderOpenAI: {
+ Provider: ProviderOpenAI,
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ MaxRetries: 3,
+ },
+ ProviderXAI: {
+ Provider: ProviderXAI,
+ BaseURL: "https://api.x.ai/v1",
+ Timeout: 30 * time.Second,
+ MaxRetries: 3,
+ },
+ ProviderClaude: {
+ Provider: ProviderClaude,
+ BaseURL: "https://api.anthropic.com/v1",
+ Timeout: 30 * time.Second,
+ MaxRetries: 3,
+ },
+ ProviderGemini: {
+ Provider: ProviderGemini,
+ BaseURL: "https://generativelanguage.googleapis.com/v1",
+ Timeout: 30 * time.Second,
+ MaxRetries: 3,
+ },
+ ProviderLocal: {
+ Provider: ProviderLocal,
+ BaseURL: "http://localhost:11434",
+ Timeout: 60 * time.Second,
+ MaxRetries: 1,
+ },
+}
diff --git a/server/llm/utils.go b/server/llm/utils.go
new file mode 100644
index 0000000..59064b3
--- /dev/null
+++ b/server/llm/utils.go
@@ -0,0 +1,218 @@
+package llm
+
+import (
+ "context"
+ "fmt"
+ "time"
+)
+
+// Common error types
+var (
+ ErrInvalidConfig = fmt.Errorf("invalid configuration")
+ ErrUnsupportedProvider = fmt.Errorf("unsupported provider")
+ ErrAPIKeyRequired = fmt.Errorf("API key is required")
+ ErrModelNotFound = fmt.Errorf("model not found")
+ ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded")
+ ErrContextCancelled = fmt.Errorf("context cancelled")
+ ErrTimeout = fmt.Errorf("request timeout")
+)
+
+// ValidateConfig validates a configuration for an LLM provider
+func ValidateConfig(config Config) error {
+ if config.APIKey == "" {
+ return ErrAPIKeyRequired
+ }
+
+ if config.BaseURL == "" {
+ return ErrInvalidConfig
+ }
+
+ if config.Timeout <= 0 {
+ config.Timeout = 30 * time.Second
+ }
+
+ if config.MaxRetries < 0 {
+ config.MaxRetries = 3
+ }
+
+ return nil
+}
+
+// IsValidProvider checks if a provider is supported
+func IsValidProvider(provider Provider) bool {
+ switch provider {
+ case ProviderOpenAI, ProviderXAI, ProviderClaude, ProviderGemini, ProviderLocal:
+ return true
+ default:
+ return false
+ }
+}
+
+// GetDefaultConfig returns the default configuration for a provider
+func GetDefaultConfig(provider Provider) (Config, error) {
+ if !IsValidProvider(provider) {
+ return Config{}, ErrUnsupportedProvider
+ }
+
+ config, exists := DefaultConfigs[provider]
+ if !exists {
+ return Config{}, fmt.Errorf("no default config for provider: %s", provider)
+ }
+
+ return config, nil
+}
+
+// MergeConfig merges a custom config with default config
+func MergeConfig(custom Config) Config {
+ defaultConfig, err := GetDefaultConfig(custom.Provider)
+ if err != nil {
+ // If no default config, return the custom config as-is
+ return custom
+ }
+
+ // Merge custom config with defaults
+ if custom.BaseURL == "" {
+ custom.BaseURL = defaultConfig.BaseURL
+ }
+ if custom.Timeout == 0 {
+ custom.Timeout = defaultConfig.Timeout
+ }
+ if custom.MaxRetries == 0 {
+ custom.MaxRetries = defaultConfig.MaxRetries
+ }
+
+ return custom
+}
+
+// CreateContextWithTimeout creates a context with timeout
+func CreateContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) {
+ return context.WithTimeout(context.Background(), timeout)
+}
+
+// ValidateChatCompletionRequest validates a chat completion request
+func ValidateChatCompletionRequest(req ChatCompletionRequest) error {
+ if req.Model == "" {
+ return fmt.Errorf("model is required")
+ }
+
+ if len(req.Messages) == 0 {
+ return fmt.Errorf("at least one message is required")
+ }
+
+ // Validate temperature range
+ if req.Temperature != nil {
+ if *req.Temperature < 0 || *req.Temperature > 2 {
+ return fmt.Errorf("temperature must be between 0 and 2")
+ }
+ }
+
+ // Validate top_p range
+ if req.TopP != nil {
+ if *req.TopP < 0 || *req.TopP > 1 {
+ return fmt.Errorf("top_p must be between 0 and 1")
+ }
+ }
+
+ // Validate max_tokens
+ if req.MaxTokens != nil {
+ if *req.MaxTokens <= 0 {
+ return fmt.Errorf("max_tokens must be positive")
+ }
+ }
+
+ return nil
+}
+
+// ValidateEmbeddingRequest validates an embedding request
+func ValidateEmbeddingRequest(req EmbeddingRequest) error {
+ if req.Model == "" {
+ return fmt.Errorf("model is required")
+ }
+
+ if req.Input == nil {
+ return fmt.Errorf("input is required")
+ }
+
+ // Validate input type
+ switch req.Input.(type) {
+ case string, []string, []int:
+ // Valid types
+ default:
+ return fmt.Errorf("input must be string, []string, or []int")
+ }
+
+ return nil
+}
+
+// ConvertMessagesToString converts messages to a string representation
+func ConvertMessagesToString(messages []Message) string {
+ var result string
+ for i, msg := range messages {
+ if i > 0 {
+ result += "\n"
+ }
+ result += fmt.Sprintf("%s: %s", msg.Role, msg.Content)
+ }
+ return result
+}
+
+// CountTokens estimates the number of tokens in a text (rough approximation)
+func CountTokens(text string) int {
+ // This is a very rough approximation
+ // In practice, you'd want to use a proper tokenizer
+ words := len(text) / 4 // Rough estimate: 1 token ≈ 4 characters
+ if words < 1 {
+ words = 1
+ }
+ return words
+}
+
+// EstimateTokens estimates tokens for a chat completion request
+func EstimateTokens(req ChatCompletionRequest) int {
+ total := 0
+
+ // Count tokens in messages
+ for _, msg := range req.Messages {
+ total += CountTokens(msg.Content)
+ // Add some overhead for role and formatting
+ total += 4
+ }
+
+ // Add some overhead for the request
+ total += 10
+
+ return total
+}
+
+// ProviderDisplayName returns a human-readable name for a provider
+func ProviderDisplayName(provider Provider) string {
+ switch provider {
+ case ProviderOpenAI:
+ return "OpenAI"
+ case ProviderXAI:
+ return "xAI"
+ case ProviderClaude:
+ return "Claude (Anthropic)"
+ case ProviderGemini:
+ return "Gemini (Google)"
+ case ProviderLocal:
+ return "Local"
+ default:
+ return string(provider)
+ }
+}
+
+// IsToolCall checks if a message contains tool calls
+func IsToolCall(msg Message) bool {
+ return len(msg.ToolCalls) > 0
+}
+
+// HasToolCalls checks if any message in the conversation has tool calls
+func HasToolCalls(messages []Message) bool {
+ for _, msg := range messages {
+ if IsToolCall(msg) {
+ return true
+ }
+ }
+ return false
+}