| iomodo | be473d1 | 2025-07-26 11:33:08 +0400 | [diff] [blame^] | 1 | package openai |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "fmt" |
| 6 | "log" |
| 7 | "time" |
| 8 | |
| 9 | "github.com/iomodo/staff/llm" |
| 10 | ) |
| 11 | |
| 12 | // ExampleOpenAI demonstrates how to use the OpenAI implementation |
| 13 | func ExampleOpenAI() { |
| 14 | // Create OpenAI configuration |
| 15 | config := llm.Config{ |
| 16 | Provider: llm.ProviderOpenAI, |
| 17 | APIKey: "your-openai-api-key-here", // Replace with your actual API key |
| 18 | BaseURL: "https://api.openai.com/v1", |
| 19 | Timeout: 30 * time.Second, |
| 20 | ExtraConfig: map[string]interface{}{ |
| 21 | "organization": "your-org-id", // Optional: your OpenAI organization ID |
| 22 | }, |
| 23 | } |
| 24 | |
| 25 | // Create the provider |
| 26 | provider, err := llm.CreateProvider(config) |
| 27 | if err != nil { |
| 28 | log.Fatalf("Failed to create OpenAI provider: %v", err) |
| 29 | } |
| 30 | defer provider.Close() |
| 31 | |
| 32 | // Example 1: Basic chat completion |
| 33 | exampleOpenAIChatCompletion(provider) |
| 34 | |
| 35 | // Example 2: Function calling |
| 36 | exampleOpenAIFunctionCalling(provider) |
| 37 | |
| 38 | // Example 3: Embeddings |
| 39 | exampleOpenAIEmbeddings(provider) |
| 40 | } |
| 41 | |
| 42 | // exampleOpenAIChatCompletion demonstrates basic chat completion with OpenAI |
| 43 | func exampleOpenAIChatCompletion(provider llm.LLMProvider) { |
| 44 | fmt.Println("=== OpenAI Chat Completion Example ===") |
| 45 | |
| 46 | req := llm.ChatCompletionRequest{ |
| 47 | Model: "gpt-3.5-turbo", |
| 48 | Messages: []llm.Message{ |
| 49 | {Role: llm.RoleSystem, Content: "You are a helpful assistant."}, |
| 50 | {Role: llm.RoleUser, Content: "What is the capital of France?"}, |
| 51 | }, |
| 52 | MaxTokens: &[]int{100}[0], |
| 53 | Temperature: &[]float64{0.7}[0], |
| 54 | } |
| 55 | |
| 56 | ctx := context.Background() |
| 57 | resp, err := provider.ChatCompletion(ctx, req) |
| 58 | if err != nil { |
| 59 | log.Printf("Chat completion failed: %v", err) |
| 60 | return |
| 61 | } |
| 62 | |
| 63 | fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content) |
| 64 | fmt.Printf("Model: %s\n", resp.Model) |
| 65 | fmt.Printf("Provider: %s\n", resp.Provider) |
| 66 | fmt.Printf("Usage: %+v\n", resp.Usage) |
| 67 | } |
| 68 | |
| 69 | // exampleOpenAIFunctionCalling demonstrates function calling with OpenAI |
| 70 | func exampleOpenAIFunctionCalling(provider llm.LLMProvider) { |
| 71 | fmt.Println("\n=== OpenAI Function Calling Example ===") |
| 72 | |
| 73 | // Define a function that can be called |
| 74 | tools := []llm.Tool{ |
| 75 | { |
| 76 | Type: "function", |
| 77 | Function: llm.Function{ |
| 78 | Name: "get_weather", |
| 79 | Description: "Get the current weather for a location", |
| 80 | Parameters: map[string]interface{}{ |
| 81 | "type": "object", |
| 82 | "properties": map[string]interface{}{ |
| 83 | "location": map[string]interface{}{ |
| 84 | "type": "string", |
| 85 | "description": "The city and state, e.g. San Francisco, CA", |
| 86 | }, |
| 87 | "unit": map[string]interface{}{ |
| 88 | "type": "string", |
| 89 | "enum": []string{"celsius", "fahrenheit"}, |
| 90 | }, |
| 91 | }, |
| 92 | "required": []string{"location"}, |
| 93 | }, |
| 94 | }, |
| 95 | }, |
| 96 | } |
| 97 | |
| 98 | req := llm.ChatCompletionRequest{ |
| 99 | Model: "gpt-3.5-turbo", |
| 100 | Messages: []llm.Message{ |
| 101 | {Role: llm.RoleUser, Content: "What's the weather like in Tokyo?"}, |
| 102 | }, |
| 103 | Tools: tools, |
| 104 | MaxTokens: &[]int{150}[0], |
| 105 | Temperature: &[]float64{0.1}[0], |
| 106 | } |
| 107 | |
| 108 | ctx := context.Background() |
| 109 | resp, err := provider.ChatCompletion(ctx, req) |
| 110 | if err != nil { |
| 111 | log.Printf("Function calling failed: %v", err) |
| 112 | return |
| 113 | } |
| 114 | |
| 115 | // Check if the model wants to call a function |
| 116 | if len(resp.Choices[0].Message.ToolCalls) > 0 { |
| 117 | fmt.Println("Model wants to call a function:") |
| 118 | for _, toolCall := range resp.Choices[0].Message.ToolCalls { |
| 119 | fmt.Printf("Function: %s\n", toolCall.Function.Name) |
| 120 | fmt.Printf("Arguments: %+v\n", toolCall.Function.Parameters) |
| 121 | } |
| 122 | } else { |
| 123 | fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content) |
| 124 | } |
| 125 | } |
| 126 | |
| 127 | // exampleOpenAIEmbeddings demonstrates embedding generation with OpenAI |
| 128 | func exampleOpenAIEmbeddings(provider llm.LLMProvider) { |
| 129 | fmt.Println("\n=== OpenAI Embeddings Example ===") |
| 130 | |
| 131 | req := llm.EmbeddingRequest{ |
| 132 | Input: "Hello, world! This is a test sentence for embeddings.", |
| 133 | Model: "text-embedding-ada-002", |
| 134 | } |
| 135 | |
| 136 | ctx := context.Background() |
| 137 | resp, err := provider.CreateEmbeddings(ctx, req) |
| 138 | if err != nil { |
| 139 | log.Printf("Embeddings failed: %v", err) |
| 140 | return |
| 141 | } |
| 142 | |
| 143 | fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding)) |
| 144 | fmt.Printf("First 5 values: %v\n", resp.Data[0].Embedding[:5]) |
| 145 | fmt.Printf("Model: %s\n", resp.Model) |
| 146 | fmt.Printf("Provider: %s\n", resp.Provider) |
| 147 | fmt.Printf("Usage: %+v\n", resp.Usage) |
| 148 | } |
| 149 | |
| 150 | // ExampleOpenAIWithErrorHandling demonstrates proper error handling |
| 151 | func ExampleOpenAIWithErrorHandling() { |
| 152 | fmt.Println("\n=== OpenAI Error Handling Example ===") |
| 153 | |
| 154 | // Test with invalid API key |
| 155 | config := llm.Config{ |
| 156 | Provider: llm.ProviderOpenAI, |
| 157 | APIKey: "invalid-key", |
| 158 | BaseURL: "https://api.openai.com/v1", |
| 159 | Timeout: 30 * time.Second, |
| 160 | } |
| 161 | |
| 162 | provider, err := llm.CreateProvider(config) |
| 163 | if err != nil { |
| 164 | fmt.Printf("Failed to create provider: %v\n", err) |
| 165 | return |
| 166 | } |
| 167 | defer provider.Close() |
| 168 | |
| 169 | req := llm.ChatCompletionRequest{ |
| 170 | Model: "gpt-3.5-turbo", |
| 171 | Messages: []llm.Message{ |
| 172 | {Role: llm.RoleUser, Content: "Hello"}, |
| 173 | }, |
| 174 | MaxTokens: &[]int{50}[0], |
| 175 | } |
| 176 | |
| 177 | ctx := context.Background() |
| 178 | resp, err := provider.ChatCompletion(ctx, req) |
| 179 | if err != nil { |
| 180 | fmt.Printf("Expected error with invalid API key: %v\n", err) |
| 181 | return |
| 182 | } |
| 183 | |
| 184 | fmt.Printf("Unexpected success: %s\n", resp.Choices[0].Message.Content) |
| 185 | } |
| 186 | |
| 187 | // ExampleOpenAIWithCustomConfig demonstrates custom configuration |
| 188 | func ExampleOpenAIWithCustomConfig() { |
| 189 | fmt.Println("\n=== OpenAI Custom Configuration Example ===") |
| 190 | |
| 191 | // Start with minimal config |
| 192 | customConfig := llm.Config{ |
| 193 | Provider: llm.ProviderOpenAI, |
| 194 | APIKey: "your-api-key", |
| 195 | } |
| 196 | |
| 197 | // Merge with defaults |
| 198 | config := llm.MergeConfig(customConfig) |
| 199 | |
| 200 | fmt.Printf("Merged config: BaseURL=%s, Timeout=%v, MaxRetries=%d\n", |
| 201 | config.BaseURL, config.Timeout, config.MaxRetries) |
| 202 | |
| 203 | // Add extra configuration |
| 204 | config.ExtraConfig = map[string]interface{}{ |
| 205 | "organization": "your-org-id", |
| 206 | "project": "your-project-id", |
| 207 | } |
| 208 | |
| 209 | provider, err := llm.CreateProvider(config) |
| 210 | if err != nil { |
| 211 | fmt.Printf("Failed to create provider: %v\n", err) |
| 212 | return |
| 213 | } |
| 214 | defer provider.Close() |
| 215 | |
| 216 | fmt.Println("Provider created successfully with custom configuration") |
| 217 | } |
| 218 | |
| 219 | // ExampleOpenAIWithValidation demonstrates request validation |
| 220 | func ExampleOpenAIWithValidation() { |
| 221 | fmt.Println("\n=== OpenAI Validation Example ===") |
| 222 | |
| 223 | // Valid request |
| 224 | validReq := llm.ChatCompletionRequest{ |
| 225 | Model: "gpt-3.5-turbo", |
| 226 | Messages: []llm.Message{ |
| 227 | {Role: llm.RoleUser, Content: "Hello"}, |
| 228 | }, |
| 229 | Temperature: &[]float64{0.5}[0], |
| 230 | } |
| 231 | |
| 232 | if err := llm.ValidateChatCompletionRequest(validReq); err != nil { |
| 233 | fmt.Printf("Valid request validation failed: %v\n", err) |
| 234 | } else { |
| 235 | fmt.Println("Valid request passed validation") |
| 236 | } |
| 237 | |
| 238 | // Invalid request (no model) |
| 239 | invalidReq := llm.ChatCompletionRequest{ |
| 240 | Messages: []llm.Message{ |
| 241 | {Role: llm.RoleUser, Content: "Hello"}, |
| 242 | }, |
| 243 | } |
| 244 | |
| 245 | if err := llm.ValidateChatCompletionRequest(invalidReq); err != nil { |
| 246 | fmt.Printf("Invalid request correctly caught: %v\n", err) |
| 247 | } else { |
| 248 | fmt.Println("Invalid request should have failed validation") |
| 249 | } |
| 250 | |
| 251 | // Test token estimation |
| 252 | tokens := llm.EstimateTokens(validReq) |
| 253 | fmt.Printf("Estimated tokens for request: %d\n", tokens) |
| 254 | } |