blob: 358770cdee19c49e956f27fcdae17462129ef5a0 [file] [log] [blame]
David Crawshaw5a234062025-05-04 17:52:08 +00001package gem
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "math/rand"
10 "net/http"
11 "strings"
12 "time"
13
14 "sketch.dev/llm"
15 "sketch.dev/llm/gem/gemini"
16)
17
18const (
David Crawshawa9d87aa2025-05-06 10:08:56 -070019 DefaultModel = "gemini-2.5-pro-preview-03-25"
20 GeminiAPIKeyEnv = "GEMINI_API_KEY"
David Crawshaw5a234062025-05-04 17:52:08 +000021)
22
23// Service provides Gemini completions.
24// Fields should not be altered concurrently with calling any method on Service.
25type Service struct {
Josh Bleecher Snyder57afbca2025-07-23 13:29:59 -070026 HTTPC *http.Client // defaults to http.DefaultClient if nil
27 URL string // Gemini API URL, uses the gemini package default if empty
28 APIKey string // must be non-empty
29 Model string // defaults to DefaultModel if empty
30 DumpLLM bool // whether to dump request/response text to files for debugging; defaults to false
David Crawshaw5a234062025-05-04 17:52:08 +000031}
32
33var _ llm.Service = (*Service)(nil)
34
35// These maps convert between Sketch's llm package and Gemini API formats
36var fromLLMRole = map[llm.MessageRole]string{
37 llm.MessageRoleAssistant: "model",
38 llm.MessageRoleUser: "user",
39}
40
41// convertToolSchemas converts Sketch's llm.Tool schemas to Gemini's schema format
42func convertToolSchemas(tools []*llm.Tool) ([]gemini.FunctionDeclaration, error) {
43 if len(tools) == 0 {
44 return nil, nil
45 }
46
47 var decls []gemini.FunctionDeclaration
48 for _, tool := range tools {
49 // Parse the schema from raw JSON
50 var schemaJSON map[string]any
51 if err := json.Unmarshal(tool.InputSchema, &schemaJSON); err != nil {
52 return nil, fmt.Errorf("failed to unmarshal tool %s schema: %w", tool.Name, err)
53 }
54 decls = append(decls, gemini.FunctionDeclaration{
55 Name: tool.Name,
56 Description: tool.Description,
57 Parameters: convertJSONSchemaToGeminiSchema(schemaJSON),
58 })
59 }
60
61 return decls, nil
62}
63
64// convertJSONSchemaToGeminiSchema converts a JSON schema to Gemini's schema format
65func convertJSONSchemaToGeminiSchema(schemaJSON map[string]any) gemini.Schema {
66 schema := gemini.Schema{}
67
68 // Set the type based on the JSON schema type
69 if typeVal, ok := schemaJSON["type"].(string); ok {
70 switch typeVal {
71 case "string":
72 schema.Type = gemini.DataTypeSTRING
73 case "number":
74 schema.Type = gemini.DataTypeNUMBER
75 case "integer":
76 schema.Type = gemini.DataTypeINTEGER
77 case "boolean":
78 schema.Type = gemini.DataTypeBOOLEAN
79 case "array":
80 schema.Type = gemini.DataTypeARRAY
81 case "object":
82 schema.Type = gemini.DataTypeOBJECT
83 default:
84 schema.Type = gemini.DataTypeSTRING // Default to string for unknown types
85 }
86 }
87
88 // Set description if available
89 if desc, ok := schemaJSON["description"].(string); ok {
90 schema.Description = desc
91 }
92
93 // Handle enum values
94 if enumValues, ok := schemaJSON["enum"].([]any); ok {
95 schema.Enum = make([]string, len(enumValues))
96 for i, v := range enumValues {
97 if strVal, ok := v.(string); ok {
98 schema.Enum[i] = strVal
99 } else {
100 // Convert non-string values to string
101 valBytes, _ := json.Marshal(v)
102 schema.Enum[i] = string(valBytes)
103 }
104 }
105 }
106
107 // Handle object properties
108 if properties, ok := schemaJSON["properties"].(map[string]any); ok && schema.Type == gemini.DataTypeOBJECT {
109 schema.Properties = make(map[string]gemini.Schema)
110 for propName, propSchema := range properties {
111 if propSchemaMap, ok := propSchema.(map[string]any); ok {
112 schema.Properties[propName] = convertJSONSchemaToGeminiSchema(propSchemaMap)
113 }
114 }
115 }
116
117 // Handle required properties
118 if required, ok := schemaJSON["required"].([]any); ok {
119 schema.Required = make([]string, len(required))
120 for i, r := range required {
121 if strVal, ok := r.(string); ok {
122 schema.Required[i] = strVal
123 }
124 }
125 }
126
127 // Handle array items
128 if items, ok := schemaJSON["items"].(map[string]any); ok && schema.Type == gemini.DataTypeARRAY {
129 itemSchema := convertJSONSchemaToGeminiSchema(items)
130 schema.Items = &itemSchema
131 }
132
133 // Handle minimum/maximum items for arrays
134 if minItems, ok := schemaJSON["minItems"].(float64); ok {
135 schema.MinItems = fmt.Sprintf("%d", int(minItems))
136 }
137 if maxItems, ok := schemaJSON["maxItems"].(float64); ok {
138 schema.MaxItems = fmt.Sprintf("%d", int(maxItems))
139 }
140
141 return schema
142}
143
144// buildGeminiRequest converts Sketch's llm.Request to Gemini's request format
145func (s *Service) buildGeminiRequest(req *llm.Request) (*gemini.Request, error) {
146 gemReq := &gemini.Request{}
147
148 // Add system instruction if provided
149 if len(req.System) > 0 {
150 // Combine all system messages into a single system instruction
151 systemText := ""
152 for i, sys := range req.System {
153 if i > 0 && systemText != "" && sys.Text != "" {
154 systemText += "\n"
155 }
156 systemText += sys.Text
157 }
158
159 if systemText != "" {
160 gemReq.SystemInstruction = &gemini.Content{
161 Parts: []gemini.Part{{Text: systemText}},
162 }
163 }
164 }
165
166 // Convert messages to Gemini content format
167 for _, msg := range req.Messages {
168 // Set the role based on the message role
169 role, ok := fromLLMRole[msg.Role]
170 if !ok {
171 return nil, fmt.Errorf("unsupported message role: %v", msg.Role)
172 }
173
174 content := gemini.Content{
175 Role: role,
176 }
177
178 // Store tool usage information to correlate tool uses with responses
179 toolNameToID := make(map[string]string)
180
181 // First pass: collect tool use IDs for correlation
182 for _, c := range msg.Content {
183 if c.Type == llm.ContentTypeToolUse && c.ID != "" {
184 toolNameToID[c.ToolName] = c.ID
185 }
186 }
187
188 // Map each content item to Gemini's format
189 for _, c := range msg.Content {
190 switch c.Type {
191 case llm.ContentTypeText, llm.ContentTypeThinking, llm.ContentTypeRedactedThinking:
192 // Simple text content
193 content.Parts = append(content.Parts, gemini.Part{
194 Text: c.Text,
195 })
196 case llm.ContentTypeToolUse:
197 // Tool use becomes a function call
198 var args map[string]any
199 if err := json.Unmarshal(c.ToolInput, &args); err != nil {
200 return nil, fmt.Errorf("failed to unmarshal tool input: %w", err)
201 }
202
203 // Make sure we have a valid ID for this tool use
204 if c.ID == "" {
205 c.ID = fmt.Sprintf("gemini_tool_%s_%d", c.ToolName, time.Now().UnixNano())
206 }
207
208 // Save the ID for this tool name for future correlation
209 toolNameToID[c.ToolName] = c.ID
210
211 slog.DebugContext(context.Background(), "gemini_preparing_tool_use",
212 "tool_name", c.ToolName,
213 "tool_id", c.ID,
214 "input", string(c.ToolInput))
215
216 content.Parts = append(content.Parts, gemini.Part{
217 FunctionCall: &gemini.FunctionCall{
218 Name: c.ToolName,
219 Args: args,
220 },
221 })
222 case llm.ContentTypeToolResult:
223 // Tool result becomes a function response
224 // Create a map for the response
225 response := map[string]any{
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700226 "error": c.ToolError,
David Crawshaw5a234062025-05-04 17:52:08 +0000227 }
228
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700229 // Handle tool results: Gemini only supports string results
230 // Combine all text content into a single string
231 var resultText string
232 if len(c.ToolResult) > 0 {
233 // Collect all text from content objects
234 texts := make([]string, 0, len(c.ToolResult))
235 for _, result := range c.ToolResult {
236 if result.Text != "" {
237 texts = append(texts, result.Text)
238 }
239 }
240 resultText = strings.Join(texts, "\n")
241 }
242 response["result"] = resultText
243
David Crawshaw5a234062025-05-04 17:52:08 +0000244 // Determine the function name to use - this is critical
245 funcName := ""
246
247 // First try to find the function name from a stored toolUseID if we have one
248 if c.ToolUseID != "" {
249 // Try to derive the tool name from the previous tools we've seen
250 for name, id := range toolNameToID {
251 if id == c.ToolUseID {
252 funcName = name
253 break
254 }
255 }
256 }
257
258 // Fallback options if we couldn't find the tool name
259 if funcName == "" {
260 // Try the tool name directly
261 if c.ToolName != "" {
262 funcName = c.ToolName
263 } else {
264 // Last resort fallback
265 funcName = "default_tool"
266 }
267 }
268
269 slog.DebugContext(context.Background(), "gemini_preparing_tool_result",
270 "tool_use_id", c.ToolUseID,
271 "mapped_func_name", funcName,
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700272 "result_count", len(c.ToolResult))
David Crawshaw5a234062025-05-04 17:52:08 +0000273
274 content.Parts = append(content.Parts, gemini.Part{
275 FunctionResponse: &gemini.FunctionResponse{
276 Name: funcName,
277 Response: response,
278 },
279 })
280 }
281 }
282
283 gemReq.Contents = append(gemReq.Contents, content)
284 }
285
286 // Handle tools/functions
287 if len(req.Tools) > 0 {
288 // Convert tool schemas
289 decls, err := convertToolSchemas(req.Tools)
290 if err != nil {
291 return nil, fmt.Errorf("failed to convert tool schemas: %w", err)
292 }
293 if len(decls) > 0 {
294 gemReq.Tools = []gemini.Tool{{FunctionDeclarations: decls}}
295 }
296 }
297
298 return gemReq, nil
299}
300
301// convertGeminiResponsesToContent converts a Gemini response to llm.Content
302func convertGeminiResponseToContent(res *gemini.Response) []llm.Content {
303 if res == nil || len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 {
304 return []llm.Content{{
305 Type: llm.ContentTypeText,
306 Text: "",
307 }}
308 }
309
310 var contents []llm.Content
311
312 // Process each part in the first candidate's content
313 for i, part := range res.Candidates[0].Content.Parts {
314 // Log the part type for debugging
315 slog.DebugContext(context.Background(), "processing_gemini_part",
316 "index", i,
317 "has_text", part.Text != "",
318 "has_function_call", part.FunctionCall != nil,
319 "has_function_response", part.FunctionResponse != nil)
320
321 if part.Text != "" {
322 // Simple text response
323 contents = append(contents, llm.Content{
324 Type: llm.ContentTypeText,
325 Text: part.Text,
326 })
327 } else if part.FunctionCall != nil {
328 // Function call (tool use)
329 args, err := json.Marshal(part.FunctionCall.Args)
330 if err != nil {
331 // If we can't marshal, use empty args
332 slog.DebugContext(context.Background(), "gemini_failed_to_markshal_args",
333 "tool_name", part.FunctionCall.Name,
334 "args", string(args),
335 "err", err.Error(),
336 )
337 args = []byte("{}")
338 }
339
340 // Generate a unique ID for this tool use that includes the function name
341 // to make it easier to correlate with responses
342 toolID := fmt.Sprintf("gemini_tool_%s_%d", part.FunctionCall.Name, time.Now().UnixNano())
343
344 contents = append(contents, llm.Content{
345 ID: toolID,
346 Type: llm.ContentTypeToolUse,
347 ToolName: part.FunctionCall.Name,
348 ToolInput: json.RawMessage(args),
349 })
350
351 slog.DebugContext(context.Background(), "gemini_tool_call",
352 "tool_id", toolID,
353 "tool_name", part.FunctionCall.Name,
354 "args", string(args))
355 } else if part.FunctionResponse != nil {
356 // We shouldn't normally get function responses from the model, but just in case
357 respData, _ := json.Marshal(part.FunctionResponse.Response)
358 slog.DebugContext(context.Background(), "unexpected_function_response",
359 "name", part.FunctionResponse.Name,
360 "response", string(respData))
361 }
362 }
363
364 // If no content was added, add an empty text content
365 if len(contents) == 0 {
366 slog.DebugContext(context.Background(), "empty_gemini_response", "adding_empty_text", true)
367 contents = append(contents, llm.Content{
368 Type: llm.ContentTypeText,
369 Text: "",
370 })
371 }
372
373 return contents
374}
375
376// Gemini doesn't provide usage info directly, so we need to estimate it
377// ensureToolIDs makes sure all tool uses have proper IDs
378func ensureToolIDs(contents []llm.Content) {
379 for i, content := range contents {
380 if content.Type == llm.ContentTypeToolUse && content.ID == "" {
381 // Generate a stable ID using the tool name and timestamp
382 contents[i].ID = fmt.Sprintf("gemini_tool_%s_%d", content.ToolName, time.Now().UnixNano())
383 slog.DebugContext(context.Background(), "assigned_missing_tool_id",
384 "tool_name", content.ToolName,
385 "new_id", contents[i].ID)
386 }
387 }
388}
389
390func calculateUsage(req *gemini.Request, res *gemini.Response) llm.Usage {
391 // Very rough estimation of token counts
392 var inputTokens uint64
393 var outputTokens uint64
394
395 // Count system tokens
396 if req.SystemInstruction != nil {
397 for _, part := range req.SystemInstruction.Parts {
398 if part.Text != "" {
399 // Very rough estimation: 1 token per 4 characters
400 inputTokens += uint64(len(part.Text)) / 4
401 }
402 }
403 }
404
405 // Count input tokens
406 for _, content := range req.Contents {
407 for _, part := range content.Parts {
408 if part.Text != "" {
409 inputTokens += uint64(len(part.Text)) / 4
410 } else if part.FunctionCall != nil {
411 // Estimate function call tokens
412 argBytes, _ := json.Marshal(part.FunctionCall.Args)
413 inputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
414 } else if part.FunctionResponse != nil {
415 // Estimate function response tokens
416 resBytes, _ := json.Marshal(part.FunctionResponse.Response)
417 inputTokens += uint64(len(part.FunctionResponse.Name)+len(resBytes)) / 4
418 }
419 }
420 }
421
422 // Count output tokens
423 if res != nil && len(res.Candidates) > 0 {
424 for _, part := range res.Candidates[0].Content.Parts {
425 if part.Text != "" {
426 outputTokens += uint64(len(part.Text)) / 4
427 } else if part.FunctionCall != nil {
428 // Estimate function call tokens
429 argBytes, _ := json.Marshal(part.FunctionCall.Args)
430 outputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
431 }
432 }
433 }
434
David Crawshaw5a234062025-05-04 17:52:08 +0000435 return llm.Usage{
436 InputTokens: inputTokens,
437 OutputTokens: outputTokens,
David Crawshaw5a234062025-05-04 17:52:08 +0000438 }
439}
440
Philip Zeyligerb8a8f352025-06-02 07:39:37 -0700441// TokenContextWindow returns the maximum token context window size for this service
442func (s *Service) TokenContextWindow() int {
443 model := s.Model
444 if model == "" {
445 model = DefaultModel
446 }
447
448 // Gemini models generally have large context windows
449 switch model {
450 case "gemini-2.5-pro-preview-03-25":
451 return 1000000 // 1M tokens for Gemini 2.5 Pro
452 case "gemini-2.0-flash-exp":
453 return 1000000 // 1M tokens for Gemini 2.0 Flash
454 case "gemini-1.5-pro", "gemini-1.5-pro-latest":
455 return 2000000 // 2M tokens for Gemini 1.5 Pro
456 case "gemini-1.5-flash", "gemini-1.5-flash-latest":
457 return 1000000 // 1M tokens for Gemini 1.5 Flash
458 default:
459 // Default for unknown models
460 return 1000000
461 }
462}
463
David Crawshaw5a234062025-05-04 17:52:08 +0000464// Do sends a request to Gemini.
465func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
466 // Log the incoming request for debugging
467 slog.DebugContext(ctx, "gemini_request",
468 "message_count", len(ir.Messages),
469 "tool_count", len(ir.Tools),
470 "system_count", len(ir.System))
471
472 // Log tool-related information if any tools are present
473 if len(ir.Tools) > 0 {
474 var toolNames []string
475 for _, tool := range ir.Tools {
476 toolNames = append(toolNames, tool.Name)
477 }
478 slog.DebugContext(ctx, "gemini_tools", "tools", toolNames)
479 }
480
481 // Log details about the messages being sent
482 for i, msg := range ir.Messages {
483 contentTypes := make([]string, len(msg.Content))
484 for j, c := range msg.Content {
485 contentTypes[j] = c.Type.String()
486
487 // Log tool-related content with more details
488 if c.Type == llm.ContentTypeToolUse {
489 slog.DebugContext(ctx, "gemini_tool_use",
490 "message_idx", i,
491 "content_idx", j,
492 "tool_name", c.ToolName,
493 "tool_input", string(c.ToolInput))
494 } else if c.Type == llm.ContentTypeToolResult {
495 slog.DebugContext(ctx, "gemini_tool_result",
496 "message_idx", i,
497 "content_idx", j,
498 "tool_use_id", c.ToolUseID,
499 "tool_error", c.ToolError,
Philip Zeyliger72252cb2025-05-10 17:00:08 -0700500 "result_count", len(c.ToolResult))
David Crawshaw5a234062025-05-04 17:52:08 +0000501 }
502 }
503 slog.DebugContext(ctx, "gemini_message",
504 "idx", i,
505 "role", msg.Role.String(),
506 "content_types", contentTypes)
507 }
508 // Build the Gemini request
509 gemReq, err := s.buildGeminiRequest(ir)
510 if err != nil {
511 return nil, fmt.Errorf("failed to build Gemini request: %w", err)
512 }
513
514 // Log the structured Gemini request for debugging
515 if reqJSON, err := json.MarshalIndent(gemReq, "", " "); err == nil {
516 slog.DebugContext(ctx, "gemini_request_json", "request", string(reqJSON))
Josh Bleecher Snyder57afbca2025-07-23 13:29:59 -0700517 if s.DumpLLM {
518 // Construct the same URL that the Gemini client will use
519 endpoint := cmp.Or(s.URL, "https://generativelanguage.googleapis.com/v1beta")
520 url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", endpoint, cmp.Or(s.Model, DefaultModel), s.APIKey)
521 if err := llm.DumpToFile("request", url, reqJSON); err != nil {
522 slog.WarnContext(ctx, "failed to dump gemini request to file", "error", err)
523 }
524 }
David Crawshaw5a234062025-05-04 17:52:08 +0000525 }
526
527 // Create a Gemini model instance
528 model := gemini.Model{
David Crawshaw3659d872025-05-05 17:52:23 -0700529 Model: "models/" + cmp.Or(s.Model, DefaultModel),
530 Endpoint: s.URL,
531 APIKey: s.APIKey,
532 HTTPC: cmp.Or(s.HTTPC, http.DefaultClient),
David Crawshaw5a234062025-05-04 17:52:08 +0000533 }
534
535 // Send the request to Gemini with retry logic
536 startTime := time.Now()
537 endTime := startTime // Initialize endTime
538 var gemRes *gemini.Response
539
540 // Retry mechanism for handling server errors and rate limiting
541 backoff := []time.Duration{1 * time.Second, 3 * time.Second, 5 * time.Second, 10 * time.Second}
542 for attempts := 0; attempts <= len(backoff); attempts++ {
543 gemApiErr := error(nil)
544 gemRes, gemApiErr = model.GenerateContent(ctx, gemReq)
545 endTime = time.Now()
546
547 if gemApiErr == nil {
548 // Successful response
549 // Log the structured Gemini response
550 if resJSON, err := json.MarshalIndent(gemRes, "", " "); err == nil {
551 slog.DebugContext(ctx, "gemini_response_json", "response", string(resJSON))
Josh Bleecher Snyder57afbca2025-07-23 13:29:59 -0700552 if s.DumpLLM {
553 if err := llm.DumpToFile("response", "", resJSON); err != nil {
554 slog.WarnContext(ctx, "failed to dump gemini response to file", "error", err)
555 }
556 }
David Crawshaw5a234062025-05-04 17:52:08 +0000557 }
558 break
559 }
560
561 if attempts == len(backoff) {
562 // We've exhausted all retry attempts
563 return nil, fmt.Errorf("gemini: API error after %d attempts: %w", attempts, gemApiErr)
564 }
565
566 // Check if the error is retryable (e.g., server error or rate limiting)
567 if strings.Contains(gemApiErr.Error(), "429") || strings.Contains(gemApiErr.Error(), "5") {
568 // Rate limited or server error - wait and retry
569 random := time.Duration(rand.Int63n(int64(time.Second)))
570 sleep := backoff[attempts] + random
571 slog.WarnContext(ctx, "gemini_request_retry", "error", gemApiErr.Error(), "attempt", attempts+1, "sleep", sleep)
572 time.Sleep(sleep)
573 continue
574 }
575
576 // Non-retryable error
577 return nil, fmt.Errorf("gemini: API error: %w", gemApiErr)
578 }
579
580 content := convertGeminiResponseToContent(gemRes)
581
582 ensureToolIDs(content)
583
584 usage := calculateUsage(gemReq, gemRes)
Josh Bleecher Snyder59bb27d2025-06-05 07:32:10 -0700585 usage.CostUSD = llm.CostUSDFromResponse(gemRes.Header())
David Crawshaw5a234062025-05-04 17:52:08 +0000586
587 stopReason := llm.StopReasonEndTurn
588 for _, part := range content {
589 if part.Type == llm.ContentTypeToolUse {
590 stopReason = llm.StopReasonToolUse
591 slog.DebugContext(ctx, "gemini_tool_use_detected",
592 "setting_stop_reason", "llm.StopReasonToolUse",
593 "tool_name", part.ToolName)
594 break
595 }
596 }
597
598 return &llm.Response{
599 Role: llm.MessageRoleAssistant,
600 Model: s.Model,
601 Content: content,
602 StopReason: stopReason,
603 Usage: usage,
604 StartTime: &startTime,
605 EndTime: &endTime,
606 }, nil
607}