blob: fa9f3b65560a2834a578e325c7b923da559e2ffa [file] [log] [blame]
David Crawshaw5a234062025-05-04 17:52:08 +00001package gem
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "math/rand"
10 "net/http"
11 "strings"
12 "time"
13
14 "sketch.dev/llm"
15 "sketch.dev/llm/gem/gemini"
16)
17
18const (
19 DefaultModel = "gemini-2.5-pro-preview-03-25"
20 DefaultMaxTokens = 8192
21 GeminiAPIKeyEnv = "GEMINI_API_KEY"
22)
23
24// Service provides Gemini completions.
25// Fields should not be altered concurrently with calling any method on Service.
26type Service struct {
27 HTTPC *http.Client // defaults to http.DefaultClient if nil
David Crawshaw3659d872025-05-05 17:52:23 -070028 URL string // Gemini API URL, uses the gemini package default if empty
David Crawshaw5a234062025-05-04 17:52:08 +000029 APIKey string // must be non-empty
30 Model string // defaults to DefaultModel if empty
31 MaxTokens int // defaults to DefaultMaxTokens if zero
32}
33
34var _ llm.Service = (*Service)(nil)
35
36// These maps convert between Sketch's llm package and Gemini API formats
37var fromLLMRole = map[llm.MessageRole]string{
38 llm.MessageRoleAssistant: "model",
39 llm.MessageRoleUser: "user",
40}
41
42// convertToolSchemas converts Sketch's llm.Tool schemas to Gemini's schema format
43func convertToolSchemas(tools []*llm.Tool) ([]gemini.FunctionDeclaration, error) {
44 if len(tools) == 0 {
45 return nil, nil
46 }
47
48 var decls []gemini.FunctionDeclaration
49 for _, tool := range tools {
50 // Parse the schema from raw JSON
51 var schemaJSON map[string]any
52 if err := json.Unmarshal(tool.InputSchema, &schemaJSON); err != nil {
53 return nil, fmt.Errorf("failed to unmarshal tool %s schema: %w", tool.Name, err)
54 }
55 decls = append(decls, gemini.FunctionDeclaration{
56 Name: tool.Name,
57 Description: tool.Description,
58 Parameters: convertJSONSchemaToGeminiSchema(schemaJSON),
59 })
60 }
61
62 return decls, nil
63}
64
65// convertJSONSchemaToGeminiSchema converts a JSON schema to Gemini's schema format
66func convertJSONSchemaToGeminiSchema(schemaJSON map[string]any) gemini.Schema {
67 schema := gemini.Schema{}
68
69 // Set the type based on the JSON schema type
70 if typeVal, ok := schemaJSON["type"].(string); ok {
71 switch typeVal {
72 case "string":
73 schema.Type = gemini.DataTypeSTRING
74 case "number":
75 schema.Type = gemini.DataTypeNUMBER
76 case "integer":
77 schema.Type = gemini.DataTypeINTEGER
78 case "boolean":
79 schema.Type = gemini.DataTypeBOOLEAN
80 case "array":
81 schema.Type = gemini.DataTypeARRAY
82 case "object":
83 schema.Type = gemini.DataTypeOBJECT
84 default:
85 schema.Type = gemini.DataTypeSTRING // Default to string for unknown types
86 }
87 }
88
89 // Set description if available
90 if desc, ok := schemaJSON["description"].(string); ok {
91 schema.Description = desc
92 }
93
94 // Handle enum values
95 if enumValues, ok := schemaJSON["enum"].([]any); ok {
96 schema.Enum = make([]string, len(enumValues))
97 for i, v := range enumValues {
98 if strVal, ok := v.(string); ok {
99 schema.Enum[i] = strVal
100 } else {
101 // Convert non-string values to string
102 valBytes, _ := json.Marshal(v)
103 schema.Enum[i] = string(valBytes)
104 }
105 }
106 }
107
108 // Handle object properties
109 if properties, ok := schemaJSON["properties"].(map[string]any); ok && schema.Type == gemini.DataTypeOBJECT {
110 schema.Properties = make(map[string]gemini.Schema)
111 for propName, propSchema := range properties {
112 if propSchemaMap, ok := propSchema.(map[string]any); ok {
113 schema.Properties[propName] = convertJSONSchemaToGeminiSchema(propSchemaMap)
114 }
115 }
116 }
117
118 // Handle required properties
119 if required, ok := schemaJSON["required"].([]any); ok {
120 schema.Required = make([]string, len(required))
121 for i, r := range required {
122 if strVal, ok := r.(string); ok {
123 schema.Required[i] = strVal
124 }
125 }
126 }
127
128 // Handle array items
129 if items, ok := schemaJSON["items"].(map[string]any); ok && schema.Type == gemini.DataTypeARRAY {
130 itemSchema := convertJSONSchemaToGeminiSchema(items)
131 schema.Items = &itemSchema
132 }
133
134 // Handle minimum/maximum items for arrays
135 if minItems, ok := schemaJSON["minItems"].(float64); ok {
136 schema.MinItems = fmt.Sprintf("%d", int(minItems))
137 }
138 if maxItems, ok := schemaJSON["maxItems"].(float64); ok {
139 schema.MaxItems = fmt.Sprintf("%d", int(maxItems))
140 }
141
142 return schema
143}
144
145// buildGeminiRequest converts Sketch's llm.Request to Gemini's request format
146func (s *Service) buildGeminiRequest(req *llm.Request) (*gemini.Request, error) {
147 gemReq := &gemini.Request{}
148
149 // Add system instruction if provided
150 if len(req.System) > 0 {
151 // Combine all system messages into a single system instruction
152 systemText := ""
153 for i, sys := range req.System {
154 if i > 0 && systemText != "" && sys.Text != "" {
155 systemText += "\n"
156 }
157 systemText += sys.Text
158 }
159
160 if systemText != "" {
161 gemReq.SystemInstruction = &gemini.Content{
162 Parts: []gemini.Part{{Text: systemText}},
163 }
164 }
165 }
166
167 // Convert messages to Gemini content format
168 for _, msg := range req.Messages {
169 // Set the role based on the message role
170 role, ok := fromLLMRole[msg.Role]
171 if !ok {
172 return nil, fmt.Errorf("unsupported message role: %v", msg.Role)
173 }
174
175 content := gemini.Content{
176 Role: role,
177 }
178
179 // Store tool usage information to correlate tool uses with responses
180 toolNameToID := make(map[string]string)
181
182 // First pass: collect tool use IDs for correlation
183 for _, c := range msg.Content {
184 if c.Type == llm.ContentTypeToolUse && c.ID != "" {
185 toolNameToID[c.ToolName] = c.ID
186 }
187 }
188
189 // Map each content item to Gemini's format
190 for _, c := range msg.Content {
191 switch c.Type {
192 case llm.ContentTypeText, llm.ContentTypeThinking, llm.ContentTypeRedactedThinking:
193 // Simple text content
194 content.Parts = append(content.Parts, gemini.Part{
195 Text: c.Text,
196 })
197 case llm.ContentTypeToolUse:
198 // Tool use becomes a function call
199 var args map[string]any
200 if err := json.Unmarshal(c.ToolInput, &args); err != nil {
201 return nil, fmt.Errorf("failed to unmarshal tool input: %w", err)
202 }
203
204 // Make sure we have a valid ID for this tool use
205 if c.ID == "" {
206 c.ID = fmt.Sprintf("gemini_tool_%s_%d", c.ToolName, time.Now().UnixNano())
207 }
208
209 // Save the ID for this tool name for future correlation
210 toolNameToID[c.ToolName] = c.ID
211
212 slog.DebugContext(context.Background(), "gemini_preparing_tool_use",
213 "tool_name", c.ToolName,
214 "tool_id", c.ID,
215 "input", string(c.ToolInput))
216
217 content.Parts = append(content.Parts, gemini.Part{
218 FunctionCall: &gemini.FunctionCall{
219 Name: c.ToolName,
220 Args: args,
221 },
222 })
223 case llm.ContentTypeToolResult:
224 // Tool result becomes a function response
225 // Create a map for the response
226 response := map[string]any{
227 "result": c.ToolResult,
228 "error": c.ToolError,
229 }
230
231 // Determine the function name to use - this is critical
232 funcName := ""
233
234 // First try to find the function name from a stored toolUseID if we have one
235 if c.ToolUseID != "" {
236 // Try to derive the tool name from the previous tools we've seen
237 for name, id := range toolNameToID {
238 if id == c.ToolUseID {
239 funcName = name
240 break
241 }
242 }
243 }
244
245 // Fallback options if we couldn't find the tool name
246 if funcName == "" {
247 // Try the tool name directly
248 if c.ToolName != "" {
249 funcName = c.ToolName
250 } else {
251 // Last resort fallback
252 funcName = "default_tool"
253 }
254 }
255
256 slog.DebugContext(context.Background(), "gemini_preparing_tool_result",
257 "tool_use_id", c.ToolUseID,
258 "mapped_func_name", funcName,
259 "result_length", len(c.ToolResult))
260
261 content.Parts = append(content.Parts, gemini.Part{
262 FunctionResponse: &gemini.FunctionResponse{
263 Name: funcName,
264 Response: response,
265 },
266 })
267 }
268 }
269
270 gemReq.Contents = append(gemReq.Contents, content)
271 }
272
273 // Handle tools/functions
274 if len(req.Tools) > 0 {
275 // Convert tool schemas
276 decls, err := convertToolSchemas(req.Tools)
277 if err != nil {
278 return nil, fmt.Errorf("failed to convert tool schemas: %w", err)
279 }
280 if len(decls) > 0 {
281 gemReq.Tools = []gemini.Tool{{FunctionDeclarations: decls}}
282 }
283 }
284
285 return gemReq, nil
286}
287
288// convertGeminiResponsesToContent converts a Gemini response to llm.Content
289func convertGeminiResponseToContent(res *gemini.Response) []llm.Content {
290 if res == nil || len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 {
291 return []llm.Content{{
292 Type: llm.ContentTypeText,
293 Text: "",
294 }}
295 }
296
297 var contents []llm.Content
298
299 // Process each part in the first candidate's content
300 for i, part := range res.Candidates[0].Content.Parts {
301 // Log the part type for debugging
302 slog.DebugContext(context.Background(), "processing_gemini_part",
303 "index", i,
304 "has_text", part.Text != "",
305 "has_function_call", part.FunctionCall != nil,
306 "has_function_response", part.FunctionResponse != nil)
307
308 if part.Text != "" {
309 // Simple text response
310 contents = append(contents, llm.Content{
311 Type: llm.ContentTypeText,
312 Text: part.Text,
313 })
314 } else if part.FunctionCall != nil {
315 // Function call (tool use)
316 args, err := json.Marshal(part.FunctionCall.Args)
317 if err != nil {
318 // If we can't marshal, use empty args
319 slog.DebugContext(context.Background(), "gemini_failed_to_markshal_args",
320 "tool_name", part.FunctionCall.Name,
321 "args", string(args),
322 "err", err.Error(),
323 )
324 args = []byte("{}")
325 }
326
327 // Generate a unique ID for this tool use that includes the function name
328 // to make it easier to correlate with responses
329 toolID := fmt.Sprintf("gemini_tool_%s_%d", part.FunctionCall.Name, time.Now().UnixNano())
330
331 contents = append(contents, llm.Content{
332 ID: toolID,
333 Type: llm.ContentTypeToolUse,
334 ToolName: part.FunctionCall.Name,
335 ToolInput: json.RawMessage(args),
336 })
337
338 slog.DebugContext(context.Background(), "gemini_tool_call",
339 "tool_id", toolID,
340 "tool_name", part.FunctionCall.Name,
341 "args", string(args))
342 } else if part.FunctionResponse != nil {
343 // We shouldn't normally get function responses from the model, but just in case
344 respData, _ := json.Marshal(part.FunctionResponse.Response)
345 slog.DebugContext(context.Background(), "unexpected_function_response",
346 "name", part.FunctionResponse.Name,
347 "response", string(respData))
348 }
349 }
350
351 // If no content was added, add an empty text content
352 if len(contents) == 0 {
353 slog.DebugContext(context.Background(), "empty_gemini_response", "adding_empty_text", true)
354 contents = append(contents, llm.Content{
355 Type: llm.ContentTypeText,
356 Text: "",
357 })
358 }
359
360 return contents
361}
362
363// Gemini doesn't provide usage info directly, so we need to estimate it
364// ensureToolIDs makes sure all tool uses have proper IDs
365func ensureToolIDs(contents []llm.Content) {
366 for i, content := range contents {
367 if content.Type == llm.ContentTypeToolUse && content.ID == "" {
368 // Generate a stable ID using the tool name and timestamp
369 contents[i].ID = fmt.Sprintf("gemini_tool_%s_%d", content.ToolName, time.Now().UnixNano())
370 slog.DebugContext(context.Background(), "assigned_missing_tool_id",
371 "tool_name", content.ToolName,
372 "new_id", contents[i].ID)
373 }
374 }
375}
376
377func calculateUsage(req *gemini.Request, res *gemini.Response) llm.Usage {
378 // Very rough estimation of token counts
379 var inputTokens uint64
380 var outputTokens uint64
381
382 // Count system tokens
383 if req.SystemInstruction != nil {
384 for _, part := range req.SystemInstruction.Parts {
385 if part.Text != "" {
386 // Very rough estimation: 1 token per 4 characters
387 inputTokens += uint64(len(part.Text)) / 4
388 }
389 }
390 }
391
392 // Count input tokens
393 for _, content := range req.Contents {
394 for _, part := range content.Parts {
395 if part.Text != "" {
396 inputTokens += uint64(len(part.Text)) / 4
397 } else if part.FunctionCall != nil {
398 // Estimate function call tokens
399 argBytes, _ := json.Marshal(part.FunctionCall.Args)
400 inputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
401 } else if part.FunctionResponse != nil {
402 // Estimate function response tokens
403 resBytes, _ := json.Marshal(part.FunctionResponse.Response)
404 inputTokens += uint64(len(part.FunctionResponse.Name)+len(resBytes)) / 4
405 }
406 }
407 }
408
409 // Count output tokens
410 if res != nil && len(res.Candidates) > 0 {
411 for _, part := range res.Candidates[0].Content.Parts {
412 if part.Text != "" {
413 outputTokens += uint64(len(part.Text)) / 4
414 } else if part.FunctionCall != nil {
415 // Estimate function call tokens
416 argBytes, _ := json.Marshal(part.FunctionCall.Args)
417 outputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
418 }
419 }
420 }
421
422 // For Gemini 2.5 Pro Preview pricing: $1.25 per 1M input tokens, $10 per 1M output tokens
423 // Convert to dollars
424 costUSD := float64(inputTokens)*1.25/1_000_000.0 + float64(outputTokens)*10/1_000_000.0
425
426 return llm.Usage{
427 InputTokens: inputTokens,
428 OutputTokens: outputTokens,
429 CostUSD: costUSD,
430 }
431}
432
433// Do sends a request to Gemini.
434func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
435 // Log the incoming request for debugging
436 slog.DebugContext(ctx, "gemini_request",
437 "message_count", len(ir.Messages),
438 "tool_count", len(ir.Tools),
439 "system_count", len(ir.System))
440
441 // Log tool-related information if any tools are present
442 if len(ir.Tools) > 0 {
443 var toolNames []string
444 for _, tool := range ir.Tools {
445 toolNames = append(toolNames, tool.Name)
446 }
447 slog.DebugContext(ctx, "gemini_tools", "tools", toolNames)
448 }
449
450 // Log details about the messages being sent
451 for i, msg := range ir.Messages {
452 contentTypes := make([]string, len(msg.Content))
453 for j, c := range msg.Content {
454 contentTypes[j] = c.Type.String()
455
456 // Log tool-related content with more details
457 if c.Type == llm.ContentTypeToolUse {
458 slog.DebugContext(ctx, "gemini_tool_use",
459 "message_idx", i,
460 "content_idx", j,
461 "tool_name", c.ToolName,
462 "tool_input", string(c.ToolInput))
463 } else if c.Type == llm.ContentTypeToolResult {
464 slog.DebugContext(ctx, "gemini_tool_result",
465 "message_idx", i,
466 "content_idx", j,
467 "tool_use_id", c.ToolUseID,
468 "tool_error", c.ToolError,
469 "result_length", len(c.ToolResult))
470 }
471 }
472 slog.DebugContext(ctx, "gemini_message",
473 "idx", i,
474 "role", msg.Role.String(),
475 "content_types", contentTypes)
476 }
477 // Build the Gemini request
478 gemReq, err := s.buildGeminiRequest(ir)
479 if err != nil {
480 return nil, fmt.Errorf("failed to build Gemini request: %w", err)
481 }
482
483 // Log the structured Gemini request for debugging
484 if reqJSON, err := json.MarshalIndent(gemReq, "", " "); err == nil {
485 slog.DebugContext(ctx, "gemini_request_json", "request", string(reqJSON))
486 }
487
488 // Create a Gemini model instance
489 model := gemini.Model{
David Crawshaw3659d872025-05-05 17:52:23 -0700490 Model: "models/" + cmp.Or(s.Model, DefaultModel),
491 Endpoint: s.URL,
492 APIKey: s.APIKey,
493 HTTPC: cmp.Or(s.HTTPC, http.DefaultClient),
David Crawshaw5a234062025-05-04 17:52:08 +0000494 }
495
496 // Send the request to Gemini with retry logic
497 startTime := time.Now()
498 endTime := startTime // Initialize endTime
499 var gemRes *gemini.Response
500
501 // Retry mechanism for handling server errors and rate limiting
502 backoff := []time.Duration{1 * time.Second, 3 * time.Second, 5 * time.Second, 10 * time.Second}
503 for attempts := 0; attempts <= len(backoff); attempts++ {
504 gemApiErr := error(nil)
505 gemRes, gemApiErr = model.GenerateContent(ctx, gemReq)
506 endTime = time.Now()
507
508 if gemApiErr == nil {
509 // Successful response
510 // Log the structured Gemini response
511 if resJSON, err := json.MarshalIndent(gemRes, "", " "); err == nil {
512 slog.DebugContext(ctx, "gemini_response_json", "response", string(resJSON))
513 }
514 break
515 }
516
517 if attempts == len(backoff) {
518 // We've exhausted all retry attempts
519 return nil, fmt.Errorf("gemini: API error after %d attempts: %w", attempts, gemApiErr)
520 }
521
522 // Check if the error is retryable (e.g., server error or rate limiting)
523 if strings.Contains(gemApiErr.Error(), "429") || strings.Contains(gemApiErr.Error(), "5") {
524 // Rate limited or server error - wait and retry
525 random := time.Duration(rand.Int63n(int64(time.Second)))
526 sleep := backoff[attempts] + random
527 slog.WarnContext(ctx, "gemini_request_retry", "error", gemApiErr.Error(), "attempt", attempts+1, "sleep", sleep)
528 time.Sleep(sleep)
529 continue
530 }
531
532 // Non-retryable error
533 return nil, fmt.Errorf("gemini: API error: %w", gemApiErr)
534 }
535
536 content := convertGeminiResponseToContent(gemRes)
537
538 ensureToolIDs(content)
539
540 usage := calculateUsage(gemReq, gemRes)
541
542 stopReason := llm.StopReasonEndTurn
543 for _, part := range content {
544 if part.Type == llm.ContentTypeToolUse {
545 stopReason = llm.StopReasonToolUse
546 slog.DebugContext(ctx, "gemini_tool_use_detected",
547 "setting_stop_reason", "llm.StopReasonToolUse",
548 "tool_name", part.ToolName)
549 break
550 }
551 }
552
553 return &llm.Response{
554 Role: llm.MessageRoleAssistant,
555 Model: s.Model,
556 Content: content,
557 StopReason: stopReason,
558 Usage: usage,
559 StartTime: &startTime,
560 EndTime: &endTime,
561 }, nil
562}