Add OpenAI implementation
Change-Id: Iea3191cf538959002e6ae095857a7aa6126b3e2f
diff --git a/server/llm/openai/openai_example.go b/server/llm/openai/openai_example.go
new file mode 100644
index 0000000..6cd7dbf
--- /dev/null
+++ b/server/llm/openai/openai_example.go
@@ -0,0 +1,254 @@
+package openai
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "time"
+
+ "github.com/iomodo/staff/llm"
+)
+
+// ExampleOpenAI demonstrates how to use the OpenAI implementation
+func ExampleOpenAI() {
+ // Create OpenAI configuration
+ config := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "your-openai-api-key-here", // Replace with your actual API key
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ ExtraConfig: map[string]interface{}{
+ "organization": "your-org-id", // Optional: your OpenAI organization ID
+ },
+ }
+
+ // Create the provider
+ provider, err := llm.CreateProvider(config)
+ if err != nil {
+ log.Fatalf("Failed to create OpenAI provider: %v", err)
+ }
+ defer provider.Close()
+
+ // Example 1: Basic chat completion
+ exampleOpenAIChatCompletion(provider)
+
+ // Example 2: Function calling
+ exampleOpenAIFunctionCalling(provider)
+
+ // Example 3: Embeddings
+ exampleOpenAIEmbeddings(provider)
+}
+
+// exampleOpenAIChatCompletion demonstrates basic chat completion with OpenAI
+func exampleOpenAIChatCompletion(provider llm.LLMProvider) {
+ fmt.Println("=== OpenAI Chat Completion Example ===")
+
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleSystem, Content: "You are a helpful assistant."},
+ {Role: llm.RoleUser, Content: "What is the capital of France?"},
+ },
+ MaxTokens: &[]int{100}[0],
+ Temperature: &[]float64{0.7}[0],
+ }
+
+ ctx := context.Background()
+ resp, err := provider.ChatCompletion(ctx, req)
+ if err != nil {
+ log.Printf("Chat completion failed: %v", err)
+ return
+ }
+
+ fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
+ fmt.Printf("Model: %s\n", resp.Model)
+ fmt.Printf("Provider: %s\n", resp.Provider)
+ fmt.Printf("Usage: %+v\n", resp.Usage)
+}
+
+// exampleOpenAIFunctionCalling demonstrates function calling with OpenAI
+func exampleOpenAIFunctionCalling(provider llm.LLMProvider) {
+ fmt.Println("\n=== OpenAI Function Calling Example ===")
+
+ // Define a function that can be called
+ tools := []llm.Tool{
+ {
+ Type: "function",
+ Function: llm.Function{
+ Name: "get_weather",
+ Description: "Get the current weather for a location",
+ Parameters: map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ "required": []string{"location"},
+ },
+ },
+ },
+ }
+
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "What's the weather like in Tokyo?"},
+ },
+ Tools: tools,
+ MaxTokens: &[]int{150}[0],
+ Temperature: &[]float64{0.1}[0],
+ }
+
+ ctx := context.Background()
+ resp, err := provider.ChatCompletion(ctx, req)
+ if err != nil {
+ log.Printf("Function calling failed: %v", err)
+ return
+ }
+
+ // Check if the model wants to call a function
+ if len(resp.Choices[0].Message.ToolCalls) > 0 {
+ fmt.Println("Model wants to call a function:")
+ for _, toolCall := range resp.Choices[0].Message.ToolCalls {
+ fmt.Printf("Function: %s\n", toolCall.Function.Name)
+ fmt.Printf("Arguments: %+v\n", toolCall.Function.Parameters)
+ }
+ } else {
+ fmt.Printf("Response: %s\n", resp.Choices[0].Message.Content)
+ }
+}
+
+// exampleOpenAIEmbeddings demonstrates embedding generation with OpenAI
+func exampleOpenAIEmbeddings(provider llm.LLMProvider) {
+ fmt.Println("\n=== OpenAI Embeddings Example ===")
+
+ req := llm.EmbeddingRequest{
+ Input: "Hello, world! This is a test sentence for embeddings.",
+ Model: "text-embedding-ada-002",
+ }
+
+ ctx := context.Background()
+ resp, err := provider.CreateEmbeddings(ctx, req)
+ if err != nil {
+ log.Printf("Embeddings failed: %v", err)
+ return
+ }
+
+ fmt.Printf("Embedding dimensions: %d\n", len(resp.Data[0].Embedding))
+ fmt.Printf("First 5 values: %v\n", resp.Data[0].Embedding[:5])
+ fmt.Printf("Model: %s\n", resp.Model)
+ fmt.Printf("Provider: %s\n", resp.Provider)
+ fmt.Printf("Usage: %+v\n", resp.Usage)
+}
+
+// ExampleOpenAIWithErrorHandling demonstrates proper error handling
+func ExampleOpenAIWithErrorHandling() {
+ fmt.Println("\n=== OpenAI Error Handling Example ===")
+
+ // Test with invalid API key
+ config := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "invalid-key",
+ BaseURL: "https://api.openai.com/v1",
+ Timeout: 30 * time.Second,
+ }
+
+ provider, err := llm.CreateProvider(config)
+ if err != nil {
+ fmt.Printf("Failed to create provider: %v\n", err)
+ return
+ }
+ defer provider.Close()
+
+ req := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello"},
+ },
+ MaxTokens: &[]int{50}[0],
+ }
+
+ ctx := context.Background()
+ resp, err := provider.ChatCompletion(ctx, req)
+ if err != nil {
+ fmt.Printf("Expected error with invalid API key: %v\n", err)
+ return
+ }
+
+ fmt.Printf("Unexpected success: %s\n", resp.Choices[0].Message.Content)
+}
+
+// ExampleOpenAIWithCustomConfig demonstrates custom configuration
+func ExampleOpenAIWithCustomConfig() {
+ fmt.Println("\n=== OpenAI Custom Configuration Example ===")
+
+ // Start with minimal config
+ customConfig := llm.Config{
+ Provider: llm.ProviderOpenAI,
+ APIKey: "your-api-key",
+ }
+
+ // Merge with defaults
+ config := llm.MergeConfig(customConfig)
+
+ fmt.Printf("Merged config: BaseURL=%s, Timeout=%v, MaxRetries=%d\n",
+ config.BaseURL, config.Timeout, config.MaxRetries)
+
+ // Add extra configuration
+ config.ExtraConfig = map[string]interface{}{
+ "organization": "your-org-id",
+ "project": "your-project-id",
+ }
+
+ provider, err := llm.CreateProvider(config)
+ if err != nil {
+ fmt.Printf("Failed to create provider: %v\n", err)
+ return
+ }
+ defer provider.Close()
+
+ fmt.Println("Provider created successfully with custom configuration")
+}
+
+// ExampleOpenAIWithValidation demonstrates request validation
+func ExampleOpenAIWithValidation() {
+ fmt.Println("\n=== OpenAI Validation Example ===")
+
+ // Valid request
+ validReq := llm.ChatCompletionRequest{
+ Model: "gpt-3.5-turbo",
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello"},
+ },
+ Temperature: &[]float64{0.5}[0],
+ }
+
+ if err := llm.ValidateChatCompletionRequest(validReq); err != nil {
+ fmt.Printf("Valid request validation failed: %v\n", err)
+ } else {
+ fmt.Println("Valid request passed validation")
+ }
+
+ // Invalid request (no model)
+ invalidReq := llm.ChatCompletionRequest{
+ Messages: []llm.Message{
+ {Role: llm.RoleUser, Content: "Hello"},
+ },
+ }
+
+ if err := llm.ValidateChatCompletionRequest(invalidReq); err != nil {
+ fmt.Printf("Invalid request correctly caught: %v\n", err)
+ } else {
+ fmt.Println("Invalid request should have failed validation")
+ }
+
+ // Test token estimation
+ tokens := llm.EstimateTokens(validReq)
+ fmt.Printf("Estimated tokens for request: %d\n", tokens)
+}