blob: 88964055a68d37879262bdc1db5665f099c3939d [file] [log] [blame]
iomodobe473d12025-07-26 11:33:08 +04001package openai
2
3import (
4 "testing"
5 "time"
6
7 "github.com/iomodo/staff/llm"
8)
9
10func TestOpenAIProvider_Interface(t *testing.T) {
11 // Test that OpenAIProvider implements LLMProvider interface
12 var _ llm.LLMProvider = (*OpenAIProvider)(nil)
13}
14
15func TestOpenAIFactory_CreateProvider(t *testing.T) {
16 factory := &OpenAIFactory{}
17
18 // Test valid config
19 config := llm.Config{
20 Provider: llm.ProviderOpenAI,
21 APIKey: "test-key",
22 BaseURL: "https://api.openai.com/v1",
23 Timeout: 30 * time.Second,
24 }
25
26 provider, err := factory.CreateProvider(config)
27 if err != nil {
28 t.Fatalf("Failed to create provider: %v", err)
29 }
30
31 if provider == nil {
32 t.Fatal("Provider should not be nil")
33 }
34
35 // Test invalid provider
36 invalidConfig := llm.Config{
37 Provider: llm.ProviderClaude,
38 APIKey: "test-key",
39 }
40
41 _, err = factory.CreateProvider(invalidConfig)
42 if err == nil {
43 t.Fatal("Should fail with invalid provider")
44 }
45
46 // Test missing API key
47 noKeyConfig := llm.Config{
48 Provider: llm.ProviderOpenAI,
49 BaseURL: "https://api.openai.com/v1",
50 }
51
52 _, err = factory.CreateProvider(noKeyConfig)
53 if err == nil {
54 t.Fatal("Should fail with missing API key")
55 }
56}
57
58func TestOpenAIFactory_SupportsProvider(t *testing.T) {
59 factory := &OpenAIFactory{}
60
61 if !factory.SupportsProvider(llm.ProviderOpenAI) {
62 t.Fatal("Should support OpenAI provider")
63 }
64
65 if factory.SupportsProvider(llm.ProviderClaude) {
66 t.Fatal("Should not support Claude provider")
67 }
68}
69
70func TestOpenAIProvider_ConvertRequest(t *testing.T) {
71 provider := &OpenAIProvider{
72 config: llm.Config{
73 Provider: llm.ProviderOpenAI,
74 APIKey: "test-key",
75 BaseURL: "https://api.openai.com/v1",
76 },
77 }
78
79 // Test basic request conversion
80 req := llm.ChatCompletionRequest{
81 Model: "gpt-3.5-turbo",
82 Messages: []llm.Message{
83 {Role: llm.RoleUser, Content: "Hello"},
84 },
85 MaxTokens: &[]int{100}[0],
86 Temperature: &[]float64{0.7}[0],
87 }
88
89 openAIReq := provider.convertToOpenAIRequest(req)
90
91 if openAIReq.Model != "gpt-3.5-turbo" {
92 t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", openAIReq.Model)
93 }
94
95 if len(openAIReq.Messages) != 1 {
96 t.Errorf("Expected 1 message, got %d", len(openAIReq.Messages))
97 }
98
99 if openAIReq.Messages[0].Role != "user" {
100 t.Errorf("Expected role 'user', got '%s'", openAIReq.Messages[0].Role)
101 }
102
103 if openAIReq.Messages[0].Content != "Hello" {
104 t.Errorf("Expected content 'Hello', got '%s'", openAIReq.Messages[0].Content)
105 }
106
107 if *openAIReq.MaxTokens != 100 {
108 t.Errorf("Expected max_tokens 100, got %d", *openAIReq.MaxTokens)
109 }
110
111 if *openAIReq.Temperature != 0.7 {
112 t.Errorf("Expected temperature 0.7, got %f", *openAIReq.Temperature)
113 }
114}
115
116func TestOpenAIProvider_ConvertResponse(t *testing.T) {
117 provider := &OpenAIProvider{
118 config: llm.Config{
119 Provider: llm.ProviderOpenAI,
120 APIKey: "test-key",
121 BaseURL: "https://api.openai.com/v1",
122 },
123 }
124
125 // Test basic response conversion
126 openAIResp := OpenAIResponse{
127 ID: "test-id",
128 Object: "chat.completion",
129 Created: 1234567890,
130 Model: "gpt-3.5-turbo",
131 Choices: []OpenAIChoice{
132 {
133 Index: 0,
134 Message: OpenAIMessage{
135 Role: "assistant",
136 Content: "Hello! How can I help you?",
137 },
138 FinishReason: "stop",
139 },
140 },
141 Usage: OpenAIUsage{
142 PromptTokens: 10,
143 CompletionTokens: 20,
144 TotalTokens: 30,
145 },
146 }
147
148 resp := provider.convertFromOpenAIResponse(openAIResp)
149
150 if resp.ID != "test-id" {
151 t.Errorf("Expected ID 'test-id', got '%s'", resp.ID)
152 }
153
154 if resp.Model != "gpt-3.5-turbo" {
155 t.Errorf("Expected model 'gpt-3.5-turbo', got '%s'", resp.Model)
156 }
157
158 if resp.Provider != llm.ProviderOpenAI {
159 t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
160 }
161
162 if len(resp.Choices) != 1 {
163 t.Errorf("Expected 1 choice, got %d", len(resp.Choices))
164 }
165
166 if resp.Choices[0].Message.Role != llm.RoleAssistant {
167 t.Errorf("Expected role assistant, got %s", resp.Choices[0].Message.Role)
168 }
169
170 if resp.Choices[0].Message.Content != "Hello! How can I help you?" {
171 t.Errorf("Expected content 'Hello! How can I help you?', got '%s'", resp.Choices[0].Message.Content)
172 }
173
174 if resp.Usage.PromptTokens != 10 {
175 t.Errorf("Expected prompt tokens 10, got %d", resp.Usage.PromptTokens)
176 }
177
178 if resp.Usage.CompletionTokens != 20 {
179 t.Errorf("Expected completion tokens 20, got %d", resp.Usage.CompletionTokens)
180 }
181
182 if resp.Usage.TotalTokens != 30 {
183 t.Errorf("Expected total tokens 30, got %d", resp.Usage.TotalTokens)
184 }
185}
186
187func TestOpenAIProvider_ConvertRequestWithTools(t *testing.T) {
188 provider := &OpenAIProvider{
189 config: llm.Config{
190 Provider: llm.ProviderOpenAI,
191 APIKey: "test-key",
192 BaseURL: "https://api.openai.com/v1",
193 },
194 }
195
196 // Test request with tools
197 tools := []llm.Tool{
198 {
199 Type: "function",
200 Function: llm.Function{
201 Name: "get_weather",
202 Description: "Get weather information",
203 Parameters: map[string]interface{}{
204 "type": "object",
205 "properties": map[string]interface{}{
206 "location": map[string]interface{}{
207 "type": "string",
208 },
209 },
210 },
211 },
212 },
213 }
214
215 req := llm.ChatCompletionRequest{
216 Model: "gpt-3.5-turbo",
217 Messages: []llm.Message{
218 {Role: llm.RoleUser, Content: "What's the weather like?"},
219 },
220 Tools: tools,
221 }
222
223 openAIReq := provider.convertToOpenAIRequest(req)
224
225 if len(openAIReq.Tools) != 1 {
226 t.Errorf("Expected 1 tool, got %d", len(openAIReq.Tools))
227 }
228
229 if openAIReq.Tools[0].Type != "function" {
230 t.Errorf("Expected tool type 'function', got '%s'", openAIReq.Tools[0].Type)
231 }
232
233 if openAIReq.Tools[0].Function.Name != "get_weather" {
234 t.Errorf("Expected function name 'get_weather', got '%s'", openAIReq.Tools[0].Function.Name)
235 }
236}
237
238func TestOpenAIProvider_ConvertResponseWithToolCalls(t *testing.T) {
239 provider := &OpenAIProvider{
240 config: llm.Config{
241 Provider: llm.ProviderOpenAI,
242 APIKey: "test-key",
243 BaseURL: "https://api.openai.com/v1",
244 },
245 }
246
247 // Test response with tool calls
248 openAIResp := OpenAIResponse{
249 ID: "test-id",
250 Object: "chat.completion",
251 Model: "gpt-3.5-turbo",
252 Choices: []OpenAIChoice{
253 {
254 Index: 0,
255 Message: OpenAIMessage{
256 Role: "assistant",
257 ToolCalls: []OpenAIToolCall{
258 {
259 ID: "call_123",
260 Type: "function",
261 Function: OpenAIFunction{
262 Name: "get_weather",
263 Parameters: map[string]interface{}{
264 "location": "Tokyo",
265 },
266 },
267 },
268 },
269 },
270 FinishReason: "tool_calls",
271 },
272 },
273 Usage: OpenAIUsage{
274 PromptTokens: 10,
275 CompletionTokens: 20,
276 TotalTokens: 30,
277 },
278 }
279
280 resp := provider.convertFromOpenAIResponse(openAIResp)
281
282 if len(resp.Choices[0].Message.ToolCalls) != 1 {
283 t.Errorf("Expected 1 tool call, got %d", len(resp.Choices[0].Message.ToolCalls))
284 }
285
286 if resp.Choices[0].Message.ToolCalls[0].ID != "call_123" {
287 t.Errorf("Expected tool call ID 'call_123', got '%s'", resp.Choices[0].Message.ToolCalls[0].ID)
288 }
289
290 if resp.Choices[0].Message.ToolCalls[0].Function.Name != "get_weather" {
291 t.Errorf("Expected function name 'get_weather', got '%s'", resp.Choices[0].Message.ToolCalls[0].Function.Name)
292 }
293
294 if resp.Choices[0].FinishReason != "tool_calls" {
295 t.Errorf("Expected finish reason 'tool_calls', got '%s'", resp.Choices[0].FinishReason)
296 }
297}
298
299func TestOpenAIProvider_ConvertEmbeddingRequest(t *testing.T) {
300 req := llm.EmbeddingRequest{
301 Input: "Hello, world!",
302 Model: "text-embedding-ada-002",
303 User: "test-user",
304 }
305
306 // The conversion is done inline in CreateEmbeddings, so we'll test the structure
307 if req.Input != "Hello, world!" {
308 t.Errorf("Expected input 'Hello, world!', got '%v'", req.Input)
309 }
310
311 if req.Model != "text-embedding-ada-002" {
312 t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", req.Model)
313 }
314
315 if req.User != "test-user" {
316 t.Errorf("Expected user 'test-user', got '%s'", req.User)
317 }
318}
319
320func TestOpenAIProvider_ConvertEmbeddingResponse(t *testing.T) {
321 provider := &OpenAIProvider{
322 config: llm.Config{
323 Provider: llm.ProviderOpenAI,
324 APIKey: "test-key",
325 BaseURL: "https://api.openai.com/v1",
326 },
327 }
328
329 // Test embedding response conversion
330 openAIResp := OpenAIEmbeddingResponse{
331 Object: "list",
332 Model: "text-embedding-ada-002",
333 Data: []OpenAIEmbeddingData{
334 {
335 Object: "embedding",
336 Embedding: []float64{0.1, 0.2, 0.3},
337 Index: 0,
338 },
339 },
340 Usage: OpenAIUsage{
341 PromptTokens: 5,
342 CompletionTokens: 0,
343 TotalTokens: 5,
344 },
345 }
346
347 resp := provider.convertFromOpenAIEmbeddingResponse(openAIResp)
348
349 if resp.Object != "list" {
350 t.Errorf("Expected object 'list', got '%s'", resp.Object)
351 }
352
353 if resp.Model != "text-embedding-ada-002" {
354 t.Errorf("Expected model 'text-embedding-ada-002', got '%s'", resp.Model)
355 }
356
357 if resp.Provider != llm.ProviderOpenAI {
358 t.Errorf("Expected provider OpenAI, got %s", resp.Provider)
359 }
360
361 if len(resp.Data) != 1 {
362 t.Errorf("Expected 1 embedding, got %d", len(resp.Data))
363 }
364
365 if len(resp.Data[0].Embedding) != 3 {
366 t.Errorf("Expected embedding dimension 3, got %d", len(resp.Data[0].Embedding))
367 }
368
369 if resp.Data[0].Embedding[0] != 0.1 {
370 t.Errorf("Expected first embedding value 0.1, got %f", resp.Data[0].Embedding[0])
371 }
372}
373
374func TestOpenAIProvider_Close(t *testing.T) {
375 provider := &OpenAIProvider{
376 config: llm.Config{
377 Provider: llm.ProviderOpenAI,
378 APIKey: "test-key",
379 BaseURL: "https://api.openai.com/v1",
380 },
381 }
382
383 // Test that Close doesn't return an error
384 err := provider.Close()
385 if err != nil {
386 t.Errorf("Close should not return an error: %v", err)
387 }
388}
389
390func TestOpenAIProvider_Integration(t *testing.T) {
391 // This test would require a real API key and would make actual API calls
392 // It's commented out to avoid making real API calls during testing
393 /*
394 config := Config{
395 Provider: ProviderOpenAI,
396 APIKey: "your-real-api-key",
397 BaseURL: "https://api.openai.com/v1",
398 Timeout: 30 * time.Second,
399 }
400
401 provider, err := CreateProvider(config)
402 if err != nil {
403 t.Fatalf("Failed to create provider: %v", err)
404 }
405 defer provider.Close()
406
407 req := ChatCompletionRequest{
408 Model: "gpt-3.5-turbo",
409 Messages: []Message{
410 {Role: RoleUser, Content: "Say hello!"},
411 },
412 MaxTokens: &[]int{50}[0],
413 }
414
415 resp, err := provider.ChatCompletion(context.Background(), req)
416 if err != nil {
417 t.Fatalf("Chat completion failed: %v", err)
418 }
419
420 if len(resp.Choices) == 0 {
421 t.Fatal("Expected at least one choice")
422 }
423
424 if resp.Choices[0].Message.Content == "" {
425 t.Fatal("Expected non-empty response content")
426 }
427 */
428}