blob: 59064b3d91a479c70fa3a7ecb0a44de4cf66ed0f [file] [log] [blame]
package llm
import (
"context"
"fmt"
"time"
)
// Common error types
var (
ErrInvalidConfig = fmt.Errorf("invalid configuration")
ErrUnsupportedProvider = fmt.Errorf("unsupported provider")
ErrAPIKeyRequired = fmt.Errorf("API key is required")
ErrModelNotFound = fmt.Errorf("model not found")
ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded")
ErrContextCancelled = fmt.Errorf("context cancelled")
ErrTimeout = fmt.Errorf("request timeout")
)
// ValidateConfig validates a configuration for an LLM provider
func ValidateConfig(config Config) error {
if config.APIKey == "" {
return ErrAPIKeyRequired
}
if config.BaseURL == "" {
return ErrInvalidConfig
}
if config.Timeout <= 0 {
config.Timeout = 30 * time.Second
}
if config.MaxRetries < 0 {
config.MaxRetries = 3
}
return nil
}
// IsValidProvider checks if a provider is supported
func IsValidProvider(provider Provider) bool {
switch provider {
case ProviderOpenAI, ProviderXAI, ProviderClaude, ProviderGemini, ProviderLocal:
return true
default:
return false
}
}
// GetDefaultConfig returns the default configuration for a provider
func GetDefaultConfig(provider Provider) (Config, error) {
if !IsValidProvider(provider) {
return Config{}, ErrUnsupportedProvider
}
config, exists := DefaultConfigs[provider]
if !exists {
return Config{}, fmt.Errorf("no default config for provider: %s", provider)
}
return config, nil
}
// MergeConfig merges a custom config with default config
func MergeConfig(custom Config) Config {
defaultConfig, err := GetDefaultConfig(custom.Provider)
if err != nil {
// If no default config, return the custom config as-is
return custom
}
// Merge custom config with defaults
if custom.BaseURL == "" {
custom.BaseURL = defaultConfig.BaseURL
}
if custom.Timeout == 0 {
custom.Timeout = defaultConfig.Timeout
}
if custom.MaxRetries == 0 {
custom.MaxRetries = defaultConfig.MaxRetries
}
return custom
}
// CreateContextWithTimeout creates a context with timeout
func CreateContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}
// ValidateChatCompletionRequest validates a chat completion request
func ValidateChatCompletionRequest(req ChatCompletionRequest) error {
if req.Model == "" {
return fmt.Errorf("model is required")
}
if len(req.Messages) == 0 {
return fmt.Errorf("at least one message is required")
}
// Validate temperature range
if req.Temperature != nil {
if *req.Temperature < 0 || *req.Temperature > 2 {
return fmt.Errorf("temperature must be between 0 and 2")
}
}
// Validate top_p range
if req.TopP != nil {
if *req.TopP < 0 || *req.TopP > 1 {
return fmt.Errorf("top_p must be between 0 and 1")
}
}
// Validate max_tokens
if req.MaxTokens != nil {
if *req.MaxTokens <= 0 {
return fmt.Errorf("max_tokens must be positive")
}
}
return nil
}
// ValidateEmbeddingRequest validates an embedding request
func ValidateEmbeddingRequest(req EmbeddingRequest) error {
if req.Model == "" {
return fmt.Errorf("model is required")
}
if req.Input == nil {
return fmt.Errorf("input is required")
}
// Validate input type
switch req.Input.(type) {
case string, []string, []int:
// Valid types
default:
return fmt.Errorf("input must be string, []string, or []int")
}
return nil
}
// ConvertMessagesToString converts messages to a string representation
func ConvertMessagesToString(messages []Message) string {
var result string
for i, msg := range messages {
if i > 0 {
result += "\n"
}
result += fmt.Sprintf("%s: %s", msg.Role, msg.Content)
}
return result
}
// CountTokens estimates the number of tokens in a text (rough approximation)
func CountTokens(text string) int {
// This is a very rough approximation
// In practice, you'd want to use a proper tokenizer
words := len(text) / 4 // Rough estimate: 1 token ≈ 4 characters
if words < 1 {
words = 1
}
return words
}
// EstimateTokens estimates tokens for a chat completion request
func EstimateTokens(req ChatCompletionRequest) int {
total := 0
// Count tokens in messages
for _, msg := range req.Messages {
total += CountTokens(msg.Content)
// Add some overhead for role and formatting
total += 4
}
// Add some overhead for the request
total += 10
return total
}
// ProviderDisplayName returns a human-readable name for a provider
func ProviderDisplayName(provider Provider) string {
switch provider {
case ProviderOpenAI:
return "OpenAI"
case ProviderXAI:
return "xAI"
case ProviderClaude:
return "Claude (Anthropic)"
case ProviderGemini:
return "Gemini (Google)"
case ProviderLocal:
return "Local"
default:
return string(provider)
}
}
// IsToolCall checks if a message contains tool calls
func IsToolCall(msg Message) bool {
return len(msg.ToolCalls) > 0
}
// HasToolCalls checks if any message in the conversation has tool calls
func HasToolCalls(messages []Message) bool {
for _, msg := range messages {
if IsToolCall(msg) {
return true
}
}
return false
}