blob: c513c532ab6e9660308b2ab8b776bcec888cfc23 [file] [log] [blame]
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/iomodo/staff/llm"
)
// OpenAIProvider implements the LLMProvider interface for OpenAI
type OpenAIProvider struct {
config llm.Config
client *http.Client
}
// OpenAIRequest represents the OpenAI API request format
type OpenAIRequest struct {
Model string `json:"model"`
Messages []OpenAIMessage `json:"messages"`
MaxTokens *int `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
N *int `json:"n,omitempty"`
Stream *bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
Tools []OpenAITool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
Seed *int64 `json:"seed,omitempty"`
}
// OpenAIMessage represents a message in OpenAI format
type OpenAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
// OpenAIToolCall represents a tool call in OpenAI format
type OpenAIToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
}
// OpenAIFunction represents a function in OpenAI format
type OpenAIFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
}
// OpenAITool represents a tool in OpenAI format
type OpenAITool struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
}
// OpenAIResponseFormat represents response format in OpenAI format
type OpenAIResponseFormat struct {
Type string `json:"type"`
}
// OpenAIResponse represents the OpenAI API response format
type OpenAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Choices []OpenAIChoice `json:"choices"`
Usage OpenAIUsage `json:"usage"`
}
// OpenAIChoice represents a choice in OpenAI response
type OpenAIChoice struct {
Index int `json:"index"`
Message OpenAIMessage `json:"message"`
Logprobs *OpenAILogprobs `json:"logprobs,omitempty"`
FinishReason string `json:"finish_reason"`
Delta *OpenAIMessage `json:"delta,omitempty"`
}
// OpenAILogprobs represents log probabilities in OpenAI format
type OpenAILogprobs struct {
Content []OpenAILogprobContent `json:"content,omitempty"`
}
// OpenAILogprobContent represents log probability content in OpenAI format
type OpenAILogprobContent struct {
Token string `json:"token"`
Logprob float64 `json:"logprob"`
Bytes []int `json:"bytes,omitempty"`
TopLogprobs []OpenAITopLogprob `json:"top_logprobs,omitempty"`
}
// OpenAITopLogprob represents a top log probability in OpenAI format
type OpenAITopLogprob struct {
Token string `json:"token"`
Logprob float64 `json:"logprob"`
Bytes []int `json:"bytes,omitempty"`
}
// OpenAIUsage represents usage information in OpenAI format
type OpenAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// OpenAIEmbeddingRequest represents OpenAI embedding request
type OpenAIEmbeddingRequest struct {
Input interface{} `json:"input"`
Model string `json:"model"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions *int `json:"dimensions,omitempty"`
User string `json:"user,omitempty"`
}
// OpenAIEmbeddingResponse represents OpenAI embedding response
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingData `json:"data"`
Usage OpenAIUsage `json:"usage"`
Model string `json:"model"`
}
// OpenAIEmbeddingData represents embedding data in OpenAI format
type OpenAIEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
// OpenAIError represents an error from OpenAI API
type OpenAIError struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
Param string `json:"param,omitempty"`
} `json:"error"`
}
// NewOpenAIProvider creates a new OpenAI provider
func NewOpenAIProvider(config llm.Config) *OpenAIProvider {
client := &http.Client{
Timeout: config.Timeout,
}
return &OpenAIProvider{
config: config,
client: client,
}
}
// ChatCompletion implements the LLMProvider interface for OpenAI
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
// Convert our request to OpenAI format
openAIReq := p.convertToOpenAIRequest(req)
// Make the API call
resp, err := p.makeOpenAIRequest(ctx, "/chat/completions", openAIReq)
if err != nil {
return nil, fmt.Errorf("OpenAI API request failed: %w", err)
}
// Parse the response
var openAIResp OpenAIResponse
if err := json.Unmarshal(resp, &openAIResp); err != nil {
return nil, fmt.Errorf("failed to parse OpenAI response: %w", err)
}
// Convert back to our format
return p.convertFromOpenAIResponse(openAIResp), nil
}
// CreateEmbeddings implements the LLMProvider interface for OpenAI
func (p *OpenAIProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
// Convert our request to OpenAI format
openAIReq := OpenAIEmbeddingRequest{
Input: req.Input,
Model: req.Model,
EncodingFormat: req.EncodingFormat,
Dimensions: req.Dimensions,
User: req.User,
}
// Make the API call
resp, err := p.makeOpenAIRequest(ctx, "/embeddings", openAIReq)
if err != nil {
return nil, fmt.Errorf("OpenAI embeddings API request failed: %w", err)
}
// Parse the response
var openAIResp OpenAIEmbeddingResponse
if err := json.Unmarshal(resp, &openAIResp); err != nil {
return nil, fmt.Errorf("failed to parse OpenAI embeddings response: %w", err)
}
// Convert back to our format
return p.convertFromOpenAIEmbeddingResponse(openAIResp), nil
}
// Close implements the LLMProvider interface
func (p *OpenAIProvider) Close() error {
// Nothing to clean up for HTTP client
return nil
}
// convertToOpenAIRequest converts our request format to OpenAI format
func (p *OpenAIProvider) convertToOpenAIRequest(req llm.ChatCompletionRequest) OpenAIRequest {
openAIReq := OpenAIRequest{
Model: req.Model,
MaxTokens: req.MaxTokens,
Temperature: req.Temperature,
TopP: req.TopP,
N: req.N,
Stream: req.Stream,
Stop: req.Stop,
PresencePenalty: req.PresencePenalty,
FrequencyPenalty: req.FrequencyPenalty,
LogitBias: req.LogitBias,
User: req.User,
ToolChoice: req.ToolChoice,
Seed: req.Seed,
}
// Convert messages
openAIReq.Messages = make([]OpenAIMessage, len(req.Messages))
for i, msg := range req.Messages {
openAIReq.Messages[i] = OpenAIMessage{
Role: string(msg.Role),
Content: msg.Content,
ToolCallID: msg.ToolCallID,
Name: msg.Name,
}
// Convert tool calls if present
if len(msg.ToolCalls) > 0 {
openAIReq.Messages[i].ToolCalls = make([]OpenAIToolCall, len(msg.ToolCalls))
for j, toolCall := range msg.ToolCalls {
openAIReq.Messages[i].ToolCalls[j] = OpenAIToolCall{
ID: toolCall.ID,
Type: toolCall.Type,
Function: OpenAIFunction{
Name: toolCall.Function.Name,
Description: toolCall.Function.Description,
Parameters: toolCall.Function.Parameters,
},
}
}
}
}
// Convert tools if present
if len(req.Tools) > 0 {
openAIReq.Tools = make([]OpenAITool, len(req.Tools))
for i, tool := range req.Tools {
openAIReq.Tools[i] = OpenAITool{
Type: tool.Type,
Function: OpenAIFunction{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: tool.Function.Parameters,
},
}
}
}
// Convert response format if present
if req.ResponseFormat != nil {
openAIReq.ResponseFormat = &OpenAIResponseFormat{
Type: req.ResponseFormat.Type,
}
}
return openAIReq
}
// convertFromOpenAIResponse converts OpenAI response to our format
func (p *OpenAIProvider) convertFromOpenAIResponse(openAIResp OpenAIResponse) *llm.ChatCompletionResponse {
resp := &llm.ChatCompletionResponse{
ID: openAIResp.ID,
Object: openAIResp.Object,
Created: openAIResp.Created,
Model: openAIResp.Model,
SystemFingerprint: openAIResp.SystemFingerprint,
Provider: llm.ProviderOpenAI,
Usage: llm.Usage{
PromptTokens: openAIResp.Usage.PromptTokens,
CompletionTokens: openAIResp.Usage.CompletionTokens,
TotalTokens: openAIResp.Usage.TotalTokens,
},
}
// Convert choices
resp.Choices = make([]llm.ChatCompletionChoice, len(openAIResp.Choices))
for i, choice := range openAIResp.Choices {
resp.Choices[i] = llm.ChatCompletionChoice{
Index: choice.Index,
FinishReason: choice.FinishReason,
Message: llm.Message{
Role: llm.Role(choice.Message.Role),
Content: choice.Message.Content,
Name: choice.Message.Name,
},
}
// Convert tool calls if present
if len(choice.Message.ToolCalls) > 0 {
resp.Choices[i].Message.ToolCalls = make([]llm.ToolCall, len(choice.Message.ToolCalls))
for j, toolCall := range choice.Message.ToolCalls {
resp.Choices[i].Message.ToolCalls[j] = llm.ToolCall{
ID: toolCall.ID,
Type: toolCall.Type,
Function: llm.Function{
Name: toolCall.Function.Name,
Description: toolCall.Function.Description,
Parameters: toolCall.Function.Parameters,
},
}
}
}
// Convert logprobs if present
if choice.Logprobs != nil {
resp.Choices[i].Logprobs = &llm.Logprobs{
Content: make([]llm.LogprobContent, len(choice.Logprobs.Content)),
}
for j, content := range choice.Logprobs.Content {
resp.Choices[i].Logprobs.Content[j] = llm.LogprobContent{
Token: content.Token,
Logprob: content.Logprob,
Bytes: content.Bytes,
}
if len(content.TopLogprobs) > 0 {
resp.Choices[i].Logprobs.Content[j].TopLogprobs = make([]llm.TopLogprob, len(content.TopLogprobs))
for k, topLogprob := range content.TopLogprobs {
resp.Choices[i].Logprobs.Content[j].TopLogprobs[k] = llm.TopLogprob{
Token: topLogprob.Token,
Logprob: topLogprob.Logprob,
Bytes: topLogprob.Bytes,
}
}
}
}
}
}
return resp
}
// convertFromOpenAIEmbeddingResponse converts OpenAI embedding response to our format
func (p *OpenAIProvider) convertFromOpenAIEmbeddingResponse(openAIResp OpenAIEmbeddingResponse) *llm.EmbeddingResponse {
resp := &llm.EmbeddingResponse{
Object: openAIResp.Object,
Model: openAIResp.Model,
Provider: llm.ProviderOpenAI,
Usage: llm.Usage{
PromptTokens: openAIResp.Usage.PromptTokens,
CompletionTokens: openAIResp.Usage.CompletionTokens,
TotalTokens: openAIResp.Usage.TotalTokens,
},
}
// Convert embedding data
resp.Data = make([]llm.Embedding, len(openAIResp.Data))
for i, data := range openAIResp.Data {
resp.Data[i] = llm.Embedding{
Object: data.Object,
Embedding: data.Embedding,
Index: data.Index,
}
}
return resp
}
// makeOpenAIRequest makes an HTTP request to the OpenAI API
func (p *OpenAIProvider) makeOpenAIRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) {
// Prepare request body
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Create HTTP request
url := p.config.BaseURL + endpoint
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
// Add organization header if present
if org, ok := p.config.ExtraConfig["organization"].(string); ok && org != "" {
req.Header.Set("OpenAI-Organization", org)
}
// Make the request
resp, err := p.client.Do(req)
if err != nil {
return nil, fmt.Errorf("HTTP request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Check for errors
if resp.StatusCode != http.StatusOK {
var openAIErr OpenAIError
if err := json.Unmarshal(body, &openAIErr); err != nil {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
return nil, fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)",
openAIErr.Error.Message, openAIErr.Error.Type, openAIErr.Error.Code)
}
return body, nil
}
// OpenAIFactory implements ProviderFactory for OpenAI
type OpenAIFactory struct{}
// CreateProvider creates a new OpenAI provider
func (f *OpenAIFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
if config.Provider != llm.ProviderOpenAI {
return nil, fmt.Errorf("OpenAI factory cannot create provider: %s", config.Provider)
}
// Validate config
if err := llm.ValidateConfig(config); err != nil {
return nil, fmt.Errorf("invalid OpenAI config: %w", err)
}
// Merge with defaults
config = llm.MergeConfig(config)
return NewOpenAIProvider(config), nil
}
// SupportsProvider checks if this factory supports the given provider
func (f *OpenAIFactory) SupportsProvider(provider llm.Provider) bool {
return provider == llm.ProviderOpenAI
}
// Register OpenAI provider with the default registry
func init() {
llm.RegisterProvider(llm.ProviderOpenAI, &OpenAIFactory{})
}