blob: e425eeeb6421a244c22fadc2a71ebe85cff38efd [file] [log] [blame]
package llm
import (
"fmt"
"sync"
)
// GlobalProviderFactory is the main factory for creating LLM providers
type GlobalProviderFactory struct {
providers map[Provider]ProviderFactory
mu sync.RWMutex
}
// NewGlobalProviderFactory creates a new global provider factory
func NewGlobalProviderFactory() *GlobalProviderFactory {
return &GlobalProviderFactory{
providers: make(map[Provider]ProviderFactory),
}
}
// RegisterProvider registers a provider factory for a specific provider type
func (f *GlobalProviderFactory) RegisterProvider(provider Provider, factory ProviderFactory) error {
f.mu.Lock()
defer f.mu.Unlock()
if !IsValidProvider(provider) {
return fmt.Errorf("unsupported provider: %s", provider)
}
f.providers[provider] = factory
return nil
}
// CreateProvider creates a new LLM provider instance
func (f *GlobalProviderFactory) CreateProvider(config Config) (LLMProvider, error) {
f.mu.RLock()
factory, exists := f.providers[config.Provider]
f.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
}
// Validate and merge config
if err := ValidateConfig(config); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
config = MergeConfig(config)
return factory.CreateProvider(config)
}
// SupportsProvider checks if the factory supports the given provider
func (f *GlobalProviderFactory) SupportsProvider(provider Provider) bool {
f.mu.RLock()
defer f.mu.RUnlock()
_, exists := f.providers[provider]
return exists
}
// ListSupportedProviders returns a list of supported providers
func (f *GlobalProviderFactory) ListSupportedProviders() []Provider {
f.mu.RLock()
defer f.mu.RUnlock()
providers := make([]Provider, 0, len(f.providers))
for provider := range f.providers {
providers = append(providers, provider)
}
return providers
}
// UnregisterProvider removes a provider factory
func (f *GlobalProviderFactory) UnregisterProvider(provider Provider) {
f.mu.Lock()
defer f.mu.Unlock()
delete(f.providers, provider)
}
// DefaultFactory is the default global factory instance
var DefaultFactory = NewGlobalProviderFactory()
// RegisterDefaultProvider registers a provider with the default factory
func RegisterDefaultProvider(provider Provider, factory ProviderFactory) error {
return DefaultFactory.RegisterProvider(provider, factory)
}
// CreateDefaultProvider creates a provider using the default factory
func CreateDefaultProvider(config Config) (LLMProvider, error) {
return DefaultFactory.CreateProvider(config)
}
// SupportsDefaultProvider checks if the default factory supports a provider
func SupportsDefaultProvider(provider Provider) bool {
return DefaultFactory.SupportsProvider(provider)
}
// ListDefaultSupportedProviders returns providers supported by the default factory
func ListDefaultSupportedProviders() []Provider {
return DefaultFactory.ListSupportedProviders()
}
// ProviderRegistry provides a simple way to register and manage providers
type ProviderRegistry struct {
factories map[Provider]ProviderFactory
mu sync.RWMutex
}
// NewProviderRegistry creates a new provider registry
func NewProviderRegistry() *ProviderRegistry {
return &ProviderRegistry{
factories: make(map[Provider]ProviderFactory),
}
}
// Register registers a provider factory
func (r *ProviderRegistry) Register(provider Provider, factory ProviderFactory) error {
r.mu.Lock()
defer r.mu.Unlock()
if !IsValidProvider(provider) {
return fmt.Errorf("unsupported provider: %s", provider)
}
r.factories[provider] = factory
return nil
}
// Get retrieves a provider factory
func (r *ProviderRegistry) Get(provider Provider) (ProviderFactory, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
factory, exists := r.factories[provider]
return factory, exists
}
// Create creates a new LLM provider instance
func (r *ProviderRegistry) Create(config Config) (LLMProvider, error) {
factory, exists := r.Get(config.Provider)
if !exists {
return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
}
// Validate and merge config
if err := ValidateConfig(config); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
config = MergeConfig(config)
return factory.CreateProvider(config)
}
// List returns all registered providers
func (r *ProviderRegistry) List() []Provider {
r.mu.RLock()
defer r.mu.RUnlock()
providers := make([]Provider, 0, len(r.factories))
for provider := range r.factories {
providers = append(providers, provider)
}
return providers
}
// Unregister removes a provider factory
func (r *ProviderRegistry) Unregister(provider Provider) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.factories, provider)
}
// DefaultRegistry is the default provider registry
var DefaultRegistry = NewProviderRegistry()
// RegisterProvider registers a provider with the default registry
func RegisterProvider(provider Provider, factory ProviderFactory) error {
return DefaultRegistry.Register(provider, factory)
}
// CreateProvider creates a provider using the default registry
func CreateProvider(config Config) (LLMProvider, error) {
return DefaultRegistry.Create(config)
}
// GetProviderFactory gets a provider factory from the default registry
func GetProviderFactory(provider Provider) (ProviderFactory, bool) {
return DefaultRegistry.Get(provider)
}
// ListProviders returns all providers registered with the default registry
func ListProviders() []Provider {
return DefaultRegistry.List()
}
// UnregisterProvider removes a provider from the default registry
func UnregisterProvider(provider Provider) {
DefaultRegistry.Unregister(provider)
}