blob: fad11f68c365b6750bae45268110087104cd29ed [file] [log] [blame]
iomodobe473d12025-07-26 11:33:08 +04001package openai
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10
11 "github.com/iomodo/staff/llm"
12)
13
14// OpenAIProvider implements the LLMProvider interface for OpenAI
15type OpenAIProvider struct {
16 config llm.Config
17 client *http.Client
18}
19
20// OpenAIRequest represents the OpenAI API request format
21type OpenAIRequest struct {
22 Model string `json:"model"`
23 Messages []OpenAIMessage `json:"messages"`
24 MaxTokens *int `json:"max_tokens,omitempty"`
25 Temperature *float64 `json:"temperature,omitempty"`
26 TopP *float64 `json:"top_p,omitempty"`
27 N *int `json:"n,omitempty"`
28 Stream *bool `json:"stream,omitempty"`
29 Stop []string `json:"stop,omitempty"`
30 PresencePenalty *float64 `json:"presence_penalty,omitempty"`
31 FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
32 LogitBias map[string]int `json:"logit_bias,omitempty"`
33 User string `json:"user,omitempty"`
34 Tools []OpenAITool `json:"tools,omitempty"`
35 ToolChoice interface{} `json:"tool_choice,omitempty"`
36 ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
37 Seed *int64 `json:"seed,omitempty"`
38}
39
40// OpenAIMessage represents a message in OpenAI format
41type OpenAIMessage struct {
42 Role string `json:"role"`
43 Content string `json:"content"`
44 ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
45 ToolCallID string `json:"tool_call_id,omitempty"`
46 Name string `json:"name,omitempty"`
47}
48
49// OpenAIToolCall represents a tool call in OpenAI format
50type OpenAIToolCall struct {
51 ID string `json:"id"`
52 Type string `json:"type"`
53 Function OpenAIFunction `json:"function"`
54}
55
56// OpenAIFunction represents a function in OpenAI format
57type OpenAIFunction struct {
58 Name string `json:"name"`
59 Description string `json:"description,omitempty"`
60 Parameters map[string]interface{} `json:"parameters,omitempty"`
61}
62
63// OpenAITool represents a tool in OpenAI format
64type OpenAITool struct {
65 Type string `json:"type"`
66 Function OpenAIFunction `json:"function"`
67}
68
69// OpenAIResponseFormat represents response format in OpenAI format
70type OpenAIResponseFormat struct {
71 Type string `json:"type"`
72}
73
74// OpenAIResponse represents the OpenAI API response format
75type OpenAIResponse struct {
76 ID string `json:"id"`
77 Object string `json:"object"`
78 Created int64 `json:"created"`
79 Model string `json:"model"`
80 SystemFingerprint string `json:"system_fingerprint,omitempty"`
81 Choices []OpenAIChoice `json:"choices"`
82 Usage OpenAIUsage `json:"usage"`
83}
84
85// OpenAIChoice represents a choice in OpenAI response
86type OpenAIChoice struct {
87 Index int `json:"index"`
88 Message OpenAIMessage `json:"message"`
89 Logprobs *OpenAILogprobs `json:"logprobs,omitempty"`
90 FinishReason string `json:"finish_reason"`
91 Delta *OpenAIMessage `json:"delta,omitempty"`
92}
93
94// OpenAILogprobs represents log probabilities in OpenAI format
95type OpenAILogprobs struct {
96 Content []OpenAILogprobContent `json:"content,omitempty"`
97}
98
99// OpenAILogprobContent represents log probability content in OpenAI format
100type OpenAILogprobContent struct {
101 Token string `json:"token"`
102 Logprob float64 `json:"logprob"`
103 Bytes []int `json:"bytes,omitempty"`
104 TopLogprobs []OpenAITopLogprob `json:"top_logprobs,omitempty"`
105}
106
107// OpenAITopLogprob represents a top log probability in OpenAI format
108type OpenAITopLogprob struct {
109 Token string `json:"token"`
110 Logprob float64 `json:"logprob"`
111 Bytes []int `json:"bytes,omitempty"`
112}
113
114// OpenAIUsage represents usage information in OpenAI format
115type OpenAIUsage struct {
116 PromptTokens int `json:"prompt_tokens"`
117 CompletionTokens int `json:"completion_tokens"`
118 TotalTokens int `json:"total_tokens"`
119}
120
121// OpenAIEmbeddingRequest represents OpenAI embedding request
122type OpenAIEmbeddingRequest struct {
123 Input interface{} `json:"input"`
124 Model string `json:"model"`
125 EncodingFormat string `json:"encoding_format,omitempty"`
126 Dimensions *int `json:"dimensions,omitempty"`
127 User string `json:"user,omitempty"`
128}
129
130// OpenAIEmbeddingResponse represents OpenAI embedding response
131type OpenAIEmbeddingResponse struct {
132 Object string `json:"object"`
133 Data []OpenAIEmbeddingData `json:"data"`
134 Usage OpenAIUsage `json:"usage"`
135 Model string `json:"model"`
136}
137
138// OpenAIEmbeddingData represents embedding data in OpenAI format
139type OpenAIEmbeddingData struct {
140 Object string `json:"object"`
141 Embedding []float64 `json:"embedding"`
142 Index int `json:"index"`
143}
144
145// OpenAIError represents an error from OpenAI API
146type OpenAIError struct {
147 Error struct {
148 Message string `json:"message"`
149 Type string `json:"type"`
150 Code string `json:"code,omitempty"`
151 Param string `json:"param,omitempty"`
152 } `json:"error"`
153}
154
iomodo75542322025-07-30 19:27:48 +0400155func New(config llm.Config) *OpenAIProvider {
iomodobe473d12025-07-26 11:33:08 +0400156 client := &http.Client{
157 Timeout: config.Timeout,
158 }
159
160 return &OpenAIProvider{
161 config: config,
162 client: client,
163 }
164}
165
166// ChatCompletion implements the LLMProvider interface for OpenAI
167func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
168 // Convert our request to OpenAI format
169 openAIReq := p.convertToOpenAIRequest(req)
170
171 // Make the API call
172 resp, err := p.makeOpenAIRequest(ctx, "/chat/completions", openAIReq)
173 if err != nil {
174 return nil, fmt.Errorf("OpenAI API request failed: %w", err)
175 }
176
177 // Parse the response
178 var openAIResp OpenAIResponse
179 if err := json.Unmarshal(resp, &openAIResp); err != nil {
180 return nil, fmt.Errorf("failed to parse OpenAI response: %w", err)
181 }
182
183 // Convert back to our format
184 return p.convertFromOpenAIResponse(openAIResp), nil
185}
186
187// CreateEmbeddings implements the LLMProvider interface for OpenAI
188func (p *OpenAIProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
189 // Convert our request to OpenAI format
190 openAIReq := OpenAIEmbeddingRequest{
191 Input: req.Input,
192 Model: req.Model,
193 EncodingFormat: req.EncodingFormat,
194 Dimensions: req.Dimensions,
195 User: req.User,
196 }
197
198 // Make the API call
199 resp, err := p.makeOpenAIRequest(ctx, "/embeddings", openAIReq)
200 if err != nil {
201 return nil, fmt.Errorf("OpenAI embeddings API request failed: %w", err)
202 }
203
204 // Parse the response
205 var openAIResp OpenAIEmbeddingResponse
206 if err := json.Unmarshal(resp, &openAIResp); err != nil {
207 return nil, fmt.Errorf("failed to parse OpenAI embeddings response: %w", err)
208 }
209
210 // Convert back to our format
211 return p.convertFromOpenAIEmbeddingResponse(openAIResp), nil
212}
213
214// Close implements the LLMProvider interface
215func (p *OpenAIProvider) Close() error {
216 // Nothing to clean up for HTTP client
217 return nil
218}
219
220// convertToOpenAIRequest converts our request format to OpenAI format
221func (p *OpenAIProvider) convertToOpenAIRequest(req llm.ChatCompletionRequest) OpenAIRequest {
222 openAIReq := OpenAIRequest{
223 Model: req.Model,
224 MaxTokens: req.MaxTokens,
225 Temperature: req.Temperature,
226 TopP: req.TopP,
227 N: req.N,
228 Stream: req.Stream,
229 Stop: req.Stop,
230 PresencePenalty: req.PresencePenalty,
231 FrequencyPenalty: req.FrequencyPenalty,
232 LogitBias: req.LogitBias,
233 User: req.User,
234 ToolChoice: req.ToolChoice,
235 Seed: req.Seed,
236 }
237
238 // Convert messages
239 openAIReq.Messages = make([]OpenAIMessage, len(req.Messages))
240 for i, msg := range req.Messages {
241 openAIReq.Messages[i] = OpenAIMessage{
242 Role: string(msg.Role),
243 Content: msg.Content,
244 ToolCallID: msg.ToolCallID,
245 Name: msg.Name,
246 }
247
248 // Convert tool calls if present
249 if len(msg.ToolCalls) > 0 {
250 openAIReq.Messages[i].ToolCalls = make([]OpenAIToolCall, len(msg.ToolCalls))
251 for j, toolCall := range msg.ToolCalls {
252 openAIReq.Messages[i].ToolCalls[j] = OpenAIToolCall{
253 ID: toolCall.ID,
254 Type: toolCall.Type,
255 Function: OpenAIFunction{
256 Name: toolCall.Function.Name,
257 Description: toolCall.Function.Description,
258 Parameters: toolCall.Function.Parameters,
259 },
260 }
261 }
262 }
263 }
264
265 // Convert tools if present
266 if len(req.Tools) > 0 {
267 openAIReq.Tools = make([]OpenAITool, len(req.Tools))
268 for i, tool := range req.Tools {
269 openAIReq.Tools[i] = OpenAITool{
270 Type: tool.Type,
271 Function: OpenAIFunction{
272 Name: tool.Function.Name,
273 Description: tool.Function.Description,
274 Parameters: tool.Function.Parameters,
275 },
276 }
277 }
278 }
279
280 // Convert response format if present
281 if req.ResponseFormat != nil {
282 openAIReq.ResponseFormat = &OpenAIResponseFormat{
283 Type: req.ResponseFormat.Type,
284 }
285 }
286
287 return openAIReq
288}
289
290// convertFromOpenAIResponse converts OpenAI response to our format
291func (p *OpenAIProvider) convertFromOpenAIResponse(openAIResp OpenAIResponse) *llm.ChatCompletionResponse {
292 resp := &llm.ChatCompletionResponse{
293 ID: openAIResp.ID,
294 Object: openAIResp.Object,
295 Created: openAIResp.Created,
296 Model: openAIResp.Model,
297 SystemFingerprint: openAIResp.SystemFingerprint,
298 Provider: llm.ProviderOpenAI,
299 Usage: llm.Usage{
300 PromptTokens: openAIResp.Usage.PromptTokens,
301 CompletionTokens: openAIResp.Usage.CompletionTokens,
302 TotalTokens: openAIResp.Usage.TotalTokens,
303 },
304 }
305
306 // Convert choices
307 resp.Choices = make([]llm.ChatCompletionChoice, len(openAIResp.Choices))
308 for i, choice := range openAIResp.Choices {
309 resp.Choices[i] = llm.ChatCompletionChoice{
310 Index: choice.Index,
311 FinishReason: choice.FinishReason,
312 Message: llm.Message{
313 Role: llm.Role(choice.Message.Role),
314 Content: choice.Message.Content,
315 Name: choice.Message.Name,
316 },
317 }
318
319 // Convert tool calls if present
320 if len(choice.Message.ToolCalls) > 0 {
321 resp.Choices[i].Message.ToolCalls = make([]llm.ToolCall, len(choice.Message.ToolCalls))
322 for j, toolCall := range choice.Message.ToolCalls {
323 resp.Choices[i].Message.ToolCalls[j] = llm.ToolCall{
324 ID: toolCall.ID,
325 Type: toolCall.Type,
326 Function: llm.Function{
327 Name: toolCall.Function.Name,
328 Description: toolCall.Function.Description,
329 Parameters: toolCall.Function.Parameters,
330 },
331 }
332 }
333 }
334
335 // Convert logprobs if present
336 if choice.Logprobs != nil {
337 resp.Choices[i].Logprobs = &llm.Logprobs{
338 Content: make([]llm.LogprobContent, len(choice.Logprobs.Content)),
339 }
340 for j, content := range choice.Logprobs.Content {
341 resp.Choices[i].Logprobs.Content[j] = llm.LogprobContent{
342 Token: content.Token,
343 Logprob: content.Logprob,
344 Bytes: content.Bytes,
345 }
346 if len(content.TopLogprobs) > 0 {
347 resp.Choices[i].Logprobs.Content[j].TopLogprobs = make([]llm.TopLogprob, len(content.TopLogprobs))
348 for k, topLogprob := range content.TopLogprobs {
349 resp.Choices[i].Logprobs.Content[j].TopLogprobs[k] = llm.TopLogprob{
350 Token: topLogprob.Token,
351 Logprob: topLogprob.Logprob,
352 Bytes: topLogprob.Bytes,
353 }
354 }
355 }
356 }
357 }
358 }
359
360 return resp
361}
362
363// convertFromOpenAIEmbeddingResponse converts OpenAI embedding response to our format
364func (p *OpenAIProvider) convertFromOpenAIEmbeddingResponse(openAIResp OpenAIEmbeddingResponse) *llm.EmbeddingResponse {
365 resp := &llm.EmbeddingResponse{
366 Object: openAIResp.Object,
367 Model: openAIResp.Model,
368 Provider: llm.ProviderOpenAI,
369 Usage: llm.Usage{
370 PromptTokens: openAIResp.Usage.PromptTokens,
371 CompletionTokens: openAIResp.Usage.CompletionTokens,
372 TotalTokens: openAIResp.Usage.TotalTokens,
373 },
374 }
375
376 // Convert embedding data
377 resp.Data = make([]llm.Embedding, len(openAIResp.Data))
378 for i, data := range openAIResp.Data {
379 resp.Data[i] = llm.Embedding{
380 Object: data.Object,
381 Embedding: data.Embedding,
382 Index: data.Index,
383 }
384 }
385
386 return resp
387}
388
389// makeOpenAIRequest makes an HTTP request to the OpenAI API
390func (p *OpenAIProvider) makeOpenAIRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) {
391 // Prepare request body
392 jsonData, err := json.Marshal(payload)
393 if err != nil {
394 return nil, fmt.Errorf("failed to marshal request: %w", err)
395 }
396
397 // Create HTTP request
398 url := p.config.BaseURL + endpoint
399 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
400 if err != nil {
401 return nil, fmt.Errorf("failed to create request: %w", err)
402 }
403
404 // Set headers
405 req.Header.Set("Content-Type", "application/json")
406 req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
407
408 // Add organization header if present
409 if org, ok := p.config.ExtraConfig["organization"].(string); ok && org != "" {
410 req.Header.Set("OpenAI-Organization", org)
411 }
412
413 // Make the request
414 resp, err := p.client.Do(req)
415 if err != nil {
416 return nil, fmt.Errorf("HTTP request failed: %w", err)
417 }
418 defer resp.Body.Close()
419
420 // Read response body
421 body, err := io.ReadAll(resp.Body)
422 if err != nil {
423 return nil, fmt.Errorf("failed to read response body: %w", err)
424 }
425
426 // Check for errors
427 if resp.StatusCode != http.StatusOK {
428 var openAIErr OpenAIError
429 if err := json.Unmarshal(body, &openAIErr); err != nil {
430 return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
431 }
432 return nil, fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)",
433 openAIErr.Error.Message, openAIErr.Error.Type, openAIErr.Error.Code)
434 }
435
436 return body, nil
437}