blob: 88964055a68d37879262bdc1db5665f099c3939d [file] [log] [blame]
package openai
import (
"testing"
"time"
"github.com/iomodo/staff/llm"
)
func TestOpenAIProvider_Interface(t *testing.T) {
// Test that OpenAIProvider implements LLMProvider interface
var _ llm.LLMProvider = (*OpenAIProvider)(nil)
}
func TestOpenAIFactory_CreateProvider(t *testing.T) {
factory := &OpenAIFactory{}
// Test valid config
config := llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "test-key",
BaseURL: "https://api.openai.com/v1",
Timeout: 30 * time.Second,
}
provider, err := factory.CreateProvider(config)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
if provider == nil {
t.Fatal("Provider should not be nil")
}
// Test invalid provider
invalidConfig := llm.Config{
Provider: llm.ProviderClaude,
APIKey: "test-key",
}
_, err = factory.CreateProvider(invalidConfig)
if err == nil {
t.Fatal("Should fail with invalid provider")
}
// Test missing API key
noKeyConfig := llm.Config{
Provider: llm.ProviderOpenAI,
BaseURL: "https://api.openai.com/v1",
}
_, err = factory.CreateProvider(noKeyConfig)
if err == nil {
t.Fatal("Should fail with missing API key")
}
}
func TestOpenAIFactory_SupportsProvider(t *testing.T) {
factory := &OpenAIFactory{}
if !factory.SupportsProvider(llm.ProviderOpenAI) {
t.Fatal("Should support OpenAI provider")
}
if factory.SupportsProvider(llm.ProviderClaude) {
t.Fatal("Should not support Claude provider")
}
}
func TestOpenAIProvider_ConvertRequest(t *testing.T) {
provider := &OpenAIProvider{
config: llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "test-key",
BaseURL: "https://api.openai.com/v1",
},
}
// Test basic request conversion
req := llm.ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []llm.Message{
{Role: llm.RoleUser, Content: "Hello"},
},
MaxTokens: &[]int{100}[0],
Temperature: &[]float64{0.7}[0],
}
openAIReq := provider.convertToOpenAIRequest(req)
if openAIReq.Model != "gpt-3.5-turbo" {
t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", openAIReq.Model)
}
if len(openAIReq.Messages) != 1 {
t.Errorf("Expected 1 message, got %d", len(openAIReq.Messages))
}
if openAIReq.Messages[0].Role != "user" {
t.Errorf("Expected role 'user', got '%s'", openAIReq.Messages[0].Role)
}
if openAIReq.Messages[0].Content != "Hello" {
t.Errorf("Expected content 'Hello', got '%s'", openAIReq.Messages[0].Content)
}
if *openAIReq.MaxTokens != 100 {
t.Errorf("Expected max_tokens 100, got %d", *openAIReq.MaxTokens)
}
if *openAIReq.Temperature != 0.7 {
t.Errorf("Expected temperature 0.7, got %f", *openAIReq.Temperature)
}
}
func TestOpenAIProvider_ConvertResponse(t *testing.T) {
provider := &OpenAIProvider{
config: llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "test-key",
BaseURL: "https://api.openai.com/v1",
},
}
// Test basic response conversion
openAIResp := OpenAIResponse{
ID: "test-id",
Object: "chat.completion",
Created: 1234567890,
Model: "gpt-3.5-turbo",
Choices: []OpenAIChoice{
{
Index: 0,
Message: OpenAIMessage{
Role: "assistant",
Content: "Hello! How can I help you?",
},
FinishReason: "stop",
},
},
Usage: OpenAIUsage{
PromptTokens: 10,
CompletionTokens: 20,
TotalTokens: 30,
},
}
resp := provider.convertFromOpenAIResponse(openAIResp)
if resp.ID != "test-id" {
t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
}
if resp.Model != "gpt-3.5-turbo" {
t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", resp.Model)
}
if resp.Provider != llm.ProviderOpenAI {
t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
}
if len(resp.Choices) != 1 {
t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
}
if resp.Choices[0].Message.Role != llm.RoleAssistant {
t.Errorf("Expected role assistant, got %s", resp.Choices[0].Message.Role)
}
if resp.Choices[0].Message.Content != "Hello! How can I help you?" {
t.Errorf("Expected content 'Hello! How can I help you?', got '%s'", resp.Choices[0].Message.Content)
}
if resp.Usage.PromptTokens != 10 {
t.Errorf("Expected prompt tokens 10, got %d", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 20 {
t.Errorf("Expected completion tokens 20, got %d", resp.Usage.CompletionTokens)
}
if resp.Usage.TotalTokens != 30 {
t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
}
}
func TestOpenAIProvider_ConvertRequestWithTools(t *testing.T) {
provider := &OpenAIProvider{
config: llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "test-key",
BaseURL: "https://api.openai.com/v1",
},
}
// Test request with tools
tools := []llm.Tool{
{
Type: "function",
Function: llm.Function{
Name: "get_weather",
Description: "Get weather information",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"location": map[string]interface{}{
"type": "string",
},
},
},
},
},
}
req := llm.ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []llm.Message{
{Role: llm.RoleUser, Content: "What's the weather like?"},
},
Tools: tools,
}
openAIReq := provider.convertToOpenAIRequest(req)
if len(openAIReq.Tools) != 1 {
t.Errorf("Expected 1 tool, got %d", len(openAIReq.Tools))
}
if openAIReq.Tools[0].Type != "function" {
t.Errorf("Expected tool type 'function', got '%s'", openAIReq.Tools[0].Type)
}
if openAIReq.Tools[0].Function.Name != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", openAIReq.Tools[0].Function.Name)
}
}
func TestOpenAIProvider_ConvertResponseWithToolCalls(t *testing.T) {
provider := &OpenAIProvider{
config: llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "test-key",
BaseURL: "https://api.openai.com/v1",
},
}
// Test response with tool calls
openAIResp := OpenAIResponse{
ID: "test-id",
Object: "chat.completion",
Model: "gpt-3.5-turbo",
Choices: []OpenAIChoice{
{
Index: 0,
Message: OpenAIMessage{
Role: "assistant",
ToolCalls: []OpenAIToolCall{
{
ID: "call_123",
Type: "function",
Function: OpenAIFunction{
Name: "get_weather",
Parameters: map[string]interface{}{
"location": "Tokyo",
},
},
},
},
},
FinishReason: "tool_calls",
},
},
Usage: OpenAIUsage{
PromptTokens: 10,
CompletionTokens: 20,
TotalTokens: 30,
},
}
resp := provider.convertFromOpenAIResponse(openAIResp)
if len(resp.Choices[0].Message.ToolCalls) != 1 {
t.Errorf("Expected 1 tool call, got %d", len(resp.Choices[0].Message.ToolCalls))
}
if resp.Choices[0].Message.ToolCalls[0].ID != "call_123" {
t.Errorf("Expected tool call ID 'call_123', got '%s'", resp.Choices[0].Message.ToolCalls[0].ID)
}
if resp.Choices[0].Message.ToolCalls[0].Function.Name != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", resp.Choices[0].Message.ToolCalls[0].Function.Name)
}
if resp.Choices[0].FinishReason != "tool_calls" {
t.Errorf("Expected finish reason 'tool_calls', got '%s'", resp.Choices[0].FinishReason)
}
}
func TestOpenAIProvider_ConvertEmbeddingRequest(t *testing.T) {
req := llm.EmbeddingRequest{
Input: "Hello, world!",
Model: "text-embedding-ada-002",
User: "test-user",
}
// The conversion is done inline in CreateEmbeddings, so we'll test the structure
if req.Input != "Hello, world!" {
t.Errorf("Expected input 'Hello, world!', got '%v'", req.Input)
}
if req.Model != "text-embedding-ada-002" {
t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", req.Model)
}
if req.User != "test-user" {
t.Errorf("Expected user 'test-user', got '%s'", req.User)
}
}
func TestOpenAIProvider_ConvertEmbeddingResponse(t *testing.T) {
provider := &OpenAIProvider{
config: llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "test-key",
BaseURL: "https://api.openai.com/v1",
},
}
// Test embedding response conversion
openAIResp := OpenAIEmbeddingResponse{
Object: "list",
Model: "text-embedding-ada-002",
Data: []OpenAIEmbeddingData{
{
Object: "embedding",
Embedding: []float64{0.1, 0.2, 0.3},
Index: 0,
},
},
Usage: OpenAIUsage{
PromptTokens: 5,
CompletionTokens: 0,
TotalTokens: 5,
},
}
resp := provider.convertFromOpenAIEmbeddingResponse(openAIResp)
if resp.Object != "list" {
t.Errorf("Expected object 'list', got '%s'", resp.Object)
}
if resp.Model != "text-embedding-ada-002" {
t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", resp.Model)
}
if resp.Provider != llm.ProviderOpenAI {
t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
}
if len(resp.Data) != 1 {
t.Errorf("Expected 1 embedding, got %d", len(resp.Data))
}
if len(resp.Data[0].Embedding) != 3 {
t.Errorf("Expected embedding dimension 3, got %d", len(resp.Data[0].Embedding))
}
if resp.Data[0].Embedding[0] != 0.1 {
t.Errorf("Expected first embedding value 0.1, got %f", resp.Data[0].Embedding[0])
}
}
func TestOpenAIProvider_Close(t *testing.T) {
provider := &OpenAIProvider{
config: llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "test-key",
BaseURL: "https://api.openai.com/v1",
},
}
// Test that Close doesn't return an error
err := provider.Close()
if err != nil {
t.Errorf("Close should not return an error: %v", err)
}
}
func TestOpenAIProvider_Integration(t *testing.T) {
// This test would require a real API key and would make actual API calls
// It's commented out to avoid making real API calls during testing
/*
config := Config{
Provider: ProviderOpenAI,
APIKey: "your-real-api-key",
BaseURL: "https://api.openai.com/v1",
Timeout: 30 * time.Second,
}
provider, err := CreateProvider(config)
if err != nil {
t.Fatalf("Failed to create provider: %v", err)
}
defer provider.Close()
req := ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []Message{
{Role: RoleUser, Content: "Say hello!"},
},
MaxTokens: &[]int{50}[0],
}
resp, err := provider.ChatCompletion(context.Background(), req)
if err != nil {
t.Fatalf("Chat completion failed: %v", err)
}
if len(resp.Choices) == 0 {
t.Fatal("Expected at least one choice")
}
if resp.Choices[0].Message.Content == "" {
t.Fatal("Expected non-empty response content")
}
*/
}