Add llm interface
Change-Id: Idf599500fc131fb9509102e38736a6baeff6d6d8
diff --git a/server/llm/utils.go b/server/llm/utils.go
new file mode 100644
index 0000000..59064b3
--- /dev/null
+++ b/server/llm/utils.go
@@ -0,0 +1,218 @@
+package llm
+
+import (
+ "context"
+ "fmt"
+ "time"
+)
+
+// Common error types
+var (
+ ErrInvalidConfig = fmt.Errorf("invalid configuration")
+ ErrUnsupportedProvider = fmt.Errorf("unsupported provider")
+ ErrAPIKeyRequired = fmt.Errorf("API key is required")
+ ErrModelNotFound = fmt.Errorf("model not found")
+ ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded")
+ ErrContextCancelled = fmt.Errorf("context cancelled")
+ ErrTimeout = fmt.Errorf("request timeout")
+)
+
+// ValidateConfig validates a configuration for an LLM provider
+func ValidateConfig(config Config) error {
+ if config.APIKey == "" {
+ return ErrAPIKeyRequired
+ }
+
+ if config.BaseURL == "" {
+ return ErrInvalidConfig
+ }
+
+ if config.Timeout <= 0 {
+ config.Timeout = 30 * time.Second
+ }
+
+ if config.MaxRetries < 0 {
+ config.MaxRetries = 3
+ }
+
+ return nil
+}
+
+// IsValidProvider checks if a provider is supported
+func IsValidProvider(provider Provider) bool {
+ switch provider {
+ case ProviderOpenAI, ProviderXAI, ProviderClaude, ProviderGemini, ProviderLocal:
+ return true
+ default:
+ return false
+ }
+}
+
+// GetDefaultConfig returns the default configuration for a provider
+func GetDefaultConfig(provider Provider) (Config, error) {
+ if !IsValidProvider(provider) {
+ return Config{}, ErrUnsupportedProvider
+ }
+
+ config, exists := DefaultConfigs[provider]
+ if !exists {
+ return Config{}, fmt.Errorf("no default config for provider: %s", provider)
+ }
+
+ return config, nil
+}
+
+// MergeConfig merges a custom config with default config
+func MergeConfig(custom Config) Config {
+ defaultConfig, err := GetDefaultConfig(custom.Provider)
+ if err != nil {
+ // If no default config, return the custom config as-is
+ return custom
+ }
+
+ // Merge custom config with defaults
+ if custom.BaseURL == "" {
+ custom.BaseURL = defaultConfig.BaseURL
+ }
+ if custom.Timeout == 0 {
+ custom.Timeout = defaultConfig.Timeout
+ }
+ if custom.MaxRetries == 0 {
+ custom.MaxRetries = defaultConfig.MaxRetries
+ }
+
+ return custom
+}
+
+// CreateContextWithTimeout creates a context with timeout
+func CreateContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) {
+ return context.WithTimeout(context.Background(), timeout)
+}
+
+// ValidateChatCompletionRequest validates a chat completion request
+func ValidateChatCompletionRequest(req ChatCompletionRequest) error {
+ if req.Model == "" {
+ return fmt.Errorf("model is required")
+ }
+
+ if len(req.Messages) == 0 {
+ return fmt.Errorf("at least one message is required")
+ }
+
+ // Validate temperature range
+ if req.Temperature != nil {
+ if *req.Temperature < 0 || *req.Temperature > 2 {
+ return fmt.Errorf("temperature must be between 0 and 2")
+ }
+ }
+
+ // Validate top_p range
+ if req.TopP != nil {
+ if *req.TopP < 0 || *req.TopP > 1 {
+ return fmt.Errorf("top_p must be between 0 and 1")
+ }
+ }
+
+ // Validate max_tokens
+ if req.MaxTokens != nil {
+ if *req.MaxTokens <= 0 {
+ return fmt.Errorf("max_tokens must be positive")
+ }
+ }
+
+ return nil
+}
+
+// ValidateEmbeddingRequest validates an embedding request
+func ValidateEmbeddingRequest(req EmbeddingRequest) error {
+ if req.Model == "" {
+ return fmt.Errorf("model is required")
+ }
+
+ if req.Input == nil {
+ return fmt.Errorf("input is required")
+ }
+
+ // Validate input type
+ switch req.Input.(type) {
+ case string, []string, []int:
+ // Valid types
+ default:
+ return fmt.Errorf("input must be string, []string, or []int")
+ }
+
+ return nil
+}
+
+// ConvertMessagesToString converts messages to a string representation
+func ConvertMessagesToString(messages []Message) string {
+ var result string
+ for i, msg := range messages {
+ if i > 0 {
+ result += "\n"
+ }
+ result += fmt.Sprintf("%s: %s", msg.Role, msg.Content)
+ }
+ return result
+}
+
+// CountTokens estimates the number of tokens in a text (rough approximation)
+func CountTokens(text string) int {
+ // This is a very rough approximation
+ // In practice, you'd want to use a proper tokenizer
+ words := len(text) / 4 // Rough estimate: 1 token ≈ 4 characters
+ if words < 1 {
+ words = 1
+ }
+ return words
+}
+
+// EstimateTokens estimates tokens for a chat completion request
+func EstimateTokens(req ChatCompletionRequest) int {
+ total := 0
+
+ // Count tokens in messages
+ for _, msg := range req.Messages {
+ total += CountTokens(msg.Content)
+ // Add some overhead for role and formatting
+ total += 4
+ }
+
+ // Add some overhead for the request
+ total += 10
+
+ return total
+}
+
+// ProviderDisplayName returns a human-readable name for a provider
+func ProviderDisplayName(provider Provider) string {
+ switch provider {
+ case ProviderOpenAI:
+ return "OpenAI"
+ case ProviderXAI:
+ return "xAI"
+ case ProviderClaude:
+ return "Claude (Anthropic)"
+ case ProviderGemini:
+ return "Gemini (Google)"
+ case ProviderLocal:
+ return "Local"
+ default:
+ return string(provider)
+ }
+}
+
+// IsToolCall checks if a message contains tool calls
+func IsToolCall(msg Message) bool {
+ return len(msg.ToolCalls) > 0
+}
+
+// HasToolCalls checks if any message in the conversation has tool calls
+func HasToolCalls(messages []Message) bool {
+ for _, msg := range messages {
+ if IsToolCall(msg) {
+ return true
+ }
+ }
+ return false
+}