Remove llm factories
Change-Id: I87afaad65f299b79ceb447b99c464bfe5c7d68cd
diff --git a/server/llm/openai/openai.go b/server/llm/openai/openai.go
index c513c53..fad11f6 100644
--- a/server/llm/openai/openai.go
+++ b/server/llm/openai/openai.go
@@ -152,8 +152,7 @@
} `json:"error"`
}
-// NewOpenAIProvider creates a new OpenAI provider
-func NewOpenAIProvider(config llm.Config) *OpenAIProvider {
+func New(config llm.Config) *OpenAIProvider {
client := &http.Client{
Timeout: config.Timeout,
}
@@ -436,33 +435,3 @@
return body, nil
}
-
-// OpenAIFactory implements ProviderFactory for OpenAI
-type OpenAIFactory struct{}
-
-// CreateProvider creates a new OpenAI provider
-func (f *OpenAIFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
- if config.Provider != llm.ProviderOpenAI {
- return nil, fmt.Errorf("OpenAI factory cannot create provider: %s", config.Provider)
- }
-
- // Validate config
- if err := llm.ValidateConfig(config); err != nil {
- return nil, fmt.Errorf("invalid OpenAI config: %w", err)
- }
-
- // Merge with defaults
- config = llm.MergeConfig(config)
-
- return NewOpenAIProvider(config), nil
-}
-
-// SupportsProvider checks if this factory supports the given provider
-func (f *OpenAIFactory) SupportsProvider(provider llm.Provider) bool {
- return provider == llm.ProviderOpenAI
-}
-
-// Register OpenAI provider with the default registry
-func init() {
- llm.RegisterProvider(llm.ProviderOpenAI, &OpenAIFactory{})
-}
diff --git a/server/llm/openai/openai_example.go b/server/llm/openai/openai_example.go
deleted file mode 100644
index 6cd7dbf..0000000
--- a/server/llm/openai/openai_example.go
+++ /dev/null
@@ -1,254 +0,0 @@
-package openai
-
-import (
- "context"
- "fmt"
- "log"
- "time"
-
- "github.com/iomodo/staff/llm"
-)
-
-// ExampleOpenAI demonstrates how to use the OpenAI implementation
-func ExampleOpenAI() {
- // Create OpenAI configuration
- config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "your-openai-api-key-here", // Replace with your actual API key
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- ExtraConfig: map[string]interface{}{
- "organization": "your-org-id", // Optional: your OpenAI organization ID
- },
- }
-
- // Create the provider
- provider, err := llm.CreateProvider(config)
- if err != nil {
- log.Fatalf("Failed to create OpenAI provider: %v", err)
- }
- defer provider.Close()
-
- // Example 1: Basic chat completion
- exampleOpenAIChatCompletion(provider)
-
- // Example 2: Function calling
- exampleOpenAIFunctionCalling(provider)
-
- // Example 3: Embeddings
- exampleOpenAIEmbeddings(provider)
-}
-
-// exampleOpenAIChatCompletion demonstrates basic chat completion with OpenAI
-func exampleOpenAIChatCompletion(provider llm.LLMProvider) {
- fmt.Println("=== OpenAI Chat Completion Example ===")
-
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleSystem, Content: "You are a helpful assistant."},
- {Role: llm.RoleUser, Content: "What is the capital of France?"},
- },
- MaxTokens: &[]int{100}[0],
- Temperature: &[]float64{0.7}[0],
- }
-
- ctx := context.Background()
- resp, err := provider.ChatCompletion(ctx, req)
- if err != nil {
- log.Printf("Chat completion failed: %v", err)
- return
- }
-
- fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
- fmt.Printf("Model: %s\n", resp.Model)
- fmt.Printf("Provider: %s\n", resp.Provider)
- fmt.Printf("Usage: %+v\n", resp.Usage)
-}
-
-// exampleOpenAIFunctionCalling demonstrates function calling with OpenAI
-func exampleOpenAIFunctionCalling(provider llm.LLMProvider) {
- fmt.Println("\n=== OpenAI Function Calling Example ===")
-
- // Define a function that can be called
- tools := []llm.Tool{
- {
- Type: "function",
- Function: llm.Function{
- Name: "get_weather",
- Description: "Get the current weather for a location",
- Parameters: map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "location": map[string]interface{}{
- "type": "string",
- "description": "The city and state, e.g. San Francisco, CA",
- },
- "unit": map[string]interface{}{
- "type": "string",
- "enum": []string{"celsius", "fahrenheit"},
- },
- },
- "required": []string{"location"},
- },
- },
- },
- }
-
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "What's the weather like in Tokyo?"},
- },
- Tools: tools,
- MaxTokens: &[]int{150}[0],
- Temperature: &[]float64{0.1}[0],
- }
-
- ctx := context.Background()
- resp, err := provider.ChatCompletion(ctx, req)
- if err != nil {
- log.Printf("Function calling failed: %v", err)
- return
- }
-
- // Check if the model wants to call a function
- if len(resp.Choices[0].Message.ToolCalls) > 0 {
- fmt.Println("Model wants to call a function:")
- for _, toolCall := range resp.Choices[0].Message.ToolCalls {
- fmt.Printf("Function: %s\n", toolCall.Function.Name)
- fmt.Printf("Arguments: %+v\n", toolCall.Function.Parameters)
- }
- } else {
- fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
- }
-}
-
-// exampleOpenAIEmbeddings demonstrates embedding generation with OpenAI
-func exampleOpenAIEmbeddings(provider llm.LLMProvider) {
- fmt.Println("\n=== OpenAI Embeddings Example ===")
-
- req := llm.EmbeddingRequest{
- Input: "Hello, world! This is a test sentence for embeddings.",
- Model: "text-embedding-ada-002",
- }
-
- ctx := context.Background()
- resp, err := provider.CreateEmbeddings(ctx, req)
- if err != nil {
- log.Printf("Embeddings failed: %v", err)
- return
- }
-
- fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding))
- fmt.Printf("First 5 values: %v\n", resp.Data[0].Embedding[:5])
- fmt.Printf("Model: %s\n", resp.Model)
- fmt.Printf("Provider: %s\n", resp.Provider)
- fmt.Printf("Usage: %+v\n", resp.Usage)
-}
-
-// ExampleOpenAIWithErrorHandling demonstrates proper error handling
-func ExampleOpenAIWithErrorHandling() {
- fmt.Println("\n=== OpenAI Error Handling Example ===")
-
- // Test with invalid API key
- config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "invalid-key",
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- }
-
- provider, err := llm.CreateProvider(config)
- if err != nil {
- fmt.Printf("Failed to create provider: %v\n", err)
- return
- }
- defer provider.Close()
-
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello"},
- },
- MaxTokens: &[]int{50}[0],
- }
-
- ctx := context.Background()
- resp, err := provider.ChatCompletion(ctx, req)
- if err != nil {
- fmt.Printf("Expected error with invalid API key: %v\n", err)
- return
- }
-
- fmt.Printf("Unexpected success: %s\n", resp.Choices[0].Message.Content)
-}
-
-// ExampleOpenAIWithCustomConfig demonstrates custom configuration
-func ExampleOpenAIWithCustomConfig() {
- fmt.Println("\n=== OpenAI Custom Configuration Example ===")
-
- // Start with minimal config
- customConfig := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "your-api-key",
- }
-
- // Merge with defaults
- config := llm.MergeConfig(customConfig)
-
- fmt.Printf("Merged config: BaseURL=%s, Timeout=%v, MaxRetries=%d\n",
- config.BaseURL, config.Timeout, config.MaxRetries)
-
- // Add extra configuration
- config.ExtraConfig = map[string]interface{}{
- "organization": "your-org-id",
- "project": "your-project-id",
- }
-
- provider, err := llm.CreateProvider(config)
- if err != nil {
- fmt.Printf("Failed to create provider: %v\n", err)
- return
- }
- defer provider.Close()
-
- fmt.Println("Provider created successfully with custom configuration")
-}
-
-// ExampleOpenAIWithValidation demonstrates request validation
-func ExampleOpenAIWithValidation() {
- fmt.Println("\n=== OpenAI Validation Example ===")
-
- // Valid request
- validReq := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello"},
- },
- Temperature: &[]float64{0.5}[0],
- }
-
- if err := llm.ValidateChatCompletionRequest(validReq); err != nil {
- fmt.Printf("Valid request validation failed: %v\n", err)
- } else {
- fmt.Println("Valid request passed validation")
- }
-
- // Invalid request (no model)
- invalidReq := llm.ChatCompletionRequest{
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello"},
- },
- }
-
- if err := llm.ValidateChatCompletionRequest(invalidReq); err != nil {
- fmt.Printf("Invalid request correctly caught: %v\n", err)
- } else {
- fmt.Println("Invalid request should have failed validation")
- }
-
- // Test token estimation
- tokens := llm.EstimateTokens(validReq)
- fmt.Printf("Estimated tokens for request: %d\n", tokens)
-}
diff --git a/server/llm/openai/openai_test.go b/server/llm/openai/openai_test.go
deleted file mode 100644
index 8896405..0000000
--- a/server/llm/openai/openai_test.go
+++ /dev/null
@@ -1,428 +0,0 @@
-package openai
-
-import (
- "testing"
- "time"
-
- "github.com/iomodo/staff/llm"
-)
-
-func TestOpenAIProvider_Interface(t *testing.T) {
- // Test that OpenAIProvider implements LLMProvider interface
- var _ llm.LLMProvider = (*OpenAIProvider)(nil)
-}
-
-func TestOpenAIFactory_CreateProvider(t *testing.T) {
- factory := &OpenAIFactory{}
-
- // Test valid config
- config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- }
-
- provider, err := factory.CreateProvider(config)
- if err != nil {
- t.Fatalf("Failed to create provider: %v", err)
- }
-
- if provider == nil {
- t.Fatal("Provider should not be nil")
- }
-
- // Test invalid provider
- invalidConfig := llm.Config{
- Provider: llm.ProviderClaude,
- APIKey: "test-key",
- }
-
- _, err = factory.CreateProvider(invalidConfig)
- if err == nil {
- t.Fatal("Should fail with invalid provider")
- }
-
- // Test missing API key
- noKeyConfig := llm.Config{
- Provider: llm.ProviderOpenAI,
- BaseURL: "https://api.openai.com/v1",
- }
-
- _, err = factory.CreateProvider(noKeyConfig)
- if err == nil {
- t.Fatal("Should fail with missing API key")
- }
-}
-
-func TestOpenAIFactory_SupportsProvider(t *testing.T) {
- factory := &OpenAIFactory{}
-
- if !factory.SupportsProvider(llm.ProviderOpenAI) {
- t.Fatal("Should support OpenAI provider")
- }
-
- if factory.SupportsProvider(llm.ProviderClaude) {
- t.Fatal("Should not support Claude provider")
- }
-}
-
-func TestOpenAIProvider_ConvertRequest(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test basic request conversion
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello"},
- },
- MaxTokens: &[]int{100}[0],
- Temperature: &[]float64{0.7}[0],
- }
-
- openAIReq := provider.convertToOpenAIRequest(req)
-
- if openAIReq.Model != "gpt-3.5-turbo" {
- t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", openAIReq.Model)
- }
-
- if len(openAIReq.Messages) != 1 {
- t.Errorf("Expected 1 message, got %d", len(openAIReq.Messages))
- }
-
- if openAIReq.Messages[0].Role != "user" {
- t.Errorf("Expected role 'user', got '%s'", openAIReq.Messages[0].Role)
- }
-
- if openAIReq.Messages[0].Content != "Hello" {
- t.Errorf("Expected content 'Hello', got '%s'", openAIReq.Messages[0].Content)
- }
-
- if *openAIReq.MaxTokens != 100 {
- t.Errorf("Expected max_tokens 100, got %d", *openAIReq.MaxTokens)
- }
-
- if *openAIReq.Temperature != 0.7 {
- t.Errorf("Expected temperature 0.7, got %f", *openAIReq.Temperature)
- }
-}
-
-func TestOpenAIProvider_ConvertResponse(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test basic response conversion
- openAIResp := OpenAIResponse{
- ID: "test-id",
- Object: "chat.completion",
- Created: 1234567890,
- Model: "gpt-3.5-turbo",
- Choices: []OpenAIChoice{
- {
- Index: 0,
- Message: OpenAIMessage{
- Role: "assistant",
- Content: "Hello! How can I help you?",
- },
- FinishReason: "stop",
- },
- },
- Usage: OpenAIUsage{
- PromptTokens: 10,
- CompletionTokens: 20,
- TotalTokens: 30,
- },
- }
-
- resp := provider.convertFromOpenAIResponse(openAIResp)
-
- if resp.ID != "test-id" {
- t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
- }
-
- if resp.Model != "gpt-3.5-turbo" {
- t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", resp.Model)
- }
-
- if resp.Provider != llm.ProviderOpenAI {
- t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
- }
-
- if len(resp.Choices) != 1 {
- t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
- }
-
- if resp.Choices[0].Message.Role != llm.RoleAssistant {
- t.Errorf("Expected role assistant, got %s", resp.Choices[0].Message.Role)
- }
-
- if resp.Choices[0].Message.Content != "Hello! How can I help you?" {
- t.Errorf("Expected content 'Hello! How can I help you?', got '%s'", resp.Choices[0].Message.Content)
- }
-
- if resp.Usage.PromptTokens != 10 {
- t.Errorf("Expected prompt tokens 10, got %d", resp.Usage.PromptTokens)
- }
-
- if resp.Usage.CompletionTokens != 20 {
- t.Errorf("Expected completion tokens 20, got %d", resp.Usage.CompletionTokens)
- }
-
- if resp.Usage.TotalTokens != 30 {
- t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
- }
-}
-
-func TestOpenAIProvider_ConvertRequestWithTools(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test request with tools
- tools := []llm.Tool{
- {
- Type: "function",
- Function: llm.Function{
- Name: "get_weather",
- Description: "Get weather information",
- Parameters: map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "location": map[string]interface{}{
- "type": "string",
- },
- },
- },
- },
- },
- }
-
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "What's the weather like?"},
- },
- Tools: tools,
- }
-
- openAIReq := provider.convertToOpenAIRequest(req)
-
- if len(openAIReq.Tools) != 1 {
- t.Errorf("Expected 1 tool, got %d", len(openAIReq.Tools))
- }
-
- if openAIReq.Tools[0].Type != "function" {
- t.Errorf("Expected tool type 'function', got '%s'", openAIReq.Tools[0].Type)
- }
-
- if openAIReq.Tools[0].Function.Name != "get_weather" {
- t.Errorf("Expected function name 'get_weather', got '%s'", openAIReq.Tools[0].Function.Name)
- }
-}
-
-func TestOpenAIProvider_ConvertResponseWithToolCalls(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test response with tool calls
- openAIResp := OpenAIResponse{
- ID: "test-id",
- Object: "chat.completion",
- Model: "gpt-3.5-turbo",
- Choices: []OpenAIChoice{
- {
- Index: 0,
- Message: OpenAIMessage{
- Role: "assistant",
- ToolCalls: []OpenAIToolCall{
- {
- ID: "call_123",
- Type: "function",
- Function: OpenAIFunction{
- Name: "get_weather",
- Parameters: map[string]interface{}{
- "location": "Tokyo",
- },
- },
- },
- },
- },
- FinishReason: "tool_calls",
- },
- },
- Usage: OpenAIUsage{
- PromptTokens: 10,
- CompletionTokens: 20,
- TotalTokens: 30,
- },
- }
-
- resp := provider.convertFromOpenAIResponse(openAIResp)
-
- if len(resp.Choices[0].Message.ToolCalls) != 1 {
- t.Errorf("Expected 1 tool call, got %d", len(resp.Choices[0].Message.ToolCalls))
- }
-
- if resp.Choices[0].Message.ToolCalls[0].ID != "call_123" {
- t.Errorf("Expected tool call ID 'call_123', got '%s'", resp.Choices[0].Message.ToolCalls[0].ID)
- }
-
- if resp.Choices[0].Message.ToolCalls[0].Function.Name != "get_weather" {
- t.Errorf("Expected function name 'get_weather', got '%s'", resp.Choices[0].Message.ToolCalls[0].Function.Name)
- }
-
- if resp.Choices[0].FinishReason != "tool_calls" {
- t.Errorf("Expected finish reason 'tool_calls', got '%s'", resp.Choices[0].FinishReason)
- }
-}
-
-func TestOpenAIProvider_ConvertEmbeddingRequest(t *testing.T) {
- req := llm.EmbeddingRequest{
- Input: "Hello, world!",
- Model: "text-embedding-ada-002",
- User: "test-user",
- }
-
- // The conversion is done inline in CreateEmbeddings, so we'll test the structure
- if req.Input != "Hello, world!" {
- t.Errorf("Expected input 'Hello, world!', got '%v'", req.Input)
- }
-
- if req.Model != "text-embedding-ada-002" {
- t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", req.Model)
- }
-
- if req.User != "test-user" {
- t.Errorf("Expected user 'test-user', got '%s'", req.User)
- }
-}
-
-func TestOpenAIProvider_ConvertEmbeddingResponse(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test embedding response conversion
- openAIResp := OpenAIEmbeddingResponse{
- Object: "list",
- Model: "text-embedding-ada-002",
- Data: []OpenAIEmbeddingData{
- {
- Object: "embedding",
- Embedding: []float64{0.1, 0.2, 0.3},
- Index: 0,
- },
- },
- Usage: OpenAIUsage{
- PromptTokens: 5,
- CompletionTokens: 0,
- TotalTokens: 5,
- },
- }
-
- resp := provider.convertFromOpenAIEmbeddingResponse(openAIResp)
-
- if resp.Object != "list" {
- t.Errorf("Expected object 'list', got '%s'", resp.Object)
- }
-
- if resp.Model != "text-embedding-ada-002" {
- t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", resp.Model)
- }
-
- if resp.Provider != llm.ProviderOpenAI {
- t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
- }
-
- if len(resp.Data) != 1 {
- t.Errorf("Expected 1 embedding, got %d", len(resp.Data))
- }
-
- if len(resp.Data[0].Embedding) != 3 {
- t.Errorf("Expected embedding dimension 3, got %d", len(resp.Data[0].Embedding))
- }
-
- if resp.Data[0].Embedding[0] != 0.1 {
- t.Errorf("Expected first embedding value 0.1, got %f", resp.Data[0].Embedding[0])
- }
-}
-
-func TestOpenAIProvider_Close(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test that Close doesn't return an error
- err := provider.Close()
- if err != nil {
- t.Errorf("Close should not return an error: %v", err)
- }
-}
-
-func TestOpenAIProvider_Integration(t *testing.T) {
- // This test would require a real API key and would make actual API calls
- // It's commented out to avoid making real API calls during testing
- /*
- config := Config{
- Provider: ProviderOpenAI,
- APIKey: "your-real-api-key",
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- }
-
- provider, err := CreateProvider(config)
- if err != nil {
- t.Fatalf("Failed to create provider: %v", err)
- }
- defer provider.Close()
-
- req := ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []Message{
- {Role: RoleUser, Content: "Say hello!"},
- },
- MaxTokens: &[]int{50}[0],
- }
-
- resp, err := provider.ChatCompletion(context.Background(), req)
- if err != nil {
- t.Fatalf("Chat completion failed: %v", err)
- }
-
- if len(resp.Choices) == 0 {
- t.Fatal("Expected at least one choice")
- }
-
- if resp.Choices[0].Message.Content == "" {
- t.Fatal("Expected non-empty response content")
- }
- */
-}