Add llm interface
Change-Id: Idf599500fc131fb9509102e38736a6baeff6d6d8
diff --git a/server/llm/factory.go b/server/llm/factory.go
new file mode 100644
index 0000000..e425eee
--- /dev/null
+++ b/server/llm/factory.go
@@ -0,0 +1,206 @@
+package llm
+
+import (
+ "fmt"
+ "sync"
+)
+
+// GlobalProviderFactory is the main factory for creating LLM providers
+type GlobalProviderFactory struct {
+ providers map[Provider]ProviderFactory
+ mu sync.RWMutex
+}
+
+// NewGlobalProviderFactory creates a new global provider factory
+func NewGlobalProviderFactory() *GlobalProviderFactory {
+ return &GlobalProviderFactory{
+ providers: make(map[Provider]ProviderFactory),
+ }
+}
+
+// RegisterProvider registers a provider factory for a specific provider type
+func (f *GlobalProviderFactory) RegisterProvider(provider Provider, factory ProviderFactory) error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if !IsValidProvider(provider) {
+ return fmt.Errorf("unsupported provider: %s", provider)
+ }
+
+ f.providers[provider] = factory
+ return nil
+}
+
+// CreateProvider creates a new LLM provider instance
+func (f *GlobalProviderFactory) CreateProvider(config Config) (LLMProvider, error) {
+ f.mu.RLock()
+ factory, exists := f.providers[config.Provider]
+ f.mu.RUnlock()
+
+ if !exists {
+ return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
+ }
+
+ // Validate and merge config
+ if err := ValidateConfig(config); err != nil {
+ return nil, fmt.Errorf("invalid config: %w", err)
+ }
+
+ config = MergeConfig(config)
+
+ return factory.CreateProvider(config)
+}
+
+// SupportsProvider checks if the factory supports the given provider
+func (f *GlobalProviderFactory) SupportsProvider(provider Provider) bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ _, exists := f.providers[provider]
+ return exists
+}
+
+// ListSupportedProviders returns a list of supported providers
+func (f *GlobalProviderFactory) ListSupportedProviders() []Provider {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+
+ providers := make([]Provider, 0, len(f.providers))
+ for provider := range f.providers {
+ providers = append(providers, provider)
+ }
+
+ return providers
+}
+
+// UnregisterProvider removes a provider factory
+func (f *GlobalProviderFactory) UnregisterProvider(provider Provider) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ delete(f.providers, provider)
+}
+
+// DefaultFactory is the default global factory instance
+var DefaultFactory = NewGlobalProviderFactory()
+
+// RegisterDefaultProvider registers a provider with the default factory
+func RegisterDefaultProvider(provider Provider, factory ProviderFactory) error {
+ return DefaultFactory.RegisterProvider(provider, factory)
+}
+
+// CreateDefaultProvider creates a provider using the default factory
+func CreateDefaultProvider(config Config) (LLMProvider, error) {
+ return DefaultFactory.CreateProvider(config)
+}
+
+// SupportsDefaultProvider checks if the default factory supports a provider
+func SupportsDefaultProvider(provider Provider) bool {
+ return DefaultFactory.SupportsProvider(provider)
+}
+
+// ListDefaultSupportedProviders returns providers supported by the default factory
+func ListDefaultSupportedProviders() []Provider {
+ return DefaultFactory.ListSupportedProviders()
+}
+
+// ProviderRegistry provides a simple way to register and manage providers
+type ProviderRegistry struct {
+ factories map[Provider]ProviderFactory
+ mu sync.RWMutex
+}
+
+// NewProviderRegistry creates a new provider registry
+func NewProviderRegistry() *ProviderRegistry {
+ return &ProviderRegistry{
+ factories: make(map[Provider]ProviderFactory),
+ }
+}
+
+// Register registers a provider factory
+func (r *ProviderRegistry) Register(provider Provider, factory ProviderFactory) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if !IsValidProvider(provider) {
+ return fmt.Errorf("unsupported provider: %s", provider)
+ }
+
+ r.factories[provider] = factory
+ return nil
+}
+
+// Get retrieves a provider factory
+func (r *ProviderRegistry) Get(provider Provider) (ProviderFactory, bool) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ factory, exists := r.factories[provider]
+ return factory, exists
+}
+
+// Create creates a new LLM provider instance
+func (r *ProviderRegistry) Create(config Config) (LLMProvider, error) {
+ factory, exists := r.Get(config.Provider)
+ if !exists {
+ return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
+ }
+
+ // Validate and merge config
+ if err := ValidateConfig(config); err != nil {
+ return nil, fmt.Errorf("invalid config: %w", err)
+ }
+
+ config = MergeConfig(config)
+
+ return factory.CreateProvider(config)
+}
+
+// List returns all registered providers
+func (r *ProviderRegistry) List() []Provider {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ providers := make([]Provider, 0, len(r.factories))
+ for provider := range r.factories {
+ providers = append(providers, provider)
+ }
+
+ return providers
+}
+
+// Unregister removes a provider factory
+func (r *ProviderRegistry) Unregister(provider Provider) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ delete(r.factories, provider)
+}
+
+// DefaultRegistry is the default provider registry
+var DefaultRegistry = NewProviderRegistry()
+
+// RegisterProvider registers a provider with the default registry
+func RegisterProvider(provider Provider, factory ProviderFactory) error {
+ return DefaultRegistry.Register(provider, factory)
+}
+
+// CreateProvider creates a provider using the default registry
+func CreateProvider(config Config) (LLMProvider, error) {
+ return DefaultRegistry.Create(config)
+}
+
+// GetProviderFactory gets a provider factory from the default registry
+func GetProviderFactory(provider Provider) (ProviderFactory, bool) {
+ return DefaultRegistry.Get(provider)
+}
+
+// ListProviders returns all providers registered with the default registry
+func ListProviders() []Provider {
+ return DefaultRegistry.List()
+}
+
+// UnregisterProvider removes a provider from the default registry
+func UnregisterProvider(provider Provider) {
+ DefaultRegistry.Unregister(provider)
+}