blob: 1621f2d22318845249630c7a1ff0e2eec107f7fe [file] [log] [blame]
David Crawshaw5a234062025-05-04 17:52:08 +00001package gem
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "math/rand"
10 "net/http"
11 "strings"
12 "time"
13
14 "sketch.dev/llm"
15 "sketch.dev/llm/gem/gemini"
16)
17
18const (
David Crawshawa9d87aa2025-05-06 10:08:56 -070019 DefaultModel = "gemini-2.5-pro-preview-03-25"
20 GeminiAPIKeyEnv = "GEMINI_API_KEY"
David Crawshaw5a234062025-05-04 17:52:08 +000021)
22
23// Service provides Gemini completions.
24// Fields should not be altered concurrently with calling any method on Service.
25type Service struct {
David Crawshawa9d87aa2025-05-06 10:08:56 -070026 HTTPC *http.Client // defaults to http.DefaultClient if nil
27 URL string // Gemini API URL, uses the gemini package default if empty
28 APIKey string // must be non-empty
29 Model string // defaults to DefaultModel if empty
David Crawshaw5a234062025-05-04 17:52:08 +000030}
31
32var _ llm.Service = (*Service)(nil)
33
34// These maps convert between Sketch's llm package and Gemini API formats
35var fromLLMRole = map[llm.MessageRole]string{
36 llm.MessageRoleAssistant: "model",
37 llm.MessageRoleUser: "user",
38}
39
40// convertToolSchemas converts Sketch's llm.Tool schemas to Gemini's schema format
41func convertToolSchemas(tools []*llm.Tool) ([]gemini.FunctionDeclaration, error) {
42 if len(tools) == 0 {
43 return nil, nil
44 }
45
46 var decls []gemini.FunctionDeclaration
47 for _, tool := range tools {
48 // Parse the schema from raw JSON
49 var schemaJSON map[string]any
50 if err := json.Unmarshal(tool.InputSchema, &schemaJSON); err != nil {
51 return nil, fmt.Errorf("failed to unmarshal tool %s schema: %w", tool.Name, err)
52 }
53 decls = append(decls, gemini.FunctionDeclaration{
54 Name: tool.Name,
55 Description: tool.Description,
56 Parameters: convertJSONSchemaToGeminiSchema(schemaJSON),
57 })
58 }
59
60 return decls, nil
61}
62
63// convertJSONSchemaToGeminiSchema converts a JSON schema to Gemini's schema format
64func convertJSONSchemaToGeminiSchema(schemaJSON map[string]any) gemini.Schema {
65 schema := gemini.Schema{}
66
67 // Set the type based on the JSON schema type
68 if typeVal, ok := schemaJSON["type"].(string); ok {
69 switch typeVal {
70 case "string":
71 schema.Type = gemini.DataTypeSTRING
72 case "number":
73 schema.Type = gemini.DataTypeNUMBER
74 case "integer":
75 schema.Type = gemini.DataTypeINTEGER
76 case "boolean":
77 schema.Type = gemini.DataTypeBOOLEAN
78 case "array":
79 schema.Type = gemini.DataTypeARRAY
80 case "object":
81 schema.Type = gemini.DataTypeOBJECT
82 default:
83 schema.Type = gemini.DataTypeSTRING // Default to string for unknown types
84 }
85 }
86
87 // Set description if available
88 if desc, ok := schemaJSON["description"].(string); ok {
89 schema.Description = desc
90 }
91
92 // Handle enum values
93 if enumValues, ok := schemaJSON["enum"].([]any); ok {
94 schema.Enum = make([]string, len(enumValues))
95 for i, v := range enumValues {
96 if strVal, ok := v.(string); ok {
97 schema.Enum[i] = strVal
98 } else {
99 // Convert non-string values to string
100 valBytes, _ := json.Marshal(v)
101 schema.Enum[i] = string(valBytes)
102 }
103 }
104 }
105
106 // Handle object properties
107 if properties, ok := schemaJSON["properties"].(map[string]any); ok && schema.Type == gemini.DataTypeOBJECT {
108 schema.Properties = make(map[string]gemini.Schema)
109 for propName, propSchema := range properties {
110 if propSchemaMap, ok := propSchema.(map[string]any); ok {
111 schema.Properties[propName] = convertJSONSchemaToGeminiSchema(propSchemaMap)
112 }
113 }
114 }
115
116 // Handle required properties
117 if required, ok := schemaJSON["required"].([]any); ok {
118 schema.Required = make([]string, len(required))
119 for i, r := range required {
120 if strVal, ok := r.(string); ok {
121 schema.Required[i] = strVal
122 }
123 }
124 }
125
126 // Handle array items
127 if items, ok := schemaJSON["items"].(map[string]any); ok && schema.Type == gemini.DataTypeARRAY {
128 itemSchema := convertJSONSchemaToGeminiSchema(items)
129 schema.Items = &itemSchema
130 }
131
132 // Handle minimum/maximum items for arrays
133 if minItems, ok := schemaJSON["minItems"].(float64); ok {
134 schema.MinItems = fmt.Sprintf("%d", int(minItems))
135 }
136 if maxItems, ok := schemaJSON["maxItems"].(float64); ok {
137 schema.MaxItems = fmt.Sprintf("%d", int(maxItems))
138 }
139
140 return schema
141}
142
143// buildGeminiRequest converts Sketch's llm.Request to Gemini's request format
144func (s *Service) buildGeminiRequest(req *llm.Request) (*gemini.Request, error) {
145 gemReq := &gemini.Request{}
146
147 // Add system instruction if provided
148 if len(req.System) > 0 {
149 // Combine all system messages into a single system instruction
150 systemText := ""
151 for i, sys := range req.System {
152 if i > 0 && systemText != "" && sys.Text != "" {
153 systemText += "\n"
154 }
155 systemText += sys.Text
156 }
157
158 if systemText != "" {
159 gemReq.SystemInstruction = &gemini.Content{
160 Parts: []gemini.Part{{Text: systemText}},
161 }
162 }
163 }
164
165 // Convert messages to Gemini content format
166 for _, msg := range req.Messages {
167 // Set the role based on the message role
168 role, ok := fromLLMRole[msg.Role]
169 if !ok {
170 return nil, fmt.Errorf("unsupported message role: %v", msg.Role)
171 }
172
173 content := gemini.Content{
174 Role: role,
175 }
176
177 // Store tool usage information to correlate tool uses with responses
178 toolNameToID := make(map[string]string)
179
180 // First pass: collect tool use IDs for correlation
181 for _, c := range msg.Content {
182 if c.Type == llm.ContentTypeToolUse && c.ID != "" {
183 toolNameToID[c.ToolName] = c.ID
184 }
185 }
186
187 // Map each content item to Gemini's format
188 for _, c := range msg.Content {
189 switch c.Type {
190 case llm.ContentTypeText, llm.ContentTypeThinking, llm.ContentTypeRedactedThinking:
191 // Simple text content
192 content.Parts = append(content.Parts, gemini.Part{
193 Text: c.Text,
194 })
195 case llm.ContentTypeToolUse:
196 // Tool use becomes a function call
197 var args map[string]any
198 if err := json.Unmarshal(c.ToolInput, &args); err != nil {
199 return nil, fmt.Errorf("failed to unmarshal tool input: %w", err)
200 }
201
202 // Make sure we have a valid ID for this tool use
203 if c.ID == "" {
204 c.ID = fmt.Sprintf("gemini_tool_%s_%d", c.ToolName, time.Now().UnixNano())
205 }
206
207 // Save the ID for this tool name for future correlation
208 toolNameToID[c.ToolName] = c.ID
209
210 slog.DebugContext(context.Background(), "gemini_preparing_tool_use",
211 "tool_name", c.ToolName,
212 "tool_id", c.ID,
213 "input", string(c.ToolInput))
214
215 content.Parts = append(content.Parts, gemini.Part{
216 FunctionCall: &gemini.FunctionCall{
217 Name: c.ToolName,
218 Args: args,
219 },
220 })
221 case llm.ContentTypeToolResult:
222 // Tool result becomes a function response
223 // Create a map for the response
224 response := map[string]any{
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700225 "error": c.ToolError,
David Crawshaw5a234062025-05-04 17:52:08 +0000226 }
227
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700228 // Handle tool results: Gemini only supports string results
229 // Combine all text content into a single string
230 var resultText string
231 if len(c.ToolResult) > 0 {
232 // Collect all text from content objects
233 texts := make([]string, 0, len(c.ToolResult))
234 for _, result := range c.ToolResult {
235 if result.Text != "" {
236 texts = append(texts, result.Text)
237 }
238 }
239 resultText = strings.Join(texts, "\n")
240 }
241 response["result"] = resultText
242
David Crawshaw5a234062025-05-04 17:52:08 +0000243 // Determine the function name to use - this is critical
244 funcName := ""
245
246 // First try to find the function name from a stored toolUseID if we have one
247 if c.ToolUseID != "" {
248 // Try to derive the tool name from the previous tools we've seen
249 for name, id := range toolNameToID {
250 if id == c.ToolUseID {
251 funcName = name
252 break
253 }
254 }
255 }
256
257 // Fallback options if we couldn't find the tool name
258 if funcName == "" {
259 // Try the tool name directly
260 if c.ToolName != "" {
261 funcName = c.ToolName
262 } else {
263 // Last resort fallback
264 funcName = "default_tool"
265 }
266 }
267
268 slog.DebugContext(context.Background(), "gemini_preparing_tool_result",
269 "tool_use_id", c.ToolUseID,
270 "mapped_func_name", funcName,
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700271 "result_count", len(c.ToolResult))
David Crawshaw5a234062025-05-04 17:52:08 +0000272
273 content.Parts = append(content.Parts, gemini.Part{
274 FunctionResponse: &gemini.FunctionResponse{
275 Name: funcName,
276 Response: response,
277 },
278 })
279 }
280 }
281
282 gemReq.Contents = append(gemReq.Contents, content)
283 }
284
285 // Handle tools/functions
286 if len(req.Tools) > 0 {
287 // Convert tool schemas
288 decls, err := convertToolSchemas(req.Tools)
289 if err != nil {
290 return nil, fmt.Errorf("failed to convert tool schemas: %w", err)
291 }
292 if len(decls) > 0 {
293 gemReq.Tools = []gemini.Tool{{FunctionDeclarations: decls}}
294 }
295 }
296
297 return gemReq, nil
298}
299
300// convertGeminiResponsesToContent converts a Gemini response to llm.Content
301func convertGeminiResponseToContent(res *gemini.Response) []llm.Content {
302 if res == nil || len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 {
303 return []llm.Content{{
304 Type: llm.ContentTypeText,
305 Text: "",
306 }}
307 }
308
309 var contents []llm.Content
310
311 // Process each part in the first candidate's content
312 for i, part := range res.Candidates[0].Content.Parts {
313 // Log the part type for debugging
314 slog.DebugContext(context.Background(), "processing_gemini_part",
315 "index", i,
316 "has_text", part.Text != "",
317 "has_function_call", part.FunctionCall != nil,
318 "has_function_response", part.FunctionResponse != nil)
319
320 if part.Text != "" {
321 // Simple text response
322 contents = append(contents, llm.Content{
323 Type: llm.ContentTypeText,
324 Text: part.Text,
325 })
326 } else if part.FunctionCall != nil {
327 // Function call (tool use)
328 args, err := json.Marshal(part.FunctionCall.Args)
329 if err != nil {
330 // If we can't marshal, use empty args
331 slog.DebugContext(context.Background(), "gemini_failed_to_markshal_args",
332 "tool_name", part.FunctionCall.Name,
333 "args", string(args),
334 "err", err.Error(),
335 )
336 args = []byte("{}")
337 }
338
339 // Generate a unique ID for this tool use that includes the function name
340 // to make it easier to correlate with responses
341 toolID := fmt.Sprintf("gemini_tool_%s_%d", part.FunctionCall.Name, time.Now().UnixNano())
342
343 contents = append(contents, llm.Content{
344 ID: toolID,
345 Type: llm.ContentTypeToolUse,
346 ToolName: part.FunctionCall.Name,
347 ToolInput: json.RawMessage(args),
348 })
349
350 slog.DebugContext(context.Background(), "gemini_tool_call",
351 "tool_id", toolID,
352 "tool_name", part.FunctionCall.Name,
353 "args", string(args))
354 } else if part.FunctionResponse != nil {
355 // We shouldn't normally get function responses from the model, but just in case
356 respData, _ := json.Marshal(part.FunctionResponse.Response)
357 slog.DebugContext(context.Background(), "unexpected_function_response",
358 "name", part.FunctionResponse.Name,
359 "response", string(respData))
360 }
361 }
362
363 // If no content was added, add an empty text content
364 if len(contents) == 0 {
365 slog.DebugContext(context.Background(), "empty_gemini_response", "adding_empty_text", true)
366 contents = append(contents, llm.Content{
367 Type: llm.ContentTypeText,
368 Text: "",
369 })
370 }
371
372 return contents
373}
374
375// Gemini doesn't provide usage info directly, so we need to estimate it
376// ensureToolIDs makes sure all tool uses have proper IDs
377func ensureToolIDs(contents []llm.Content) {
378 for i, content := range contents {
379 if content.Type == llm.ContentTypeToolUse && content.ID == "" {
380 // Generate a stable ID using the tool name and timestamp
381 contents[i].ID = fmt.Sprintf("gemini_tool_%s_%d", content.ToolName, time.Now().UnixNano())
382 slog.DebugContext(context.Background(), "assigned_missing_tool_id",
383 "tool_name", content.ToolName,
384 "new_id", contents[i].ID)
385 }
386 }
387}
388
389func calculateUsage(req *gemini.Request, res *gemini.Response) llm.Usage {
390 // Very rough estimation of token counts
391 var inputTokens uint64
392 var outputTokens uint64
393
394 // Count system tokens
395 if req.SystemInstruction != nil {
396 for _, part := range req.SystemInstruction.Parts {
397 if part.Text != "" {
398 // Very rough estimation: 1 token per 4 characters
399 inputTokens += uint64(len(part.Text)) / 4
400 }
401 }
402 }
403
404 // Count input tokens
405 for _, content := range req.Contents {
406 for _, part := range content.Parts {
407 if part.Text != "" {
408 inputTokens += uint64(len(part.Text)) / 4
409 } else if part.FunctionCall != nil {
410 // Estimate function call tokens
411 argBytes, _ := json.Marshal(part.FunctionCall.Args)
412 inputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
413 } else if part.FunctionResponse != nil {
414 // Estimate function response tokens
415 resBytes, _ := json.Marshal(part.FunctionResponse.Response)
416 inputTokens += uint64(len(part.FunctionResponse.Name)+len(resBytes)) / 4
417 }
418 }
419 }
420
421 // Count output tokens
422 if res != nil && len(res.Candidates) > 0 {
423 for _, part := range res.Candidates[0].Content.Parts {
424 if part.Text != "" {
425 outputTokens += uint64(len(part.Text)) / 4
426 } else if part.FunctionCall != nil {
427 // Estimate function call tokens
428 argBytes, _ := json.Marshal(part.FunctionCall.Args)
429 outputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
430 }
431 }
432 }
433
David Crawshaw5a234062025-05-04 17:52:08 +0000434 return llm.Usage{
435 InputTokens: inputTokens,
436 OutputTokens: outputTokens,
David Crawshaw5a234062025-05-04 17:52:08 +0000437 }
438}
439
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700440// TokenContextWindow returns the maximum token context window size for this service
441func (s *Service) TokenContextWindow() int {
442 model := s.Model
443 if model == "" {
444 model = DefaultModel
445 }
446
447 // Gemini models generally have large context windows
448 switch model {
449 case "gemini-2.5-pro-preview-03-25":
450 return 1000000 // 1M tokens for Gemini 2.5 Pro
451 case "gemini-2.0-flash-exp":
452 return 1000000 // 1M tokens for Gemini 2.0 Flash
453 case "gemini-1.5-pro", "gemini-1.5-pro-latest":
454 return 2000000 // 2M tokens for Gemini 1.5 Pro
455 case "gemini-1.5-flash", "gemini-1.5-flash-latest":
456 return 1000000 // 1M tokens for Gemini 1.5 Flash
457 default:
458 // Default for unknown models
459 return 1000000
460 }
461}
462
David Crawshaw5a234062025-05-04 17:52:08 +0000463// Do sends a request to Gemini.
464func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
465 // Log the incoming request for debugging
466 slog.DebugContext(ctx, "gemini_request",
467 "message_count", len(ir.Messages),
468 "tool_count", len(ir.Tools),
469 "system_count", len(ir.System))
470
471 // Log tool-related information if any tools are present
472 if len(ir.Tools) > 0 {
473 var toolNames []string
474 for _, tool := range ir.Tools {
475 toolNames = append(toolNames, tool.Name)
476 }
477 slog.DebugContext(ctx, "gemini_tools", "tools", toolNames)
478 }
479
480 // Log details about the messages being sent
481 for i, msg := range ir.Messages {
482 contentTypes := make([]string, len(msg.Content))
483 for j, c := range msg.Content {
484 contentTypes[j] = c.Type.String()
485
486 // Log tool-related content with more details
487 if c.Type == llm.ContentTypeToolUse {
488 slog.DebugContext(ctx, "gemini_tool_use",
489 "message_idx", i,
490 "content_idx", j,
491 "tool_name", c.ToolName,
492 "tool_input", string(c.ToolInput))
493 } else if c.Type == llm.ContentTypeToolResult {
494 slog.DebugContext(ctx, "gemini_tool_result",
495 "message_idx", i,
496 "content_idx", j,
497 "tool_use_id", c.ToolUseID,
498 "tool_error", c.ToolError,
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700499 "result_count", len(c.ToolResult))
David Crawshaw5a234062025-05-04 17:52:08 +0000500 }
501 }
502 slog.DebugContext(ctx, "gemini_message",
503 "idx", i,
504 "role", msg.Role.String(),
505 "content_types", contentTypes)
506 }
507 // Build the Gemini request
508 gemReq, err := s.buildGeminiRequest(ir)
509 if err != nil {
510 return nil, fmt.Errorf("failed to build Gemini request: %w", err)
511 }
512
513 // Log the structured Gemini request for debugging
514 if reqJSON, err := json.MarshalIndent(gemReq, "", " "); err == nil {
515 slog.DebugContext(ctx, "gemini_request_json", "request", string(reqJSON))
516 }
517
518 // Create a Gemini model instance
519 model := gemini.Model{
David Crawshaw3659d872025-05-05 17:52:23 -0700520 Model: "models/" + cmp.Or(s.Model, DefaultModel),
521 Endpoint: s.URL,
522 APIKey: s.APIKey,
523 HTTPC: cmp.Or(s.HTTPC, http.DefaultClient),
David Crawshaw5a234062025-05-04 17:52:08 +0000524 }
525
526 // Send the request to Gemini with retry logic
527 startTime := time.Now()
528 endTime := startTime // Initialize endTime
529 var gemRes *gemini.Response
530
531 // Retry mechanism for handling server errors and rate limiting
532 backoff := []time.Duration{1 * time.Second, 3 * time.Second, 5 * time.Second, 10 * time.Second}
533 for attempts := 0; attempts <= len(backoff); attempts++ {
534 gemApiErr := error(nil)
535 gemRes, gemApiErr = model.GenerateContent(ctx, gemReq)
536 endTime = time.Now()
537
538 if gemApiErr == nil {
539 // Successful response
540 // Log the structured Gemini response
541 if resJSON, err := json.MarshalIndent(gemRes, "", " "); err == nil {
542 slog.DebugContext(ctx, "gemini_response_json", "response", string(resJSON))
543 }
544 break
545 }
546
547 if attempts == len(backoff) {
548 // We've exhausted all retry attempts
549 return nil, fmt.Errorf("gemini: API error after %d attempts: %w", attempts, gemApiErr)
550 }
551
552 // Check if the error is retryable (e.g., server error or rate limiting)
553 if strings.Contains(gemApiErr.Error(), "429") || strings.Contains(gemApiErr.Error(), "5") {
554 // Rate limited or server error - wait and retry
555 random := time.Duration(rand.Int63n(int64(time.Second)))
556 sleep := backoff[attempts] + random
557 slog.WarnContext(ctx, "gemini_request_retry", "error", gemApiErr.Error(), "attempt", attempts+1, "sleep", sleep)
558 time.Sleep(sleep)
559 continue
560 }
561
562 // Non-retryable error
563 return nil, fmt.Errorf("gemini: API error: %w", gemApiErr)
564 }
565
566 content := convertGeminiResponseToContent(gemRes)
567
568 ensureToolIDs(content)
569
570 usage := calculateUsage(gemReq, gemRes)
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -0700571 usage.CostUSD = llm.CostUSDFromResponse(gemRes.Header())
David Crawshaw5a234062025-05-04 17:52:08 +0000572
573 stopReason := llm.StopReasonEndTurn
574 for _, part := range content {
575 if part.Type == llm.ContentTypeToolUse {
576 stopReason = llm.StopReasonToolUse
577 slog.DebugContext(ctx, "gemini_tool_use_detected",
578 "setting_stop_reason", "llm.StopReasonToolUse",
579 "tool_name", part.ToolName)
580 break
581 }
582 }
583
584 return &llm.Response{
585 Role: llm.MessageRoleAssistant,
586 Model: s.Model,
587 Content: content,
588 StopReason: stopReason,
589 Usage: usage,
590 StartTime: &startTime,
591 EndTime: &endTime,
592 }, nil
593}