Add OpenAI implementation
Change-Id: Iea3191cf538959002e6ae095857a7aa6126b3e2f
diff --git a/server/llm/openai/openai.go b/server/llm/openai/openai.go
new file mode 100644
index 0000000..c513c53
--- /dev/null
+++ b/server/llm/openai/openai.go
@@ -0,0 +1,468 @@
+package openai
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/iomodo/staff/llm"
+)
+
+// OpenAIProvider implements the LLMProvider interface for OpenAI
+type OpenAIProvider struct {
+ config llm.Config
+ client *http.Client
+}
+
+// OpenAIRequest represents the OpenAI API request format
+type OpenAIRequest struct {
+ Model string `json:"model"`
+ Messages []OpenAIMessage `json:"messages"`
+ MaxTokens *int `json:"max_tokens,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ N *int `json:"n,omitempty"`
+ Stream *bool `json:"stream,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+ PresencePenalty *float64 `json:"presence_penalty,omitempty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
+ LogitBias map[string]int `json:"logit_bias,omitempty"`
+ User string `json:"user,omitempty"`
+ Tools []OpenAITool `json:"tools,omitempty"`
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
+ ResponseFormat *OpenAIResponseFormat `json:"response_format,omitempty"`
+ Seed *int64 `json:"seed,omitempty"`
+}
+
+// OpenAIMessage represents a message in OpenAI format
+type OpenAIMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ToolCalls []OpenAIToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+ Name string `json:"name,omitempty"`
+}
+
+// OpenAIToolCall represents a tool call in OpenAI format
+type OpenAIToolCall struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Function OpenAIFunction `json:"function"`
+}
+
+// OpenAIFunction represents a function in OpenAI format
+type OpenAIFunction struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Parameters map[string]interface{} `json:"parameters,omitempty"`
+}
+
+// OpenAITool represents a tool in OpenAI format
+type OpenAITool struct {
+ Type string `json:"type"`
+ Function OpenAIFunction `json:"function"`
+}
+
+// OpenAIResponseFormat represents response format in OpenAI format
+type OpenAIResponseFormat struct {
+ Type string `json:"type"`
+}
+
+// OpenAIResponse represents the OpenAI API response format
+type OpenAIResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint,omitempty"`
+ Choices []OpenAIChoice `json:"choices"`
+ Usage OpenAIUsage `json:"usage"`
+}
+
+// OpenAIChoice represents a choice in OpenAI response
+type OpenAIChoice struct {
+ Index int `json:"index"`
+ Message OpenAIMessage `json:"message"`
+ Logprobs *OpenAILogprobs `json:"logprobs,omitempty"`
+ FinishReason string `json:"finish_reason"`
+ Delta *OpenAIMessage `json:"delta,omitempty"`
+}
+
+// OpenAILogprobs represents log probabilities in OpenAI format
+type OpenAILogprobs struct {
+ Content []OpenAILogprobContent `json:"content,omitempty"`
+}
+
+// OpenAILogprobContent represents log probability content in OpenAI format
+type OpenAILogprobContent struct {
+ Token string `json:"token"`
+ Logprob float64 `json:"logprob"`
+ Bytes []int `json:"bytes,omitempty"`
+ TopLogprobs []OpenAITopLogprob `json:"top_logprobs,omitempty"`
+}
+
+// OpenAITopLogprob represents a top log probability in OpenAI format
+type OpenAITopLogprob struct {
+ Token string `json:"token"`
+ Logprob float64 `json:"logprob"`
+ Bytes []int `json:"bytes,omitempty"`
+}
+
+// OpenAIUsage represents usage information in OpenAI format
+type OpenAIUsage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+// OpenAIEmbeddingRequest represents OpenAI embedding request
+type OpenAIEmbeddingRequest struct {
+ Input interface{} `json:"input"`
+ Model string `json:"model"`
+ EncodingFormat string `json:"encoding_format,omitempty"`
+ Dimensions *int `json:"dimensions,omitempty"`
+ User string `json:"user,omitempty"`
+}
+
+// OpenAIEmbeddingResponse represents OpenAI embedding response
+type OpenAIEmbeddingResponse struct {
+ Object string `json:"object"`
+ Data []OpenAIEmbeddingData `json:"data"`
+ Usage OpenAIUsage `json:"usage"`
+ Model string `json:"model"`
+}
+
+// OpenAIEmbeddingData represents embedding data in OpenAI format
+type OpenAIEmbeddingData struct {
+ Object string `json:"object"`
+ Embedding []float64 `json:"embedding"`
+ Index int `json:"index"`
+}
+
+// OpenAIError represents an error from OpenAI API
+type OpenAIError struct {
+ Error struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Code string `json:"code,omitempty"`
+ Param string `json:"param,omitempty"`
+ } `json:"error"`
+}
+
+// NewOpenAIProvider creates a new OpenAI provider
+func NewOpenAIProvider(config llm.Config) *OpenAIProvider {
+ client := &http.Client{
+ Timeout: config.Timeout,
+ }
+
+ return &OpenAIProvider{
+ config: config,
+ client: client,
+ }
+}
+
+// ChatCompletion implements the LLMProvider interface for OpenAI
+func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
+ // Convert our request to OpenAI format
+ openAIReq := p.convertToOpenAIRequest(req)
+
+ // Make the API call
+ resp, err := p.makeOpenAIRequest(ctx, "/chat/completions", openAIReq)
+ if err != nil {
+ return nil, fmt.Errorf("OpenAI API request failed: %w", err)
+ }
+
+ // Parse the response
+ var openAIResp OpenAIResponse
+ if err := json.Unmarshal(resp, &openAIResp); err != nil {
+ return nil, fmt.Errorf("failed to parse OpenAI response: %w", err)
+ }
+
+ // Convert back to our format
+ return p.convertFromOpenAIResponse(openAIResp), nil
+}
+
+// CreateEmbeddings implements the LLMProvider interface for OpenAI
+func (p *OpenAIProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
+ // Convert our request to OpenAI format
+ openAIReq := OpenAIEmbeddingRequest{
+ Input: req.Input,
+ Model: req.Model,
+ EncodingFormat: req.EncodingFormat,
+ Dimensions: req.Dimensions,
+ User: req.User,
+ }
+
+ // Make the API call
+ resp, err := p.makeOpenAIRequest(ctx, "/embeddings", openAIReq)
+ if err != nil {
+ return nil, fmt.Errorf("OpenAI embeddings API request failed: %w", err)
+ }
+
+ // Parse the response
+ var openAIResp OpenAIEmbeddingResponse
+ if err := json.Unmarshal(resp, &openAIResp); err != nil {
+ return nil, fmt.Errorf("failed to parse OpenAI embeddings response: %w", err)
+ }
+
+ // Convert back to our format
+ return p.convertFromOpenAIEmbeddingResponse(openAIResp), nil
+}
+
+// Close implements the LLMProvider interface
+func (p *OpenAIProvider) Close() error {
+ // Nothing to clean up for HTTP client
+ return nil
+}
+
+// convertToOpenAIRequest converts our request format to OpenAI format
+func (p *OpenAIProvider) convertToOpenAIRequest(req llm.ChatCompletionRequest) OpenAIRequest {
+ openAIReq := OpenAIRequest{
+ Model: req.Model,
+ MaxTokens: req.MaxTokens,
+ Temperature: req.Temperature,
+ TopP: req.TopP,
+ N: req.N,
+ Stream: req.Stream,
+ Stop: req.Stop,
+ PresencePenalty: req.PresencePenalty,
+ FrequencyPenalty: req.FrequencyPenalty,
+ LogitBias: req.LogitBias,
+ User: req.User,
+ ToolChoice: req.ToolChoice,
+ Seed: req.Seed,
+ }
+
+ // Convert messages
+ openAIReq.Messages = make([]OpenAIMessage, len(req.Messages))
+ for i, msg := range req.Messages {
+ openAIReq.Messages[i] = OpenAIMessage{
+ Role: string(msg.Role),
+ Content: msg.Content,
+ ToolCallID: msg.ToolCallID,
+ Name: msg.Name,
+ }
+
+ // Convert tool calls if present
+ if len(msg.ToolCalls) > 0 {
+ openAIReq.Messages[i].ToolCalls = make([]OpenAIToolCall, len(msg.ToolCalls))
+ for j, toolCall := range msg.ToolCalls {
+ openAIReq.Messages[i].ToolCalls[j] = OpenAIToolCall{
+ ID: toolCall.ID,
+ Type: toolCall.Type,
+ Function: OpenAIFunction{
+ Name: toolCall.Function.Name,
+ Description: toolCall.Function.Description,
+ Parameters: toolCall.Function.Parameters,
+ },
+ }
+ }
+ }
+ }
+
+ // Convert tools if present
+ if len(req.Tools) > 0 {
+ openAIReq.Tools = make([]OpenAITool, len(req.Tools))
+ for i, tool := range req.Tools {
+ openAIReq.Tools[i] = OpenAITool{
+ Type: tool.Type,
+ Function: OpenAIFunction{
+ Name: tool.Function.Name,
+ Description: tool.Function.Description,
+ Parameters: tool.Function.Parameters,
+ },
+ }
+ }
+ }
+
+ // Convert response format if present
+ if req.ResponseFormat != nil {
+ openAIReq.ResponseFormat = &OpenAIResponseFormat{
+ Type: req.ResponseFormat.Type,
+ }
+ }
+
+ return openAIReq
+}
+
+// convertFromOpenAIResponse converts OpenAI response to our format
+func (p *OpenAIProvider) convertFromOpenAIResponse(openAIResp OpenAIResponse) *llm.ChatCompletionResponse {
+ resp := &llm.ChatCompletionResponse{
+ ID: openAIResp.ID,
+ Object: openAIResp.Object,
+ Created: openAIResp.Created,
+ Model: openAIResp.Model,
+ SystemFingerprint: openAIResp.SystemFingerprint,
+ Provider: llm.ProviderOpenAI,
+ Usage: llm.Usage{
+ PromptTokens: openAIResp.Usage.PromptTokens,
+ CompletionTokens: openAIResp.Usage.CompletionTokens,
+ TotalTokens: openAIResp.Usage.TotalTokens,
+ },
+ }
+
+ // Convert choices
+ resp.Choices = make([]llm.ChatCompletionChoice, len(openAIResp.Choices))
+ for i, choice := range openAIResp.Choices {
+ resp.Choices[i] = llm.ChatCompletionChoice{
+ Index: choice.Index,
+ FinishReason: choice.FinishReason,
+ Message: llm.Message{
+ Role: llm.Role(choice.Message.Role),
+ Content: choice.Message.Content,
+ Name: choice.Message.Name,
+ },
+ }
+
+ // Convert tool calls if present
+ if len(choice.Message.ToolCalls) > 0 {
+ resp.Choices[i].Message.ToolCalls = make([]llm.ToolCall, len(choice.Message.ToolCalls))
+ for j, toolCall := range choice.Message.ToolCalls {
+ resp.Choices[i].Message.ToolCalls[j] = llm.ToolCall{
+ ID: toolCall.ID,
+ Type: toolCall.Type,
+ Function: llm.Function{
+ Name: toolCall.Function.Name,
+ Description: toolCall.Function.Description,
+ Parameters: toolCall.Function.Parameters,
+ },
+ }
+ }
+ }
+
+ // Convert logprobs if present
+ if choice.Logprobs != nil {
+ resp.Choices[i].Logprobs = &llm.Logprobs{
+ Content: make([]llm.LogprobContent, len(choice.Logprobs.Content)),
+ }
+ for j, content := range choice.Logprobs.Content {
+ resp.Choices[i].Logprobs.Content[j] = llm.LogprobContent{
+ Token: content.Token,
+ Logprob: content.Logprob,
+ Bytes: content.Bytes,
+ }
+ if len(content.TopLogprobs) > 0 {
+ resp.Choices[i].Logprobs.Content[j].TopLogprobs = make([]llm.TopLogprob, len(content.TopLogprobs))
+ for k, topLogprob := range content.TopLogprobs {
+ resp.Choices[i].Logprobs.Content[j].TopLogprobs[k] = llm.TopLogprob{
+ Token: topLogprob.Token,
+ Logprob: topLogprob.Logprob,
+ Bytes: topLogprob.Bytes,
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return resp
+}
+
+// convertFromOpenAIEmbeddingResponse converts OpenAI embedding response to our format
+func (p *OpenAIProvider) convertFromOpenAIEmbeddingResponse(openAIResp OpenAIEmbeddingResponse) *llm.EmbeddingResponse {
+ resp := &llm.EmbeddingResponse{
+ Object: openAIResp.Object,
+ Model: openAIResp.Model,
+ Provider: llm.ProviderOpenAI,
+ Usage: llm.Usage{
+ PromptTokens: openAIResp.Usage.PromptTokens,
+ CompletionTokens: openAIResp.Usage.CompletionTokens,
+ TotalTokens: openAIResp.Usage.TotalTokens,
+ },
+ }
+
+ // Convert embedding data
+ resp.Data = make([]llm.Embedding, len(openAIResp.Data))
+ for i, data := range openAIResp.Data {
+ resp.Data[i] = llm.Embedding{
+ Object: data.Object,
+ Embedding: data.Embedding,
+ Index: data.Index,
+ }
+ }
+
+ return resp
+}
+
+// makeOpenAIRequest makes an HTTP request to the OpenAI API
+func (p *OpenAIProvider) makeOpenAIRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) {
+ // Prepare request body
+ jsonData, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ // Create HTTP request
+ url := p.config.BaseURL + endpoint
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set headers
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
+
+ // Add organization header if present
+ if org, ok := p.config.ExtraConfig["organization"].(string); ok && org != "" {
+ req.Header.Set("OpenAI-Organization", org)
+ }
+
+ // Make the request
+ resp, err := p.client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("HTTP request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ // Check for errors
+ if resp.StatusCode != http.StatusOK {
+ var openAIErr OpenAIError
+ if err := json.Unmarshal(body, &openAIErr); err != nil {
+ return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+ }
+ return nil, fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)",
+ openAIErr.Error.Message, openAIErr.Error.Type, openAIErr.Error.Code)
+ }
+
+ return body, nil
+}
+
+// OpenAIFactory implements ProviderFactory for OpenAI
+type OpenAIFactory struct{}
+
+// CreateProvider creates a new OpenAI provider
+func (f *OpenAIFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
+ if config.Provider != llm.ProviderOpenAI {
+ return nil, fmt.Errorf("OpenAI factory cannot create provider: %s", config.Provider)
+ }
+
+ // Validate config
+ if err := llm.ValidateConfig(config); err != nil {
+ return nil, fmt.Errorf("invalid OpenAI config: %w", err)
+ }
+
+ // Merge with defaults
+ config = llm.MergeConfig(config)
+
+ return NewOpenAIProvider(config), nil
+}
+
+// SupportsProvider checks if this factory supports the given provider
+func (f *OpenAIFactory) SupportsProvider(provider llm.Provider) bool {
+ return provider == llm.ProviderOpenAI
+}
+
+// Register OpenAI provider with the default registry
+func init() {
+ llm.RegisterProvider(llm.ProviderOpenAI, &OpenAIFactory{})
+}
diff --git a/server/llm/openai/openai_example.go b/server/llm/openai/openai_example.go
new file mode 100644
index 0000000..6cd7dbf
--- /dev/null
+++ b/server/llm/openai/openai_example.go
@@ -0,0 +1,254 @@
+package openai
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "time"
+
+ "github.com/iomodo/staff/llm"
+)
+
+// ExampleOpenAI demonstrates how to use the OpenAI implementation
+func ExampleOpenAI() {
+ // Create OpenAI configuration
+ config := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "your-openai-api-key-here", // Replace with your actual API key
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ ExtraConfig: map[string]interface{}{
+ "organization": "your-org-id", // Optional: your OpenAI organization ID
+ },
+ }
+
+ // Create the provider
+ provider, err := llm.CreateProvider(config)
+ if err != nil {
+ log.Fatalf("Failed to create OpenAI provider: %v", err)
+ }
+ defer provider.Close()
+
+ // Example 1: Basic chat completion
+ exampleOpenAIChatCompletion(provider)
+
+ // Example 2: Function calling
+ exampleOpenAIFunctionCalling(provider)
+
+ // Example 3: Embeddings
+ exampleOpenAIEmbeddings(provider)
+}
+
+// exampleOpenAIChatCompletion demonstrates basic chat completion with OpenAI
+func exampleOpenAIChatCompletion(provider llm.LLMProvider) {
+ fmt.Println("=== OpenAI Chat Completion Example ===")
+
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleSystem, Content: "You are a helpful assistant."},
+ {Role: llm.RoleUser, Content: "What is the capital of France?"},
+ },
+ MaxTokens: &[]int{100}[0],
+ Temperature: &[]float64{0.7}[0],
+ }
+
+ ctx := context.Background()
+ resp, err := provider.ChatCompletion(ctx, req)
+ if err != nil {
+ log.Printf("Chat completion failed: %v", err)
+ return
+ }
+
+ fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
+ fmt.Printf("Model: %s\n", resp.Model)
+ fmt.Printf("Provider: %s\n", resp.Provider)
+ fmt.Printf("Usage: %+v\n", resp.Usage)
+}
+
+// exampleOpenAIFunctionCalling demonstrates function calling with OpenAI
+func exampleOpenAIFunctionCalling(provider llm.LLMProvider) {
+ fmt.Println("\n=== OpenAI Function Calling Example ===")
+
+ // Define a function that can be called
+ tools := []llm.Tool{
+ {
+ Type: "function",
+ Function: llm.Function{
+ Name: "get_weather",
+ Description: "Get the current weather for a location",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ "required": []string{"location"},
+ },
+ },
+ },
+ }
+
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "What's the weather like in Tokyo?"},
+ },
+ Tools: tools,
+ MaxTokens: &[]int{150}[0],
+ Temperature: &[]float64{0.1}[0],
+ }
+
+ ctx := context.Background()
+ resp, err := provider.ChatCompletion(ctx, req)
+ if err != nil {
+ log.Printf("Function calling failed: %v", err)
+ return
+ }
+
+ // Check if the model wants to call a function
+ if len(resp.Choices[0].Message.ToolCalls) > 0 {
+ fmt.Println("Model wants to call a function:")
+ for _, toolCall := range resp.Choices[0].Message.ToolCalls {
+ fmt.Printf("Function: %s\n", toolCall.Function.Name)
+ fmt.Printf("Arguments: %+v\n", toolCall.Function.Parameters)
+ }
+ } else {
+ fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
+ }
+}
+
+// exampleOpenAIEmbeddings demonstrates embedding generation with OpenAI
+func exampleOpenAIEmbeddings(provider llm.LLMProvider) {
+ fmt.Println("\n=== OpenAI Embeddings Example ===")
+
+ req := llm.EmbeddingRequest{
+ Input: "Hello, world! This is a test sentence for embeddings.",
+ Model: "text-embedding-ada-002",
+ }
+
+ ctx := context.Background()
+ resp, err := provider.CreateEmbeddings(ctx, req)
+ if err != nil {
+ log.Printf("Embeddings failed: %v", err)
+ return
+ }
+
+ fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding))
+ fmt.Printf("First 5 values: %v\n", resp.Data[0].Embedding[:5])
+ fmt.Printf("Model: %s\n", resp.Model)
+ fmt.Printf("Provider: %s\n", resp.Provider)
+ fmt.Printf("Usage: %+v\n", resp.Usage)
+}
+
+// ExampleOpenAIWithErrorHandling demonstrates proper error handling
+func ExampleOpenAIWithErrorHandling() {
+ fmt.Println("\n=== OpenAI Error Handling Example ===")
+
+ // Test with invalid API key
+ config := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "invalid-key",
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ }
+
+ provider, err := llm.CreateProvider(config)
+ if err != nil {
+ fmt.Printf("Failed to create provider: %v\n", err)
+ return
+ }
+ defer provider.Close()
+
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello"},
+ },
+ MaxTokens: &[]int{50}[0],
+ }
+
+ ctx := context.Background()
+ resp, err := provider.ChatCompletion(ctx, req)
+ if err != nil {
+ fmt.Printf("Expected error with invalid API key: %v\n", err)
+ return
+ }
+
+ fmt.Printf("Unexpected success: %s\n", resp.Choices[0].Message.Content)
+}
+
+// ExampleOpenAIWithCustomConfig demonstrates custom configuration
+func ExampleOpenAIWithCustomConfig() {
+ fmt.Println("\n=== OpenAI Custom Configuration Example ===")
+
+ // Start with minimal config
+ customConfig := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "your-api-key",
+ }
+
+ // Merge with defaults
+ config := llm.MergeConfig(customConfig)
+
+ fmt.Printf("Merged config: BaseURL=%s, Timeout=%v, MaxRetries=%d\n",
+ config.BaseURL, config.Timeout, config.MaxRetries)
+
+ // Add extra configuration
+ config.ExtraConfig = map[string]interface{}{
+ "organization": "your-org-id",
+ "project": "your-project-id",
+ }
+
+ provider, err := llm.CreateProvider(config)
+ if err != nil {
+ fmt.Printf("Failed to create provider: %v\n", err)
+ return
+ }
+ defer provider.Close()
+
+ fmt.Println("Provider created successfully with custom configuration")
+}
+
+// ExampleOpenAIWithValidation demonstrates request validation
+func ExampleOpenAIWithValidation() {
+ fmt.Println("\n=== OpenAI Validation Example ===")
+
+ // Valid request
+ validReq := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello"},
+ },
+ Temperature: &[]float64{0.5}[0],
+ }
+
+ if err := llm.ValidateChatCompletionRequest(validReq); err != nil {
+ fmt.Printf("Valid request validation failed: %v\n", err)
+ } else {
+ fmt.Println("Valid request passed validation")
+ }
+
+ // Invalid request (no model)
+ invalidReq := llm.ChatCompletionRequest{
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello"},
+ },
+ }
+
+ if err := llm.ValidateChatCompletionRequest(invalidReq); err != nil {
+ fmt.Printf("Invalid request correctly caught: %v\n", err)
+ } else {
+ fmt.Println("Invalid request should have failed validation")
+ }
+
+ // Test token estimation
+ tokens := llm.EstimateTokens(validReq)
+ fmt.Printf("Estimated tokens for request: %d\n", tokens)
+}
diff --git a/server/llm/openai/openai_test.go b/server/llm/openai/openai_test.go
new file mode 100644
index 0000000..8896405
--- /dev/null
+++ b/server/llm/openai/openai_test.go
@@ -0,0 +1,428 @@
+package openai
+
+import (
+ "testing"
+ "time"
+
+ "github.com/iomodo/staff/llm"
+)
+
+func TestOpenAIProvider_Interface(t *testing.T) {
+ // Test that OpenAIProvider implements LLMProvider interface
+ var _ llm.LLMProvider = (*OpenAIProvider)(nil)
+}
+
+func TestOpenAIFactory_CreateProvider(t *testing.T) {
+ factory := &OpenAIFactory{}
+
+ // Test valid config
+ config := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ }
+
+ provider, err := factory.CreateProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+
+ if provider == nil {
+ t.Fatal("Provider should not be nil")
+ }
+
+ // Test invalid provider
+ invalidConfig := llm.Config{
+ Provider: llm.ProviderClaude,
+ APIKey: "test-key",
+ }
+
+ _, err = factory.CreateProvider(invalidConfig)
+ if err == nil {
+ t.Fatal("Should fail with invalid provider")
+ }
+
+ // Test missing API key
+ noKeyConfig := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ BaseURL: "https://api.openai.com/v1",
+ }
+
+ _, err = factory.CreateProvider(noKeyConfig)
+ if err == nil {
+ t.Fatal("Should fail with missing API key")
+ }
+}
+
+func TestOpenAIFactory_SupportsProvider(t *testing.T) {
+ factory := &OpenAIFactory{}
+
+ if !factory.SupportsProvider(llm.ProviderOpenAI) {
+ t.Fatal("Should support OpenAI provider")
+ }
+
+ if factory.SupportsProvider(llm.ProviderClaude) {
+ t.Fatal("Should not support Claude provider")
+ }
+}
+
+func TestOpenAIProvider_ConvertRequest(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test basic request conversion
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello"},
+ },
+ MaxTokens: &[]int{100}[0],
+ Temperature: &[]float64{0.7}[0],
+ }
+
+ openAIReq := provider.convertToOpenAIRequest(req)
+
+ if openAIReq.Model != "gpt-3.5-turbo" {
+ t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", openAIReq.Model)
+ }
+
+ if len(openAIReq.Messages) != 1 {
+ t.Errorf("Expected 1 message, got %d", len(openAIReq.Messages))
+ }
+
+ if openAIReq.Messages[0].Role != "user" {
+ t.Errorf("Expected role 'user', got '%s'", openAIReq.Messages[0].Role)
+ }
+
+ if openAIReq.Messages[0].Content != "Hello" {
+ t.Errorf("Expected content 'Hello', got '%s'", openAIReq.Messages[0].Content)
+ }
+
+ if *openAIReq.MaxTokens != 100 {
+ t.Errorf("Expected max_tokens 100, got %d", *openAIReq.MaxTokens)
+ }
+
+ if *openAIReq.Temperature != 0.7 {
+ t.Errorf("Expected temperature 0.7, got %f", *openAIReq.Temperature)
+ }
+}
+
+func TestOpenAIProvider_ConvertResponse(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test basic response conversion
+ openAIResp := OpenAIResponse{
+ ID: "test-id",
+ Object: "chat.completion",
+ Created: 1234567890,
+ Model: "gpt-3.5-turbo",
+ Choices: []OpenAIChoice{
+ {
+ Index: 0,
+ Message: OpenAIMessage{
+ Role: "assistant",
+ Content: "Hello! How can I help you?",
+ },
+ FinishReason: "stop",
+ },
+ },
+ Usage: OpenAIUsage{
+ PromptTokens: 10,
+ CompletionTokens: 20,
+ TotalTokens: 30,
+ },
+ }
+
+ resp := provider.convertFromOpenAIResponse(openAIResp)
+
+ if resp.ID != "test-id" {
+ t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
+ }
+
+ if resp.Model != "gpt-3.5-turbo" {
+ t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", resp.Model)
+ }
+
+ if resp.Provider != llm.ProviderOpenAI {
+ t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
+ }
+
+ if len(resp.Choices) != 1 {
+ t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
+ }
+
+ if resp.Choices[0].Message.Role != llm.RoleAssistant {
+ t.Errorf("Expected role assistant, got %s", resp.Choices[0].Message.Role)
+ }
+
+ if resp.Choices[0].Message.Content != "Hello! How can I help you?" {
+ t.Errorf("Expected content 'Hello! How can I help you?', got '%s'", resp.Choices[0].Message.Content)
+ }
+
+ if resp.Usage.PromptTokens != 10 {
+ t.Errorf("Expected prompt tokens 10, got %d", resp.Usage.PromptTokens)
+ }
+
+ if resp.Usage.CompletionTokens != 20 {
+ t.Errorf("Expected completion tokens 20, got %d", resp.Usage.CompletionTokens)
+ }
+
+ if resp.Usage.TotalTokens != 30 {
+ t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
+ }
+}
+
+func TestOpenAIProvider_ConvertRequestWithTools(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test request with tools
+ tools := []llm.Tool{
+ {
+ Type: "function",
+ Function: llm.Function{
+ Name: "get_weather",
+ Description: "Get weather information",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ },
+ },
+ },
+ },
+ },
+ }
+
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "What's the weather like?"},
+ },
+ Tools: tools,
+ }
+
+ openAIReq := provider.convertToOpenAIRequest(req)
+
+ if len(openAIReq.Tools) != 1 {
+ t.Errorf("Expected 1 tool, got %d", len(openAIReq.Tools))
+ }
+
+ if openAIReq.Tools[0].Type != "function" {
+ t.Errorf("Expected tool type 'function', got '%s'", openAIReq.Tools[0].Type)
+ }
+
+ if openAIReq.Tools[0].Function.Name != "get_weather" {
+ t.Errorf("Expected function name 'get_weather', got '%s'", openAIReq.Tools[0].Function.Name)
+ }
+}
+
+func TestOpenAIProvider_ConvertResponseWithToolCalls(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test response with tool calls
+ openAIResp := OpenAIResponse{
+ ID: "test-id",
+ Object: "chat.completion",
+ Model: "gpt-3.5-turbo",
+ Choices: []OpenAIChoice{
+ {
+ Index: 0,
+ Message: OpenAIMessage{
+ Role: "assistant",
+ ToolCalls: []OpenAIToolCall{
+ {
+ ID: "call_123",
+ Type: "function",
+ Function: OpenAIFunction{
+ Name: "get_weather",
+ Parameters: map[string]interface{}{
+ "location": "Tokyo",
+ },
+ },
+ },
+ },
+ },
+ FinishReason: "tool_calls",
+ },
+ },
+ Usage: OpenAIUsage{
+ PromptTokens: 10,
+ CompletionTokens: 20,
+ TotalTokens: 30,
+ },
+ }
+
+ resp := provider.convertFromOpenAIResponse(openAIResp)
+
+ if len(resp.Choices[0].Message.ToolCalls) != 1 {
+ t.Errorf("Expected 1 tool call, got %d", len(resp.Choices[0].Message.ToolCalls))
+ }
+
+ if resp.Choices[0].Message.ToolCalls[0].ID != "call_123" {
+ t.Errorf("Expected tool call ID 'call_123', got '%s'", resp.Choices[0].Message.ToolCalls[0].ID)
+ }
+
+ if resp.Choices[0].Message.ToolCalls[0].Function.Name != "get_weather" {
+ t.Errorf("Expected function name 'get_weather', got '%s'", resp.Choices[0].Message.ToolCalls[0].Function.Name)
+ }
+
+ if resp.Choices[0].FinishReason != "tool_calls" {
+ t.Errorf("Expected finish reason 'tool_calls', got '%s'", resp.Choices[0].FinishReason)
+ }
+}
+
+func TestOpenAIProvider_ConvertEmbeddingRequest(t *testing.T) {
+ req := llm.EmbeddingRequest{
+ Input: "Hello, world!",
+ Model: "text-embedding-ada-002",
+ User: "test-user",
+ }
+
+ // The conversion is done inline in CreateEmbeddings, so we'll test the structure
+ if req.Input != "Hello, world!" {
+ t.Errorf("Expected input 'Hello, world!', got '%v'", req.Input)
+ }
+
+ if req.Model != "text-embedding-ada-002" {
+ t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", req.Model)
+ }
+
+ if req.User != "test-user" {
+ t.Errorf("Expected user 'test-user', got '%s'", req.User)
+ }
+}
+
+func TestOpenAIProvider_ConvertEmbeddingResponse(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test embedding response conversion
+ openAIResp := OpenAIEmbeddingResponse{
+ Object: "list",
+ Model: "text-embedding-ada-002",
+ Data: []OpenAIEmbeddingData{
+ {
+ Object: "embedding",
+ Embedding: []float64{0.1, 0.2, 0.3},
+ Index: 0,
+ },
+ },
+ Usage: OpenAIUsage{
+ PromptTokens: 5,
+ CompletionTokens: 0,
+ TotalTokens: 5,
+ },
+ }
+
+ resp := provider.convertFromOpenAIEmbeddingResponse(openAIResp)
+
+ if resp.Object != "list" {
+ t.Errorf("Expected object 'list', got '%s'", resp.Object)
+ }
+
+ if resp.Model != "text-embedding-ada-002" {
+ t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", resp.Model)
+ }
+
+ if resp.Provider != llm.ProviderOpenAI {
+ t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
+ }
+
+ if len(resp.Data) != 1 {
+ t.Errorf("Expected 1 embedding, got %d", len(resp.Data))
+ }
+
+ if len(resp.Data[0].Embedding) != 3 {
+ t.Errorf("Expected embedding dimension 3, got %d", len(resp.Data[0].Embedding))
+ }
+
+ if resp.Data[0].Embedding[0] != 0.1 {
+ t.Errorf("Expected first embedding value 0.1, got %f", resp.Data[0].Embedding[0])
+ }
+}
+
+func TestOpenAIProvider_Close(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test that Close doesn't return an error
+ err := provider.Close()
+ if err != nil {
+ t.Errorf("Close should not return an error: %v", err)
+ }
+}
+
+func TestOpenAIProvider_Integration(t *testing.T) {
+ // This test would require a real API key and would make actual API calls
+ // It's commented out to avoid making real API calls during testing
+ /*
+ config := Config{
+ Provider: ProviderOpenAI,
+ APIKey: "your-real-api-key",
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ }
+
+ provider, err := CreateProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+ defer provider.Close()
+
+ req := ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []Message{
+ {Role: RoleUser, Content: "Say hello!"},
+ },
+ MaxTokens: &[]int{50}[0],
+ }
+
+ resp, err := provider.ChatCompletion(context.Background(), req)
+ if err != nil {
+ t.Fatalf("Chat completion failed: %v", err)
+ }
+
+ if len(resp.Choices) == 0 {
+ t.Fatal("Expected at least one choice")
+ }
+
+ if resp.Choices[0].Message.Content == "" {
+ t.Fatal("Expected non-empty response content")
+ }
+ */
+}