blob: 59064b3d91a479c70fa3a7ecb0a44de4cf66ed0f [file] [log] [blame]
iomodoa97eb222025-07-26 11:18:17 +04001package llm
2
3import (
4 "context"
5 "fmt"
6 "time"
7)
8
9// Common error types
10var (
11 ErrInvalidConfig = fmt.Errorf("invalid configuration")
12 ErrUnsupportedProvider = fmt.Errorf("unsupported provider")
13 ErrAPIKeyRequired = fmt.Errorf("API key is required")
14 ErrModelNotFound = fmt.Errorf("model not found")
15 ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded")
16 ErrContextCancelled = fmt.Errorf("context cancelled")
17 ErrTimeout = fmt.Errorf("request timeout")
18)
19
20// ValidateConfig validates a configuration for an LLM provider
21func ValidateConfig(config Config) error {
22 if config.APIKey == "" {
23 return ErrAPIKeyRequired
24 }
25
26 if config.BaseURL == "" {
27 return ErrInvalidConfig
28 }
29
30 if config.Timeout <= 0 {
31 config.Timeout = 30 * time.Second
32 }
33
34 if config.MaxRetries < 0 {
35 config.MaxRetries = 3
36 }
37
38 return nil
39}
40
41// IsValidProvider checks if a provider is supported
42func IsValidProvider(provider Provider) bool {
43 switch provider {
44 case ProviderOpenAI, ProviderXAI, ProviderClaude, ProviderGemini, ProviderLocal:
45 return true
46 default:
47 return false
48 }
49}
50
51// GetDefaultConfig returns the default configuration for a provider
52func GetDefaultConfig(provider Provider) (Config, error) {
53 if !IsValidProvider(provider) {
54 return Config{}, ErrUnsupportedProvider
55 }
56
57 config, exists := DefaultConfigs[provider]
58 if !exists {
59 return Config{}, fmt.Errorf("no default config for provider: %s", provider)
60 }
61
62 return config, nil
63}
64
65// MergeConfig merges a custom config with default config
66func MergeConfig(custom Config) Config {
67 defaultConfig, err := GetDefaultConfig(custom.Provider)
68 if err != nil {
69 // If no default config, return the custom config as-is
70 return custom
71 }
72
73 // Merge custom config with defaults
74 if custom.BaseURL == "" {
75 custom.BaseURL = defaultConfig.BaseURL
76 }
77 if custom.Timeout == 0 {
78 custom.Timeout = defaultConfig.Timeout
79 }
80 if custom.MaxRetries == 0 {
81 custom.MaxRetries = defaultConfig.MaxRetries
82 }
83
84 return custom
85}
86
87// CreateContextWithTimeout creates a context with timeout
88func CreateContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) {
89 return context.WithTimeout(context.Background(), timeout)
90}
91
92// ValidateChatCompletionRequest validates a chat completion request
93func ValidateChatCompletionRequest(req ChatCompletionRequest) error {
94 if req.Model == "" {
95 return fmt.Errorf("model is required")
96 }
97
98 if len(req.Messages) == 0 {
99 return fmt.Errorf("at least one message is required")
100 }
101
102 // Validate temperature range
103 if req.Temperature != nil {
104 if *req.Temperature < 0 || *req.Temperature > 2 {
105 return fmt.Errorf("temperature must be between 0 and 2")
106 }
107 }
108
109 // Validate top_p range
110 if req.TopP != nil {
111 if *req.TopP < 0 || *req.TopP > 1 {
112 return fmt.Errorf("top_p must be between 0 and 1")
113 }
114 }
115
116 // Validate max_tokens
117 if req.MaxTokens != nil {
118 if *req.MaxTokens <= 0 {
119 return fmt.Errorf("max_tokens must be positive")
120 }
121 }
122
123 return nil
124}
125
126// ValidateEmbeddingRequest validates an embedding request
127func ValidateEmbeddingRequest(req EmbeddingRequest) error {
128 if req.Model == "" {
129 return fmt.Errorf("model is required")
130 }
131
132 if req.Input == nil {
133 return fmt.Errorf("input is required")
134 }
135
136 // Validate input type
137 switch req.Input.(type) {
138 case string, []string, []int:
139 // Valid types
140 default:
141 return fmt.Errorf("input must be string, []string, or []int")
142 }
143
144 return nil
145}
146
147// ConvertMessagesToString converts messages to a string representation
148func ConvertMessagesToString(messages []Message) string {
149 var result string
150 for i, msg := range messages {
151 if i > 0 {
152 result += "\n"
153 }
154 result += fmt.Sprintf("%s: %s", msg.Role, msg.Content)
155 }
156 return result
157}
158
159// CountTokens estimates the number of tokens in a text (rough approximation)
160func CountTokens(text string) int {
161 // This is a very rough approximation
162 // In practice, you'd want to use a proper tokenizer
163 words := len(text) / 4 // Rough estimate: 1 token ≈ 4 characters
164 if words < 1 {
165 words = 1
166 }
167 return words
168}
169
170// EstimateTokens estimates tokens for a chat completion request
171func EstimateTokens(req ChatCompletionRequest) int {
172 total := 0
173
174 // Count tokens in messages
175 for _, msg := range req.Messages {
176 total += CountTokens(msg.Content)
177 // Add some overhead for role and formatting
178 total += 4
179 }
180
181 // Add some overhead for the request
182 total += 10
183
184 return total
185}
186
187// ProviderDisplayName returns a human-readable name for a provider
188func ProviderDisplayName(provider Provider) string {
189 switch provider {
190 case ProviderOpenAI:
191 return "OpenAI"
192 case ProviderXAI:
193 return "xAI"
194 case ProviderClaude:
195 return "Claude (Anthropic)"
196 case ProviderGemini:
197 return "Gemini (Google)"
198 case ProviderLocal:
199 return "Local"
200 default:
201 return string(provider)
202 }
203}
204
205// IsToolCall checks if a message contains tool calls
206func IsToolCall(msg Message) bool {
207 return len(msg.ToolCalls) > 0
208}
209
210// HasToolCalls checks if any message in the conversation has tool calls
211func HasToolCalls(messages []Message) bool {
212 for _, msg := range messages {
213 if IsToolCall(msg) {
214 return true
215 }
216 }
217 return false
218}