Add OpenAI implementation
Change-Id: Iea3191cf538959002e6ae095857a7aa6126b3e2f
diff --git a/server/llm/openai/openai.go b/server/llm/openai/openai.go
new file mode 100644
index 0000000..c513c53
--- /dev/null
+++ b/server/llm/openai/openai.go
@@ -0,0 +1,468 @@
+package openai
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/iomodo/staff/llm"
+)
+
+// OpenAIProvider implements the LLMProvider interface for OpenAI
+type OpenAIProvider struct {
+ config llm.Config
+ client *http.Client
+}
+
+// OpenAIRequest represents the OpenAI API request format
+type OpenAIRequest struct {
+ Model string `json:"model"`
+ Messages []OpenAIMessage `json:"messages"`
+ MaxTokens *int `json:"max_tokens,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ N *int `json:"n,omitempty"`
+ Stream *bool `json:"stream,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+ PresencePenalty *float64 `json:"presence_penalty,omitempty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
+ LogitBias map[string]int `json:"logit_bias,omitempty"`
+ User string `json:"user,omitempty"`
+ Tools []OpenAITool `json:"tools,omitempty"`
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
+ ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
+ Seed *int64 `json:"seed,omitempty"`
+}
+
+// OpenAIMessage represents a message in OpenAI format
+type OpenAIMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+ Name string `json:"name,omitempty"`
+}
+
+// OpenAIToolCall represents a tool call in OpenAI format
+type OpenAIToolCall struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Function OpenAIFunction `json:"function"`
+}
+
+// OpenAIFunction represents a function in OpenAI format
+type OpenAIFunction struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Parameters map[string]interface{} `json:"parameters,omitempty"`
+}
+
+// OpenAITool represents a tool in OpenAI format
+type OpenAITool struct {
+ Type string `json:"type"`
+ Function OpenAIFunction `json:"function"`
+}
+
+// OpenAIResponseFormat represents response format in OpenAI format
+type OpenAIResponseFormat struct {
+ Type string `json:"type"`
+}
+
+// OpenAIResponse represents the OpenAI API response format
+type OpenAIResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint,omitempty"`
+ Choices []OpenAIChoice `json:"choices"`
+ Usage OpenAIUsage `json:"usage"`
+}
+
+// OpenAIChoice represents a choice in OpenAI response
+type OpenAIChoice struct {
+ Index int `json:"index"`
+ Message OpenAIMessage `json:"message"`
+ Logprobs *OpenAILogprobs `json:"logprobs,omitempty"`
+ FinishReason string `json:"finish_reason"`
+ Delta *OpenAIMessage `json:"delta,omitempty"`
+}
+
+// OpenAILogprobs represents log probabilities in OpenAI format
+type OpenAILogprobs struct {
+ Content []OpenAILogprobContent `json:"content,omitempty"`
+}
+
+// OpenAILogprobContent represents log probability content in OpenAI format
+type OpenAILogprobContent struct {
+ Token string `json:"token"`
+ Logprob float64 `json:"logprob"`
+ Bytes []int `json:"bytes,omitempty"`
+ TopLogprobs []OpenAITopLogprob `json:"top_logprobs,omitempty"`
+}
+
+// OpenAITopLogprob represents a top log probability in OpenAI format
+type OpenAITopLogprob struct {
+ Token string `json:"token"`
+ Logprob float64 `json:"logprob"`
+ Bytes []int `json:"bytes,omitempty"`
+}
+
+// OpenAIUsage represents usage information in OpenAI format
+type OpenAIUsage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+// OpenAIEmbeddingRequest represents OpenAI embedding request
+type OpenAIEmbeddingRequest struct {
+ Input interface{} `json:"input"`
+ Model string `json:"model"`
+ EncodingFormat string `json:"encoding_format,omitempty"`
+ Dimensions *int `json:"dimensions,omitempty"`
+ User string `json:"user,omitempty"`
+}
+
+// OpenAIEmbeddingResponse represents OpenAI embedding response
+type OpenAIEmbeddingResponse struct {
+ Object string `json:"object"`
+ Data []OpenAIEmbeddingData `json:"data"`
+ Usage OpenAIUsage `json:"usage"`
+ Model string `json:"model"`
+}
+
+// OpenAIEmbeddingData represents embedding data in OpenAI format
+type OpenAIEmbeddingData struct {
+ Object string `json:"object"`
+ Embedding []float64 `json:"embedding"`
+ Index int `json:"index"`
+}
+
+// OpenAIError represents an error from OpenAI API
+type OpenAIError struct {
+ Error struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Code string `json:"code,omitempty"`
+ Param string `json:"param,omitempty"`
+ } `json:"error"`
+}
+
+// NewOpenAIProvider creates a new OpenAI provider
+func NewOpenAIProvider(config llm.Config) *OpenAIProvider {
+ client := &http.Client{
+ Timeout: config.Timeout,
+ }
+
+ return &OpenAIProvider{
+ config: config,
+ client: client,
+ }
+}
+
+// ChatCompletion implements the LLMProvider interface for OpenAI
+func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
+ // Convert our request to OpenAI format
+ openAIReq := p.convertToOpenAIRequest(req)
+
+ // Make the API call
+ resp, err := p.makeOpenAIRequest(ctx, "/chat/completions", openAIReq)
+ if err != nil {
+ return nil, fmt.Errorf("OpenAI API request failed: %w", err)
+ }
+
+ // Parse the response
+ var openAIResp OpenAIResponse
+ if err := json.Unmarshal(resp, &openAIResp); err != nil {
+ return nil, fmt.Errorf("failed to parse OpenAI response: %w", err)
+ }
+
+ // Convert back to our format
+ return p.convertFromOpenAIResponse(openAIResp), nil
+}
+
+// CreateEmbeddings implements the LLMProvider interface for OpenAI
+func (p *OpenAIProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
+ // Convert our request to OpenAI format
+ openAIReq := OpenAIEmbeddingRequest{
+ Input: req.Input,
+ Model: req.Model,
+ EncodingFormat: req.EncodingFormat,
+ Dimensions: req.Dimensions,
+ User: req.User,
+ }
+
+ // Make the API call
+ resp, err := p.makeOpenAIRequest(ctx, "/embeddings", openAIReq)
+ if err != nil {
+ return nil, fmt.Errorf("OpenAI embeddings API request failed: %w", err)
+ }
+
+ // Parse the response
+ var openAIResp OpenAIEmbeddingResponse
+ if err := json.Unmarshal(resp, &openAIResp); err != nil {
+ return nil, fmt.Errorf("failed to parse OpenAI embeddings response: %w", err)
+ }
+
+ // Convert back to our format
+ return p.convertFromOpenAIEmbeddingResponse(openAIResp), nil
+}
+
+// Close implements the LLMProvider interface
+func (p *OpenAIProvider) Close() error {
+ // Nothing to clean up for HTTP client
+ return nil
+}
+
+// convertToOpenAIRequest converts our request format to OpenAI format
+func (p *OpenAIProvider) convertToOpenAIRequest(req llm.ChatCompletionRequest) OpenAIRequest {
+ openAIReq := OpenAIRequest{
+ Model: req.Model,
+ MaxTokens: req.MaxTokens,
+ Temperature: req.Temperature,
+ TopP: req.TopP,
+ N: req.N,
+ Stream: req.Stream,
+ Stop: req.Stop,
+ PresencePenalty: req.PresencePenalty,
+ FrequencyPenalty: req.FrequencyPenalty,
+ LogitBias: req.LogitBias,
+ User: req.User,
+ ToolChoice: req.ToolChoice,
+ Seed: req.Seed,
+ }
+
+ // Convert messages
+ openAIReq.Messages = make([]OpenAIMessage, len(req.Messages))
+ for i, msg := range req.Messages {
+ openAIReq.Messages[i] = OpenAIMessage{
+ Role: string(msg.Role),
+ Content: msg.Content,
+ ToolCallID: msg.ToolCallID,
+ Name: msg.Name,
+ }
+
+ // Convert tool calls if present
+ if len(msg.ToolCalls) > 0 {
+ openAIReq.Messages[i].ToolCalls = make([]OpenAIToolCall, len(msg.ToolCalls))
+ for j, toolCall := range msg.ToolCalls {
+ openAIReq.Messages[i].ToolCalls[j] = OpenAIToolCall{
+ ID: toolCall.ID,
+ Type: toolCall.Type,
+ Function: OpenAIFunction{
+ Name: toolCall.Function.Name,
+ Description: toolCall.Function.Description,
+ Parameters: toolCall.Function.Parameters,
+ },
+ }
+ }
+ }
+ }
+
+ // Convert tools if present
+ if len(req.Tools) > 0 {
+ openAIReq.Tools = make([]OpenAITool, len(req.Tools))
+ for i, tool := range req.Tools {
+ openAIReq.Tools[i] = OpenAITool{
+ Type: tool.Type,
+ Function: OpenAIFunction{
+ Name: tool.Function.Name,
+ Description: tool.Function.Description,
+ Parameters: tool.Function.Parameters,
+ },
+ }
+ }
+ }
+
+ // Convert response format if present
+ if req.ResponseFormat != nil {
+ openAIReq.ResponseFormat = &OpenAIResponseFormat{
+ Type: req.ResponseFormat.Type,
+ }
+ }
+
+ return openAIReq
+}
+
+// convertFromOpenAIResponse converts OpenAI response to our format
+func (p *OpenAIProvider) convertFromOpenAIResponse(openAIResp OpenAIResponse) *llm.ChatCompletionResponse {
+ resp := &llm.ChatCompletionResponse{
+ ID: openAIResp.ID,
+ Object: openAIResp.Object,
+ Created: openAIResp.Created,
+ Model: openAIResp.Model,
+ SystemFingerprint: openAIResp.SystemFingerprint,
+ Provider: llm.ProviderOpenAI,
+ Usage: llm.Usage{
+ PromptTokens: openAIResp.Usage.PromptTokens,
+ CompletionTokens: openAIResp.Usage.CompletionTokens,
+ TotalTokens: openAIResp.Usage.TotalTokens,
+ },
+ }
+
+ // Convert choices
+ resp.Choices = make([]llm.ChatCompletionChoice, len(openAIResp.Choices))
+ for i, choice := range openAIResp.Choices {
+ resp.Choices[i] = llm.ChatCompletionChoice{
+ Index: choice.Index,
+ FinishReason: choice.FinishReason,
+ Message: llm.Message{
+ Role: llm.Role(choice.Message.Role),
+ Content: choice.Message.Content,
+ Name: choice.Message.Name,
+ },
+ }
+
+ // Convert tool calls if present
+ if len(choice.Message.ToolCalls) > 0 {
+ resp.Choices[i].Message.ToolCalls = make([]llm.ToolCall, len(choice.Message.ToolCalls))
+ for j, toolCall := range choice.Message.ToolCalls {
+ resp.Choices[i].Message.ToolCalls[j] = llm.ToolCall{
+ ID: toolCall.ID,
+ Type: toolCall.Type,
+ Function: llm.Function{
+ Name: toolCall.Function.Name,
+ Description: toolCall.Function.Description,
+ Parameters: toolCall.Function.Parameters,
+ },
+ }
+ }
+ }
+
+ // Convert logprobs if present
+ if choice.Logprobs != nil {
+ resp.Choices[i].Logprobs = &llm.Logprobs{
+ Content: make([]llm.LogprobContent, len(choice.Logprobs.Content)),
+ }
+ for j, content := range choice.Logprobs.Content {
+ resp.Choices[i].Logprobs.Content[j] = llm.LogprobContent{
+ Token: content.Token,
+ Logprob: content.Logprob,
+ Bytes: content.Bytes,
+ }
+ if len(content.TopLogprobs) > 0 {
+ resp.Choices[i].Logprobs.Content[j].TopLogprobs = make([]llm.TopLogprob, len(content.TopLogprobs))
+ for k, topLogprob := range content.TopLogprobs {
+ resp.Choices[i].Logprobs.Content[j].TopLogprobs[k] = llm.TopLogprob{
+ Token: topLogprob.Token,
+ Logprob: topLogprob.Logprob,
+ Bytes: topLogprob.Bytes,
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return resp
+}
+
+// convertFromOpenAIEmbeddingResponse converts OpenAI embedding response to our format
+func (p *OpenAIProvider) convertFromOpenAIEmbeddingResponse(openAIResp OpenAIEmbeddingResponse) *llm.EmbeddingResponse {
+ resp := &llm.EmbeddingResponse{
+ Object: openAIResp.Object,
+ Model: openAIResp.Model,
+ Provider: llm.ProviderOpenAI,
+ Usage: llm.Usage{
+ PromptTokens: openAIResp.Usage.PromptTokens,
+ CompletionTokens: openAIResp.Usage.CompletionTokens,
+ TotalTokens: openAIResp.Usage.TotalTokens,
+ },
+ }
+
+ // Convert embedding data
+ resp.Data = make([]llm.Embedding, len(openAIResp.Data))
+ for i, data := range openAIResp.Data {
+ resp.Data[i] = llm.Embedding{
+ Object: data.Object,
+ Embedding: data.Embedding,
+ Index: data.Index,
+ }
+ }
+
+ return resp
+}
+
+// makeOpenAIRequest makes an HTTP request to the OpenAI API
+func (p *OpenAIProvider) makeOpenAIRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) {
+ // Prepare request body
+ jsonData, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ // Create HTTP request
+ url := p.config.BaseURL + endpoint
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set headers
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
+
+ // Add organization header if present
+ if org, ok := p.config.ExtraConfig["organization"].(string); ok && org != "" {
+ req.Header.Set("OpenAI-Organization", org)
+ }
+
+ // Make the request
+ resp, err := p.client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("HTTP request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ // Check for errors
+ if resp.StatusCode != http.StatusOK {
+ var openAIErr OpenAIError
+ if err := json.Unmarshal(body, &openAIErr); err != nil {
+ return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+ }
+ return nil, fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)",
+ openAIErr.Error.Message, openAIErr.Error.Type, openAIErr.Error.Code)
+ }
+
+ return body, nil
+}
+
+// OpenAIFactory implements ProviderFactory for OpenAI
+type OpenAIFactory struct{}
+
+// CreateProvider creates a new OpenAI provider
+func (f *OpenAIFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
+ if config.Provider != llm.ProviderOpenAI {
+ return nil, fmt.Errorf("OpenAI factory cannot create provider: %s", config.Provider)
+ }
+
+ // Validate config
+ if err := llm.ValidateConfig(config); err != nil {
+ return nil, fmt.Errorf("invalid OpenAI config: %w", err)
+ }
+
+ // Merge with defaults
+ config = llm.MergeConfig(config)
+
+ return NewOpenAIProvider(config), nil
+}
+
+// SupportsProvider checks if this factory supports the given provider
+func (f *OpenAIFactory) SupportsProvider(provider llm.Provider) bool {
+ return provider == llm.ProviderOpenAI
+}
+
+// Register OpenAI provider with the default registry
+func init() {
+ llm.RegisterProvider(llm.ProviderOpenAI, &OpenAIFactory{})
+}