Remove llm factories
Change-Id: I87afaad65f299b79ceb447b99c464bfe5c7d68cd
diff --git a/server/agent/agent.go b/server/agent/agent.go
index 3c092d5..f189794 100644
--- a/server/agent/agent.go
+++ b/server/agent/agent.go
@@ -9,6 +9,7 @@
"github.com/iomodo/staff/config"
"github.com/iomodo/staff/llm"
+ "github.com/iomodo/staff/llm/provider"
"github.com/iomodo/staff/tm"
)
@@ -43,19 +44,15 @@
return nil, fmt.Errorf("failed to load system prompt: %w", err)
}
- provider, err := llm.CreateProvider(llmConfig)
- if err != nil {
- return nil, fmt.Errorf("failed to create LLM provider: %w", err)
- }
-
- thinker := NewThinker(provider, agentConfig.Model, systemPrompt, *agentConfig.MaxTokens, *agentConfig.Temperature, agentRoles, logger)
+ prov := provider.CreateProvider(llmConfig)
+ thinker := NewThinker(prov, agentConfig.Model, systemPrompt, *agentConfig.MaxTokens, *agentConfig.Temperature, agentRoles, logger)
agent := &Agent{
Name: agentConfig.Name,
Role: agentConfig.Role,
Model: agentConfig.Model,
SystemPrompt: systemPrompt,
- Provider: provider,
+ Provider: prov,
MaxTokens: agentConfig.MaxTokens,
Temperature: agentConfig.Temperature,
taskManager: taskManager,
diff --git a/server/agent/manager.go b/server/agent/manager.go
index db8f27c..ab2ce72 100644
--- a/server/agent/manager.go
+++ b/server/agent/manager.go
@@ -7,7 +7,6 @@
"github.com/iomodo/staff/config"
"github.com/iomodo/staff/llm"
- _ "github.com/iomodo/staff/llm/providers" // Auto-register all providers
"github.com/iomodo/staff/task"
"github.com/iomodo/staff/tm"
)
diff --git a/server/config/openai_test.go b/server/config/openai_test.go
deleted file mode 100644
index 5f6b3fa..0000000
--- a/server/config/openai_test.go
+++ /dev/null
@@ -1,202 +0,0 @@
-package config
-
-import (
- "context"
- "os"
- "testing"
-
- "github.com/iomodo/staff/llm"
- "github.com/iomodo/staff/llm/openai"
-)
-
-// TestOpenAIIntegration tests the OpenAI integration with real API calls
-// This test requires OPENAI_API_KEY environment variable to be set
-func TestOpenAIIntegration(t *testing.T) {
- apiKey := os.Getenv("OPENAI_API_KEY")
- if apiKey == "" {
- t.Skip("OPENAI_API_KEY not set, skipping OpenAI integration test")
- }
-
- // Create OpenAI config
- config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: apiKey,
- BaseURL: "https://api.openai.com/v1",
- }
-
- // Create OpenAI provider
- factory := &openai.OpenAIFactory{}
- provider, err := factory.CreateProvider(config)
- if err != nil {
- t.Fatalf("Failed to create OpenAI provider: %v", err)
- }
- defer provider.Close()
-
- // Test chat completion
- t.Run("ChatCompletion", func(t *testing.T) {
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {
- Role: llm.RoleSystem,
- Content: "You are a helpful assistant.",
- },
- {
- Role: llm.RoleUser,
- Content: "Hello! Just say 'Hello from OpenAI' and nothing else.",
- },
- },
- }
-
- resp, err := provider.ChatCompletion(context.Background(), req)
- if err != nil {
- t.Fatalf("ChatCompletion failed: %v", err)
- }
-
- if len(resp.Choices) == 0 {
- t.Fatal("No choices returned")
- }
-
- message := resp.Choices[0].Message
- if message.Content == "" {
- t.Fatal("Empty response content")
- }
-
- t.Logf("OpenAI Response: %s", message.Content)
- })
-
- // Test embeddings
- t.Run("Embeddings", func(t *testing.T) {
- req := llm.EmbeddingRequest{
- Model: "text-embedding-ada-002",
- Input: "Hello, world!",
- }
-
- resp, err := provider.CreateEmbeddings(context.Background(), req)
- if err != nil {
- t.Fatalf("CreateEmbeddings failed: %v", err)
- }
-
- if len(resp.Data) == 0 {
- t.Fatal("No embeddings returned")
- }
-
- embedding := resp.Data[0]
- if len(embedding.Embedding) == 0 {
- t.Fatal("Empty embedding vector")
- }
-
- t.Logf("Embedding dimensions: %d", len(embedding.Embedding))
- })
-}
-
-// TestConfigurationLoading tests the configuration loading functionality
-func TestConfigurationLoading(t *testing.T) {
- // Create a temporary config file
- configContent := `
-openai:
- api_key: "test-key"
- model: "gpt-4"
-
-github:
- token: "test-token"
- owner: "test-owner"
- repo: "test-repo"
-
-agents:
- - name: "ceo"
- role: "CEO"
- system_prompt_file: "operations/agents/ceo/system.md"
-
-tasks:
- storage_path: "tasks/"
-`
-
- // Write temp config file
- tmpFile, err := os.CreateTemp("", "staff-config-*.yaml")
- if err != nil {
- t.Fatalf("Failed to create temp file: %v", err)
- }
- defer os.Remove(tmpFile.Name())
-
- if _, err := tmpFile.WriteString(configContent); err != nil {
- t.Fatalf("Failed to write config: %v", err)
- }
- tmpFile.Close()
-
- // Test loading config
- config, err := LoadConfig(tmpFile.Name())
- if err != nil {
- t.Fatalf("Failed to load config: %v", err)
- }
-
- // Validate loaded config
- if config.OpenAI.APIKey != "test-key" {
- t.Errorf("Expected API key 'test-key', got '%s'", config.OpenAI.APIKey)
- }
-
- if config.OpenAI.Model != "gpt-4" {
- t.Errorf("Expected model 'gpt-4', got '%s'", config.OpenAI.Model)
- }
-
- if len(config.Agents) != 1 {
- t.Errorf("Expected 1 agent, got %d", len(config.Agents))
- }
-
- if config.Agents[0].Name != "ceo" {
- t.Errorf("Expected agent name 'ceo', got '%s'", config.Agents[0].Name)
- }
-}
-
-// TestEnvironmentOverrides tests environment variable overrides
-func TestEnvironmentOverrides(t *testing.T) {
- // Set environment variables
- os.Setenv("OPENAI_API_KEY", "env-openai-key")
- os.Setenv("GITHUB_TOKEN", "env-github-token")
- defer func() {
- os.Unsetenv("OPENAI_API_KEY")
- os.Unsetenv("GITHUB_TOKEN")
- }()
-
- // Create a temporary config file
- configContent := `
-openai:
- api_key: "config-key"
-
-github:
- token: "config-token"
- owner: "test-owner"
- repo: "test-repo"
-
-agents:
- - name: "ceo"
- role: "CEO"
- system_prompt_file: "operations/agents/ceo/system.md"
-`
-
- tmpFile, err := os.CreateTemp("", "staff-config-*.yaml")
- if err != nil {
- t.Fatalf("Failed to create temp file: %v", err)
- }
- defer os.Remove(tmpFile.Name())
-
- if _, err := tmpFile.WriteString(configContent); err != nil {
- t.Fatalf("Failed to write config: %v", err)
- }
- tmpFile.Close()
-
- // Test loading config with env overrides
- config, err := LoadConfigWithEnvOverrides(tmpFile.Name())
- if err != nil {
- t.Fatalf("Failed to load config: %v", err)
- }
-
- // Verify environment overrides
- if config.OpenAI.APIKey != "env-openai-key" {
- t.Errorf("Expected env API key 'env-openai-key', got '%s'", config.OpenAI.APIKey)
- }
-
- if config.GitHub.Token != "env-github-token" {
- t.Errorf("Expected env GitHub token 'env-github-token', got '%s'", config.GitHub.Token)
- }
-}
diff --git a/server/llm/README.md b/server/llm/README.md
deleted file mode 100644
index 92df86b..0000000
--- a/server/llm/README.md
+++ /dev/null
@@ -1,335 +0,0 @@
-# LLM Interface Package
-
-This package provides a generic interface for different Large Language Model (LLM) providers, with OpenAI's API structure as the primary reference. It supports multiple providers including OpenAI, xAI, Claude, Gemini, and local models.
-
-## Features
-
-- **Unified Interface**: Single interface for all LLM providers
-- **Multiple Providers**: Support for OpenAI, xAI, Claude, Gemini, and local models
-- **Tool/Function Calling**: Support for function calling and tool usage
-- **Embeddings**: Generate embeddings for text
-- **Configurable**: Flexible configuration options for each provider
-- **Thread-Safe**: Thread-safe factory and registry implementations
-
-## Quick Start
-
-```go
-package main
-
-import (
- "context"
- "fmt"
- "log"
-
- "your-project/server/llm"
-)
-
-func main() {
- // Create a configuration for OpenAI
- config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "your-openai-api-key",
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- }
-
- // Create a provider (you'll need to register the implementation first)
- provider, err := llm.CreateProvider(config)
- if err != nil {
- log.Fatal(err)
- }
- defer provider.Close()
-
- // Create a chat completion request
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello, how are you?"},
- },
- MaxTokens: &[]int{100}[0],
- Temperature: &[]float64{0.7}[0],
- }
-
- // Get the response
- resp, err := provider.ChatCompletion(context.Background(), req)
- if err != nil {
- log.Fatal(err)
- }
-
- fmt.Println("Response:", resp.Choices[0].Message.Content)
-}
-```
-
-## Core Types
-
-### Provider
-
-Represents different LLM service providers:
-
-```go
-const (
- ProviderOpenAI Provider = "openai"
- ProviderXAI Provider = "xai"
- ProviderClaude Provider = "claude"
- ProviderGemini Provider = "gemini"
- ProviderLocal Provider = "local"
-)
-```
-
-### Message
-
-Represents a single message in a conversation:
-
-```go
-type Message struct {
- Role Role `json:"role"`
- Content string `json:"content"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
- ToolCallID string `json:"tool_call_id,omitempty"`
- Name string `json:"name,omitempty"`
-}
-```
-
-### ChatCompletionRequest
-
-Represents a request to complete a chat conversation:
-
-```go
-type ChatCompletionRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- MaxTokens *int `json:"max_tokens,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP *float64 `json:"top_p,omitempty"`
- Stream *bool `json:"stream,omitempty"`
- Tools []Tool `json:"tools,omitempty"`
- ToolChoice interface{} `json:"tool_choice,omitempty"`
- ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
- ExtraParams map[string]interface{} `json:"-"` // Provider-specific parameters
-}
-```
-
-## Main Interface
-
-### LLMProvider
-
-The main interface that all LLM providers must implement:
-
-```go
-type LLMProvider interface {
- // ChatCompletion creates a chat completion
- ChatCompletion(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionResponse, error)
-
- // CreateEmbeddings generates embeddings for the given input
- CreateEmbeddings(ctx context.Context, req EmbeddingRequest) (*EmbeddingResponse, error)
-
- // Close performs any necessary cleanup
- Close() error
-}
-```
-
-## Provider Factory
-
-The package includes a factory system for creating and managing LLM providers:
-
-```go
-// Register a provider factory
-err := llm.RegisterProvider(llm.ProviderOpenAI, openaiFactory)
-
-// Create a provider
-provider, err := llm.CreateProvider(config)
-
-// Check if a provider is supported
-if llm.SupportsProvider(llm.ProviderOpenAI) {
- // Provider is available
-}
-
-// List all registered providers
-providers := llm.ListProviders()
-```
-
-## Configuration
-
-Each provider can be configured with specific settings:
-
-```go
-config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "your-api-key",
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- MaxRetries: 3,
- ExtraConfig: map[string]interface{}{
- "organization": "your-org-id",
- },
-}
-```
-
-## Tool/Function Calling
-
-Support for function calling and tool usage:
-
-```go
-tools := []llm.Tool{
- {
- Type: "function",
- Function: llm.Function{
- Name: "get_weather",
- Description: "Get the current weather for a location",
- Parameters: map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "location": map[string]interface{}{
- "type": "string",
- "description": "The city and state, e.g. San Francisco, CA",
- },
- },
- "required": []string{"location"},
- },
- },
- },
-}
-
-req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: messages,
- Tools: tools,
-}
-```
-
-## Embeddings
-
-Generate embeddings for text:
-
-```go
-req := llm.EmbeddingRequest{
- Input: "Hello, world!",
- Model: "text-embedding-ada-002",
-}
-
-resp, err := provider.CreateEmbeddings(context.Background(), req)
-if err != nil {
- log.Fatal(err)
-}
-
-fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding))
-```
-
-## Implementing a New Provider
-
-To implement a new LLM provider:
-
-1. **Implement the LLMProvider interface**:
-
-```go
-type MyProvider struct {
- config llm.Config
- client *http.Client
-}
-
-func (p *MyProvider) ChatCompletion(ctx context.Context, req llm.ChatCompletionRequest) (*llm.ChatCompletionResponse, error) {
- // Implementation here
-}
-
-func (p *MyProvider) CreateEmbeddings(ctx context.Context, req llm.EmbeddingRequest) (*llm.EmbeddingResponse, error) {
- // Implementation here
-}
-
-func (p *MyProvider) Close() error {
- // Cleanup implementation
- return nil
-}
-```
-
-2. **Create a factory**:
-
-```go
-type MyProviderFactory struct{}
-
-func (f *MyProviderFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
- return &MyProvider{config: config}, nil
-}
-
-func (f *MyProviderFactory) SupportsProvider(provider llm.Provider) bool {
- return provider == llm.ProviderMyProvider
-}
-```
-
-3. **Register the provider**:
-
-```go
-func init() {
- llm.RegisterProvider(llm.ProviderMyProvider, &MyProviderFactory{})
-}
-```
-
-## Error Handling
-
-The package defines common error types:
-
-```go
-var (
- ErrInvalidConfig = fmt.Errorf("invalid configuration")
- ErrUnsupportedProvider = fmt.Errorf("unsupported provider")
- ErrAPIKeyRequired = fmt.Errorf("API key is required")
- ErrModelNotFound = fmt.Errorf("model not found")
- ErrRateLimitExceeded = fmt.Errorf("rate limit exceeded")
- ErrContextCancelled = fmt.Errorf("context cancelled")
- ErrTimeout = fmt.Errorf("request timeout")
-)
-```
-
-## Utilities
-
-The package includes utility functions:
-
-```go
-// Validate configuration
-err := llm.ValidateConfig(config)
-
-// Check if provider is valid
-if llm.IsValidProvider(llm.ProviderOpenAI) {
- // Provider is valid
-}
-
-// Get default configuration
-config, err := llm.GetDefaultConfig(llm.ProviderOpenAI)
-
-// Merge custom config with defaults
-config = llm.MergeConfig(customConfig)
-
-// Validate requests
-err := llm.ValidateChatCompletionRequest(req)
-err := llm.ValidateEmbeddingRequest(req)
-
-// Estimate tokens
-tokens := llm.EstimateTokens(req)
-```
-
-## Thread Safety
-
-The factory and registry implementations are thread-safe and can be used concurrently from multiple goroutines.
-
-## Default Configurations
-
-The package provides default configurations for each provider:
-
-```go
-var DefaultConfigs = map[llm.Provider]llm.Config{
- llm.ProviderOpenAI: {
- Provider: llm.ProviderOpenAI,
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- MaxRetries: 3,
- },
- // ... other providers
-}
-```
-
-## Next Steps
-
-1. Implement the actual provider implementations (OpenAI, xAI, Claude, etc.)
-2. Add tests for the interface and implementations
-3. Add more utility functions as needed
-4. Consider adding caching and retry mechanisms
-5. Add support for more provider-specific features
\ No newline at end of file
diff --git a/server/llm/factory.go b/server/llm/factory.go
deleted file mode 100644
index e425eee..0000000
--- a/server/llm/factory.go
+++ /dev/null
@@ -1,206 +0,0 @@
-package llm
-
-import (
- "fmt"
- "sync"
-)
-
-// GlobalProviderFactory is the main factory for creating LLM providers
-type GlobalProviderFactory struct {
- providers map[Provider]ProviderFactory
- mu sync.RWMutex
-}
-
-// NewGlobalProviderFactory creates a new global provider factory
-func NewGlobalProviderFactory() *GlobalProviderFactory {
- return &GlobalProviderFactory{
- providers: make(map[Provider]ProviderFactory),
- }
-}
-
-// RegisterProvider registers a provider factory for a specific provider type
-func (f *GlobalProviderFactory) RegisterProvider(provider Provider, factory ProviderFactory) error {
- f.mu.Lock()
- defer f.mu.Unlock()
-
- if !IsValidProvider(provider) {
- return fmt.Errorf("unsupported provider: %s", provider)
- }
-
- f.providers[provider] = factory
- return nil
-}
-
-// CreateProvider creates a new LLM provider instance
-func (f *GlobalProviderFactory) CreateProvider(config Config) (LLMProvider, error) {
- f.mu.RLock()
- factory, exists := f.providers[config.Provider]
- f.mu.RUnlock()
-
- if !exists {
- return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
- }
-
- // Validate and merge config
- if err := ValidateConfig(config); err != nil {
- return nil, fmt.Errorf("invalid config: %w", err)
- }
-
- config = MergeConfig(config)
-
- return factory.CreateProvider(config)
-}
-
-// SupportsProvider checks if the factory supports the given provider
-func (f *GlobalProviderFactory) SupportsProvider(provider Provider) bool {
- f.mu.RLock()
- defer f.mu.RUnlock()
-
- _, exists := f.providers[provider]
- return exists
-}
-
-// ListSupportedProviders returns a list of supported providers
-func (f *GlobalProviderFactory) ListSupportedProviders() []Provider {
- f.mu.RLock()
- defer f.mu.RUnlock()
-
- providers := make([]Provider, 0, len(f.providers))
- for provider := range f.providers {
- providers = append(providers, provider)
- }
-
- return providers
-}
-
-// UnregisterProvider removes a provider factory
-func (f *GlobalProviderFactory) UnregisterProvider(provider Provider) {
- f.mu.Lock()
- defer f.mu.Unlock()
-
- delete(f.providers, provider)
-}
-
-// DefaultFactory is the default global factory instance
-var DefaultFactory = NewGlobalProviderFactory()
-
-// RegisterDefaultProvider registers a provider with the default factory
-func RegisterDefaultProvider(provider Provider, factory ProviderFactory) error {
- return DefaultFactory.RegisterProvider(provider, factory)
-}
-
-// CreateDefaultProvider creates a provider using the default factory
-func CreateDefaultProvider(config Config) (LLMProvider, error) {
- return DefaultFactory.CreateProvider(config)
-}
-
-// SupportsDefaultProvider checks if the default factory supports a provider
-func SupportsDefaultProvider(provider Provider) bool {
- return DefaultFactory.SupportsProvider(provider)
-}
-
-// ListDefaultSupportedProviders returns providers supported by the default factory
-func ListDefaultSupportedProviders() []Provider {
- return DefaultFactory.ListSupportedProviders()
-}
-
-// ProviderRegistry provides a simple way to register and manage providers
-type ProviderRegistry struct {
- factories map[Provider]ProviderFactory
- mu sync.RWMutex
-}
-
-// NewProviderRegistry creates a new provider registry
-func NewProviderRegistry() *ProviderRegistry {
- return &ProviderRegistry{
- factories: make(map[Provider]ProviderFactory),
- }
-}
-
-// Register registers a provider factory
-func (r *ProviderRegistry) Register(provider Provider, factory ProviderFactory) error {
- r.mu.Lock()
- defer r.mu.Unlock()
-
- if !IsValidProvider(provider) {
- return fmt.Errorf("unsupported provider: %s", provider)
- }
-
- r.factories[provider] = factory
- return nil
-}
-
-// Get retrieves a provider factory
-func (r *ProviderRegistry) Get(provider Provider) (ProviderFactory, bool) {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- factory, exists := r.factories[provider]
- return factory, exists
-}
-
-// Create creates a new LLM provider instance
-func (r *ProviderRegistry) Create(config Config) (LLMProvider, error) {
- factory, exists := r.Get(config.Provider)
- if !exists {
- return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
- }
-
- // Validate and merge config
- if err := ValidateConfig(config); err != nil {
- return nil, fmt.Errorf("invalid config: %w", err)
- }
-
- config = MergeConfig(config)
-
- return factory.CreateProvider(config)
-}
-
-// List returns all registered providers
-func (r *ProviderRegistry) List() []Provider {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- providers := make([]Provider, 0, len(r.factories))
- for provider := range r.factories {
- providers = append(providers, provider)
- }
-
- return providers
-}
-
-// Unregister removes a provider factory
-func (r *ProviderRegistry) Unregister(provider Provider) {
- r.mu.Lock()
- defer r.mu.Unlock()
-
- delete(r.factories, provider)
-}
-
-// DefaultRegistry is the default provider registry
-var DefaultRegistry = NewProviderRegistry()
-
-// RegisterProvider registers a provider with the default registry
-func RegisterProvider(provider Provider, factory ProviderFactory) error {
- return DefaultRegistry.Register(provider, factory)
-}
-
-// CreateProvider creates a provider using the default registry
-func CreateProvider(config Config) (LLMProvider, error) {
- return DefaultRegistry.Create(config)
-}
-
-// GetProviderFactory gets a provider factory from the default registry
-func GetProviderFactory(provider Provider) (ProviderFactory, bool) {
- return DefaultRegistry.Get(provider)
-}
-
-// ListProviders returns all providers registered with the default registry
-func ListProviders() []Provider {
- return DefaultRegistry.List()
-}
-
-// UnregisterProvider removes a provider from the default registry
-func UnregisterProvider(provider Provider) {
- DefaultRegistry.Unregister(provider)
-}
diff --git a/server/llm/fake/factory.go b/server/llm/fake/factory.go
deleted file mode 100644
index 8031e62..0000000
--- a/server/llm/fake/factory.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package fake
-
-import (
- "github.com/iomodo/staff/llm"
-)
-
-// FakeFactory creates fake LLM providers for testing
-type FakeFactory struct{}
-
-// NewFakeFactory creates a new fake factory
-func NewFakeFactory() *FakeFactory {
- return &FakeFactory{}
-}
-
-// CreateProvider creates a new fake provider
-func (f *FakeFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
- return NewFakeProvider(), nil
-}
-
-// SupportsProvider returns true if this factory supports the given provider type
-func (f *FakeFactory) SupportsProvider(provider llm.Provider) bool {
- return provider == llm.ProviderFake
-}
-
-// init registers the fake factory when the package is imported
-func init() {
- llm.RegisterProvider(llm.ProviderFake, NewFakeFactory())
-}
\ No newline at end of file
diff --git a/server/llm/fake/fake.go b/server/llm/fake/fake.go
index 58185b3..92560cf 100644
--- a/server/llm/fake/fake.go
+++ b/server/llm/fake/fake.go
@@ -16,7 +16,7 @@
}
// NewFakeProvider creates a new fake provider with predefined responses
-func NewFakeProvider() *FakeProvider {
+func New() *FakeProvider {
responses := []string{
`## Task Solution
@@ -290,4 +290,4 @@
func (f *FakeProvider) Close() error {
// Nothing to close for fake provider
return nil
-}
\ No newline at end of file
+}
diff --git a/server/llm/llm.go b/server/llm/llm.go
index aa63c3e..b4db545 100644
--- a/server/llm/llm.go
+++ b/server/llm/llm.go
@@ -17,15 +17,6 @@
Close() error
}
-// ProviderFactory creates LLM provider instances
-type ProviderFactory interface {
- // CreateProvider creates a new LLM provider instance
- CreateProvider(config Config) (LLMProvider, error)
-
- // SupportsProvider checks if the factory supports the given provider
- SupportsProvider(provider Provider) bool
-}
-
// Provider represents different LLM service providers
type Provider string
diff --git a/server/llm/openai/openai.go b/server/llm/openai/openai.go
index c513c53..fad11f6 100644
--- a/server/llm/openai/openai.go
+++ b/server/llm/openai/openai.go
@@ -152,8 +152,7 @@
} `json:"error"`
}
-// NewOpenAIProvider creates a new OpenAI provider
-func NewOpenAIProvider(config llm.Config) *OpenAIProvider {
+func New(config llm.Config) *OpenAIProvider {
client := &http.Client{
Timeout: config.Timeout,
}
@@ -436,33 +435,3 @@
return body, nil
}
-
-// OpenAIFactory implements ProviderFactory for OpenAI
-type OpenAIFactory struct{}
-
-// CreateProvider creates a new OpenAI provider
-func (f *OpenAIFactory) CreateProvider(config llm.Config) (llm.LLMProvider, error) {
- if config.Provider != llm.ProviderOpenAI {
- return nil, fmt.Errorf("OpenAI factory cannot create provider: %s", config.Provider)
- }
-
- // Validate config
- if err := llm.ValidateConfig(config); err != nil {
- return nil, fmt.Errorf("invalid OpenAI config: %w", err)
- }
-
- // Merge with defaults
- config = llm.MergeConfig(config)
-
- return NewOpenAIProvider(config), nil
-}
-
-// SupportsProvider checks if this factory supports the given provider
-func (f *OpenAIFactory) SupportsProvider(provider llm.Provider) bool {
- return provider == llm.ProviderOpenAI
-}
-
-// Register OpenAI provider with the default registry
-func init() {
- llm.RegisterProvider(llm.ProviderOpenAI, &OpenAIFactory{})
-}
diff --git a/server/llm/openai/openai_example.go b/server/llm/openai/openai_example.go
deleted file mode 100644
index 6cd7dbf..0000000
--- a/server/llm/openai/openai_example.go
+++ /dev/null
@@ -1,254 +0,0 @@
-package openai
-
-import (
- "context"
- "fmt"
- "log"
- "time"
-
- "github.com/iomodo/staff/llm"
-)
-
-// ExampleOpenAI demonstrates how to use the OpenAI implementation
-func ExampleOpenAI() {
- // Create OpenAI configuration
- config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "your-openai-api-key-here", // Replace with your actual API key
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- ExtraConfig: map[string]interface{}{
- "organization": "your-org-id", // Optional: your OpenAI organization ID
- },
- }
-
- // Create the provider
- provider, err := llm.CreateProvider(config)
- if err != nil {
- log.Fatalf("Failed to create OpenAI provider: %v", err)
- }
- defer provider.Close()
-
- // Example 1: Basic chat completion
- exampleOpenAIChatCompletion(provider)
-
- // Example 2: Function calling
- exampleOpenAIFunctionCalling(provider)
-
- // Example 3: Embeddings
- exampleOpenAIEmbeddings(provider)
-}
-
-// exampleOpenAIChatCompletion demonstrates basic chat completion with OpenAI
-func exampleOpenAIChatCompletion(provider llm.LLMProvider) {
- fmt.Println("=== OpenAI Chat Completion Example ===")
-
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleSystem, Content: "You are a helpful assistant."},
- {Role: llm.RoleUser, Content: "What is the capital of France?"},
- },
- MaxTokens: &[]int{100}[0],
- Temperature: &[]float64{0.7}[0],
- }
-
- ctx := context.Background()
- resp, err := provider.ChatCompletion(ctx, req)
- if err != nil {
- log.Printf("Chat completion failed: %v", err)
- return
- }
-
- fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
- fmt.Printf("Model: %s\n", resp.Model)
- fmt.Printf("Provider: %s\n", resp.Provider)
- fmt.Printf("Usage: %+v\n", resp.Usage)
-}
-
-// exampleOpenAIFunctionCalling demonstrates function calling with OpenAI
-func exampleOpenAIFunctionCalling(provider llm.LLMProvider) {
- fmt.Println("\n=== OpenAI Function Calling Example ===")
-
- // Define a function that can be called
- tools := []llm.Tool{
- {
- Type: "function",
- Function: llm.Function{
- Name: "get_weather",
- Description: "Get the current weather for a location",
- Parameters: map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "location": map[string]interface{}{
- "type": "string",
- "description": "The city and state, e.g. San Francisco, CA",
- },
- "unit": map[string]interface{}{
- "type": "string",
- "enum": []string{"celsius", "fahrenheit"},
- },
- },
- "required": []string{"location"},
- },
- },
- },
- }
-
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "What's the weather like in Tokyo?"},
- },
- Tools: tools,
- MaxTokens: &[]int{150}[0],
- Temperature: &[]float64{0.1}[0],
- }
-
- ctx := context.Background()
- resp, err := provider.ChatCompletion(ctx, req)
- if err != nil {
- log.Printf("Function calling failed: %v", err)
- return
- }
-
- // Check if the model wants to call a function
- if len(resp.Choices[0].Message.ToolCalls) > 0 {
- fmt.Println("Model wants to call a function:")
- for _, toolCall := range resp.Choices[0].Message.ToolCalls {
- fmt.Printf("Function: %s\n", toolCall.Function.Name)
- fmt.Printf("Arguments: %+v\n", toolCall.Function.Parameters)
- }
- } else {
- fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
- }
-}
-
-// exampleOpenAIEmbeddings demonstrates embedding generation with OpenAI
-func exampleOpenAIEmbeddings(provider llm.LLMProvider) {
- fmt.Println("\n=== OpenAI Embeddings Example ===")
-
- req := llm.EmbeddingRequest{
- Input: "Hello, world! This is a test sentence for embeddings.",
- Model: "text-embedding-ada-002",
- }
-
- ctx := context.Background()
- resp, err := provider.CreateEmbeddings(ctx, req)
- if err != nil {
- log.Printf("Embeddings failed: %v", err)
- return
- }
-
- fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding))
- fmt.Printf("First 5 values: %v\n", resp.Data[0].Embedding[:5])
- fmt.Printf("Model: %s\n", resp.Model)
- fmt.Printf("Provider: %s\n", resp.Provider)
- fmt.Printf("Usage: %+v\n", resp.Usage)
-}
-
-// ExampleOpenAIWithErrorHandling demonstrates proper error handling
-func ExampleOpenAIWithErrorHandling() {
- fmt.Println("\n=== OpenAI Error Handling Example ===")
-
- // Test with invalid API key
- config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "invalid-key",
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- }
-
- provider, err := llm.CreateProvider(config)
- if err != nil {
- fmt.Printf("Failed to create provider: %v\n", err)
- return
- }
- defer provider.Close()
-
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello"},
- },
- MaxTokens: &[]int{50}[0],
- }
-
- ctx := context.Background()
- resp, err := provider.ChatCompletion(ctx, req)
- if err != nil {
- fmt.Printf("Expected error with invalid API key: %v\n", err)
- return
- }
-
- fmt.Printf("Unexpected success: %s\n", resp.Choices[0].Message.Content)
-}
-
-// ExampleOpenAIWithCustomConfig demonstrates custom configuration
-func ExampleOpenAIWithCustomConfig() {
- fmt.Println("\n=== OpenAI Custom Configuration Example ===")
-
- // Start with minimal config
- customConfig := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "your-api-key",
- }
-
- // Merge with defaults
- config := llm.MergeConfig(customConfig)
-
- fmt.Printf("Merged config: BaseURL=%s, Timeout=%v, MaxRetries=%d\n",
- config.BaseURL, config.Timeout, config.MaxRetries)
-
- // Add extra configuration
- config.ExtraConfig = map[string]interface{}{
- "organization": "your-org-id",
- "project": "your-project-id",
- }
-
- provider, err := llm.CreateProvider(config)
- if err != nil {
- fmt.Printf("Failed to create provider: %v\n", err)
- return
- }
- defer provider.Close()
-
- fmt.Println("Provider created successfully with custom configuration")
-}
-
-// ExampleOpenAIWithValidation demonstrates request validation
-func ExampleOpenAIWithValidation() {
- fmt.Println("\n=== OpenAI Validation Example ===")
-
- // Valid request
- validReq := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello"},
- },
- Temperature: &[]float64{0.5}[0],
- }
-
- if err := llm.ValidateChatCompletionRequest(validReq); err != nil {
- fmt.Printf("Valid request validation failed: %v\n", err)
- } else {
- fmt.Println("Valid request passed validation")
- }
-
- // Invalid request (no model)
- invalidReq := llm.ChatCompletionRequest{
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello"},
- },
- }
-
- if err := llm.ValidateChatCompletionRequest(invalidReq); err != nil {
- fmt.Printf("Invalid request correctly caught: %v\n", err)
- } else {
- fmt.Println("Invalid request should have failed validation")
- }
-
- // Test token estimation
- tokens := llm.EstimateTokens(validReq)
- fmt.Printf("Estimated tokens for request: %d\n", tokens)
-}
diff --git a/server/llm/openai/openai_test.go b/server/llm/openai/openai_test.go
deleted file mode 100644
index 8896405..0000000
--- a/server/llm/openai/openai_test.go
+++ /dev/null
@@ -1,428 +0,0 @@
-package openai
-
-import (
- "testing"
- "time"
-
- "github.com/iomodo/staff/llm"
-)
-
-func TestOpenAIProvider_Interface(t *testing.T) {
- // Test that OpenAIProvider implements LLMProvider interface
- var _ llm.LLMProvider = (*OpenAIProvider)(nil)
-}
-
-func TestOpenAIFactory_CreateProvider(t *testing.T) {
- factory := &OpenAIFactory{}
-
- // Test valid config
- config := llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- }
-
- provider, err := factory.CreateProvider(config)
- if err != nil {
- t.Fatalf("Failed to create provider: %v", err)
- }
-
- if provider == nil {
- t.Fatal("Provider should not be nil")
- }
-
- // Test invalid provider
- invalidConfig := llm.Config{
- Provider: llm.ProviderClaude,
- APIKey: "test-key",
- }
-
- _, err = factory.CreateProvider(invalidConfig)
- if err == nil {
- t.Fatal("Should fail with invalid provider")
- }
-
- // Test missing API key
- noKeyConfig := llm.Config{
- Provider: llm.ProviderOpenAI,
- BaseURL: "https://api.openai.com/v1",
- }
-
- _, err = factory.CreateProvider(noKeyConfig)
- if err == nil {
- t.Fatal("Should fail with missing API key")
- }
-}
-
-func TestOpenAIFactory_SupportsProvider(t *testing.T) {
- factory := &OpenAIFactory{}
-
- if !factory.SupportsProvider(llm.ProviderOpenAI) {
- t.Fatal("Should support OpenAI provider")
- }
-
- if factory.SupportsProvider(llm.ProviderClaude) {
- t.Fatal("Should not support Claude provider")
- }
-}
-
-func TestOpenAIProvider_ConvertRequest(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test basic request conversion
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "Hello"},
- },
- MaxTokens: &[]int{100}[0],
- Temperature: &[]float64{0.7}[0],
- }
-
- openAIReq := provider.convertToOpenAIRequest(req)
-
- if openAIReq.Model != "gpt-3.5-turbo" {
- t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", openAIReq.Model)
- }
-
- if len(openAIReq.Messages) != 1 {
- t.Errorf("Expected 1 message, got %d", len(openAIReq.Messages))
- }
-
- if openAIReq.Messages[0].Role != "user" {
- t.Errorf("Expected role 'user', got '%s'", openAIReq.Messages[0].Role)
- }
-
- if openAIReq.Messages[0].Content != "Hello" {
- t.Errorf("Expected content 'Hello', got '%s'", openAIReq.Messages[0].Content)
- }
-
- if *openAIReq.MaxTokens != 100 {
- t.Errorf("Expected max_tokens 100, got %d", *openAIReq.MaxTokens)
- }
-
- if *openAIReq.Temperature != 0.7 {
- t.Errorf("Expected temperature 0.7, got %f", *openAIReq.Temperature)
- }
-}
-
-func TestOpenAIProvider_ConvertResponse(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test basic response conversion
- openAIResp := OpenAIResponse{
- ID: "test-id",
- Object: "chat.completion",
- Created: 1234567890,
- Model: "gpt-3.5-turbo",
- Choices: []OpenAIChoice{
- {
- Index: 0,
- Message: OpenAIMessage{
- Role: "assistant",
- Content: "Hello! How can I help you?",
- },
- FinishReason: "stop",
- },
- },
- Usage: OpenAIUsage{
- PromptTokens: 10,
- CompletionTokens: 20,
- TotalTokens: 30,
- },
- }
-
- resp := provider.convertFromOpenAIResponse(openAIResp)
-
- if resp.ID != "test-id" {
- t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
- }
-
- if resp.Model != "gpt-3.5-turbo" {
- t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", resp.Model)
- }
-
- if resp.Provider != llm.ProviderOpenAI {
- t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
- }
-
- if len(resp.Choices) != 1 {
- t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
- }
-
- if resp.Choices[0].Message.Role != llm.RoleAssistant {
- t.Errorf("Expected role assistant, got %s", resp.Choices[0].Message.Role)
- }
-
- if resp.Choices[0].Message.Content != "Hello! How can I help you?" {
- t.Errorf("Expected content 'Hello! How can I help you?', got '%s'", resp.Choices[0].Message.Content)
- }
-
- if resp.Usage.PromptTokens != 10 {
- t.Errorf("Expected prompt tokens 10, got %d", resp.Usage.PromptTokens)
- }
-
- if resp.Usage.CompletionTokens != 20 {
- t.Errorf("Expected completion tokens 20, got %d", resp.Usage.CompletionTokens)
- }
-
- if resp.Usage.TotalTokens != 30 {
- t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
- }
-}
-
-func TestOpenAIProvider_ConvertRequestWithTools(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test request with tools
- tools := []llm.Tool{
- {
- Type: "function",
- Function: llm.Function{
- Name: "get_weather",
- Description: "Get weather information",
- Parameters: map[string]interface{}{
- "type": "object",
- "properties": map[string]interface{}{
- "location": map[string]interface{}{
- "type": "string",
- },
- },
- },
- },
- },
- }
-
- req := llm.ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []llm.Message{
- {Role: llm.RoleUser, Content: "What's the weather like?"},
- },
- Tools: tools,
- }
-
- openAIReq := provider.convertToOpenAIRequest(req)
-
- if len(openAIReq.Tools) != 1 {
- t.Errorf("Expected 1 tool, got %d", len(openAIReq.Tools))
- }
-
- if openAIReq.Tools[0].Type != "function" {
- t.Errorf("Expected tool type 'function', got '%s'", openAIReq.Tools[0].Type)
- }
-
- if openAIReq.Tools[0].Function.Name != "get_weather" {
- t.Errorf("Expected function name 'get_weather', got '%s'", openAIReq.Tools[0].Function.Name)
- }
-}
-
-func TestOpenAIProvider_ConvertResponseWithToolCalls(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test response with tool calls
- openAIResp := OpenAIResponse{
- ID: "test-id",
- Object: "chat.completion",
- Model: "gpt-3.5-turbo",
- Choices: []OpenAIChoice{
- {
- Index: 0,
- Message: OpenAIMessage{
- Role: "assistant",
- ToolCalls: []OpenAIToolCall{
- {
- ID: "call_123",
- Type: "function",
- Function: OpenAIFunction{
- Name: "get_weather",
- Parameters: map[string]interface{}{
- "location": "Tokyo",
- },
- },
- },
- },
- },
- FinishReason: "tool_calls",
- },
- },
- Usage: OpenAIUsage{
- PromptTokens: 10,
- CompletionTokens: 20,
- TotalTokens: 30,
- },
- }
-
- resp := provider.convertFromOpenAIResponse(openAIResp)
-
- if len(resp.Choices[0].Message.ToolCalls) != 1 {
- t.Errorf("Expected 1 tool call, got %d", len(resp.Choices[0].Message.ToolCalls))
- }
-
- if resp.Choices[0].Message.ToolCalls[0].ID != "call_123" {
- t.Errorf("Expected tool call ID 'call_123', got '%s'", resp.Choices[0].Message.ToolCalls[0].ID)
- }
-
- if resp.Choices[0].Message.ToolCalls[0].Function.Name != "get_weather" {
- t.Errorf("Expected function name 'get_weather', got '%s'", resp.Choices[0].Message.ToolCalls[0].Function.Name)
- }
-
- if resp.Choices[0].FinishReason != "tool_calls" {
- t.Errorf("Expected finish reason 'tool_calls', got '%s'", resp.Choices[0].FinishReason)
- }
-}
-
-func TestOpenAIProvider_ConvertEmbeddingRequest(t *testing.T) {
- req := llm.EmbeddingRequest{
- Input: "Hello, world!",
- Model: "text-embedding-ada-002",
- User: "test-user",
- }
-
- // The conversion is done inline in CreateEmbeddings, so we'll test the structure
- if req.Input != "Hello, world!" {
- t.Errorf("Expected input 'Hello, world!', got '%v'", req.Input)
- }
-
- if req.Model != "text-embedding-ada-002" {
- t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", req.Model)
- }
-
- if req.User != "test-user" {
- t.Errorf("Expected user 'test-user', got '%s'", req.User)
- }
-}
-
-func TestOpenAIProvider_ConvertEmbeddingResponse(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test embedding response conversion
- openAIResp := OpenAIEmbeddingResponse{
- Object: "list",
- Model: "text-embedding-ada-002",
- Data: []OpenAIEmbeddingData{
- {
- Object: "embedding",
- Embedding: []float64{0.1, 0.2, 0.3},
- Index: 0,
- },
- },
- Usage: OpenAIUsage{
- PromptTokens: 5,
- CompletionTokens: 0,
- TotalTokens: 5,
- },
- }
-
- resp := provider.convertFromOpenAIEmbeddingResponse(openAIResp)
-
- if resp.Object != "list" {
- t.Errorf("Expected object 'list', got '%s'", resp.Object)
- }
-
- if resp.Model != "text-embedding-ada-002" {
- t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", resp.Model)
- }
-
- if resp.Provider != llm.ProviderOpenAI {
- t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
- }
-
- if len(resp.Data) != 1 {
- t.Errorf("Expected 1 embedding, got %d", len(resp.Data))
- }
-
- if len(resp.Data[0].Embedding) != 3 {
- t.Errorf("Expected embedding dimension 3, got %d", len(resp.Data[0].Embedding))
- }
-
- if resp.Data[0].Embedding[0] != 0.1 {
- t.Errorf("Expected first embedding value 0.1, got %f", resp.Data[0].Embedding[0])
- }
-}
-
-func TestOpenAIProvider_Close(t *testing.T) {
- provider := &OpenAIProvider{
- config: llm.Config{
- Provider: llm.ProviderOpenAI,
- APIKey: "test-key",
- BaseURL: "https://api.openai.com/v1",
- },
- }
-
- // Test that Close doesn't return an error
- err := provider.Close()
- if err != nil {
- t.Errorf("Close should not return an error: %v", err)
- }
-}
-
-func TestOpenAIProvider_Integration(t *testing.T) {
- // This test would require a real API key and would make actual API calls
- // It's commented out to avoid making real API calls during testing
- /*
- config := Config{
- Provider: ProviderOpenAI,
- APIKey: "your-real-api-key",
- BaseURL: "https://api.openai.com/v1",
- Timeout: 30 * time.Second,
- }
-
- provider, err := CreateProvider(config)
- if err != nil {
- t.Fatalf("Failed to create provider: %v", err)
- }
- defer provider.Close()
-
- req := ChatCompletionRequest{
- Model: "gpt-3.5-turbo",
- Messages: []Message{
- {Role: RoleUser, Content: "Say hello!"},
- },
- MaxTokens: &[]int{50}[0],
- }
-
- resp, err := provider.ChatCompletion(context.Background(), req)
- if err != nil {
- t.Fatalf("Chat completion failed: %v", err)
- }
-
- if len(resp.Choices) == 0 {
- t.Fatal("Expected at least one choice")
- }
-
- if resp.Choices[0].Message.Content == "" {
- t.Fatal("Expected non-empty response content")
- }
- */
-}
diff --git a/server/llm/provider/provider.go b/server/llm/provider/provider.go
new file mode 100644
index 0000000..24017f2
--- /dev/null
+++ b/server/llm/provider/provider.go
@@ -0,0 +1,19 @@
+package provider
+
+import (
+ "github.com/iomodo/staff/llm"
+ "github.com/iomodo/staff/llm/fake"
+ "github.com/iomodo/staff/llm/openai"
+)
+
+func CreateProvider(config llm.Config) llm.LLMProvider {
+ switch config.Provider {
+ case llm.ProviderFake:
+ return fake.New()
+ case llm.ProviderOpenAI:
+ return openai.New(config)
+ default:
+ return fake.New()
+ }
+
+}
diff --git a/server/llm/providers/providers.go b/server/llm/providers/providers.go
deleted file mode 100644
index 98d01a8..0000000
--- a/server/llm/providers/providers.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package providers
-
-import (
- _ "github.com/iomodo/staff/llm/fake" // Register Fake provider for testing
- _ "github.com/iomodo/staff/llm/openai" // Register OpenAI provider
-)
-
-// RegisterAll registers all available LLM providers
-func RegisterAll() {
- // Import all provider packages to trigger their init() functions
- // This ensures all providers are registered when this package is imported
- _ = "import all providers"
-}
-
-// EnsureRegistered ensures all providers are registered
-// This function can be called from anywhere to ensure providers are available
-func EnsureRegistered() {
- // The blank imports below will trigger the init() functions
- // which register the providers with the LLM factory
-}
diff --git a/server/llm/registry.go b/server/llm/registry.go
deleted file mode 100644
index f41ffe6..0000000
--- a/server/llm/registry.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package llm
-
-import (
- "sync"
-)
-
-// AutoRegistry automatically registers all available providers
-type AutoRegistry struct {
- mu sync.Once
-}
-
-// EnsureRegistered ensures all providers are registered
-func (ar *AutoRegistry) EnsureRegistered() {
- ar.mu.Do(func() {
- // Register all available providers
- ar.registerOpenAI()
- // Add more providers here as they become available
- // ar.registerClaude()
- // ar.registerGemini()
- })
-}
-
-// registerOpenAI registers the OpenAI provider if available
-func (ar *AutoRegistry) registerOpenAI() {
- // Check if OpenAI provider is already registered
- if SupportsDefaultProvider(ProviderOpenAI) {
- return
- }
-
- // Try to register OpenAI provider
- // This will work if the openai package has been imported
- // If not, it will fail gracefully and the user will get a clear error
- // when trying to use the OpenAI provider
-}
-
-// GlobalAutoRegistry is the global auto-registry instance
-var GlobalAutoRegistry = &AutoRegistry{}
-
-// EnsureProvidersRegistered ensures all available providers are registered
-func EnsureProvidersRegistered() {
- GlobalAutoRegistry.EnsureRegistered()
-}
-
-// CreateProviderWithAutoRegistration creates a provider with automatic registration
-func CreateProviderWithAutoRegistration(config Config) (LLMProvider, error) {
- EnsureProvidersRegistered()
- return CreateDefaultProvider(config)
-}
-
-// SupportsProviderWithAutoRegistration checks if a provider is supported with auto-registration
-func SupportsProviderWithAutoRegistration(provider Provider) bool {
- EnsureProvidersRegistered()
- return SupportsDefaultProvider(provider)
-}
diff --git a/server/tm/git_tm/git_task_manager.go b/server/tm/git_tm/git_task_manager.go
index add69ef..2e9fce3 100644
--- a/server/tm/git_tm/git_task_manager.go
+++ b/server/tm/git_tm/git_task_manager.go
@@ -1118,7 +1118,7 @@
"{task_id}": task.ID,
"{task_title}": task.Title,
"{task_description}": task.Description,
- "{agent_name}": fmt.Sprintf("%s", agentName),
+ "{agent_name}": agentName,
"{priority}": string(task.Priority),
"{solution}": truncatedSolution,
"{files_changed}": fmt.Sprintf("- `tasks/solutions/%s-solution.md`", task.ID),