| package llm |
| |
| import ( |
| "context" |
| "fmt" |
| "time" |
| ) |
| |
| // Common error types |
| var ( |
| ErrInvalidConfig = fmt.Errorf("invalid configuration") |
| ErrUnsupportedProvider = fmt.Errorf("unsupported provider") |
| ErrAPIKeyRequired = fmt.Errorf("API key is required") |
| ErrModelNotFound = fmt.Errorf("model not found") |
| ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded") |
| ErrContextCancelled = fmt.Errorf("context cancelled") |
| ErrTimeout = fmt.Errorf("request timeout") |
| ) |
| |
| // ValidateConfig validates a configuration for an LLM provider |
| func ValidateConfig(config Config) error { |
| if config.APIKey == "" { |
| return ErrAPIKeyRequired |
| } |
| |
| if config.BaseURL == "" { |
| return ErrInvalidConfig |
| } |
| |
| if config.Timeout <= 0 { |
| config.Timeout = 30 * time.Second |
| } |
| |
| if config.MaxRetries < 0 { |
| config.MaxRetries = 3 |
| } |
| |
| return nil |
| } |
| |
| // IsValidProvider checks if a provider is supported |
| func IsValidProvider(provider Provider) bool { |
| switch provider { |
| case ProviderOpenAI, ProviderXAI, ProviderClaude, ProviderGemini, ProviderLocal: |
| return true |
| default: |
| return false |
| } |
| } |
| |
| // GetDefaultConfig returns the default configuration for a provider |
| func GetDefaultConfig(provider Provider) (Config, error) { |
| if !IsValidProvider(provider) { |
| return Config{}, ErrUnsupportedProvider |
| } |
| |
| config, exists := DefaultConfigs[provider] |
| if !exists { |
| return Config{}, fmt.Errorf("no default config for provider: %s", provider) |
| } |
| |
| return config, nil |
| } |
| |
| // MergeConfig merges a custom config with default config |
| func MergeConfig(custom Config) Config { |
| defaultConfig, err := GetDefaultConfig(custom.Provider) |
| if err != nil { |
| // If no default config, return the custom config as-is |
| return custom |
| } |
| |
| // Merge custom config with defaults |
| if custom.BaseURL == "" { |
| custom.BaseURL = defaultConfig.BaseURL |
| } |
| if custom.Timeout == 0 { |
| custom.Timeout = defaultConfig.Timeout |
| } |
| if custom.MaxRetries == 0 { |
| custom.MaxRetries = defaultConfig.MaxRetries |
| } |
| |
| return custom |
| } |
| |
| // CreateContextWithTimeout creates a context with timeout |
| func CreateContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) { |
| return context.WithTimeout(context.Background(), timeout) |
| } |
| |
| // ValidateChatCompletionRequest validates a chat completion request |
| func ValidateChatCompletionRequest(req ChatCompletionRequest) error { |
| if req.Model == "" { |
| return fmt.Errorf("model is required") |
| } |
| |
| if len(req.Messages) == 0 { |
| return fmt.Errorf("at least one message is required") |
| } |
| |
| // Validate temperature range |
| if req.Temperature != nil { |
| if *req.Temperature < 0 || *req.Temperature > 2 { |
| return fmt.Errorf("temperature must be between 0 and 2") |
| } |
| } |
| |
| // Validate top_p range |
| if req.TopP != nil { |
| if *req.TopP < 0 || *req.TopP > 1 { |
| return fmt.Errorf("top_p must be between 0 and 1") |
| } |
| } |
| |
| // Validate max_tokens |
| if req.MaxTokens != nil { |
| if *req.MaxTokens <= 0 { |
| return fmt.Errorf("max_tokens must be positive") |
| } |
| } |
| |
| return nil |
| } |
| |
| // ValidateEmbeddingRequest validates an embedding request |
| func ValidateEmbeddingRequest(req EmbeddingRequest) error { |
| if req.Model == "" { |
| return fmt.Errorf("model is required") |
| } |
| |
| if req.Input == nil { |
| return fmt.Errorf("input is required") |
| } |
| |
| // Validate input type |
| switch req.Input.(type) { |
| case string, []string, []int: |
| // Valid types |
| default: |
| return fmt.Errorf("input must be string, []string, or []int") |
| } |
| |
| return nil |
| } |
| |
| // ConvertMessagesToString converts messages to a string representation |
| func ConvertMessagesToString(messages []Message) string { |
| var result string |
| for i, msg := range messages { |
| if i > 0 { |
| result += "\n" |
| } |
| result += fmt.Sprintf("%s: %s", msg.Role, msg.Content) |
| } |
| return result |
| } |
| |
| // CountTokens estimates the number of tokens in a text (rough approximation) |
| func CountTokens(text string) int { |
| // This is a very rough approximation |
| // In practice, you'd want to use a proper tokenizer |
| words := len(text) / 4 // Rough estimate: 1 token ≈ 4 characters |
| if words < 1 { |
| words = 1 |
| } |
| return words |
| } |
| |
| // EstimateTokens estimates tokens for a chat completion request |
| func EstimateTokens(req ChatCompletionRequest) int { |
| total := 0 |
| |
| // Count tokens in messages |
| for _, msg := range req.Messages { |
| total += CountTokens(msg.Content) |
| // Add some overhead for role and formatting |
| total += 4 |
| } |
| |
| // Add some overhead for the request |
| total += 10 |
| |
| return total |
| } |
| |
| // ProviderDisplayName returns a human-readable name for a provider |
| func ProviderDisplayName(provider Provider) string { |
| switch provider { |
| case ProviderOpenAI: |
| return "OpenAI" |
| case ProviderXAI: |
| return "xAI" |
| case ProviderClaude: |
| return "Claude (Anthropic)" |
| case ProviderGemini: |
| return "Gemini (Google)" |
| case ProviderLocal: |
| return "Local" |
| default: |
| return string(provider) |
| } |
| } |
| |
| // IsToolCall checks if a message contains tool calls |
| func IsToolCall(msg Message) bool { |
| return len(msg.ToolCalls) > 0 |
| } |
| |
| // HasToolCalls checks if any message in the conversation has tool calls |
| func HasToolCalls(messages []Message) bool { |
| for _, msg := range messages { |
| if IsToolCall(msg) { |
| return true |
| } |
| } |
| return false |
| } |