blob: e425eeeb6421a244c22fadc2a71ebe85cff38efd [file] [log] [blame]
iomodoa97eb222025-07-26 11:18:17 +04001package llm
2
3import (
4 "fmt"
5 "sync"
6)
7
8// GlobalProviderFactory is the main factory for creating LLM providers
9type GlobalProviderFactory struct {
10 providers map[Provider]ProviderFactory
11 mu sync.RWMutex
12}
13
14// NewGlobalProviderFactory creates a new global provider factory
15func NewGlobalProviderFactory() *GlobalProviderFactory {
16 return &GlobalProviderFactory{
17 providers: make(map[Provider]ProviderFactory),
18 }
19}
20
21// RegisterProvider registers a provider factory for a specific provider type
22func (f *GlobalProviderFactory) RegisterProvider(provider Provider, factory ProviderFactory) error {
23 f.mu.Lock()
24 defer f.mu.Unlock()
25
26 if !IsValidProvider(provider) {
27 return fmt.Errorf("unsupported provider: %s", provider)
28 }
29
30 f.providers[provider] = factory
31 return nil
32}
33
34// CreateProvider creates a new LLM provider instance
35func (f *GlobalProviderFactory) CreateProvider(config Config) (LLMProvider, error) {
36 f.mu.RLock()
37 factory, exists := f.providers[config.Provider]
38 f.mu.RUnlock()
39
40 if !exists {
41 return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
42 }
43
44 // Validate and merge config
45 if err := ValidateConfig(config); err != nil {
46 return nil, fmt.Errorf("invalid config: %w", err)
47 }
48
49 config = MergeConfig(config)
50
51 return factory.CreateProvider(config)
52}
53
54// SupportsProvider checks if the factory supports the given provider
55func (f *GlobalProviderFactory) SupportsProvider(provider Provider) bool {
56 f.mu.RLock()
57 defer f.mu.RUnlock()
58
59 _, exists := f.providers[provider]
60 return exists
61}
62
63// ListSupportedProviders returns a list of supported providers
64func (f *GlobalProviderFactory) ListSupportedProviders() []Provider {
65 f.mu.RLock()
66 defer f.mu.RUnlock()
67
68 providers := make([]Provider, 0, len(f.providers))
69 for provider := range f.providers {
70 providers = append(providers, provider)
71 }
72
73 return providers
74}
75
76// UnregisterProvider removes a provider factory
77func (f *GlobalProviderFactory) UnregisterProvider(provider Provider) {
78 f.mu.Lock()
79 defer f.mu.Unlock()
80
81 delete(f.providers, provider)
82}
83
84// DefaultFactory is the default global factory instance
85var DefaultFactory = NewGlobalProviderFactory()
86
87// RegisterDefaultProvider registers a provider with the default factory
88func RegisterDefaultProvider(provider Provider, factory ProviderFactory) error {
89 return DefaultFactory.RegisterProvider(provider, factory)
90}
91
92// CreateDefaultProvider creates a provider using the default factory
93func CreateDefaultProvider(config Config) (LLMProvider, error) {
94 return DefaultFactory.CreateProvider(config)
95}
96
97// SupportsDefaultProvider checks if the default factory supports a provider
98func SupportsDefaultProvider(provider Provider) bool {
99 return DefaultFactory.SupportsProvider(provider)
100}
101
102// ListDefaultSupportedProviders returns providers supported by the default factory
103func ListDefaultSupportedProviders() []Provider {
104 return DefaultFactory.ListSupportedProviders()
105}
106
107// ProviderRegistry provides a simple way to register and manage providers
108type ProviderRegistry struct {
109 factories map[Provider]ProviderFactory
110 mu sync.RWMutex
111}
112
113// NewProviderRegistry creates a new provider registry
114func NewProviderRegistry() *ProviderRegistry {
115 return &ProviderRegistry{
116 factories: make(map[Provider]ProviderFactory),
117 }
118}
119
120// Register registers a provider factory
121func (r *ProviderRegistry) Register(provider Provider, factory ProviderFactory) error {
122 r.mu.Lock()
123 defer r.mu.Unlock()
124
125 if !IsValidProvider(provider) {
126 return fmt.Errorf("unsupported provider: %s", provider)
127 }
128
129 r.factories[provider] = factory
130 return nil
131}
132
133// Get retrieves a provider factory
134func (r *ProviderRegistry) Get(provider Provider) (ProviderFactory, bool) {
135 r.mu.RLock()
136 defer r.mu.RUnlock()
137
138 factory, exists := r.factories[provider]
139 return factory, exists
140}
141
142// Create creates a new LLM provider instance
143func (r *ProviderRegistry) Create(config Config) (LLMProvider, error) {
144 factory, exists := r.Get(config.Provider)
145 if !exists {
146 return nil, fmt.Errorf("no factory registered for provider: %s", config.Provider)
147 }
148
149 // Validate and merge config
150 if err := ValidateConfig(config); err != nil {
151 return nil, fmt.Errorf("invalid config: %w", err)
152 }
153
154 config = MergeConfig(config)
155
156 return factory.CreateProvider(config)
157}
158
159// List returns all registered providers
160func (r *ProviderRegistry) List() []Provider {
161 r.mu.RLock()
162 defer r.mu.RUnlock()
163
164 providers := make([]Provider, 0, len(r.factories))
165 for provider := range r.factories {
166 providers = append(providers, provider)
167 }
168
169 return providers
170}
171
172// Unregister removes a provider factory
173func (r *ProviderRegistry) Unregister(provider Provider) {
174 r.mu.Lock()
175 defer r.mu.Unlock()
176
177 delete(r.factories, provider)
178}
179
180// DefaultRegistry is the default provider registry
181var DefaultRegistry = NewProviderRegistry()
182
183// RegisterProvider registers a provider with the default registry
184func RegisterProvider(provider Provider, factory ProviderFactory) error {
185 return DefaultRegistry.Register(provider, factory)
186}
187
188// CreateProvider creates a provider using the default registry
189func CreateProvider(config Config) (LLMProvider, error) {
190 return DefaultRegistry.Create(config)
191}
192
193// GetProviderFactory gets a provider factory from the default registry
194func GetProviderFactory(provider Provider) (ProviderFactory, bool) {
195 return DefaultRegistry.Get(provider)
196}
197
198// ListProviders returns all providers registered with the default registry
199func ListProviders() []Provider {
200 return DefaultRegistry.List()
201}
202
203// UnregisterProvider removes a provider from the default registry
204func UnregisterProvider(provider Provider) {
205 DefaultRegistry.Unregister(provider)
206}