blob: 002b4d154fefa28b831cb7907ebb79331b83b86c [file] [log] [blame]
David Crawshaw5a234062025-05-04 17:52:08 +00001package gem
2
3import (
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -07004 "bytes"
5 "context"
David Crawshaw5a234062025-05-04 17:52:08 +00006 "encoding/json"
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -07007 "io"
8 "net/http"
David Crawshaw5a234062025-05-04 17:52:08 +00009 "testing"
10
11 "sketch.dev/llm"
12 "sketch.dev/llm/gem/gemini"
13)
14
15func TestBuildGeminiRequest(t *testing.T) {
16 // Create a service
17 service := &Service{
18 Model: DefaultModel,
19 APIKey: "test-api-key",
20 }
21
22 // Create a simple request
23 req := &llm.Request{
24 Messages: []llm.Message{
25 {
26 Role: llm.MessageRoleUser,
27 Content: []llm.Content{
28 {
29 Type: llm.ContentTypeText,
30 Text: "Hello, world!",
31 },
32 },
33 },
34 },
35 System: []llm.SystemContent{
36 {
37 Text: "You are a helpful assistant.",
38 },
39 },
40 }
41
42 // Build the Gemini request
43 gemReq, err := service.buildGeminiRequest(req)
44 if err != nil {
45 t.Fatalf("Failed to build Gemini request: %v", err)
46 }
47
48 // Verify the system instruction
49 if gemReq.SystemInstruction == nil {
50 t.Fatalf("Expected system instruction, got nil")
51 }
52 if len(gemReq.SystemInstruction.Parts) != 1 {
53 t.Fatalf("Expected 1 system part, got %d", len(gemReq.SystemInstruction.Parts))
54 }
55 if gemReq.SystemInstruction.Parts[0].Text != "You are a helpful assistant." {
56 t.Fatalf("Expected system text 'You are a helpful assistant.', got '%s'", gemReq.SystemInstruction.Parts[0].Text)
57 }
58
59 // Verify the contents
60 if len(gemReq.Contents) != 1 {
61 t.Fatalf("Expected 1 content, got %d", len(gemReq.Contents))
62 }
63 if len(gemReq.Contents[0].Parts) != 1 {
64 t.Fatalf("Expected 1 part, got %d", len(gemReq.Contents[0].Parts))
65 }
66 if gemReq.Contents[0].Parts[0].Text != "Hello, world!" {
67 t.Fatalf("Expected text 'Hello, world!', got '%s'", gemReq.Contents[0].Parts[0].Text)
68 }
69 // Verify the role is set correctly
70 if gemReq.Contents[0].Role != "user" {
71 t.Fatalf("Expected role 'user', got '%s'", gemReq.Contents[0].Role)
72 }
73}
74
75func TestConvertToolSchemas(t *testing.T) {
76 // Create a simple tool with a JSON schema
77 schema := `{
78 "type": "object",
79 "properties": {
80 "name": {
81 "type": "string",
82 "description": "The name of the person"
83 },
84 "age": {
85 "type": "integer",
86 "description": "The age of the person"
87 }
88 },
89 "required": ["name"]
90 }`
91
92 tools := []*llm.Tool{
93 {
94 Name: "get_person",
95 Description: "Get information about a person",
96 InputSchema: json.RawMessage(schema),
97 },
98 }
99
100 // Convert the tools
101 decls, err := convertToolSchemas(tools)
102 if err != nil {
103 t.Fatalf("Failed to convert tool schemas: %v", err)
104 }
105
106 // Verify the result
107 if len(decls) != 1 {
108 t.Fatalf("Expected 1 declaration, got %d", len(decls))
109 }
110 if decls[0].Name != "get_person" {
111 t.Fatalf("Expected name 'get_person', got '%s'", decls[0].Name)
112 }
113 if decls[0].Description != "Get information about a person" {
114 t.Fatalf("Expected description 'Get information about a person', got '%s'", decls[0].Description)
115 }
116
117 // Verify the schema properties
118 if decls[0].Parameters.Type != 6 { // DataTypeOBJECT
119 t.Fatalf("Expected type OBJECT (6), got %d", decls[0].Parameters.Type)
120 }
121 if len(decls[0].Parameters.Properties) != 2 {
122 t.Fatalf("Expected 2 properties, got %d", len(decls[0].Parameters.Properties))
123 }
124 if decls[0].Parameters.Properties["name"].Type != 1 { // DataTypeSTRING
125 t.Fatalf("Expected name type STRING (1), got %d", decls[0].Parameters.Properties["name"].Type)
126 }
127 if decls[0].Parameters.Properties["age"].Type != 3 { // DataTypeINTEGER
128 t.Fatalf("Expected age type INTEGER (3), got %d", decls[0].Parameters.Properties["age"].Type)
129 }
130 if len(decls[0].Parameters.Required) != 1 || decls[0].Parameters.Required[0] != "name" {
131 t.Fatalf("Expected required field 'name', got %v", decls[0].Parameters.Required)
132 }
133}
134
135func TestService_Do_MockResponse(t *testing.T) {
136 // This is a mock test that doesn't make actual API calls
137 // Create a mock HTTP client that returns a predefined response
138
139 // Create a Service with a mock client
140 service := &Service{
141 Model: DefaultModel,
142 APIKey: "test-api-key",
143 // We would use a mock HTTP client here in a real test
144 }
145
146 // Create a sample request
147 ir := &llm.Request{
148 Messages: []llm.Message{
149 {
150 Role: llm.MessageRoleUser,
151 Content: []llm.Content{
152 {
153 Type: llm.ContentTypeText,
154 Text: "Hello",
155 },
156 },
157 },
158 },
159 }
160
161 // In a real test, we would execute service.Do with a mock client
162 // and verify the response structure
163
164 // For now, we'll just test that buildGeminiRequest works correctly
165 _, err := service.buildGeminiRequest(ir)
166 if err != nil {
167 t.Fatalf("Failed to build request: %v", err)
168 }
169}
170
171func TestConvertResponseWithToolCall(t *testing.T) {
172 // Create a mock Gemini response with a function call
173 gemRes := &gemini.Response{
174 Candidates: []gemini.Candidate{
175 {
176 Content: gemini.Content{
177 Parts: []gemini.Part{
178 {
179 FunctionCall: &gemini.FunctionCall{
180 Name: "bash",
181 Args: map[string]any{
182 "command": "cat README.md",
183 },
184 },
185 },
186 },
187 },
188 },
189 },
190 }
191
192 // Convert the response
193 content := convertGeminiResponseToContent(gemRes)
194
195 // Verify that content has a tool use
196 if len(content) != 1 {
197 t.Fatalf("Expected 1 content item, got %d", len(content))
198 }
199
200 if content[0].Type != llm.ContentTypeToolUse {
201 t.Fatalf("Expected content type ToolUse, got %s", content[0].Type)
202 }
203
204 if content[0].ToolName != "bash" {
205 t.Fatalf("Expected tool name 'bash', got '%s'", content[0].ToolName)
206 }
207
208 // Verify the tool input
209 var args map[string]any
210 if err := json.Unmarshal(content[0].ToolInput, &args); err != nil {
211 t.Fatalf("Failed to unmarshal tool input: %v", err)
212 }
213
214 cmd, ok := args["command"]
215 if !ok {
216 t.Fatalf("Expected 'command' argument, not found")
217 }
218
219 if cmd != "cat README.md" {
220 t.Fatalf("Expected command 'cat README.md', got '%s'", cmd)
221 }
222}
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -0700223
224func TestGeminiHeaderCapture(t *testing.T) {
225 // Create a mock HTTP client that returns a response with headers
226 mockClient := &http.Client{
227 Transport: &mockRoundTripper{
228 response: &http.Response{
229 StatusCode: http.StatusOK,
230 Header: http.Header{
231 "Content-Type": []string{"application/json"},
232 "Skaband-Cost-Microcents": []string{"123456"},
233 },
234 Body: io.NopCloser(bytes.NewBufferString(`{
235 "candidates": [{
236 "content": {
237 "parts": [{
238 "text": "Hello!"
239 }]
240 }
241 }]
242 }`)),
243 },
244 },
245 }
246
247 // Create a Gemini model with the mock client
248 model := gemini.Model{
249 Model: "models/gemini-test",
250 APIKey: "test-key",
251 HTTPC: mockClient,
252 Endpoint: "https://test.googleapis.com",
253 }
254
255 // Make a request
256 req := &gemini.Request{
257 Contents: []gemini.Content{
258 {
259 Parts: []gemini.Part{{Text: "Hello"}},
260 Role: "user",
261 },
262 },
263 }
264
265 ctx := context.Background()
266 res, err := model.GenerateContent(ctx, req)
267 if err != nil {
268 t.Fatalf("Failed to generate content: %v", err)
269 }
270
271 // Verify that headers were captured
272 headers := res.Header()
273 if headers == nil {
274 t.Fatalf("Expected headers to be captured, got nil")
275 }
276
277 // Check for the cost header
278 costHeader := headers.Get("Skaband-Cost-Microcents")
279 if costHeader != "123456" {
280 t.Fatalf("Expected cost header '123456', got '%s'", costHeader)
281 }
282
283 // Verify that llm.CostUSDFromResponse works with these headers
284 costUSD := llm.CostUSDFromResponse(headers)
285 expectedCost := 0.00123456 // 123456 microcents / 100,000,000
286 if costUSD != expectedCost {
287 t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, costUSD)
288 }
289}
290
291// mockRoundTripper is a mock HTTP transport for testing
292type mockRoundTripper struct {
293 response *http.Response
294}
295
296func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
297 return m.response, nil
298}
299
300func TestHeaderCostIntegration(t *testing.T) {
301 // Create a mock HTTP client that returns a response with cost headers
302 mockClient := &http.Client{
303 Transport: &mockRoundTripper{
304 response: &http.Response{
305 StatusCode: http.StatusOK,
306 Header: http.Header{
307 "Content-Type": []string{"application/json"},
308 "Skaband-Cost-Microcents": []string{"50000"}, // 0.5 USD
309 },
310 Body: io.NopCloser(bytes.NewBufferString(`{
311 "candidates": [{
312 "content": {
313 "parts": [{
314 "text": "Test response"
315 }]
316 }
317 }]
318 }`)),
319 },
320 },
321 }
322
323 // Create a Gem service with the mock client
324 service := &Service{
325 Model: "gemini-test",
326 APIKey: "test-key",
327 HTTPC: mockClient,
328 URL: "https://test.googleapis.com",
329 }
330
331 // Create a request
332 ir := &llm.Request{
333 Messages: []llm.Message{
334 {
335 Role: llm.MessageRoleUser,
336 Content: []llm.Content{
337 {
338 Type: llm.ContentTypeText,
339 Text: "Hello",
340 },
341 },
342 },
343 },
344 }
345
346 // Make the request
347 ctx := context.Background()
348 res, err := service.Do(ctx, ir)
349 if err != nil {
350 t.Fatalf("Failed to make request: %v", err)
351 }
352
353 // Verify that the cost was captured from headers
354 expectedCost := 0.0005 // 50000 microcents / 100,000,000
355 if res.Usage.CostUSD != expectedCost {
356 t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, res.Usage.CostUSD)
357 }
358
359 // Verify token counts are still estimated
360 if res.Usage.InputTokens == 0 {
361 t.Fatalf("Expected input tokens to be estimated, got 0")
362 }
363 if res.Usage.OutputTokens == 0 {
364 t.Fatalf("Expected output tokens to be estimated, got 0")
365 }
366}