Add OpenAI implementation

Change-Id: Iea3191cf538959002e6ae095857a7aa6126b3e2f
diff --git a/server/llm/openai/openai.go b/server/llm/openai/openai.go
new file mode 100644
index 0000000..c513c53
--- /dev/null
+++ b/server/llm/openai/openai.go
@@ -0,0 +1,468 @@
+package openai
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+
+	"github.com/iomodo/staff/llm"
+)
+
+// OpenAIProvider implements the LLMProvider interface for OpenAI
+type OpenAIProvider struct {
+	config llm.Config
+	client *http.Client
+}
+
+// OpenAIRequest represents the OpenAI API request format
+type OpenAIRequest struct {
+	Model            string                `json:"model"`
+	Messages         []OpenAIMessage       `json:"messages"`
+	MaxTokens        *int                  `json:"max_tokens,omitempty"`
+	Temperature      *float64              `json:"temperature,omitempty"`
+	TopP             *float64              `json:"top_p,omitempty"`
+	N                *int                  `json:"n,omitempty"`
+	Stream           *bool                 `json:"stream,omitempty"`
+	Stop             []string              `json:"stop,omitempty"`
+	PresencePenalty  *float64              `json:"presence_penalty,omitempty"`
+	FrequencyPenalty *float64              `json:"frequency_penalty,omitempty"`
+	LogitBias        map[string]int        `json:"logit_bias,omitempty"`
+	User             string                `json:"user,omitempty"`
+	Tools            []OpenAITool          `json:"tools,omitempty"`
+	ToolChoice       interface{}           `json:"tool_choice,omitempty"`
+	ResponseFormat   *OpenAIResponseFormat `json:"response_format,omitempty"`
+	Seed             *int64                `json:"seed,omitempty"`
+}
+
+// OpenAIMessage represents a message in OpenAI format
+type OpenAIMessage struct {
+	Role       string           `json:"role"`
+	Content    string           `json:"content"`
+	ToolCalls  []OpenAIToolCall `json:"tool_calls,omitempty"`
+	ToolCallID string           `json:"tool_call_id,omitempty"`
+	Name       string           `json:"name,omitempty"`
+}
+
+// OpenAIToolCall represents a tool call in OpenAI format
+type OpenAIToolCall struct {
+	ID       string         `json:"id"`
+	Type     string         `json:"type"`
+	Function OpenAIFunction `json:"function"`
+}
+
+// OpenAIFunction represents a function in OpenAI format
+type OpenAIFunction struct {
+	Name        string                 `json:"name"`
+	Description string                 `json:"description,omitempty"`
+	Parameters  map[string]interface{} `json:"parameters,omitempty"`
+}
+
+// OpenAITool represents a tool in OpenAI format
+type OpenAITool struct {
+	Type     string         `json:"type"`
+	Function OpenAIFunction `json:"function"`
+}
+
+// OpenAIResponseFormat represents response format in OpenAI format
+type OpenAIResponseFormat struct {
+	Type string `json:"type"`
+}
+
+// OpenAIResponse represents the OpenAI API response format
+type OpenAIResponse struct {
+	ID                string         `json:"id"`
+	Object            string         `json:"object"`
+	Created           int64          `json:"created"`
+	Model             string         `json:"model"`
+	SystemFingerprint string         `json:"system_fingerprint,omitempty"`
+	Choices           []OpenAIChoice `json:"choices"`
+	Usage             OpenAIUsage    `json:"usage"`
+}
+
+// OpenAIChoice represents a choice in OpenAI response
+type OpenAIChoice struct {
+	Index        int             `json:"index"`
+	Message      OpenAIMessage   `json:"message"`
+	Logprobs     *OpenAILogprobs `json:"logprobs,omitempty"`
+	FinishReason string          `json:"finish_reason"`
+	Delta        *OpenAIMessage  `json:"delta,omitempty"`
+}
+
+// OpenAILogprobs represents log probabilities in OpenAI format
+type OpenAILogprobs struct {
+	Content []OpenAILogprobContent `json:"content,omitempty"`
+}
+
+// OpenAILogprobContent represents log probability content in OpenAI format
+type OpenAILogprobContent struct {
+	Token       string             `json:"token"`
+	Logprob     float64            `json:"logprob"`
+	Bytes       []int              `json:"bytes,omitempty"`
+	TopLogprobs []OpenAITopLogprob `json:"top_logprobs,omitempty"`
+}
+
+// OpenAITopLogprob represents a top log probability in OpenAI format
+type OpenAITopLogprob struct {
+	Token   string  `json:"token"`
+	Logprob float64 `json:"logprob"`
+	Bytes   []int   `json:"bytes,omitempty"`
+}
+
+// OpenAIUsage represents usage information in OpenAI format
+type OpenAIUsage struct {
+	PromptTokens     int `json:"prompt_tokens"`
+	CompletionTokens int `json:"completion_tokens"`
+	TotalTokens      int `json:"total_tokens"`
+}
+
+// OpenAIEmbeddingRequest represents OpenAI embedding request
+type OpenAIEmbeddingRequest struct {
+	Input          interface{} `json:"input"`
+	Model          string      `json:"model"`
+	EncodingFormat string      `json:"encoding_format,omitempty"`
+	Dimensions     *int        `json:"dimensions,omitempty"`
+	User           string      `json:"user,omitempty"`
+}
+
+// OpenAIEmbeddingResponse represents OpenAI embedding response
+type OpenAIEmbeddingResponse struct {
+	Object string                `json:"object"`
+	Data   []OpenAIEmbeddingData `json:"data"`
+	Usage  OpenAIUsage           `json:"usage"`
+	Model  string                `json:"model"`
+}
+
+// OpenAIEmbeddingData represents embedding data in OpenAI format
+type OpenAIEmbeddingData struct {
+	Object    string    `json:"object"`
+	Embedding []float64 `json:"embedding"`
+	Index     int       `json:"index"`
+}
+
+// OpenAIError represents an error from OpenAI API
+type OpenAIError struct {
+	Error struct {
+		Message string `json:"message"`
+		Type    string `json:"type"`
+		Code    string `json:"code,omitempty"`
+		Param   string `json:"param,omitempty"`
+	} `json:"error"`
+}
+
+// NewOpenAIProvider creates a new OpenAI provider
+func NewOpenAIProvider(config llm.Config) *OpenAIProvider {
+	client := &http.Client{
+		Timeout: config.Timeout,
+	}
+
+	return &OpenAIProvider{
+		config: config,
+		client: client,
+	}
+}
+
+// ChatCompletion implements the LLMProvider interface for OpenAI
+func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
+	// Convert our request to OpenAI format
+	openAIReq := p.convertToOpenAIRequest(req)
+
+	// Make the API call
+	resp, err := p.makeOpenAIRequest(ctx, "/chat/completions", openAIReq)
+	if err != nil {
+		return nil, fmt.Errorf("OpenAI API request failed: %w", err)
+	}
+
+	// Parse the response
+	var openAIResp OpenAIResponse
+	if err := json.Unmarshal(resp, &openAIResp); err != nil {
+		return nil, fmt.Errorf("failed to parse OpenAI response: %w", err)
+	}
+
+	// Convert back to our format
+	return p.convertFromOpenAIResponse(openAIResp), nil
+}
+
+// CreateEmbeddings implements the LLMProvider interface for OpenAI
+func (p *OpenAIProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
+	// Convert our request to OpenAI format
+	openAIReq := OpenAIEmbeddingRequest{
+		Input:          req.Input,
+		Model:          req.Model,
+		EncodingFormat: req.EncodingFormat,
+		Dimensions:     req.Dimensions,
+		User:           req.User,
+	}
+
+	// Make the API call
+	resp, err := p.makeOpenAIRequest(ctx, "/embeddings", openAIReq)
+	if err != nil {
+		return nil, fmt.Errorf("OpenAI embeddings API request failed: %w", err)
+	}
+
+	// Parse the response
+	var openAIResp OpenAIEmbeddingResponse
+	if err := json.Unmarshal(resp, &openAIResp); err != nil {
+		return nil, fmt.Errorf("failed to parse OpenAI embeddings response: %w", err)
+	}
+
+	// Convert back to our format
+	return p.convertFromOpenAIEmbeddingResponse(openAIResp), nil
+}
+
+// Close implements the LLMProvider interface
+func (p *OpenAIProvider) Close() error {
+	// Nothing to clean up for HTTP client
+	return nil
+}
+
+// convertToOpenAIRequest converts our request format to OpenAI format
+func (p *OpenAIProvider) convertToOpenAIRequest(req llm.ChatCompletionRequest) OpenAIRequest {
+	openAIReq := OpenAIRequest{
+		Model:            req.Model,
+		MaxTokens:        req.MaxTokens,
+		Temperature:      req.Temperature,
+		TopP:             req.TopP,
+		N:                req.N,
+		Stream:           req.Stream,
+		Stop:             req.Stop,
+		PresencePenalty:  req.PresencePenalty,
+		FrequencyPenalty: req.FrequencyPenalty,
+		LogitBias:        req.LogitBias,
+		User:             req.User,
+		ToolChoice:       req.ToolChoice,
+		Seed:             req.Seed,
+	}
+
+	// Convert messages
+	openAIReq.Messages = make([]OpenAIMessage, len(req.Messages))
+	for i, msg := range req.Messages {
+		openAIReq.Messages[i] = OpenAIMessage{
+			Role:       string(msg.Role),
+			Content:    msg.Content,
+			ToolCallID: msg.ToolCallID,
+			Name:       msg.Name,
+		}
+
+		// Convert tool calls if present
+		if len(msg.ToolCalls) > 0 {
+			openAIReq.Messages[i].ToolCalls = make([]OpenAIToolCall, len(msg.ToolCalls))
+			for j, toolCall := range msg.ToolCalls {
+				openAIReq.Messages[i].ToolCalls[j] = OpenAIToolCall{
+					ID:   toolCall.ID,
+					Type: toolCall.Type,
+					Function: OpenAIFunction{
+						Name:        toolCall.Function.Name,
+						Description: toolCall.Function.Description,
+						Parameters:  toolCall.Function.Parameters,
+					},
+				}
+			}
+		}
+	}
+
+	// Convert tools if present
+	if len(req.Tools) > 0 {
+		openAIReq.Tools = make([]OpenAITool, len(req.Tools))
+		for i, tool := range req.Tools {
+			openAIReq.Tools[i] = OpenAITool{
+				Type: tool.Type,
+				Function: OpenAIFunction{
+					Name:        tool.Function.Name,
+					Description: tool.Function.Description,
+					Parameters:  tool.Function.Parameters,
+				},
+			}
+		}
+	}
+
+	// Convert response format if present
+	if req.ResponseFormat != nil {
+		openAIReq.ResponseFormat = &OpenAIResponseFormat{
+			Type: req.ResponseFormat.Type,
+		}
+	}
+
+	return openAIReq
+}
+
+// convertFromOpenAIResponse converts OpenAI response to our format
+func (p *OpenAIProvider) convertFromOpenAIResponse(openAIResp OpenAIResponse) *llm.ChatCompletionResponse {
+	resp := &llm.ChatCompletionResponse{
+		ID:                openAIResp.ID,
+		Object:            openAIResp.Object,
+		Created:           openAIResp.Created,
+		Model:             openAIResp.Model,
+		SystemFingerprint: openAIResp.SystemFingerprint,
+		Provider:          llm.ProviderOpenAI,
+		Usage: llm.Usage{
+			PromptTokens:     openAIResp.Usage.PromptTokens,
+			CompletionTokens: openAIResp.Usage.CompletionTokens,
+			TotalTokens:      openAIResp.Usage.TotalTokens,
+		},
+	}
+
+	// Convert choices
+	resp.Choices = make([]llm.ChatCompletionChoice, len(openAIResp.Choices))
+	for i, choice := range openAIResp.Choices {
+		resp.Choices[i] = llm.ChatCompletionChoice{
+			Index:        choice.Index,
+			FinishReason: choice.FinishReason,
+			Message: llm.Message{
+				Role:    llm.Role(choice.Message.Role),
+				Content: choice.Message.Content,
+				Name:    choice.Message.Name,
+			},
+		}
+
+		// Convert tool calls if present
+		if len(choice.Message.ToolCalls) > 0 {
+			resp.Choices[i].Message.ToolCalls = make([]llm.ToolCall, len(choice.Message.ToolCalls))
+			for j, toolCall := range choice.Message.ToolCalls {
+				resp.Choices[i].Message.ToolCalls[j] = llm.ToolCall{
+					ID:   toolCall.ID,
+					Type: toolCall.Type,
+					Function: llm.Function{
+						Name:        toolCall.Function.Name,
+						Description: toolCall.Function.Description,
+						Parameters:  toolCall.Function.Parameters,
+					},
+				}
+			}
+		}
+
+		// Convert logprobs if present
+		if choice.Logprobs != nil {
+			resp.Choices[i].Logprobs = &llm.Logprobs{
+				Content: make([]llm.LogprobContent, len(choice.Logprobs.Content)),
+			}
+			for j, content := range choice.Logprobs.Content {
+				resp.Choices[i].Logprobs.Content[j] = llm.LogprobContent{
+					Token:   content.Token,
+					Logprob: content.Logprob,
+					Bytes:   content.Bytes,
+				}
+				if len(content.TopLogprobs) > 0 {
+					resp.Choices[i].Logprobs.Content[j].TopLogprobs = make([]llm.TopLogprob, len(content.TopLogprobs))
+					for k, topLogprob := range content.TopLogprobs {
+						resp.Choices[i].Logprobs.Content[j].TopLogprobs[k] = llm.TopLogprob{
+							Token:   topLogprob.Token,
+							Logprob: topLogprob.Logprob,
+							Bytes:   topLogprob.Bytes,
+						}
+					}
+				}
+			}
+		}
+	}
+
+	return resp
+}
+
+// convertFromOpenAIEmbeddingResponse converts OpenAI embedding response to our format
+func (p *OpenAIProvider) convertFromOpenAIEmbeddingResponse(openAIResp OpenAIEmbeddingResponse) *llm.EmbeddingResponse {
+	resp := &llm.EmbeddingResponse{
+		Object:   openAIResp.Object,
+		Model:    openAIResp.Model,
+		Provider: llm.ProviderOpenAI,
+		Usage: llm.Usage{
+			PromptTokens:     openAIResp.Usage.PromptTokens,
+			CompletionTokens: openAIResp.Usage.CompletionTokens,
+			TotalTokens:      openAIResp.Usage.TotalTokens,
+		},
+	}
+
+	// Convert embedding data
+	resp.Data = make([]llm.Embedding, len(openAIResp.Data))
+	for i, data := range openAIResp.Data {
+		resp.Data[i] = llm.Embedding{
+			Object:    data.Object,
+			Embedding: data.Embedding,
+			Index:     data.Index,
+		}
+	}
+
+	return resp
+}
+
+// makeOpenAIRequest makes an HTTP request to the OpenAI API
+func (p *OpenAIProvider) makeOpenAIRequest(ctx context.Context, endpoint string, payload interface{}) ([]byte, error) {
+	// Prepare request body
+	jsonData, err := json.Marshal(payload)
+	if err != nil {
+		return nil, fmt.Errorf("failed to marshal request: %w", err)
+	}
+
+	// Create HTTP request
+	url := p.config.BaseURL + endpoint
+	req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
+	if err != nil {
+		return nil, fmt.Errorf("failed to create request: %w", err)
+	}
+
+	// Set headers
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
+
+	// Add organization header if present
+	if org, ok := p.config.ExtraConfig["organization"].(string); ok && org != "" {
+		req.Header.Set("OpenAI-Organization", org)
+	}
+
+	// Make the request
+	resp, err := p.client.Do(req)
+	if err != nil {
+		return nil, fmt.Errorf("HTTP request failed: %w", err)
+	}
+	defer resp.Body.Close()
+
+	// Read response body
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, fmt.Errorf("failed to read response body: %w", err)
+	}
+
+	// Check for errors
+	if resp.StatusCode != http.StatusOK {
+		var openAIErr OpenAIError
+		if err := json.Unmarshal(body, &openAIErr); err != nil {
+			return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+		}
+		return nil, fmt.Errorf("OpenAI API error: %s (type: %s, code: %s)",
+			openAIErr.Error.Message, openAIErr.Error.Type, openAIErr.Error.Code)
+	}
+
+	return body, nil
+}
+
+// OpenAIFactory implements ProviderFactory for OpenAI
+type OpenAIFactory struct{}
+
+// CreateProvider creates a new OpenAI provider
+func (f *OpenAIFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
+	if config.Provider != llm.ProviderOpenAI {
+		return nil, fmt.Errorf("OpenAI factory cannot create provider: %s", config.Provider)
+	}
+
+	// Validate config
+	if err := llm.ValidateConfig(config); err != nil {
+		return nil, fmt.Errorf("invalid OpenAI config: %w", err)
+	}
+
+	// Merge with defaults
+	config = llm.MergeConfig(config)
+
+	return NewOpenAIProvider(config), nil
+}
+
+// SupportsProvider checks if this factory supports the given provider
+func (f *OpenAIFactory) SupportsProvider(provider llm.Provider) bool {
+	return provider == llm.ProviderOpenAI
+}
+
+// Register OpenAI provider with the default registry
+func init() {
+	llm.RegisterProvider(llm.ProviderOpenAI, &OpenAIFactory{})
+}