blob: 73813641f2028a9b7e24557a84b5d25c499495a0 [file] [log] [blame]
iomodoa97eb222025-07-26 11:18:17 +04001package llm
2
3import (
4 "context"
5 "fmt"
6 "time"
7)
8
9// Common error types
10var (
11 ErrInvalidConfig = fmt.Errorf("invalid configuration")
12 ErrUnsupportedProvider = fmt.Errorf("unsupported provider")
13 ErrAPIKeyRequired = fmt.Errorf("API key is required")
14 ErrModelNotFound = fmt.Errorf("model not found")
15 ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded")
16 ErrContextCancelled = fmt.Errorf("context cancelled")
17 ErrTimeout = fmt.Errorf("request timeout")
18)
19
20// ValidateConfig validates a configuration for an LLM provider
21func ValidateConfig(config Config) error {
iomodof1ddefe2025-07-28 09:02:05 +040022 // Fake provider doesn't need API key
23 if config.Provider != ProviderFake && config.APIKey == "" {
iomodoa97eb222025-07-26 11:18:17 +040024 return ErrAPIKeyRequired
25 }
26
27 if config.BaseURL == "" {
28 return ErrInvalidConfig
29 }
30
31 if config.Timeout <= 0 {
32 config.Timeout = 30 * time.Second
33 }
34
35 if config.MaxRetries < 0 {
36 config.MaxRetries = 3
37 }
38
39 return nil
40}
41
42// IsValidProvider checks if a provider is supported
43func IsValidProvider(provider Provider) bool {
44 switch provider {
iomodof1ddefe2025-07-28 09:02:05 +040045 case ProviderOpenAI, ProviderXAI, ProviderClaude, ProviderGemini, ProviderLocal, ProviderFake:
iomodoa97eb222025-07-26 11:18:17 +040046 return true
47 default:
48 return false
49 }
50}
51
52// GetDefaultConfig returns the default configuration for a provider
53func GetDefaultConfig(provider Provider) (Config, error) {
54 if !IsValidProvider(provider) {
55 return Config{}, ErrUnsupportedProvider
56 }
57
58 config, exists := DefaultConfigs[provider]
59 if !exists {
60 return Config{}, fmt.Errorf("no default config for provider: %s", provider)
61 }
62
63 return config, nil
64}
65
66// MergeConfig merges a custom config with default config
67func MergeConfig(custom Config) Config {
68 defaultConfig, err := GetDefaultConfig(custom.Provider)
69 if err != nil {
70 // If no default config, return the custom config as-is
71 return custom
72 }
73
74 // Merge custom config with defaults
75 if custom.BaseURL == "" {
76 custom.BaseURL = defaultConfig.BaseURL
77 }
78 if custom.Timeout == 0 {
79 custom.Timeout = defaultConfig.Timeout
80 }
81 if custom.MaxRetries == 0 {
82 custom.MaxRetries = defaultConfig.MaxRetries
83 }
84
85 return custom
86}
87
88// CreateContextWithTimeout creates a context with timeout
89func CreateContextWithTimeout(timeout time.Duration) (context.Context, context.CancelFunc) {
90 return context.WithTimeout(context.Background(), timeout)
91}
92
93// ValidateChatCompletionRequest validates a chat completion request
94func ValidateChatCompletionRequest(req ChatCompletionRequest) error {
95 if req.Model == "" {
96 return fmt.Errorf("model is required")
97 }
98
99 if len(req.Messages) == 0 {
100 return fmt.Errorf("at least one message is required")
101 }
102
103 // Validate temperature range
104 if req.Temperature != nil {
105 if *req.Temperature < 0 || *req.Temperature > 2 {
106 return fmt.Errorf("temperature must be between 0 and 2")
107 }
108 }
109
110 // Validate top_p range
111 if req.TopP != nil {
112 if *req.TopP < 0 || *req.TopP > 1 {
113 return fmt.Errorf("top_p must be between 0 and 1")
114 }
115 }
116
117 // Validate max_tokens
118 if req.MaxTokens != nil {
119 if *req.MaxTokens <= 0 {
120 return fmt.Errorf("max_tokens must be positive")
121 }
122 }
123
124 return nil
125}
126
127// ValidateEmbeddingRequest validates an embedding request
128func ValidateEmbeddingRequest(req EmbeddingRequest) error {
129 if req.Model == "" {
130 return fmt.Errorf("model is required")
131 }
132
133 if req.Input == nil {
134 return fmt.Errorf("input is required")
135 }
136
137 // Validate input type
138 switch req.Input.(type) {
139 case string, []string, []int:
140 // Valid types
141 default:
142 return fmt.Errorf("input must be string, []string, or []int")
143 }
144
145 return nil
146}
147
148// ConvertMessagesToString converts messages to a string representation
149func ConvertMessagesToString(messages []Message) string {
150 var result string
151 for i, msg := range messages {
152 if i > 0 {
153 result += "\n"
154 }
155 result += fmt.Sprintf("%s: %s", msg.Role, msg.Content)
156 }
157 return result
158}
159
160// CountTokens estimates the number of tokens in a text (rough approximation)
161func CountTokens(text string) int {
162 // This is a very rough approximation
163 // In practice, you'd want to use a proper tokenizer
164 words := len(text) / 4 // Rough estimate: 1 token ≈ 4 characters
165 if words < 1 {
166 words = 1
167 }
168 return words
169}
170
171// EstimateTokens estimates tokens for a chat completion request
172func EstimateTokens(req ChatCompletionRequest) int {
173 total := 0
174
175 // Count tokens in messages
176 for _, msg := range req.Messages {
177 total += CountTokens(msg.Content)
178 // Add some overhead for role and formatting
179 total += 4
180 }
181
182 // Add some overhead for the request
183 total += 10
184
185 return total
186}
187
188// ProviderDisplayName returns a human-readable name for a provider
189func ProviderDisplayName(provider Provider) string {
190 switch provider {
191 case ProviderOpenAI:
192 return "OpenAI"
193 case ProviderXAI:
194 return "xAI"
195 case ProviderClaude:
196 return "Claude (Anthropic)"
197 case ProviderGemini:
198 return "Gemini (Google)"
199 case ProviderLocal:
200 return "Local"
iomodof1ddefe2025-07-28 09:02:05 +0400201 case ProviderFake:
202 return "Fake (Testing)"
iomodoa97eb222025-07-26 11:18:17 +0400203 default:
204 return string(provider)
205 }
206}
207
208// IsToolCall checks if a message contains tool calls
209func IsToolCall(msg Message) bool {
210 return len(msg.ToolCalls) > 0
211}
212
213// HasToolCalls checks if any message in the conversation has tool calls
214func HasToolCalls(messages []Message) bool {
215 for _, msg := range messages {
216 if IsToolCall(msg) {
217 return true
218 }
219 }
220 return false
221}