Add OpenAI implementation

Change-Id: Iea3191cf538959002e6ae095857a7aa6126b3e2f
diff --git a/server/llm/openai/openai.go b/server/llm/openai/openai.go
new file mode 100644
index 0000000..c513c53
--- /dev/null
+++ b/server/llm/openai/openai.go
@@ -0,0 +1,468 @@
+package openai
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+
+	"github.com/iomodo/staff/llm"
+)
+
+// OpenAIProvider implements the LLMProvider interface for OpenAI
+type OpenAIProvider struct {
+	config llm.Config
+	client *http.Client
+}
+
+// OpenAIRequest represents the OpenAI API request format
+type OpenAIRequest struct {
+	Model            string                `json:"model"`
+	Messages         []OpenAIMessage       `json:"messages"`
+	MaxTokens        *int                  `json:"max_tokens,omitempty"`
+	Temperature      *float64              `json:"temperature,omitempty"`
+	TopP             *float64              `json:"top_p,omitempty"`
+	N                *int                  `json:"n,omitempty"`
+	Stream           *bool                 `json:"stream,omitempty"`
+	Stop             []string              `json:"stop,omitempty"`
+	PresencePenalty  *float64              `json:"presence_penalty,omitempty"`
+	FrequencyPenalty *float64              `json:"frequency_penalty,omitempty"`
+	LogitBias        map[string]int        `json:"logit_bias,omitempty"`
+	User             string                `json:"user,omitempty"`
+	Tools            []OpenAITool          `json:"tools,omitempty"`
+	ToolChoice       interface{}           `json:"tool_choice,omitempty"`
+	ResponseFormat   *OpenAIResponseFormat `json:"response_format,omitempty"`
+	Seed             *int64                `json:"seed,omitempty"`
+}
+
+// OpenAIMessage represents a message in OpenAI format
+type OpenAIMessage struct {
+	Role       string           `json:"role"`
+	Content    string           `json:"content"`
+	ToolCalls  []OpenAIToolCall `json:"tool_calls,omitempty"`
+	ToolCallID string           `json:"tool_call_id,omitempty"`
+	Name       string           `json:"name,omitempty"`
+}
+
+// OpenAIToolCall represents a tool call in OpenAI format
+type OpenAIToolCall struct {
+	ID       string         `json:"id"`
+	Type     string         `json:"type"`
+	Function OpenAIFunction `json:"function"`
+}
+
+// OpenAIFunction represents a function in OpenAI format
+type OpenAIFunction struct {
+	Name        string                 `json:"name"`
+	Description string                 `json:"description,omitempty"`
+	Parameters  map[string]interface{} `json:"parameters,omitempty"`
+}
+
+// OpenAITool represents a tool in OpenAI format
+type OpenAITool struct {
+	Type     string         `json:"type"`
+	Function OpenAIFunction `json:"function"`
+}
+
+// OpenAIResponseFormat represents response format in OpenAI format
+type OpenAIResponseFormat struct {
+	Type string `json:"type"`
+}
+
+// OpenAIResponse represents the OpenAI API response format
+type OpenAIResponse struct {
+	ID                string         `json:"id"`
+	Object            string         `json:"object"`
+	Created           int64          `json:"created"`
+	Model             string         `json:"model"`
+	SystemFingerprint string         `json:"system_fingerprint,omitempty"`
+	Choices           []OpenAIChoice `json:"choices"`
+	Usage             OpenAIUsage    `json:"usage"`
+}
+
+// OpenAIChoice represents a choice in OpenAI response
+type OpenAIChoice struct {
+	Index        int             `json:"index"`
+	Message      OpenAIMessage   `json:"message"`
+	Logprobs     *OpenAILogprobs `json:"logprobs,omitempty"`
+	FinishReason string          `json:"finish_reason"`
+	Delta        *OpenAIMessage  `json:"delta,omitempty"`
+}
+
+// OpenAILogprobs represents log probabilities in OpenAI format
+type OpenAILogprobs struct {
+	Content []OpenAILogprobContent `json:"content,omitempty"`
+}
+
+// OpenAILogprobContent represents log probability content in OpenAI format
+type OpenAILogprobContent struct {
+	Token       string             `json:"token"`
+	Logprob     float64            `json:"logprob"`
+	Bytes       []int              `json:"bytes,omitempty"`
+	TopLogprobs []OpenAITopLogprob `json:"top_logprobs,omitempty"`
+}
+
+// OpenAITopLogprob represents a top log probability in OpenAI format
+type OpenAITopLogprob struct {
+	Token   string  `json:"token"`
+	Logprob float64 `json:"logprob"`
+	Bytes   []int   `json:"bytes,omitempty"`
+}
+
+// OpenAIUsage represents usage information in OpenAI format
+type OpenAIUsage struct {
+	PromptTokens     int `json:"prompt_tokens"`
+	CompletionTokens int `json:"completion_tokens"`
+	TotalTokens      int `json:"total_tokens"`
+}
+
+// OpenAIEmbeddingRequest represents OpenAI embedding request
+type OpenAIEmbeddingRequest struct {
+	Input          interface{} `json:"input"`
+	Model          string      `json:"model"`
+	EncodingFormat string      `json:"encoding_format,omitempty"`
+	Dimensions     *int        `json:"dimensions,omitempty"`
+	User           string      `json:"user,omitempty"`
+}
+
+// OpenAIEmbeddingResponse represents OpenAI embedding response
+type OpenAIEmbeddingResponse struct {
+	Object string                `json:"object"`
+	Data   []OpenAIEmbeddingData `json:"data"`
+	Usage  OpenAIUsage           `json:"usage"`
+	Model  string                `json:"model"`
+}
+
+// OpenAIEmbeddingData represents embedding data in OpenAI format
+type OpenAIEmbeddingData struct {
+	Object    string    `json:"object"`
+	Embedding []float64 `json:"embedding"`
+	Index     int       `json:"index"`
+}
+
+// OpenAIError represents an error from OpenAI API
+type OpenAIError struct {
+	Error struct {
+		Message string `json:"message"`
+		Type    string `json:"type"`
+		Code    string `json:"code,omitempty"`
+		Param   string `json:"param,omitempty"`
+	} `json:"error"`
+}
+
+// NewOpenAIProvider creates a new OpenAI provider
+func NewOpenAIProvider(config llm.Config) *OpenAIProvider {
+	client := &http.Client{
+		Timeout: config.Timeout,
+	}
+
+	return &OpenAIProvider{
+		config: config,
+		client: client,
+	}
+}
+
+// ChatCompletion implements the LLMProvider interface for OpenAI
+func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
+	// Convert our request to OpenAI format
+	openAIReq := p.convertToOpenAIRequest(req)
+
+	// Make the API call
+	resp, err := p.makeOpenAIRequest(ctx, "/chat/completions", openAIReq)
+	if err != nil {
+		return nil, fmt.Errorf("OpenAI API request failed: %w", err)
+	}
+
+	// Parse the response
+	var openAIResp OpenAIResponse
+	if err := json.Unmarshal(resp, &openAIResp); err != nil {
+		return nil, fmt.Errorf("failed to parse OpenAI response: %w", err)
+	}
+
+	// Convert back to our format
+	return p.convertFromOpenAIResponse(openAIResp), nil
+}
+
+// CreateEmbeddings implements the LLMProvider interface for OpenAI
+func (p *OpenAIProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
+	// Convert our request to OpenAI format
+	openAIReq := OpenAIEmbeddingRequest{
+		Input:          req.Input,
+		Model:          req.Model,
+		EncodingFormat: req.EncodingFormat,
+		Dimensions:     req.Dimensions,
+		User:           req.User,
+	}
+
+	// Make the API call
+	resp, err := p.makeOpenAIRequest(ctx, "/embeddings", openAIReq)
+	if err != nil {
+		return nil, fmt.Errorf("OpenAI embeddings API request failed: %w", err)
+	}
+
+	// Parse the response
+	var openAIResp OpenAIEmbeddingResponse
+	if err := json.Unmarshal(resp, &openAIResp); err != nil {
+		return nil, fmt.Errorf("failed to parse OpenAI embeddings response: %w", err)
+	}
+
+	// Convert back to our format
+	return p.convertFromOpenAIEmbeddingResponse(openAIResp), nil
+}
+
+// Close implements the LLMProvider interface
+func (p *OpenAIProvider) Close() error {
+	// Nothing to clean up for HTTP client
+	return nil
+}
+
+// convertToOpenAIRequest converts our request format to OpenAI format
+func (p *OpenAIProvider) convertToOpenAIRequest(req llm.ChatCompletionRequest) OpenAIRequest {
+	openAIReq := OpenAIRequest{
+		Model:            req.Model,
+		MaxTokens:        req.MaxTokens,
+		Temperature:      req.Temperature,
+		TopP:             req.TopP,
+		N:                req.N,
+		Stream:           req.Stream,
+		Stop:             req.Stop,
+		PresencePenalty:  req.PresencePenalty,
+		FrequencyPenalty: req.FrequencyPenalty,
+		LogitBias:        req.LogitBias,
+		User:             req.User,
+		ToolChoice:       req.ToolChoice,
+		Seed:             req.Seed,
+	}
+
+	// Convert messages
+	openAIReq.Messages = make([]OpenAIMessage, len(req.Messages))
+	for i, msg := range req.Messages {
+		openAIReq.Messages[i] = OpenAIMessage{
+			Role:       string(msg.Role),
+			Content:    msg.Content,
+			ToolCallID: msg.ToolCallID,
+			Name:       msg.Name,
+		}
+
+		// Convert tool calls if present
+		if len(msg.ToolCalls) > 0 {
+			openAIReq.Messages[i].ToolCalls = make([]OpenAIToolCall, len(msg.ToolCalls))
+			for j, toolCall := range msg.ToolCalls {
+				openAIReq.Messages[i].ToolCalls[j] = OpenAIToolCall{
+					ID:   toolCall.ID,
+					Type: toolCall.Type,
+					Function: OpenAIFunction{
+						Name:        toolCall.Function.Name,
+						Description: toolCall.Function.Description,
+						Parameters:  toolCall.Function.Parameters,
+					},
+				}
+			}
+		}
+	}
+
+	// Convert tools if present
+	if len(req.Tools) > 0 {
+		openAIReq.Tools = make([]OpenAITool, len(req.Tools))
+		for i, tool := range req.Tools {
+			openAIReq.Tools[i] = OpenAITool{
+				Type: tool.Type,
+				Function: OpenAIFunction{
+					Name:        tool.Function.Name,
+					Description: tool.Function.Description,
+					Parameters:  tool.Function.Parameters,
+				},
+			}
+		}
+	}
+
+	// Convert response format if present
+	if req.ResponseFormat != nil {
+		openAIReq.ResponseFormat = &OpenAIResponseFormat{
+			Type: req.ResponseFormat.Type,
+		}
+	}
+
+	return openAIReq
+}
+
+// convertFromOpenAIResponse converts OpenAI response to our format
+func (p *OpenAIProvider) convertFromOpenAIResponse(openAIResp OpenAIResponse) *llm.ChatCompletionResponse {
+	resp := &llm.ChatCompletionResponse{
+		ID:                openAIResp.ID,
+		Object:            openAIResp.Object,
+		Created:           openAIResp.Created,
+		Model:             openAIResp.Model,
+		SystemFingerprint: openAIResp.SystemFingerprint,
+		Provider:          llm.ProviderOpenAI,
+		Usage: llm.Usage{
+			PromptTokens:     openAIResp.Usage.PromptTokens,
+			CompletionTokens: openAIResp.Usage.CompletionTokens,
+			TotalTokens:      openAIResp.Usage.TotalTokens,
+		},
+	}
+
+	// Convert choices
+	resp.Choices = make([]llm.ChatCompletionChoice, len(openAIResp.Choices))
+	for i, choice := range openAIResp.Choices {
+		resp.Choices[i] = llm.ChatCompletionChoice{
+			Index:        choice.Index,
+			FinishReason: choice.FinishReason,
+			Message: llm.Message{
+				Role:    llm.Role(choice.Message.Role),
+				Content: choice.Message.Content,
+				Name:    choice.Message.Name,
+			},
+		}
+
+		// Convert tool calls if present
+		if len(choice.Message.ToolCalls) > 0 {
+			resp.Choices[i].Message.ToolCalls = make([]llm.ToolCall, len(choice.Message.ToolCalls))
+			for j, toolCall := range choice.Message.ToolCalls {
+				resp.Choices[i].Message.ToolCalls[j] = llm.ToolCall{
+					ID:   toolCall.ID,
+					Type: toolCall.Type,
+					Function: llm.Function{
+						Name:        toolCall.Function.Name,
+						Description: toolCall.Function.Description,
+						Parameters:  toolCall.Function.Parameters,
+					},
+				}
+			}
+		}
+
+		// Convert logprobs if present
+		if choice.Logprobs != nil {
+			resp.Choices[i].Logprobs = &llm.Logprobs{
+				Content: make([]llm.LogprobContent, len(choice.Logprobs.Content)),
+			}
+			for j, content := range choice.Logprobs.Content {
+				resp.Choices[i].Logprobs.Content[j] = llm.LogprobContent{
+					Token:   content.Token,
+					Logprob: content.Logprob,
+					Bytes:   content.Bytes,
+				}
+				if len(content.TopLogprobs) > 0 {
+					resp.Choices[i].Logprobs.Content[j].TopLogprobs = make([]llm.TopLogprob, len(content.TopLogprobs))
+					for k, topLogprob := range content.TopLogprobs {
+						resp.Choices[i].Logprobs.Content[j].TopLogprobs[k] = llm.TopLogprob{
+							Token:   topLogprob.Token,
+							Logprob: topLogprob.Logprob,
+							Bytes:   topLogprob.Bytes,
+						}
+					}
+				}
+			}
+		}
+	}
+
+	return resp
+}
+
+// convertFromOpenAIEmbeddingResponse converts OpenAI embedding response to our format
+func (p *OpenAIProvider) convertFromOpenAIEmbeddingResponse(openAIResp OpenAIEmbeddingResponse) *llm.EmbeddingResponse {
+	resp := &llm.EmbeddingResponse{
+		Object:   openAIResp.Object,
+		Model:    openAIResp.Model,
+		Provider: llm.ProviderOpenAI,
+		Usage: llm.Usage{
+			PromptTokens:     openAIResp.Usage.PromptTokens,
+			CompletionTokens: openAIResp.Usage.CompletionTokens,
+			TotalTokens:      openAIResp.Usage.TotalTokens,
+		},
+	}
+
+	// Convert embedding data
+	resp.Data = make([]llm.Embedding, len(openAIResp.Data))
+	for i, data := range openAIResp.Data {
+		resp.Data[i] = llm.Embedding{
+			Object:    data.Object,
+			Embedding: data.Embedding,
+			Index:     data.Index,
+		}
+	}
+
+	return resp
+}
+
+// makeOpenAIRequest makes an HTTP request to the OpenAI API
+func (p *OpenAIProvider) makeOpenAIRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) {
+	// Prepare request body
+	jsonData, err := json.Marshal(payload)
+	if err != nil {
+		return nil, fmt.Errorf("failed to marshal request: %w", err)
+	}
+
+	// Create HTTP request
+	url := p.config.BaseURL + endpoint
+	req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
+	if err != nil {
+		return nil, fmt.Errorf("failed to create request: %w", err)
+	}
+
+	// Set headers
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
+
+	// Add organization header if present
+	if org, ok := p.config.ExtraConfig["organization"].(string); ok && org != "" {
+		req.Header.Set("OpenAI-Organization", org)
+	}
+
+	// Make the request
+	resp, err := p.client.Do(req)
+	if err != nil {
+		return nil, fmt.Errorf("HTTP request failed: %w", err)
+	}
+	defer resp.Body.Close()
+
+	// Read response body
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, fmt.Errorf("failed to read response body: %w", err)
+	}
+
+	// Check for errors
+	if resp.StatusCode != http.StatusOK {
+		var openAIErr OpenAIError
+		if err := json.Unmarshal(body, &openAIErr); err != nil {
+			return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+		}
+		return nil, fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)",
+			openAIErr.Error.Message, openAIErr.Error.Type, openAIErr.Error.Code)
+	}
+
+	return body, nil
+}
+
+// OpenAIFactory implements ProviderFactory for OpenAI
+type OpenAIFactory struct{}
+
+// CreateProvider creates a new OpenAI provider
+func (f *OpenAIFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
+	if config.Provider != llm.ProviderOpenAI {
+		return nil, fmt.Errorf("OpenAI factory cannot create provider: %s", config.Provider)
+	}
+
+	// Validate config
+	if err := llm.ValidateConfig(config); err != nil {
+		return nil, fmt.Errorf("invalid OpenAI config: %w", err)
+	}
+
+	// Merge with defaults
+	config = llm.MergeConfig(config)
+
+	return NewOpenAIProvider(config), nil
+}
+
+// SupportsProvider checks if this factory supports the given provider
+func (f *OpenAIFactory) SupportsProvider(provider llm.Provider) bool {
+	return provider == llm.ProviderOpenAI
+}
+
+// Register OpenAI provider with the default registry
+func init() {
+	llm.RegisterProvider(llm.ProviderOpenAI, &OpenAIFactory{})
+}
diff --git a/server/llm/openai/openai_example.go b/server/llm/openai/openai_example.go
new file mode 100644
index 0000000..6cd7dbf
--- /dev/null
+++ b/server/llm/openai/openai_example.go
@@ -0,0 +1,254 @@
+package openai
+
+import (
+	"context"
+	"fmt"
+	"log"
+	"time"
+
+	"github.com/iomodo/staff/llm"
+)
+
+// ExampleOpenAI demonstrates how to use the OpenAI implementation
+func ExampleOpenAI() {
+	// Create OpenAI configuration
+	config := llm.Config{
+		Provider: llm.ProviderOpenAI,
+		APIKey:   "your-openai-api-key-here", // Replace with your actual API key
+		BaseURL:  "https://api.openai.com/v1",
+		Timeout:  30 * time.Second,
+		ExtraConfig: map[string]interface{}{
+			"organization": "your-org-id", // Optional: your OpenAI organization ID
+		},
+	}
+
+	// Create the provider
+	provider, err := llm.CreateProvider(config)
+	if err != nil {
+		log.Fatalf("Failed to create OpenAI provider: %v", err)
+	}
+	defer provider.Close()
+
+	// Example 1: Basic chat completion
+	exampleOpenAIChatCompletion(provider)
+
+	// Example 2: Function calling
+	exampleOpenAIFunctionCalling(provider)
+
+	// Example 3: Embeddings
+	exampleOpenAIEmbeddings(provider)
+}
+
+// exampleOpenAIChatCompletion demonstrates basic chat completion with OpenAI
+func exampleOpenAIChatCompletion(provider llm.LLMProvider) {
+	fmt.Println("=== OpenAI Chat Completion Example ===")
+
+	req := llm.ChatCompletionRequest{
+		Model: "gpt-3.5-turbo",
+		Messages: []llm.Message{
+			{Role: llm.RoleSystem, Content: "You are a helpful assistant."},
+			{Role: llm.RoleUser, Content: "What is the capital of France?"},
+		},
+		MaxTokens:   &[]int{100}[0],
+		Temperature: &[]float64{0.7}[0],
+	}
+
+	ctx := context.Background()
+	resp, err := provider.ChatCompletion(ctx, req)
+	if err != nil {
+		log.Printf("Chat completion failed: %v", err)
+		return
+	}
+
+	fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
+	fmt.Printf("Model: %s\n", resp.Model)
+	fmt.Printf("Provider: %s\n", resp.Provider)
+	fmt.Printf("Usage: %+v\n", resp.Usage)
+}
+
+// exampleOpenAIFunctionCalling demonstrates function calling with OpenAI
+func exampleOpenAIFunctionCalling(provider llm.LLMProvider) {
+	fmt.Println("\n=== OpenAI Function Calling Example ===")
+
+	// Define a function that can be called
+	tools := []llm.Tool{
+		{
+			Type: "function",
+			Function: llm.Function{
+				Name:        "get_weather",
+				Description: "Get the current weather for a location",
+				Parameters: map[string]interface{}{
+					"type": "object",
+					"properties": map[string]interface{}{
+						"location": map[string]interface{}{
+							"type":        "string",
+							"description": "The city and state, e.g. San Francisco, CA",
+						},
+						"unit": map[string]interface{}{
+							"type": "string",
+							"enum": []string{"celsius", "fahrenheit"},
+						},
+					},
+					"required": []string{"location"},
+				},
+			},
+		},
+	}
+
+	req := llm.ChatCompletionRequest{
+		Model: "gpt-3.5-turbo",
+		Messages: []llm.Message{
+			{Role: llm.RoleUser, Content: "What's the weather like in Tokyo?"},
+		},
+		Tools:       tools,
+		MaxTokens:   &[]int{150}[0],
+		Temperature: &[]float64{0.1}[0],
+	}
+
+	ctx := context.Background()
+	resp, err := provider.ChatCompletion(ctx, req)
+	if err != nil {
+		log.Printf("Function calling failed: %v", err)
+		return
+	}
+
+	// Check if the model wants to call a function
+	if len(resp.Choices[0].Message.ToolCalls) > 0 {
+		fmt.Println("Model wants to call a function:")
+		for _, toolCall := range resp.Choices[0].Message.ToolCalls {
+			fmt.Printf("Function: %s\n", toolCall.Function.Name)
+			fmt.Printf("Arguments: %+v\n", toolCall.Function.Parameters)
+		}
+	} else {
+		fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
+	}
+}
+
+// exampleOpenAIEmbeddings demonstrates embedding generation with OpenAI
+func exampleOpenAIEmbeddings(provider llm.LLMProvider) {
+	fmt.Println("\n=== OpenAI Embeddings Example ===")
+
+	req := llm.EmbeddingRequest{
+		Input: "Hello, world! This is a test sentence for embeddings.",
+		Model: "text-embedding-ada-002",
+	}
+
+	ctx := context.Background()
+	resp, err := provider.CreateEmbeddings(ctx, req)
+	if err != nil {
+		log.Printf("Embeddings failed: %v", err)
+		return
+	}
+
+	fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding))
+	fmt.Printf("First 5 values: %v\n", resp.Data[0].Embedding[:5])
+	fmt.Printf("Model: %s\n", resp.Model)
+	fmt.Printf("Provider: %s\n", resp.Provider)
+	fmt.Printf("Usage: %+v\n", resp.Usage)
+}
+
+// ExampleOpenAIWithErrorHandling demonstrates proper error handling
+func ExampleOpenAIWithErrorHandling() {
+	fmt.Println("\n=== OpenAI Error Handling Example ===")
+
+	// Test with invalid API key
+	config := llm.Config{
+		Provider: llm.ProviderOpenAI,
+		APIKey:   "invalid-key",
+		BaseURL:  "https://api.openai.com/v1",
+		Timeout:  30 * time.Second,
+	}
+
+	provider, err := llm.CreateProvider(config)
+	if err != nil {
+		fmt.Printf("Failed to create provider: %v\n", err)
+		return
+	}
+	defer provider.Close()
+
+	req := llm.ChatCompletionRequest{
+		Model: "gpt-3.5-turbo",
+		Messages: []llm.Message{
+			{Role: llm.RoleUser, Content: "Hello"},
+		},
+		MaxTokens: &[]int{50}[0],
+	}
+
+	ctx := context.Background()
+	resp, err := provider.ChatCompletion(ctx, req)
+	if err != nil {
+		fmt.Printf("Expected error with invalid API key: %v\n", err)
+		return
+	}
+
+	fmt.Printf("Unexpected success: %s\n", resp.Choices[0].Message.Content)
+}
+
+// ExampleOpenAIWithCustomConfig demonstrates custom configuration
+func ExampleOpenAIWithCustomConfig() {
+	fmt.Println("\n=== OpenAI Custom Configuration Example ===")
+
+	// Start with minimal config
+	customConfig := llm.Config{
+		Provider: llm.ProviderOpenAI,
+		APIKey:   "your-api-key",
+	}
+
+	// Merge with defaults
+	config := llm.MergeConfig(customConfig)
+
+	fmt.Printf("Merged config: BaseURL=%s, Timeout=%v, MaxRetries=%d\n",
+		config.BaseURL, config.Timeout, config.MaxRetries)
+
+	// Add extra configuration
+	config.ExtraConfig = map[string]interface{}{
+		"organization": "your-org-id",
+		"project":      "your-project-id",
+	}
+
+	provider, err := llm.CreateProvider(config)
+	if err != nil {
+		fmt.Printf("Failed to create provider: %v\n", err)
+		return
+	}
+	defer provider.Close()
+
+	fmt.Println("Provider created successfully with custom configuration")
+}
+
+// ExampleOpenAIWithValidation demonstrates request validation
+func ExampleOpenAIWithValidation() {
+	fmt.Println("\n=== OpenAI Validation Example ===")
+
+	// Valid request
+	validReq := llm.ChatCompletionRequest{
+		Model: "gpt-3.5-turbo",
+		Messages: []llm.Message{
+			{Role: llm.RoleUser, Content: "Hello"},
+		},
+		Temperature: &[]float64{0.5}[0],
+	}
+
+	if err := llm.ValidateChatCompletionRequest(validReq); err != nil {
+		fmt.Printf("Valid request validation failed: %v\n", err)
+	} else {
+		fmt.Println("Valid request passed validation")
+	}
+
+	// Invalid request (no model)
+	invalidReq := llm.ChatCompletionRequest{
+		Messages: []llm.Message{
+			{Role: llm.RoleUser, Content: "Hello"},
+		},
+	}
+
+	if err := llm.ValidateChatCompletionRequest(invalidReq); err != nil {
+		fmt.Printf("Invalid request correctly caught: %v\n", err)
+	} else {
+		fmt.Println("Invalid request should have failed validation")
+	}
+
+	// Test token estimation
+	tokens := llm.EstimateTokens(validReq)
+	fmt.Printf("Estimated tokens for request: %d\n", tokens)
+}
diff --git a/server/llm/openai/openai_test.go b/server/llm/openai/openai_test.go
new file mode 100644
index 0000000..8896405
--- /dev/null
+++ b/server/llm/openai/openai_test.go
@@ -0,0 +1,428 @@
+package openai
+
+import (
+	"testing"
+	"time"
+
+	"github.com/iomodo/staff/llm"
+)
+
+func TestOpenAIProvider_Interface(t *testing.T) {
+	// Test that OpenAIProvider implements LLMProvider interface
+	var _ llm.LLMProvider = (*OpenAIProvider)(nil)
+}
+
+func TestOpenAIFactory_CreateProvider(t *testing.T) {
+	factory := &OpenAIFactory{}
+
+	// Test valid config
+	config := llm.Config{
+		Provider: llm.ProviderOpenAI,
+		APIKey:   "test-key",
+		BaseURL:  "https://api.openai.com/v1",
+		Timeout:  30 * time.Second,
+	}
+
+	provider, err := factory.CreateProvider(config)
+	if err != nil {
+		t.Fatalf("Failed to create provider: %v", err)
+	}
+
+	if provider == nil {
+		t.Fatal("Provider should not be nil")
+	}
+
+	// Test invalid provider
+	invalidConfig := llm.Config{
+		Provider: llm.ProviderClaude,
+		APIKey:   "test-key",
+	}
+
+	_, err = factory.CreateProvider(invalidConfig)
+	if err == nil {
+		t.Fatal("Should fail with invalid provider")
+	}
+
+	// Test missing API key
+	noKeyConfig := llm.Config{
+		Provider: llm.ProviderOpenAI,
+		BaseURL:  "https://api.openai.com/v1",
+	}
+
+	_, err = factory.CreateProvider(noKeyConfig)
+	if err == nil {
+		t.Fatal("Should fail with missing API key")
+	}
+}
+
+func TestOpenAIFactory_SupportsProvider(t *testing.T) {
+	factory := &OpenAIFactory{}
+
+	if !factory.SupportsProvider(llm.ProviderOpenAI) {
+		t.Fatal("Should support OpenAI provider")
+	}
+
+	if factory.SupportsProvider(llm.ProviderClaude) {
+		t.Fatal("Should not support Claude provider")
+	}
+}
+
+func TestOpenAIProvider_ConvertRequest(t *testing.T) {
+	provider := &OpenAIProvider{
+		config: llm.Config{
+			Provider: llm.ProviderOpenAI,
+			APIKey:   "test-key",
+			BaseURL:  "https://api.openai.com/v1",
+		},
+	}
+
+	// Test basic request conversion
+	req := llm.ChatCompletionRequest{
+		Model: "gpt-3.5-turbo",
+		Messages: []llm.Message{
+			{Role: llm.RoleUser, Content: "Hello"},
+		},
+		MaxTokens:   &[]int{100}[0],
+		Temperature: &[]float64{0.7}[0],
+	}
+
+	openAIReq := provider.convertToOpenAIRequest(req)
+
+	if openAIReq.Model != "gpt-3.5-turbo" {
+		t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", openAIReq.Model)
+	}
+
+	if len(openAIReq.Messages) != 1 {
+		t.Errorf("Expected 1 message, got %d", len(openAIReq.Messages))
+	}
+
+	if openAIReq.Messages[0].Role != "user" {
+		t.Errorf("Expected role 'user', got '%s'", openAIReq.Messages[0].Role)
+	}
+
+	if openAIReq.Messages[0].Content != "Hello" {
+		t.Errorf("Expected content 'Hello', got '%s'", openAIReq.Messages[0].Content)
+	}
+
+	if *openAIReq.MaxTokens != 100 {
+		t.Errorf("Expected max_tokens 100, got %d", *openAIReq.MaxTokens)
+	}
+
+	if *openAIReq.Temperature != 0.7 {
+		t.Errorf("Expected temperature 0.7, got %f", *openAIReq.Temperature)
+	}
+}
+
+func TestOpenAIProvider_ConvertResponse(t *testing.T) {
+	provider := &OpenAIProvider{
+		config: llm.Config{
+			Provider: llm.ProviderOpenAI,
+			APIKey:   "test-key",
+			BaseURL:  "https://api.openai.com/v1",
+		},
+	}
+
+	// Test basic response conversion
+	openAIResp := OpenAIResponse{
+		ID:      "test-id",
+		Object:  "chat.completion",
+		Created: 1234567890,
+		Model:   "gpt-3.5-turbo",
+		Choices: []OpenAIChoice{
+			{
+				Index: 0,
+				Message: OpenAIMessage{
+					Role:    "assistant",
+					Content: "Hello! How can I help you?",
+				},
+				FinishReason: "stop",
+			},
+		},
+		Usage: OpenAIUsage{
+			PromptTokens:     10,
+			CompletionTokens: 20,
+			TotalTokens:      30,
+		},
+	}
+
+	resp := provider.convertFromOpenAIResponse(openAIResp)
+
+	if resp.ID != "test-id" {
+		t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
+	}
+
+	if resp.Model != "gpt-3.5-turbo" {
+		t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", resp.Model)
+	}
+
+	if resp.Provider != llm.ProviderOpenAI {
+		t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
+	}
+
+	if len(resp.Choices) != 1 {
+		t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
+	}
+
+	if resp.Choices[0].Message.Role != llm.RoleAssistant {
+		t.Errorf("Expected role assistant, got %s", resp.Choices[0].Message.Role)
+	}
+
+	if resp.Choices[0].Message.Content != "Hello! How can I help you?" {
+		t.Errorf("Expected content 'Hello! How can I help you?', got '%s'", resp.Choices[0].Message.Content)
+	}
+
+	if resp.Usage.PromptTokens != 10 {
+		t.Errorf("Expected prompt tokens 10, got %d", resp.Usage.PromptTokens)
+	}
+
+	if resp.Usage.CompletionTokens != 20 {
+		t.Errorf("Expected completion tokens 20, got %d", resp.Usage.CompletionTokens)
+	}
+
+	if resp.Usage.TotalTokens != 30 {
+		t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
+	}
+}
+
+func TestOpenAIProvider_ConvertRequestWithTools(t *testing.T) {
+	provider := &OpenAIProvider{
+		config: llm.Config{
+			Provider: llm.ProviderOpenAI,
+			APIKey:   "test-key",
+			BaseURL:  "https://api.openai.com/v1",
+		},
+	}
+
+	// Test request with tools
+	tools := []llm.Tool{
+		{
+			Type: "function",
+			Function: llm.Function{
+				Name:        "get_weather",
+				Description: "Get weather information",
+				Parameters: map[string]interface{}{
+					"type": "object",
+					"properties": map[string]interface{}{
+						"location": map[string]interface{}{
+							"type": "string",
+						},
+					},
+				},
+			},
+		},
+	}
+
+	req := llm.ChatCompletionRequest{
+		Model: "gpt-3.5-turbo",
+		Messages: []llm.Message{
+			{Role: llm.RoleUser, Content: "What's the weather like?"},
+		},
+		Tools: tools,
+	}
+
+	openAIReq := provider.convertToOpenAIRequest(req)
+
+	if len(openAIReq.Tools) != 1 {
+		t.Errorf("Expected 1 tool, got %d", len(openAIReq.Tools))
+	}
+
+	if openAIReq.Tools[0].Type != "function" {
+		t.Errorf("Expected tool type 'function', got '%s'", openAIReq.Tools[0].Type)
+	}
+
+	if openAIReq.Tools[0].Function.Name != "get_weather" {
+		t.Errorf("Expected function name 'get_weather', got '%s'", openAIReq.Tools[0].Function.Name)
+	}
+}
+
+func TestOpenAIProvider_ConvertResponseWithToolCalls(t *testing.T) {
+	provider := &OpenAIProvider{
+		config: llm.Config{
+			Provider: llm.ProviderOpenAI,
+			APIKey:   "test-key",
+			BaseURL:  "https://api.openai.com/v1",
+		},
+	}
+
+	// Test response with tool calls
+	openAIResp := OpenAIResponse{
+		ID:     "test-id",
+		Object: "chat.completion",
+		Model:  "gpt-3.5-turbo",
+		Choices: []OpenAIChoice{
+			{
+				Index: 0,
+				Message: OpenAIMessage{
+					Role: "assistant",
+					ToolCalls: []OpenAIToolCall{
+						{
+							ID:   "call_123",
+							Type: "function",
+							Function: OpenAIFunction{
+								Name: "get_weather",
+								Parameters: map[string]interface{}{
+									"location": "Tokyo",
+								},
+							},
+						},
+					},
+				},
+				FinishReason: "tool_calls",
+			},
+		},
+		Usage: OpenAIUsage{
+			PromptTokens:     10,
+			CompletionTokens: 20,
+			TotalTokens:      30,
+		},
+	}
+
+	resp := provider.convertFromOpenAIResponse(openAIResp)
+
+	if len(resp.Choices[0].Message.ToolCalls) != 1 {
+		t.Errorf("Expected 1 tool call, got %d", len(resp.Choices[0].Message.ToolCalls))
+	}
+
+	if resp.Choices[0].Message.ToolCalls[0].ID != "call_123" {
+		t.Errorf("Expected tool call ID 'call_123', got '%s'", resp.Choices[0].Message.ToolCalls[0].ID)
+	}
+
+	if resp.Choices[0].Message.ToolCalls[0].Function.Name != "get_weather" {
+		t.Errorf("Expected function name 'get_weather', got '%s'", resp.Choices[0].Message.ToolCalls[0].Function.Name)
+	}
+
+	if resp.Choices[0].FinishReason != "tool_calls" {
+		t.Errorf("Expected finish reason 'tool_calls', got '%s'", resp.Choices[0].FinishReason)
+	}
+}
+
+func TestOpenAIProvider_ConvertEmbeddingRequest(t *testing.T) {
+	req := llm.EmbeddingRequest{
+		Input: "Hello, world!",
+		Model: "text-embedding-ada-002",
+		User:  "test-user",
+	}
+
+	// The conversion is done inline in CreateEmbeddings, so we'll test the structure
+	if req.Input != "Hello, world!" {
+		t.Errorf("Expected input 'Hello, world!', got '%v'", req.Input)
+	}
+
+	if req.Model != "text-embedding-ada-002" {
+		t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", req.Model)
+	}
+
+	if req.User != "test-user" {
+		t.Errorf("Expected user 'test-user', got '%s'", req.User)
+	}
+}
+
+func TestOpenAIProvider_ConvertEmbeddingResponse(t *testing.T) {
+	provider := &OpenAIProvider{
+		config: llm.Config{
+			Provider: llm.ProviderOpenAI,
+			APIKey:   "test-key",
+			BaseURL:  "https://api.openai.com/v1",
+		},
+	}
+
+	// Test embedding response conversion
+	openAIResp := OpenAIEmbeddingResponse{
+		Object: "list",
+		Model:  "text-embedding-ada-002",
+		Data: []OpenAIEmbeddingData{
+			{
+				Object:    "embedding",
+				Embedding: []float64{0.1, 0.2, 0.3},
+				Index:     0,
+			},
+		},
+		Usage: OpenAIUsage{
+			PromptTokens:     5,
+			CompletionTokens: 0,
+			TotalTokens:      5,
+		},
+	}
+
+	resp := provider.convertFromOpenAIEmbeddingResponse(openAIResp)
+
+	if resp.Object != "list" {
+		t.Errorf("Expected object 'list', got '%s'", resp.Object)
+	}
+
+	if resp.Model != "text-embedding-ada-002" {
+		t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", resp.Model)
+	}
+
+	if resp.Provider != llm.ProviderOpenAI {
+		t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
+	}
+
+	if len(resp.Data) != 1 {
+		t.Errorf("Expected 1 embedding, got %d", len(resp.Data))
+	}
+
+	if len(resp.Data[0].Embedding) != 3 {
+		t.Errorf("Expected embedding dimension 3, got %d", len(resp.Data[0].Embedding))
+	}
+
+	if resp.Data[0].Embedding[0] != 0.1 {
+		t.Errorf("Expected first embedding value 0.1, got %f", resp.Data[0].Embedding[0])
+	}
+}
+
+func TestOpenAIProvider_Close(t *testing.T) {
+	provider := &OpenAIProvider{
+		config: llm.Config{
+			Provider: llm.ProviderOpenAI,
+			APIKey:   "test-key",
+			BaseURL:  "https://api.openai.com/v1",
+		},
+	}
+
+	// Test that Close doesn't return an error
+	err := provider.Close()
+	if err != nil {
+		t.Errorf("Close should not return an error: %v", err)
+	}
+}
+
+func TestOpenAIProvider_Integration(t *testing.T) {
+	// This test would require a real API key and would make actual API calls
+	// It's commented out to avoid making real API calls during testing
+	/*
+		config := Config{
+			Provider: ProviderOpenAI,
+			APIKey:   "your-real-api-key",
+			BaseURL:  "https://api.openai.com/v1",
+			Timeout:  30 * time.Second,
+		}
+
+		provider, err := CreateProvider(config)
+		if err != nil {
+			t.Fatalf("Failed to create provider: %v", err)
+		}
+		defer provider.Close()
+
+		req := ChatCompletionRequest{
+			Model: "gpt-3.5-turbo",
+			Messages: []Message{
+				{Role: RoleUser, Content: "Say hello!"},
+			},
+			MaxTokens: &[]int{50}[0],
+		}
+
+		resp, err := provider.ChatCompletion(context.Background(), req)
+		if err != nil {
+			t.Fatalf("Chat completion failed: %v", err)
+		}
+
+		if len(resp.Choices) == 0 {
+			t.Fatal("Expected at least one choice")
+		}
+
+		if resp.Choices[0].Message.Content == "" {
+			t.Fatal("Expected non-empty response content")
+		}
+	*/
+}