Add OpenAI implementation
Change-Id: Iea3191cf538959002e6ae095857a7aa6126b3e2f
diff --git a/server/llm/openai/openai_test.go b/server/llm/openai/openai_test.go
new file mode 100644
index 0000000..8896405
--- /dev/null
+++ b/server/llm/openai/openai_test.go
@@ -0,0 +1,428 @@
+package openai
+
+import (
+ "testing"
+ "time"
+
+ "github.com/iomodo/staff/llm"
+)
+
+func TestOpenAIProvider_Interface(t *testing.T) {
+ // Test that OpenAIProvider implements LLMProvider interface
+ var _ llm.LLMProvider = (*OpenAIProvider)(nil)
+}
+
+func TestOpenAIFactory_CreateProvider(t *testing.T) {
+ factory := &OpenAIFactory{}
+
+ // Test valid config
+ config := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ }
+
+ provider, err := factory.CreateProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+
+ if provider == nil {
+ t.Fatal("Provider should not be nil")
+ }
+
+ // Test invalid provider
+ invalidConfig := llm.Config{
+ Provider: llm.ProviderClaude,
+ APIKey: "test-key",
+ }
+
+ _, err = factory.CreateProvider(invalidConfig)
+ if err == nil {
+ t.Fatal("Should fail with invalid provider")
+ }
+
+ // Test missing API key
+ noKeyConfig := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ BaseURL: "https://api.openai.com/v1",
+ }
+
+ _, err = factory.CreateProvider(noKeyConfig)
+ if err == nil {
+ t.Fatal("Should fail with missing API key")
+ }
+}
+
+func TestOpenAIFactory_SupportsProvider(t *testing.T) {
+ factory := &OpenAIFactory{}
+
+ if !factory.SupportsProvider(llm.ProviderOpenAI) {
+ t.Fatal("Should support OpenAI provider")
+ }
+
+ if factory.SupportsProvider(llm.ProviderClaude) {
+ t.Fatal("Should not support Claude provider")
+ }
+}
+
+func TestOpenAIProvider_ConvertRequest(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test basic request conversion
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello"},
+ },
+ MaxTokens: &[]int{100}[0],
+ Temperature: &[]float64{0.7}[0],
+ }
+
+ openAIReq := provider.convertToOpenAIRequest(req)
+
+ if openAIReq.Model != "gpt-3.5-turbo" {
+ t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", openAIReq.Model)
+ }
+
+ if len(openAIReq.Messages) != 1 {
+ t.Errorf("Expected 1 message, got %d", len(openAIReq.Messages))
+ }
+
+ if openAIReq.Messages[0].Role != "user" {
+ t.Errorf("Expected role 'user', got '%s'", openAIReq.Messages[0].Role)
+ }
+
+ if openAIReq.Messages[0].Content != "Hello" {
+ t.Errorf("Expected content 'Hello', got '%s'", openAIReq.Messages[0].Content)
+ }
+
+ if *openAIReq.MaxTokens != 100 {
+ t.Errorf("Expected max_tokens 100, got %d", *openAIReq.MaxTokens)
+ }
+
+ if *openAIReq.Temperature != 0.7 {
+ t.Errorf("Expected temperature 0.7, got %f", *openAIReq.Temperature)
+ }
+}
+
+func TestOpenAIProvider_ConvertResponse(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test basic response conversion
+ openAIResp := OpenAIResponse{
+ ID: "test-id",
+ Object: "chat.completion",
+ Created: 1234567890,
+ Model: "gpt-3.5-turbo",
+ Choices: []OpenAIChoice{
+ {
+ Index: 0,
+ Message: OpenAIMessage{
+ Role: "assistant",
+ Content: "Hello! How can I help you?",
+ },
+ FinishReason: "stop",
+ },
+ },
+ Usage: OpenAIUsage{
+ PromptTokens: 10,
+ CompletionTokens: 20,
+ TotalTokens: 30,
+ },
+ }
+
+ resp := provider.convertFromOpenAIResponse(openAIResp)
+
+ if resp.ID != "test-id" {
+ t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
+ }
+
+ if resp.Model != "gpt-3.5-turbo" {
+ t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", resp.Model)
+ }
+
+ if resp.Provider != llm.ProviderOpenAI {
+ t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
+ }
+
+ if len(resp.Choices) != 1 {
+ t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
+ }
+
+ if resp.Choices[0].Message.Role != llm.RoleAssistant {
+ t.Errorf("Expected role assistant, got %s", resp.Choices[0].Message.Role)
+ }
+
+ if resp.Choices[0].Message.Content != "Hello! How can I help you?" {
+ t.Errorf("Expected content 'Hello! How can I help you?', got '%s'", resp.Choices[0].Message.Content)
+ }
+
+ if resp.Usage.PromptTokens != 10 {
+ t.Errorf("Expected prompt tokens 10, got %d", resp.Usage.PromptTokens)
+ }
+
+ if resp.Usage.CompletionTokens != 20 {
+ t.Errorf("Expected completion tokens 20, got %d", resp.Usage.CompletionTokens)
+ }
+
+ if resp.Usage.TotalTokens != 30 {
+ t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
+ }
+}
+
+func TestOpenAIProvider_ConvertRequestWithTools(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test request with tools
+ tools := []llm.Tool{
+ {
+ Type: "function",
+ Function: llm.Function{
+ Name: "get_weather",
+ Description: "Get weather information",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ },
+ },
+ },
+ },
+ },
+ }
+
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "What's the weather like?"},
+ },
+ Tools: tools,
+ }
+
+ openAIReq := provider.convertToOpenAIRequest(req)
+
+ if len(openAIReq.Tools) != 1 {
+ t.Errorf("Expected 1 tool, got %d", len(openAIReq.Tools))
+ }
+
+ if openAIReq.Tools[0].Type != "function" {
+ t.Errorf("Expected tool type 'function', got '%s'", openAIReq.Tools[0].Type)
+ }
+
+ if openAIReq.Tools[0].Function.Name != "get_weather" {
+ t.Errorf("Expected function name 'get_weather', got '%s'", openAIReq.Tools[0].Function.Name)
+ }
+}
+
+func TestOpenAIProvider_ConvertResponseWithToolCalls(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test response with tool calls
+ openAIResp := OpenAIResponse{
+ ID: "test-id",
+ Object: "chat.completion",
+ Model: "gpt-3.5-turbo",
+ Choices: []OpenAIChoice{
+ {
+ Index: 0,
+ Message: OpenAIMessage{
+ Role: "assistant",
+ ToolCalls: []OpenAIToolCall{
+ {
+ ID: "call_123",
+ Type: "function",
+ Function: OpenAIFunction{
+ Name: "get_weather",
+ Parameters: map[string]interface{}{
+ "location": "Tokyo",
+ },
+ },
+ },
+ },
+ },
+ FinishReason: "tool_calls",
+ },
+ },
+ Usage: OpenAIUsage{
+ PromptTokens: 10,
+ CompletionTokens: 20,
+ TotalTokens: 30,
+ },
+ }
+
+ resp := provider.convertFromOpenAIResponse(openAIResp)
+
+ if len(resp.Choices[0].Message.ToolCalls) != 1 {
+ t.Errorf("Expected 1 tool call, got %d", len(resp.Choices[0].Message.ToolCalls))
+ }
+
+ if resp.Choices[0].Message.ToolCalls[0].ID != "call_123" {
+ t.Errorf("Expected tool call ID 'call_123', got '%s'", resp.Choices[0].Message.ToolCalls[0].ID)
+ }
+
+ if resp.Choices[0].Message.ToolCalls[0].Function.Name != "get_weather" {
+ t.Errorf("Expected function name 'get_weather', got '%s'", resp.Choices[0].Message.ToolCalls[0].Function.Name)
+ }
+
+ if resp.Choices[0].FinishReason != "tool_calls" {
+ t.Errorf("Expected finish reason 'tool_calls', got '%s'", resp.Choices[0].FinishReason)
+ }
+}
+
+func TestOpenAIProvider_ConvertEmbeddingRequest(t *testing.T) {
+ req := llm.EmbeddingRequest{
+ Input: "Hello, world!",
+ Model: "text-embedding-ada-002",
+ User: "test-user",
+ }
+
+ // The conversion is done inline in CreateEmbeddings, so we'll test the structure
+ if req.Input != "Hello, world!" {
+ t.Errorf("Expected input 'Hello, world!', got '%v'", req.Input)
+ }
+
+ if req.Model != "text-embedding-ada-002" {
+ t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", req.Model)
+ }
+
+ if req.User != "test-user" {
+ t.Errorf("Expected user 'test-user', got '%s'", req.User)
+ }
+}
+
+func TestOpenAIProvider_ConvertEmbeddingResponse(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test embedding response conversion
+ openAIResp := OpenAIEmbeddingResponse{
+ Object: "list",
+ Model: "text-embedding-ada-002",
+ Data: []OpenAIEmbeddingData{
+ {
+ Object: "embedding",
+ Embedding: []float64{0.1, 0.2, 0.3},
+ Index: 0,
+ },
+ },
+ Usage: OpenAIUsage{
+ PromptTokens: 5,
+ CompletionTokens: 0,
+ TotalTokens: 5,
+ },
+ }
+
+ resp := provider.convertFromOpenAIEmbeddingResponse(openAIResp)
+
+ if resp.Object != "list" {
+ t.Errorf("Expected object 'list', got '%s'", resp.Object)
+ }
+
+ if resp.Model != "text-embedding-ada-002" {
+ t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", resp.Model)
+ }
+
+ if resp.Provider != llm.ProviderOpenAI {
+ t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
+ }
+
+ if len(resp.Data) != 1 {
+ t.Errorf("Expected 1 embedding, got %d", len(resp.Data))
+ }
+
+ if len(resp.Data[0].Embedding) != 3 {
+ t.Errorf("Expected embedding dimension 3, got %d", len(resp.Data[0].Embedding))
+ }
+
+ if resp.Data[0].Embedding[0] != 0.1 {
+ t.Errorf("Expected first embedding value 0.1, got %f", resp.Data[0].Embedding[0])
+ }
+}
+
+func TestOpenAIProvider_Close(t *testing.T) {
+ provider := &OpenAIProvider{
+ config: llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "test-key",
+ BaseURL: "https://api.openai.com/v1",
+ },
+ }
+
+ // Test that Close doesn't return an error
+ err := provider.Close()
+ if err != nil {
+ t.Errorf("Close should not return an error: %v", err)
+ }
+}
+
+func TestOpenAIProvider_Integration(t *testing.T) {
+ // This test would require a real API key and would make actual API calls
+ // It's commented out to avoid making real API calls during testing
+ /*
+ config := Config{
+ Provider: ProviderOpenAI,
+ APIKey: "your-real-api-key",
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ }
+
+ provider, err := CreateProvider(config)
+ if err != nil {
+ t.Fatalf("Failed to create provider: %v", err)
+ }
+ defer provider.Close()
+
+ req := ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []Message{
+ {Role: RoleUser, Content: "Say hello!"},
+ },
+ MaxTokens: &[]int{50}[0],
+ }
+
+ resp, err := provider.ChatCompletion(context.Background(), req)
+ if err != nil {
+ t.Fatalf("Chat completion failed: %v", err)
+ }
+
+ if len(resp.Choices) == 0 {
+ t.Fatal("Expected at least one choice")
+ }
+
+ if resp.Choices[0].Message.Content == "" {
+ t.Fatal("Expected non-empty response content")
+ }
+ */
+}