blob: c513c532ab6e9660308b2ab8b776bcec888cfc23 [file] [log] [blame]
iomodobe473d12025-07-26 11:33:08 +04001package openai
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10
11 "github.com/iomodo/staff/llm"
12)
13
14// OpenAIProvider implements the LLMProvider interface for OpenAI
15type OpenAIProvider struct {
16 config llm.Config
17 client *http.Client
18}
19
20// OpenAIRequest represents the OpenAI API request format
21type OpenAIRequest struct {
22 Model string `json:"model"`
23 Messages []OpenAIMessage `json:"messages"`
24 MaxTokens *int `json:"max_tokens,omitempty"`
25 Temperature *float64 `json:"temperature,omitempty"`
26 TopP *float64 `json:"top_p,omitempty"`
27 N *int `json:"n,omitempty"`
28 Stream *bool `json:"stream,omitempty"`
29 Stop []string `json:"stop,omitempty"`
30 PresencePenalty *float64 `json:"presence_penalty,omitempty"`
31 FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
32 LogitBias map[string]int `json:"logit_bias,omitempty"`
33 User string `json:"user,omitempty"`
34 Tools []OpenAITool `json:"tools,omitempty"`
35 ToolChoice interface{} `json:"tool_choice,omitempty"`
36 ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
37 Seed *int64 `json:"seed,omitempty"`
38}
39
40// OpenAIMessage represents a message in OpenAI format
41type OpenAIMessage struct {
42 Role string `json:"role"`
43 Content string `json:"content"`
44 ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
45 ToolCallID string `json:"tool_call_id,omitempty"`
46 Name string `json:"name,omitempty"`
47}
48
49// OpenAIToolCall represents a tool call in OpenAI format
50type OpenAIToolCall struct {
51 ID string `json:"id"`
52 Type string `json:"type"`
53 Function OpenAIFunction `json:"function"`
54}
55
56// OpenAIFunction represents a function in OpenAI format
57type OpenAIFunction struct {
58 Name string `json:"name"`
59 Description string `json:"description,omitempty"`
60 Parameters map[string]interface{} `json:"parameters,omitempty"`
61}
62
63// OpenAITool represents a tool in OpenAI format
64type OpenAITool struct {
65 Type string `json:"type"`
66 Function OpenAIFunction `json:"function"`
67}
68
69// OpenAIResponseFormat represents response format in OpenAI format
70type OpenAIResponseFormat struct {
71 Type string `json:"type"`
72}
73
74// OpenAIResponse represents the OpenAI API response format
75type OpenAIResponse struct {
76 ID string `json:"id"`
77 Object string `json:"object"`
78 Created int64 `json:"created"`
79 Model string `json:"model"`
80 SystemFingerprint string `json:"system_fingerprint,omitempty"`
81 Choices []OpenAIChoice `json:"choices"`
82 Usage OpenAIUsage `json:"usage"`
83}
84
85// OpenAIChoice represents a choice in OpenAI response
86type OpenAIChoice struct {
87 Index int `json:"index"`
88 Message OpenAIMessage `json:"message"`
89 Logprobs *OpenAILogprobs `json:"logprobs,omitempty"`
90 FinishReason string `json:"finish_reason"`
91 Delta *OpenAIMessage `json:"delta,omitempty"`
92}
93
94// OpenAILogprobs represents log probabilities in OpenAI format
95type OpenAILogprobs struct {
96 Content []OpenAILogprobContent `json:"content,omitempty"`
97}
98
99// OpenAILogprobContent represents log probability content in OpenAI format
100type OpenAILogprobContent struct {
101 Token string `json:"token"`
102 Logprob float64 `json:"logprob"`
103 Bytes []int `json:"bytes,omitempty"`
104 TopLogprobs []OpenAITopLogprob `json:"top_logprobs,omitempty"`
105}
106
107// OpenAITopLogprob represents a top log probability in OpenAI format
108type OpenAITopLogprob struct {
109 Token string `json:"token"`
110 Logprob float64 `json:"logprob"`
111 Bytes []int `json:"bytes,omitempty"`
112}
113
114// OpenAIUsage represents usage information in OpenAI format
115type OpenAIUsage struct {
116 PromptTokens int `json:"prompt_tokens"`
117 CompletionTokens int `json:"completion_tokens"`
118 TotalTokens int `json:"total_tokens"`
119}
120
121// OpenAIEmbeddingRequest represents OpenAI embedding request
122type OpenAIEmbeddingRequest struct {
123 Input interface{} `json:"input"`
124 Model string `json:"model"`
125 EncodingFormat string `json:"encoding_format,omitempty"`
126 Dimensions *int `json:"dimensions,omitempty"`
127 User string `json:"user,omitempty"`
128}
129
130// OpenAIEmbeddingResponse represents OpenAI embedding response
131type OpenAIEmbeddingResponse struct {
132 Object string `json:"object"`
133 Data []OpenAIEmbeddingData `json:"data"`
134 Usage OpenAIUsage `json:"usage"`
135 Model string `json:"model"`
136}
137
138// OpenAIEmbeddingData represents embedding data in OpenAI format
139type OpenAIEmbeddingData struct {
140 Object string `json:"object"`
141 Embedding []float64 `json:"embedding"`
142 Index int `json:"index"`
143}
144
145// OpenAIError represents an error from OpenAI API
146type OpenAIError struct {
147 Error struct {
148 Message string `json:"message"`
149 Type string `json:"type"`
150 Code string `json:"code,omitempty"`
151 Param string `json:"param,omitempty"`
152 } `json:"error"`
153}
154
155// NewOpenAIProvider creates a new OpenAI provider
156func NewOpenAIProvider(config llm.Config) *OpenAIProvider {
157 client := &http.Client{
158 Timeout: config.Timeout,
159 }
160
161 return &OpenAIProvider{
162 config: config,
163 client: client,
164 }
165}
166
167// ChatCompletion implements the LLMProvider interface for OpenAI
168func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
169 // Convert our request to OpenAI format
170 openAIReq := p.convertToOpenAIRequest(req)
171
172 // Make the API call
173 resp, err := p.makeOpenAIRequest(ctx, "/chat/completions", openAIReq)
174 if err != nil {
175 return nil, fmt.Errorf("OpenAI API request failed: %w", err)
176 }
177
178 // Parse the response
179 var openAIResp OpenAIResponse
180 if err := json.Unmarshal(resp, &openAIResp); err != nil {
181 return nil, fmt.Errorf("failed to parse OpenAI response: %w", err)
182 }
183
184 // Convert back to our format
185 return p.convertFromOpenAIResponse(openAIResp), nil
186}
187
188// CreateEmbeddings implements the LLMProvider interface for OpenAI
189func (p *OpenAIProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
190 // Convert our request to OpenAI format
191 openAIReq := OpenAIEmbeddingRequest{
192 Input: req.Input,
193 Model: req.Model,
194 EncodingFormat: req.EncodingFormat,
195 Dimensions: req.Dimensions,
196 User: req.User,
197 }
198
199 // Make the API call
200 resp, err := p.makeOpenAIRequest(ctx, "/embeddings", openAIReq)
201 if err != nil {
202 return nil, fmt.Errorf("OpenAI embeddings API request failed: %w", err)
203 }
204
205 // Parse the response
206 var openAIResp OpenAIEmbeddingResponse
207 if err := json.Unmarshal(resp, &openAIResp); err != nil {
208 return nil, fmt.Errorf("failed to parse OpenAI embeddings response: %w", err)
209 }
210
211 // Convert back to our format
212 return p.convertFromOpenAIEmbeddingResponse(openAIResp), nil
213}
214
215// Close implements the LLMProvider interface
216func (p *OpenAIProvider) Close() error {
217 // Nothing to clean up for HTTP client
218 return nil
219}
220
221// convertToOpenAIRequest converts our request format to OpenAI format
222func (p *OpenAIProvider) convertToOpenAIRequest(req llm.ChatCompletionRequest) OpenAIRequest {
223 openAIReq := OpenAIRequest{
224 Model: req.Model,
225 MaxTokens: req.MaxTokens,
226 Temperature: req.Temperature,
227 TopP: req.TopP,
228 N: req.N,
229 Stream: req.Stream,
230 Stop: req.Stop,
231 PresencePenalty: req.PresencePenalty,
232 FrequencyPenalty: req.FrequencyPenalty,
233 LogitBias: req.LogitBias,
234 User: req.User,
235 ToolChoice: req.ToolChoice,
236 Seed: req.Seed,
237 }
238
239 // Convert messages
240 openAIReq.Messages = make([]OpenAIMessage, len(req.Messages))
241 for i, msg := range req.Messages {
242 openAIReq.Messages[i] = OpenAIMessage{
243 Role: string(msg.Role),
244 Content: msg.Content,
245 ToolCallID: msg.ToolCallID,
246 Name: msg.Name,
247 }
248
249 // Convert tool calls if present
250 if len(msg.ToolCalls) > 0 {
251 openAIReq.Messages[i].ToolCalls = make([]OpenAIToolCall, len(msg.ToolCalls))
252 for j, toolCall := range msg.ToolCalls {
253 openAIReq.Messages[i].ToolCalls[j] = OpenAIToolCall{
254 ID: toolCall.ID,
255 Type: toolCall.Type,
256 Function: OpenAIFunction{
257 Name: toolCall.Function.Name,
258 Description: toolCall.Function.Description,
259 Parameters: toolCall.Function.Parameters,
260 },
261 }
262 }
263 }
264 }
265
266 // Convert tools if present
267 if len(req.Tools) > 0 {
268 openAIReq.Tools = make([]OpenAITool, len(req.Tools))
269 for i, tool := range req.Tools {
270 openAIReq.Tools[i] = OpenAITool{
271 Type: tool.Type,
272 Function: OpenAIFunction{
273 Name: tool.Function.Name,
274 Description: tool.Function.Description,
275 Parameters: tool.Function.Parameters,
276 },
277 }
278 }
279 }
280
281 // Convert response format if present
282 if req.ResponseFormat != nil {
283 openAIReq.ResponseFormat = &OpenAIResponseFormat{
284 Type: req.ResponseFormat.Type,
285 }
286 }
287
288 return openAIReq
289}
290
291// convertFromOpenAIResponse converts OpenAI response to our format
292func (p *OpenAIProvider) convertFromOpenAIResponse(openAIResp OpenAIResponse) *llm.ChatCompletionResponse {
293 resp := &llm.ChatCompletionResponse{
294 ID: openAIResp.ID,
295 Object: openAIResp.Object,
296 Created: openAIResp.Created,
297 Model: openAIResp.Model,
298 SystemFingerprint: openAIResp.SystemFingerprint,
299 Provider: llm.ProviderOpenAI,
300 Usage: llm.Usage{
301 PromptTokens: openAIResp.Usage.PromptTokens,
302 CompletionTokens: openAIResp.Usage.CompletionTokens,
303 TotalTokens: openAIResp.Usage.TotalTokens,
304 },
305 }
306
307 // Convert choices
308 resp.Choices = make([]llm.ChatCompletionChoice, len(openAIResp.Choices))
309 for i, choice := range openAIResp.Choices {
310 resp.Choices[i] = llm.ChatCompletionChoice{
311 Index: choice.Index,
312 FinishReason: choice.FinishReason,
313 Message: llm.Message{
314 Role: llm.Role(choice.Message.Role),
315 Content: choice.Message.Content,
316 Name: choice.Message.Name,
317 },
318 }
319
320 // Convert tool calls if present
321 if len(choice.Message.ToolCalls) > 0 {
322 resp.Choices[i].Message.ToolCalls = make([]llm.ToolCall, len(choice.Message.ToolCalls))
323 for j, toolCall := range choice.Message.ToolCalls {
324 resp.Choices[i].Message.ToolCalls[j] = llm.ToolCall{
325 ID: toolCall.ID,
326 Type: toolCall.Type,
327 Function: llm.Function{
328 Name: toolCall.Function.Name,
329 Description: toolCall.Function.Description,
330 Parameters: toolCall.Function.Parameters,
331 },
332 }
333 }
334 }
335
336 // Convert logprobs if present
337 if choice.Logprobs != nil {
338 resp.Choices[i].Logprobs = &llm.Logprobs{
339 Content: make([]llm.LogprobContent, len(choice.Logprobs.Content)),
340 }
341 for j, content := range choice.Logprobs.Content {
342 resp.Choices[i].Logprobs.Content[j] = llm.LogprobContent{
343 Token: content.Token,
344 Logprob: content.Logprob,
345 Bytes: content.Bytes,
346 }
347 if len(content.TopLogprobs) > 0 {
348 resp.Choices[i].Logprobs.Content[j].TopLogprobs = make([]llm.TopLogprob, len(content.TopLogprobs))
349 for k, topLogprob := range content.TopLogprobs {
350 resp.Choices[i].Logprobs.Content[j].TopLogprobs[k] = llm.TopLogprob{
351 Token: topLogprob.Token,
352 Logprob: topLogprob.Logprob,
353 Bytes: topLogprob.Bytes,
354 }
355 }
356 }
357 }
358 }
359 }
360
361 return resp
362}
363
364// convertFromOpenAIEmbeddingResponse converts OpenAI embedding response to our format
365func (p *OpenAIProvider) convertFromOpenAIEmbeddingResponse(openAIResp OpenAIEmbeddingResponse) *llm.EmbeddingResponse {
366 resp := &llm.EmbeddingResponse{
367 Object: openAIResp.Object,
368 Model: openAIResp.Model,
369 Provider: llm.ProviderOpenAI,
370 Usage: llm.Usage{
371 PromptTokens: openAIResp.Usage.PromptTokens,
372 CompletionTokens: openAIResp.Usage.CompletionTokens,
373 TotalTokens: openAIResp.Usage.TotalTokens,
374 },
375 }
376
377 // Convert embedding data
378 resp.Data = make([]llm.Embedding, len(openAIResp.Data))
379 for i, data := range openAIResp.Data {
380 resp.Data[i] = llm.Embedding{
381 Object: data.Object,
382 Embedding: data.Embedding,
383 Index: data.Index,
384 }
385 }
386
387 return resp
388}
389
390// makeOpenAIRequest makes an HTTP request to the OpenAI API
391func (p *OpenAIProvider) makeOpenAIRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) {
392 // Prepare request body
393 jsonData, err := json.Marshal(payload)
394 if err != nil {
395 return nil, fmt.Errorf("failed to marshal request: %w", err)
396 }
397
398 // Create HTTP request
399 url := p.config.BaseURL + endpoint
400 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
401 if err != nil {
402 return nil, fmt.Errorf("failed to create request: %w", err)
403 }
404
405 // Set headers
406 req.Header.Set("Content-Type", "application/json")
407 req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
408
409 // Add organization header if present
410 if org, ok := p.config.ExtraConfig["organization"].(string); ok && org != "" {
411 req.Header.Set("OpenAI-Organization", org)
412 }
413
414 // Make the request
415 resp, err := p.client.Do(req)
416 if err != nil {
417 return nil, fmt.Errorf("HTTP request failed: %w", err)
418 }
419 defer resp.Body.Close()
420
421 // Read response body
422 body, err := io.ReadAll(resp.Body)
423 if err != nil {
424 return nil, fmt.Errorf("failed to read response body: %w", err)
425 }
426
427 // Check for errors
428 if resp.StatusCode != http.StatusOK {
429 var openAIErr OpenAIError
430 if err := json.Unmarshal(body, &openAIErr); err != nil {
431 return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
432 }
433 return nil, fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)",
434 openAIErr.Error.Message, openAIErr.Error.Type, openAIErr.Error.Code)
435 }
436
437 return body, nil
438}
439
440// OpenAIFactory implements ProviderFactory for OpenAI
441type OpenAIFactory struct{}
442
443// CreateProvider creates a new OpenAI provider
444func (f *OpenAIFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
445 if config.Provider != llm.ProviderOpenAI {
446 return nil, fmt.Errorf("OpenAI factory cannot create provider: %s", config.Provider)
447 }
448
449 // Validate config
450 if err := llm.ValidateConfig(config); err != nil {
451 return nil, fmt.Errorf("invalid OpenAI config: %w", err)
452 }
453
454 // Merge with defaults
455 config = llm.MergeConfig(config)
456
457 return NewOpenAIProvider(config), nil
458}
459
460// SupportsProvider checks if this factory supports the given provider
461func (f *OpenAIFactory) SupportsProvider(provider llm.Provider) bool {
462 return provider == llm.ProviderOpenAI
463}
464
465// Register OpenAI provider with the default registry
466func init() {
467 llm.RegisterProvider(llm.ProviderOpenAI, &OpenAIFactory{})
468}