blob: 002b4d154fefa28b831cb7907ebb79331b83b86c [file] [log] [blame]
package gem
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"testing"
"sketch.dev/llm"
"sketch.dev/llm/gem/gemini"
)
func TestBuildGeminiRequest(t *testing.T) {
// Create a service
service := &Service{
Model: DefaultModel,
APIKey: "test-api-key",
}
// Create a simple request
req := &llm.Request{
Messages: []llm.Message{
{
Role: llm.MessageRoleUser,
Content: []llm.Content{
{
Type: llm.ContentTypeText,
Text: "Hello, world!",
},
},
},
},
System: []llm.SystemContent{
{
Text: "You are a helpful assistant.",
},
},
}
// Build the Gemini request
gemReq, err := service.buildGeminiRequest(req)
if err != nil {
t.Fatalf("Failed to build Gemini request: %v", err)
}
// Verify the system instruction
if gemReq.SystemInstruction == nil {
t.Fatalf("Expected system instruction, got nil")
}
if len(gemReq.SystemInstruction.Parts) != 1 {
t.Fatalf("Expected 1 system part, got %d", len(gemReq.SystemInstruction.Parts))
}
if gemReq.SystemInstruction.Parts[0].Text != "You are a helpful assistant." {
t.Fatalf("Expected system text 'You are a helpful assistant.', got '%s'", gemReq.SystemInstruction.Parts[0].Text)
}
// Verify the contents
if len(gemReq.Contents) != 1 {
t.Fatalf("Expected 1 content, got %d", len(gemReq.Contents))
}
if len(gemReq.Contents[0].Parts) != 1 {
t.Fatalf("Expected 1 part, got %d", len(gemReq.Contents[0].Parts))
}
if gemReq.Contents[0].Parts[0].Text != "Hello, world!" {
t.Fatalf("Expected text 'Hello, world!', got '%s'", gemReq.Contents[0].Parts[0].Text)
}
// Verify the role is set correctly
if gemReq.Contents[0].Role != "user" {
t.Fatalf("Expected role 'user', got '%s'", gemReq.Contents[0].Role)
}
}
func TestConvertToolSchemas(t *testing.T) {
// Create a simple tool with a JSON schema
schema := `{
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "The name of the person"
},
"age": {
"type": "integer",
"description": "The age of the person"
}
},
"required": ["name"]
}`
tools := []*llm.Tool{
{
Name: "get_person",
Description: "Get information about a person",
InputSchema: json.RawMessage(schema),
},
}
// Convert the tools
decls, err := convertToolSchemas(tools)
if err != nil {
t.Fatalf("Failed to convert tool schemas: %v", err)
}
// Verify the result
if len(decls) != 1 {
t.Fatalf("Expected 1 declaration, got %d", len(decls))
}
if decls[0].Name != "get_person" {
t.Fatalf("Expected name 'get_person', got '%s'", decls[0].Name)
}
if decls[0].Description != "Get information about a person" {
t.Fatalf("Expected description 'Get information about a person', got '%s'", decls[0].Description)
}
// Verify the schema properties
if decls[0].Parameters.Type != 6 { // DataTypeOBJECT
t.Fatalf("Expected type OBJECT (6), got %d", decls[0].Parameters.Type)
}
if len(decls[0].Parameters.Properties) != 2 {
t.Fatalf("Expected 2 properties, got %d", len(decls[0].Parameters.Properties))
}
if decls[0].Parameters.Properties["name"].Type != 1 { // DataTypeSTRING
t.Fatalf("Expected name type STRING (1), got %d", decls[0].Parameters.Properties["name"].Type)
}
if decls[0].Parameters.Properties["age"].Type != 3 { // DataTypeINTEGER
t.Fatalf("Expected age type INTEGER (3), got %d", decls[0].Parameters.Properties["age"].Type)
}
if len(decls[0].Parameters.Required) != 1 || decls[0].Parameters.Required[0] != "name" {
t.Fatalf("Expected required field 'name', got %v", decls[0].Parameters.Required)
}
}
func TestService_Do_MockResponse(t *testing.T) {
// This is a mock test that doesn't make actual API calls
// Create a mock HTTP client that returns a predefined response
// Create a Service with a mock client
service := &Service{
Model: DefaultModel,
APIKey: "test-api-key",
// We would use a mock HTTP client here in a real test
}
// Create a sample request
ir := &llm.Request{
Messages: []llm.Message{
{
Role: llm.MessageRoleUser,
Content: []llm.Content{
{
Type: llm.ContentTypeText,
Text: "Hello",
},
},
},
},
}
// In a real test, we would execute service.Do with a mock client
// and verify the response structure
// For now, we'll just test that buildGeminiRequest works correctly
_, err := service.buildGeminiRequest(ir)
if err != nil {
t.Fatalf("Failed to build request: %v", err)
}
}
func TestConvertResponseWithToolCall(t *testing.T) {
// Create a mock Gemini response with a function call
gemRes := &gemini.Response{
Candidates: []gemini.Candidate{
{
Content: gemini.Content{
Parts: []gemini.Part{
{
FunctionCall: &gemini.FunctionCall{
Name: "bash",
Args: map[string]any{
"command": "cat README.md",
},
},
},
},
},
},
},
}
// Convert the response
content := convertGeminiResponseToContent(gemRes)
// Verify that content has a tool use
if len(content) != 1 {
t.Fatalf("Expected 1 content item, got %d", len(content))
}
if content[0].Type != llm.ContentTypeToolUse {
t.Fatalf("Expected content type ToolUse, got %s", content[0].Type)
}
if content[0].ToolName != "bash" {
t.Fatalf("Expected tool name 'bash', got '%s'", content[0].ToolName)
}
// Verify the tool input
var args map[string]any
if err := json.Unmarshal(content[0].ToolInput, &args); err != nil {
t.Fatalf("Failed to unmarshal tool input: %v", err)
}
cmd, ok := args["command"]
if !ok {
t.Fatalf("Expected 'command' argument, not found")
}
if cmd != "cat README.md" {
t.Fatalf("Expected command 'cat README.md', got '%s'", cmd)
}
}
func TestGeminiHeaderCapture(t *testing.T) {
// Create a mock HTTP client that returns a response with headers
mockClient := &http.Client{
Transport: &mockRoundTripper{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"Skaband-Cost-Microcents": []string{"123456"},
},
Body: io.NopCloser(bytes.NewBufferString(`{
"candidates": [{
"content": {
"parts": [{
"text": "Hello!"
}]
}
}]
}`)),
},
},
}
// Create a Gemini model with the mock client
model := gemini.Model{
Model: "models/gemini-test",
APIKey: "test-key",
HTTPC: mockClient,
Endpoint: "https://test.googleapis.com",
}
// Make a request
req := &gemini.Request{
Contents: []gemini.Content{
{
Parts: []gemini.Part{{Text: "Hello"}},
Role: "user",
},
},
}
ctx := context.Background()
res, err := model.GenerateContent(ctx, req)
if err != nil {
t.Fatalf("Failed to generate content: %v", err)
}
// Verify that headers were captured
headers := res.Header()
if headers == nil {
t.Fatalf("Expected headers to be captured, got nil")
}
// Check for the cost header
costHeader := headers.Get("Skaband-Cost-Microcents")
if costHeader != "123456" {
t.Fatalf("Expected cost header '123456', got '%s'", costHeader)
}
// Verify that llm.CostUSDFromResponse works with these headers
costUSD := llm.CostUSDFromResponse(headers)
expectedCost := 0.00123456 // 123456 microcents / 100,000,000
if costUSD != expectedCost {
t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, costUSD)
}
}
// mockRoundTripper is a mock HTTP transport for testing
type mockRoundTripper struct {
response *http.Response
}
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return m.response, nil
}
func TestHeaderCostIntegration(t *testing.T) {
// Create a mock HTTP client that returns a response with cost headers
mockClient := &http.Client{
Transport: &mockRoundTripper{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"Skaband-Cost-Microcents": []string{"50000"}, // 0.5 USD
},
Body: io.NopCloser(bytes.NewBufferString(`{
"candidates": [{
"content": {
"parts": [{
"text": "Test response"
}]
}
}]
}`)),
},
},
}
// Create a Gem service with the mock client
service := &Service{
Model: "gemini-test",
APIKey: "test-key",
HTTPC: mockClient,
URL: "https://test.googleapis.com",
}
// Create a request
ir := &llm.Request{
Messages: []llm.Message{
{
Role: llm.MessageRoleUser,
Content: []llm.Content{
{
Type: llm.ContentTypeText,
Text: "Hello",
},
},
},
},
}
// Make the request
ctx := context.Background()
res, err := service.Do(ctx, ir)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
// Verify that the cost was captured from headers
expectedCost := 0.0005 // 50000 microcents / 100,000,000
if res.Usage.CostUSD != expectedCost {
t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, res.Usage.CostUSD)
}
// Verify token counts are still estimated
if res.Usage.InputTokens == 0 {
t.Fatalf("Expected input tokens to be estimated, got 0")
}
if res.Usage.OutputTokens == 0 {
t.Fatalf("Expected output tokens to be estimated, got 0")
}
}