blob: 6cd7dbfb98f9ce20f9fd6409cf472a6bfa1514ee [file] [log] [blame]
package openai
import (
"context"
"fmt"
"log"
"time"
"github.com/iomodo/staff/llm"
)
// ExampleOpenAI demonstrates how to use the OpenAI implementation
func ExampleOpenAI() {
// Create OpenAI configuration
config := llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "your-openai-api-key-here", // Replace with your actual API key
BaseURL: "https://api.openai.com/v1",
Timeout: 30 * time.Second,
ExtraConfig: map[string]interface{}{
"organization": "your-org-id", // Optional: your OpenAI organization ID
},
}
// Create the provider
provider, err := llm.CreateProvider(config)
if err != nil {
log.Fatalf("Failed to create OpenAI provider: %v", err)
}
defer provider.Close()
// Example 1: Basic chat completion
exampleOpenAIChatCompletion(provider)
// Example 2: Function calling
exampleOpenAIFunctionCalling(provider)
// Example 3: Embeddings
exampleOpenAIEmbeddings(provider)
}
// exampleOpenAIChatCompletion demonstrates basic chat completion with OpenAI
func exampleOpenAIChatCompletion(provider llm.LLMProvider) {
fmt.Println("=== OpenAI Chat Completion Example ===")
req := llm.ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []llm.Message{
{Role: llm.RoleSystem, Content: "You are a helpful assistant."},
{Role: llm.RoleUser, Content: "What is the capital of France?"},
},
MaxTokens: &[]int{100}[0],
Temperature: &[]float64{0.7}[0],
}
ctx := context.Background()
resp, err := provider.ChatCompletion(ctx, req)
if err != nil {
log.Printf("Chat completion failed: %v", err)
return
}
fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
fmt.Printf("Model: %s\n", resp.Model)
fmt.Printf("Provider: %s\n", resp.Provider)
fmt.Printf("Usage: %+v\n", resp.Usage)
}
// exampleOpenAIFunctionCalling demonstrates function calling with OpenAI
func exampleOpenAIFunctionCalling(provider llm.LLMProvider) {
fmt.Println("\n=== OpenAI Function Calling Example ===")
// Define a function that can be called
tools := []llm.Tool{
{
Type: "function",
Function: llm.Function{
Name: "get_weather",
Description: "Get the current weather for a location",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"location": map[string]interface{}{
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": map[string]interface{}{
"type": "string",
"enum": []string{"celsius", "fahrenheit"},
},
},
"required": []string{"location"},
},
},
},
}
req := llm.ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []llm.Message{
{Role: llm.RoleUser, Content: "What's the weather like in Tokyo?"},
},
Tools: tools,
MaxTokens: &[]int{150}[0],
Temperature: &[]float64{0.1}[0],
}
ctx := context.Background()
resp, err := provider.ChatCompletion(ctx, req)
if err != nil {
log.Printf("Function calling failed: %v", err)
return
}
// Check if the model wants to call a function
if len(resp.Choices[0].Message.ToolCalls) > 0 {
fmt.Println("Model wants to call a function:")
for _, toolCall := range resp.Choices[0].Message.ToolCalls {
fmt.Printf("Function: %s\n", toolCall.Function.Name)
fmt.Printf("Arguments: %+v\n", toolCall.Function.Parameters)
}
} else {
fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
}
}
// exampleOpenAIEmbeddings demonstrates embedding generation with OpenAI
func exampleOpenAIEmbeddings(provider llm.LLMProvider) {
fmt.Println("\n=== OpenAI Embeddings Example ===")
req := llm.EmbeddingRequest{
Input: "Hello, world! This is a test sentence for embeddings.",
Model: "text-embedding-ada-002",
}
ctx := context.Background()
resp, err := provider.CreateEmbeddings(ctx, req)
if err != nil {
log.Printf("Embeddings failed: %v", err)
return
}
fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding))
fmt.Printf("First 5 values: %v\n", resp.Data[0].Embedding[:5])
fmt.Printf("Model: %s\n", resp.Model)
fmt.Printf("Provider: %s\n", resp.Provider)
fmt.Printf("Usage: %+v\n", resp.Usage)
}
// ExampleOpenAIWithErrorHandling demonstrates proper error handling
func ExampleOpenAIWithErrorHandling() {
fmt.Println("\n=== OpenAI Error Handling Example ===")
// Test with invalid API key
config := llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "invalid-key",
BaseURL: "https://api.openai.com/v1",
Timeout: 30 * time.Second,
}
provider, err := llm.CreateProvider(config)
if err != nil {
fmt.Printf("Failed to create provider: %v\n", err)
return
}
defer provider.Close()
req := llm.ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []llm.Message{
{Role: llm.RoleUser, Content: "Hello"},
},
MaxTokens: &[]int{50}[0],
}
ctx := context.Background()
resp, err := provider.ChatCompletion(ctx, req)
if err != nil {
fmt.Printf("Expected error with invalid API key: %v\n", err)
return
}
fmt.Printf("Unexpected success: %s\n", resp.Choices[0].Message.Content)
}
// ExampleOpenAIWithCustomConfig demonstrates custom configuration
func ExampleOpenAIWithCustomConfig() {
fmt.Println("\n=== OpenAI Custom Configuration Example ===")
// Start with minimal config
customConfig := llm.Config{
Provider: llm.ProviderOpenAI,
APIKey: "your-api-key",
}
// Merge with defaults
config := llm.MergeConfig(customConfig)
fmt.Printf("Merged config: BaseURL=%s, Timeout=%v, MaxRetries=%d\n",
config.BaseURL, config.Timeout, config.MaxRetries)
// Add extra configuration
config.ExtraConfig = map[string]interface{}{
"organization": "your-org-id",
"project": "your-project-id",
}
provider, err := llm.CreateProvider(config)
if err != nil {
fmt.Printf("Failed to create provider: %v\n", err)
return
}
defer provider.Close()
fmt.Println("Provider created successfully with custom configuration")
}
// ExampleOpenAIWithValidation demonstrates request validation
func ExampleOpenAIWithValidation() {
fmt.Println("\n=== OpenAI Validation Example ===")
// Valid request
validReq := llm.ChatCompletionRequest{
Model: "gpt-3.5-turbo",
Messages: []llm.Message{
{Role: llm.RoleUser, Content: "Hello"},
},
Temperature: &[]float64{0.5}[0],
}
if err := llm.ValidateChatCompletionRequest(validReq); err != nil {
fmt.Printf("Valid request validation failed: %v\n", err)
} else {
fmt.Println("Valid request passed validation")
}
// Invalid request (no model)
invalidReq := llm.ChatCompletionRequest{
Messages: []llm.Message{
{Role: llm.RoleUser, Content: "Hello"},
},
}
if err := llm.ValidateChatCompletionRequest(invalidReq); err != nil {
fmt.Printf("Invalid request correctly caught: %v\n", err)
} else {
fmt.Println("Invalid request should have failed validation")
}
// Test token estimation
tokens := llm.EstimateTokens(validReq)
fmt.Printf("Estimated tokens for request: %d\n", tokens)
}