| package openai |
| |
| import ( |
| "context" |
| "fmt" |
| "log" |
| "time" |
| |
| "github.com/iomodo/staff/llm" |
| ) |
| |
| // ExampleOpenAI demonstrates how to use the OpenAI implementation |
| func ExampleOpenAI() { |
| // Create OpenAI configuration |
| config := llm.Config{ |
| Provider: llm.ProviderOpenAI, |
| APIKey: "your-openai-api-key-here", // Replace with your actual API key |
| BaseURL: "https://api.openai.com/v1", |
| Timeout: 30 * time.Second, |
| ExtraConfig: map[string]interface{}{ |
| "organization": "your-org-id", // Optional: your OpenAI organization ID |
| }, |
| } |
| |
| // Create the provider |
| provider, err := llm.CreateProvider(config) |
| if err != nil { |
| log.Fatalf("Failed to create OpenAI provider: %v", err) |
| } |
| defer provider.Close() |
| |
| // Example 1: Basic chat completion |
| exampleOpenAIChatCompletion(provider) |
| |
| // Example 2: Function calling |
| exampleOpenAIFunctionCalling(provider) |
| |
| // Example 3: Embeddings |
| exampleOpenAIEmbeddings(provider) |
| } |
| |
| // exampleOpenAIChatCompletion demonstrates basic chat completion with OpenAI |
| func exampleOpenAIChatCompletion(provider llm.LLMProvider) { |
| fmt.Println("=== OpenAI Chat Completion Example ===") |
| |
| req := llm.ChatCompletionRequest{ |
| Model: "gpt-3.5-turbo", |
| Messages: []llm.Message{ |
| {Role: llm.RoleSystem, Content: "You are a helpful assistant."}, |
| {Role: llm.RoleUser, Content: "What is the capital of France?"}, |
| }, |
| MaxTokens: &[]int{100}[0], |
| Temperature: &[]float64{0.7}[0], |
| } |
| |
| ctx := context.Background() |
| resp, err := provider.ChatCompletion(ctx, req) |
| if err != nil { |
| log.Printf("Chat completion failed: %v", err) |
| return |
| } |
| |
| fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content) |
| fmt.Printf("Model: %s\n", resp.Model) |
| fmt.Printf("Provider: %s\n", resp.Provider) |
| fmt.Printf("Usage: %+v\n", resp.Usage) |
| } |
| |
| // exampleOpenAIFunctionCalling demonstrates function calling with OpenAI |
| func exampleOpenAIFunctionCalling(provider llm.LLMProvider) { |
| fmt.Println("\n=== OpenAI Function Calling Example ===") |
| |
| // Define a function that can be called |
| tools := []llm.Tool{ |
| { |
| Type: "function", |
| Function: llm.Function{ |
| Name: "get_weather", |
| Description: "Get the current weather for a location", |
| Parameters: map[string]interface{}{ |
| "type": "object", |
| "properties": map[string]interface{}{ |
| "location": map[string]interface{}{ |
| "type": "string", |
| "description": "The city and state, e.g. San Francisco, CA", |
| }, |
| "unit": map[string]interface{}{ |
| "type": "string", |
| "enum": []string{"celsius", "fahrenheit"}, |
| }, |
| }, |
| "required": []string{"location"}, |
| }, |
| }, |
| }, |
| } |
| |
| req := llm.ChatCompletionRequest{ |
| Model: "gpt-3.5-turbo", |
| Messages: []llm.Message{ |
| {Role: llm.RoleUser, Content: "What's the weather like in Tokyo?"}, |
| }, |
| Tools: tools, |
| MaxTokens: &[]int{150}[0], |
| Temperature: &[]float64{0.1}[0], |
| } |
| |
| ctx := context.Background() |
| resp, err := provider.ChatCompletion(ctx, req) |
| if err != nil { |
| log.Printf("Function calling failed: %v", err) |
| return |
| } |
| |
| // Check if the model wants to call a function |
| if len(resp.Choices[0].Message.ToolCalls) > 0 { |
| fmt.Println("Model wants to call a function:") |
| for _, toolCall := range resp.Choices[0].Message.ToolCalls { |
| fmt.Printf("Function: %s\n", toolCall.Function.Name) |
| fmt.Printf("Arguments: %+v\n", toolCall.Function.Parameters) |
| } |
| } else { |
| fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content) |
| } |
| } |
| |
| // exampleOpenAIEmbeddings demonstrates embedding generation with OpenAI |
| func exampleOpenAIEmbeddings(provider llm.LLMProvider) { |
| fmt.Println("\n=== OpenAI Embeddings Example ===") |
| |
| req := llm.EmbeddingRequest{ |
| Input: "Hello, world! This is a test sentence for embeddings.", |
| Model: "text-embedding-ada-002", |
| } |
| |
| ctx := context.Background() |
| resp, err := provider.CreateEmbeddings(ctx, req) |
| if err != nil { |
| log.Printf("Embeddings failed: %v", err) |
| return |
| } |
| |
| fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding)) |
| fmt.Printf("First 5 values: %v\n", resp.Data[0].Embedding[:5]) |
| fmt.Printf("Model: %s\n", resp.Model) |
| fmt.Printf("Provider: %s\n", resp.Provider) |
| fmt.Printf("Usage: %+v\n", resp.Usage) |
| } |
| |
| // ExampleOpenAIWithErrorHandling demonstrates proper error handling |
| func ExampleOpenAIWithErrorHandling() { |
| fmt.Println("\n=== OpenAI Error Handling Example ===") |
| |
| // Test with invalid API key |
| config := llm.Config{ |
| Provider: llm.ProviderOpenAI, |
| APIKey: "invalid-key", |
| BaseURL: "https://api.openai.com/v1", |
| Timeout: 30 * time.Second, |
| } |
| |
| provider, err := llm.CreateProvider(config) |
| if err != nil { |
| fmt.Printf("Failed to create provider: %v\n", err) |
| return |
| } |
| defer provider.Close() |
| |
| req := llm.ChatCompletionRequest{ |
| Model: "gpt-3.5-turbo", |
| Messages: []llm.Message{ |
| {Role: llm.RoleUser, Content: "Hello"}, |
| }, |
| MaxTokens: &[]int{50}[0], |
| } |
| |
| ctx := context.Background() |
| resp, err := provider.ChatCompletion(ctx, req) |
| if err != nil { |
| fmt.Printf("Expected error with invalid API key: %v\n", err) |
| return |
| } |
| |
| fmt.Printf("Unexpected success: %s\n", resp.Choices[0].Message.Content) |
| } |
| |
| // ExampleOpenAIWithCustomConfig demonstrates custom configuration |
| func ExampleOpenAIWithCustomConfig() { |
| fmt.Println("\n=== OpenAI Custom Configuration Example ===") |
| |
| // Start with minimal config |
| customConfig := llm.Config{ |
| Provider: llm.ProviderOpenAI, |
| APIKey: "your-api-key", |
| } |
| |
| // Merge with defaults |
| config := llm.MergeConfig(customConfig) |
| |
| fmt.Printf("Merged config: BaseURL=%s, Timeout=%v, MaxRetries=%d\n", |
| config.BaseURL, config.Timeout, config.MaxRetries) |
| |
| // Add extra configuration |
| config.ExtraConfig = map[string]interface{}{ |
| "organization": "your-org-id", |
| "project": "your-project-id", |
| } |
| |
| provider, err := llm.CreateProvider(config) |
| if err != nil { |
| fmt.Printf("Failed to create provider: %v\n", err) |
| return |
| } |
| defer provider.Close() |
| |
| fmt.Println("Provider created successfully with custom configuration") |
| } |
| |
| // ExampleOpenAIWithValidation demonstrates request validation |
| func ExampleOpenAIWithValidation() { |
| fmt.Println("\n=== OpenAI Validation Example ===") |
| |
| // Valid request |
| validReq := llm.ChatCompletionRequest{ |
| Model: "gpt-3.5-turbo", |
| Messages: []llm.Message{ |
| {Role: llm.RoleUser, Content: "Hello"}, |
| }, |
| Temperature: &[]float64{0.5}[0], |
| } |
| |
| if err := llm.ValidateChatCompletionRequest(validReq); err != nil { |
| fmt.Printf("Valid request validation failed: %v\n", err) |
| } else { |
| fmt.Println("Valid request passed validation") |
| } |
| |
| // Invalid request (no model) |
| invalidReq := llm.ChatCompletionRequest{ |
| Messages: []llm.Message{ |
| {Role: llm.RoleUser, Content: "Hello"}, |
| }, |
| } |
| |
| if err := llm.ValidateChatCompletionRequest(invalidReq); err != nil { |
| fmt.Printf("Invalid request correctly caught: %v\n", err) |
| } else { |
| fmt.Println("Invalid request should have failed validation") |
| } |
| |
| // Test token estimation |
| tokens := llm.EstimateTokens(validReq) |
| fmt.Printf("Estimated tokens for request: %d\n", tokens) |
| } |