| package openai |
| |
| import ( |
| "bytes" |
| "context" |
| "encoding/json" |
| "fmt" |
| "io" |
| "net/http" |
| |
| "github.com/iomodo/staff/llm" |
| ) |
| |
| // OpenAIProvider implements the LLMProvider interface for OpenAI |
| type OpenAIProvider struct { |
| config llm.Config |
| client *http.Client |
| } |
| |
| // OpenAIRequest represents the OpenAI API request format |
| type OpenAIRequest struct { |
| Model string `json:"model"` |
| Messages []OpenAIMessage `json:"messages"` |
| MaxTokens *int `json:"max_tokens,omitempty"` |
| Temperature *float64 `json:"temperature,omitempty"` |
| TopP *float64 `json:"top_p,omitempty"` |
| N *int `json:"n,omitempty"` |
| Stream *bool `json:"stream,omitempty"` |
| Stop []string `json:"stop,omitempty"` |
| PresencePenalty *float64 `json:"presence_penalty,omitempty"` |
| FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` |
| LogitBias map[string]int `json:"logit_bias,omitempty"` |
| User string `json:"user,omitempty"` |
| Tools []OpenAITool `json:"tools,omitempty"` |
| ToolChoice interface{} `json:"tool_choice,omitempty"` |
| ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"` |
| Seed *int64 `json:"seed,omitempty"` |
| } |
| |
| // OpenAIMessage represents a message in OpenAI format |
| type OpenAIMessage struct { |
| Role string `json:"role"` |
| Content string `json:"content"` |
| ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"` |
| ToolCallID string `json:"tool_call_id,omitempty"` |
| Name string `json:"name,omitempty"` |
| } |
| |
| // OpenAIToolCall represents a tool call in OpenAI format |
| type OpenAIToolCall struct { |
| ID string `json:"id"` |
| Type string `json:"type"` |
| Function OpenAIFunction `json:"function"` |
| } |
| |
| // OpenAIFunction represents a function in OpenAI format |
| type OpenAIFunction struct { |
| Name string `json:"name"` |
| Description string `json:"description,omitempty"` |
| Parameters map[string]interface{} `json:"parameters,omitempty"` |
| } |
| |
| // OpenAITool represents a tool in OpenAI format |
| type OpenAITool struct { |
| Type string `json:"type"` |
| Function OpenAIFunction `json:"function"` |
| } |
| |
| // OpenAIResponseFormat represents response format in OpenAI format |
| type OpenAIResponseFormat struct { |
| Type string `json:"type"` |
| } |
| |
| // OpenAIResponse represents the OpenAI API response format |
| type OpenAIResponse struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Created int64 `json:"created"` |
| Model string `json:"model"` |
| SystemFingerprint string `json:"system_fingerprint,omitempty"` |
| Choices []OpenAIChoice `json:"choices"` |
| Usage OpenAIUsage `json:"usage"` |
| } |
| |
| // OpenAIChoice represents a choice in OpenAI response |
| type OpenAIChoice struct { |
| Index int `json:"index"` |
| Message OpenAIMessage `json:"message"` |
| Logprobs *OpenAILogprobs `json:"logprobs,omitempty"` |
| FinishReason string `json:"finish_reason"` |
| Delta *OpenAIMessage `json:"delta,omitempty"` |
| } |
| |
| // OpenAILogprobs represents log probabilities in OpenAI format |
| type OpenAILogprobs struct { |
| Content []OpenAILogprobContent `json:"content,omitempty"` |
| } |
| |
| // OpenAILogprobContent represents log probability content in OpenAI format |
| type OpenAILogprobContent struct { |
| Token string `json:"token"` |
| Logprob float64 `json:"logprob"` |
| Bytes []int `json:"bytes,omitempty"` |
| TopLogprobs []OpenAITopLogprob `json:"top_logprobs,omitempty"` |
| } |
| |
| // OpenAITopLogprob represents a top log probability in OpenAI format |
| type OpenAITopLogprob struct { |
| Token string `json:"token"` |
| Logprob float64 `json:"logprob"` |
| Bytes []int `json:"bytes,omitempty"` |
| } |
| |
| // OpenAIUsage represents usage information in OpenAI format |
| type OpenAIUsage struct { |
| PromptTokens int `json:"prompt_tokens"` |
| CompletionTokens int `json:"completion_tokens"` |
| TotalTokens int `json:"total_tokens"` |
| } |
| |
| // OpenAIEmbeddingRequest represents OpenAI embedding request |
| type OpenAIEmbeddingRequest struct { |
| Input interface{} `json:"input"` |
| Model string `json:"model"` |
| EncodingFormat string `json:"encoding_format,omitempty"` |
| Dimensions *int `json:"dimensions,omitempty"` |
| User string `json:"user,omitempty"` |
| } |
| |
| // OpenAIEmbeddingResponse represents OpenAI embedding response |
| type OpenAIEmbeddingResponse struct { |
| Object string `json:"object"` |
| Data []OpenAIEmbeddingData `json:"data"` |
| Usage OpenAIUsage `json:"usage"` |
| Model string `json:"model"` |
| } |
| |
| // OpenAIEmbeddingData represents embedding data in OpenAI format |
| type OpenAIEmbeddingData struct { |
| Object string `json:"object"` |
| Embedding []float64 `json:"embedding"` |
| Index int `json:"index"` |
| } |
| |
| // OpenAIError represents an error from OpenAI API |
| type OpenAIError struct { |
| Error struct { |
| Message string `json:"message"` |
| Type string `json:"type"` |
| Code string `json:"code,omitempty"` |
| Param string `json:"param,omitempty"` |
| } `json:"error"` |
| } |
| |
| func New(config llm.Config) *OpenAIProvider { |
| client := &http.Client{ |
| Timeout: config.Timeout, |
| } |
| |
| return &OpenAIProvider{ |
| config: config, |
| client: client, |
| } |
| } |
| |
| // ChatCompletion implements the LLMProvider interface for OpenAI |
| func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) { |
| // Convert our request to OpenAI format |
| openAIReq := p.convertToOpenAIRequest(req) |
| |
| // Make the API call |
| resp, err := p.makeOpenAIRequest(ctx, "/chat/completions", openAIReq) |
| if err != nil { |
| return nil, fmt.Errorf("OpenAI API request failed: %w", err) |
| } |
| |
| // Parse the response |
| var openAIResp OpenAIResponse |
| if err := json.Unmarshal(resp, &openAIResp); err != nil { |
| return nil, fmt.Errorf("failed to parse OpenAI response: %w", err) |
| } |
| |
| // Convert back to our format |
| return p.convertFromOpenAIResponse(openAIResp), nil |
| } |
| |
| // CreateEmbeddings implements the LLMProvider interface for OpenAI |
| func (p *OpenAIProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) { |
| // Convert our request to OpenAI format |
| openAIReq := OpenAIEmbeddingRequest{ |
| Input: req.Input, |
| Model: req.Model, |
| EncodingFormat: req.EncodingFormat, |
| Dimensions: req.Dimensions, |
| User: req.User, |
| } |
| |
| // Make the API call |
| resp, err := p.makeOpenAIRequest(ctx, "/embeddings", openAIReq) |
| if err != nil { |
| return nil, fmt.Errorf("OpenAI embeddings API request failed: %w", err) |
| } |
| |
| // Parse the response |
| var openAIResp OpenAIEmbeddingResponse |
| if err := json.Unmarshal(resp, &openAIResp); err != nil { |
| return nil, fmt.Errorf("failed to parse OpenAI embeddings response: %w", err) |
| } |
| |
| // Convert back to our format |
| return p.convertFromOpenAIEmbeddingResponse(openAIResp), nil |
| } |
| |
| // Close implements the LLMProvider interface |
| func (p *OpenAIProvider) Close() error { |
| // Nothing to clean up for HTTP client |
| return nil |
| } |
| |
| // convertToOpenAIRequest converts our request format to OpenAI format |
| func (p *OpenAIProvider) convertToOpenAIRequest(req llm.ChatCompletionRequest) OpenAIRequest { |
| openAIReq := OpenAIRequest{ |
| Model: req.Model, |
| MaxTokens: req.MaxTokens, |
| Temperature: req.Temperature, |
| TopP: req.TopP, |
| N: req.N, |
| Stream: req.Stream, |
| Stop: req.Stop, |
| PresencePenalty: req.PresencePenalty, |
| FrequencyPenalty: req.FrequencyPenalty, |
| LogitBias: req.LogitBias, |
| User: req.User, |
| ToolChoice: req.ToolChoice, |
| Seed: req.Seed, |
| } |
| |
| // Convert messages |
| openAIReq.Messages = make([]OpenAIMessage, len(req.Messages)) |
| for i, msg := range req.Messages { |
| openAIReq.Messages[i] = OpenAIMessage{ |
| Role: string(msg.Role), |
| Content: msg.Content, |
| ToolCallID: msg.ToolCallID, |
| Name: msg.Name, |
| } |
| |
| // Convert tool calls if present |
| if len(msg.ToolCalls) > 0 { |
| openAIReq.Messages[i].ToolCalls = make([]OpenAIToolCall, len(msg.ToolCalls)) |
| for j, toolCall := range msg.ToolCalls { |
| openAIReq.Messages[i].ToolCalls[j] = OpenAIToolCall{ |
| ID: toolCall.ID, |
| Type: toolCall.Type, |
| Function: OpenAIFunction{ |
| Name: toolCall.Function.Name, |
| Description: toolCall.Function.Description, |
| Parameters: toolCall.Function.Parameters, |
| }, |
| } |
| } |
| } |
| } |
| |
| // Convert tools if present |
| if len(req.Tools) > 0 { |
| openAIReq.Tools = make([]OpenAITool, len(req.Tools)) |
| for i, tool := range req.Tools { |
| openAIReq.Tools[i] = OpenAITool{ |
| Type: tool.Type, |
| Function: OpenAIFunction{ |
| Name: tool.Function.Name, |
| Description: tool.Function.Description, |
| Parameters: tool.Function.Parameters, |
| }, |
| } |
| } |
| } |
| |
| // Convert response format if present |
| if req.ResponseFormat != nil { |
| openAIReq.ResponseFormat = &OpenAIResponseFormat{ |
| Type: req.ResponseFormat.Type, |
| } |
| } |
| |
| return openAIReq |
| } |
| |
| // convertFromOpenAIResponse converts OpenAI response to our format |
| func (p *OpenAIProvider) convertFromOpenAIResponse(openAIResp OpenAIResponse) *llm.ChatCompletionResponse { |
| resp := &llm.ChatCompletionResponse{ |
| ID: openAIResp.ID, |
| Object: openAIResp.Object, |
| Created: openAIResp.Created, |
| Model: openAIResp.Model, |
| SystemFingerprint: openAIResp.SystemFingerprint, |
| Provider: llm.ProviderOpenAI, |
| Usage: llm.Usage{ |
| PromptTokens: openAIResp.Usage.PromptTokens, |
| CompletionTokens: openAIResp.Usage.CompletionTokens, |
| TotalTokens: openAIResp.Usage.TotalTokens, |
| }, |
| } |
| |
| // Convert choices |
| resp.Choices = make([]llm.ChatCompletionChoice, len(openAIResp.Choices)) |
| for i, choice := range openAIResp.Choices { |
| resp.Choices[i] = llm.ChatCompletionChoice{ |
| Index: choice.Index, |
| FinishReason: choice.FinishReason, |
| Message: llm.Message{ |
| Role: llm.Role(choice.Message.Role), |
| Content: choice.Message.Content, |
| Name: choice.Message.Name, |
| }, |
| } |
| |
| // Convert tool calls if present |
| if len(choice.Message.ToolCalls) > 0 { |
| resp.Choices[i].Message.ToolCalls = make([]llm.ToolCall, len(choice.Message.ToolCalls)) |
| for j, toolCall := range choice.Message.ToolCalls { |
| resp.Choices[i].Message.ToolCalls[j] = llm.ToolCall{ |
| ID: toolCall.ID, |
| Type: toolCall.Type, |
| Function: llm.Function{ |
| Name: toolCall.Function.Name, |
| Description: toolCall.Function.Description, |
| Parameters: toolCall.Function.Parameters, |
| }, |
| } |
| } |
| } |
| |
| // Convert logprobs if present |
| if choice.Logprobs != nil { |
| resp.Choices[i].Logprobs = &llm.Logprobs{ |
| Content: make([]llm.LogprobContent, len(choice.Logprobs.Content)), |
| } |
| for j, content := range choice.Logprobs.Content { |
| resp.Choices[i].Logprobs.Content[j] = llm.LogprobContent{ |
| Token: content.Token, |
| Logprob: content.Logprob, |
| Bytes: content.Bytes, |
| } |
| if len(content.TopLogprobs) > 0 { |
| resp.Choices[i].Logprobs.Content[j].TopLogprobs = make([]llm.TopLogprob, len(content.TopLogprobs)) |
| for k, topLogprob := range content.TopLogprobs { |
| resp.Choices[i].Logprobs.Content[j].TopLogprobs[k] = llm.TopLogprob{ |
| Token: topLogprob.Token, |
| Logprob: topLogprob.Logprob, |
| Bytes: topLogprob.Bytes, |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| return resp |
| } |
| |
| // convertFromOpenAIEmbeddingResponse converts OpenAI embedding response to our format |
| func (p *OpenAIProvider) convertFromOpenAIEmbeddingResponse(openAIResp OpenAIEmbeddingResponse) *llm.EmbeddingResponse { |
| resp := &llm.EmbeddingResponse{ |
| Object: openAIResp.Object, |
| Model: openAIResp.Model, |
| Provider: llm.ProviderOpenAI, |
| Usage: llm.Usage{ |
| PromptTokens: openAIResp.Usage.PromptTokens, |
| CompletionTokens: openAIResp.Usage.CompletionTokens, |
| TotalTokens: openAIResp.Usage.TotalTokens, |
| }, |
| } |
| |
| // Convert embedding data |
| resp.Data = make([]llm.Embedding, len(openAIResp.Data)) |
| for i, data := range openAIResp.Data { |
| resp.Data[i] = llm.Embedding{ |
| Object: data.Object, |
| Embedding: data.Embedding, |
| Index: data.Index, |
| } |
| } |
| |
| return resp |
| } |
| |
| // makeOpenAIRequest makes an HTTP request to the OpenAI API |
| func (p *OpenAIProvider) makeOpenAIRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) { |
| // Prepare request body |
| jsonData, err := json.Marshal(payload) |
| if err != nil { |
| return nil, fmt.Errorf("failed to marshal request: %w", err) |
| } |
| |
| // Create HTTP request |
| url := p.config.BaseURL + endpoint |
| req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) |
| if err != nil { |
| return nil, fmt.Errorf("failed to create request: %w", err) |
| } |
| |
| // Set headers |
| req.Header.Set("Content-Type", "application/json") |
| req.Header.Set("Authorization", "Bearer "+p.config.APIKey) |
| |
| // Add organization header if present |
| if org, ok := p.config.ExtraConfig["organization"].(string); ok && org != "" { |
| req.Header.Set("OpenAI-Organization", org) |
| } |
| |
| // Make the request |
| resp, err := p.client.Do(req) |
| if err != nil { |
| return nil, fmt.Errorf("HTTP request failed: %w", err) |
| } |
| defer resp.Body.Close() |
| |
| // Read response body |
| body, err := io.ReadAll(resp.Body) |
| if err != nil { |
| return nil, fmt.Errorf("failed to read response body: %w", err) |
| } |
| |
| // Check for errors |
| if resp.StatusCode != http.StatusOK { |
| var openAIErr OpenAIError |
| if err := json.Unmarshal(body, &openAIErr); err != nil { |
| return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) |
| } |
| return nil, fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)", |
| openAIErr.Error.Message, openAIErr.Error.Type, openAIErr.Error.Code) |
| } |
| |
| return body, nil |
| } |