blob: 5f01db7ab567716ec005d8a1319f07290f19056e [file] [log] [blame]
David Crawshaw5a234062025-05-04 17:52:08 +00001package gem
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "math/rand"
10 "net/http"
11 "strings"
12 "time"
13
14 "sketch.dev/llm"
15 "sketch.dev/llm/gem/gemini"
16)
17
18const (
David Crawshawa9d87aa2025-05-06 10:08:56 -070019 DefaultModel = "gemini-2.5-pro-preview-03-25"
20 GeminiAPIKeyEnv = "GEMINI_API_KEY"
David Crawshaw5a234062025-05-04 17:52:08 +000021)
22
23// Service provides Gemini completions.
24// Fields should not be altered concurrently with calling any method on Service.
25type Service struct {
David Crawshawa9d87aa2025-05-06 10:08:56 -070026 HTTPC *http.Client // defaults to http.DefaultClient if nil
27 URL string // Gemini API URL, uses the gemini package default if empty
28 APIKey string // must be non-empty
29 Model string // defaults to DefaultModel if empty
David Crawshaw5a234062025-05-04 17:52:08 +000030}
31
32var _ llm.Service = (*Service)(nil)
33
Philip Zeyliger022b3632025-05-10 06:14:21 -070034func (s *Service) ModelName() string {
35 return cmp.Or(s.Model, DefaultModel)
36}
37
David Crawshaw5a234062025-05-04 17:52:08 +000038// These maps convert between Sketch's llm package and Gemini API formats
39var fromLLMRole = map[llm.MessageRole]string{
40 llm.MessageRoleAssistant: "model",
41 llm.MessageRoleUser: "user",
42}
43
44// convertToolSchemas converts Sketch's llm.Tool schemas to Gemini's schema format
45func convertToolSchemas(tools []*llm.Tool) ([]gemini.FunctionDeclaration, error) {
46 if len(tools) == 0 {
47 return nil, nil
48 }
49
50 var decls []gemini.FunctionDeclaration
51 for _, tool := range tools {
52 // Parse the schema from raw JSON
53 var schemaJSON map[string]any
54 if err := json.Unmarshal(tool.InputSchema, &schemaJSON); err != nil {
55 return nil, fmt.Errorf("failed to unmarshal tool %s schema: %w", tool.Name, err)
56 }
57 decls = append(decls, gemini.FunctionDeclaration{
58 Name: tool.Name,
59 Description: tool.Description,
60 Parameters: convertJSONSchemaToGeminiSchema(schemaJSON),
61 })
62 }
63
64 return decls, nil
65}
66
67// convertJSONSchemaToGeminiSchema converts a JSON schema to Gemini's schema format
68func convertJSONSchemaToGeminiSchema(schemaJSON map[string]any) gemini.Schema {
69 schema := gemini.Schema{}
70
71 // Set the type based on the JSON schema type
72 if typeVal, ok := schemaJSON["type"].(string); ok {
73 switch typeVal {
74 case "string":
75 schema.Type = gemini.DataTypeSTRING
76 case "number":
77 schema.Type = gemini.DataTypeNUMBER
78 case "integer":
79 schema.Type = gemini.DataTypeINTEGER
80 case "boolean":
81 schema.Type = gemini.DataTypeBOOLEAN
82 case "array":
83 schema.Type = gemini.DataTypeARRAY
84 case "object":
85 schema.Type = gemini.DataTypeOBJECT
86 default:
87 schema.Type = gemini.DataTypeSTRING // Default to string for unknown types
88 }
89 }
90
91 // Set description if available
92 if desc, ok := schemaJSON["description"].(string); ok {
93 schema.Description = desc
94 }
95
96 // Handle enum values
97 if enumValues, ok := schemaJSON["enum"].([]any); ok {
98 schema.Enum = make([]string, len(enumValues))
99 for i, v := range enumValues {
100 if strVal, ok := v.(string); ok {
101 schema.Enum[i] = strVal
102 } else {
103 // Convert non-string values to string
104 valBytes, _ := json.Marshal(v)
105 schema.Enum[i] = string(valBytes)
106 }
107 }
108 }
109
110 // Handle object properties
111 if properties, ok := schemaJSON["properties"].(map[string]any); ok && schema.Type == gemini.DataTypeOBJECT {
112 schema.Properties = make(map[string]gemini.Schema)
113 for propName, propSchema := range properties {
114 if propSchemaMap, ok := propSchema.(map[string]any); ok {
115 schema.Properties[propName] = convertJSONSchemaToGeminiSchema(propSchemaMap)
116 }
117 }
118 }
119
120 // Handle required properties
121 if required, ok := schemaJSON["required"].([]any); ok {
122 schema.Required = make([]string, len(required))
123 for i, r := range required {
124 if strVal, ok := r.(string); ok {
125 schema.Required[i] = strVal
126 }
127 }
128 }
129
130 // Handle array items
131 if items, ok := schemaJSON["items"].(map[string]any); ok && schema.Type == gemini.DataTypeARRAY {
132 itemSchema := convertJSONSchemaToGeminiSchema(items)
133 schema.Items = &itemSchema
134 }
135
136 // Handle minimum/maximum items for arrays
137 if minItems, ok := schemaJSON["minItems"].(float64); ok {
138 schema.MinItems = fmt.Sprintf("%d", int(minItems))
139 }
140 if maxItems, ok := schemaJSON["maxItems"].(float64); ok {
141 schema.MaxItems = fmt.Sprintf("%d", int(maxItems))
142 }
143
144 return schema
145}
146
147// buildGeminiRequest converts Sketch's llm.Request to Gemini's request format
148func (s *Service) buildGeminiRequest(req *llm.Request) (*gemini.Request, error) {
149 gemReq := &gemini.Request{}
150
151 // Add system instruction if provided
152 if len(req.System) > 0 {
153 // Combine all system messages into a single system instruction
154 systemText := ""
155 for i, sys := range req.System {
156 if i > 0 && systemText != "" && sys.Text != "" {
157 systemText += "\n"
158 }
159 systemText += sys.Text
160 }
161
162 if systemText != "" {
163 gemReq.SystemInstruction = &gemini.Content{
164 Parts: []gemini.Part{{Text: systemText}},
165 }
166 }
167 }
168
169 // Convert messages to Gemini content format
170 for _, msg := range req.Messages {
171 // Set the role based on the message role
172 role, ok := fromLLMRole[msg.Role]
173 if !ok {
174 return nil, fmt.Errorf("unsupported message role: %v", msg.Role)
175 }
176
177 content := gemini.Content{
178 Role: role,
179 }
180
181 // Store tool usage information to correlate tool uses with responses
182 toolNameToID := make(map[string]string)
183
184 // First pass: collect tool use IDs for correlation
185 for _, c := range msg.Content {
186 if c.Type == llm.ContentTypeToolUse && c.ID != "" {
187 toolNameToID[c.ToolName] = c.ID
188 }
189 }
190
191 // Map each content item to Gemini's format
192 for _, c := range msg.Content {
193 switch c.Type {
194 case llm.ContentTypeText, llm.ContentTypeThinking, llm.ContentTypeRedactedThinking:
195 // Simple text content
196 content.Parts = append(content.Parts, gemini.Part{
197 Text: c.Text,
198 })
199 case llm.ContentTypeToolUse:
200 // Tool use becomes a function call
201 var args map[string]any
202 if err := json.Unmarshal(c.ToolInput, &args); err != nil {
203 return nil, fmt.Errorf("failed to unmarshal tool input: %w", err)
204 }
205
206 // Make sure we have a valid ID for this tool use
207 if c.ID == "" {
208 c.ID = fmt.Sprintf("gemini_tool_%s_%d", c.ToolName, time.Now().UnixNano())
209 }
210
211 // Save the ID for this tool name for future correlation
212 toolNameToID[c.ToolName] = c.ID
213
214 slog.DebugContext(context.Background(), "gemini_preparing_tool_use",
215 "tool_name", c.ToolName,
216 "tool_id", c.ID,
217 "input", string(c.ToolInput))
218
219 content.Parts = append(content.Parts, gemini.Part{
220 FunctionCall: &gemini.FunctionCall{
221 Name: c.ToolName,
222 Args: args,
223 },
224 })
225 case llm.ContentTypeToolResult:
226 // Tool result becomes a function response
227 // Create a map for the response
228 response := map[string]any{
229 "result": c.ToolResult,
230 "error": c.ToolError,
231 }
232
233 // Determine the function name to use - this is critical
234 funcName := ""
235
236 // First try to find the function name from a stored toolUseID if we have one
237 if c.ToolUseID != "" {
238 // Try to derive the tool name from the previous tools we've seen
239 for name, id := range toolNameToID {
240 if id == c.ToolUseID {
241 funcName = name
242 break
243 }
244 }
245 }
246
247 // Fallback options if we couldn't find the tool name
248 if funcName == "" {
249 // Try the tool name directly
250 if c.ToolName != "" {
251 funcName = c.ToolName
252 } else {
253 // Last resort fallback
254 funcName = "default_tool"
255 }
256 }
257
258 slog.DebugContext(context.Background(), "gemini_preparing_tool_result",
259 "tool_use_id", c.ToolUseID,
260 "mapped_func_name", funcName,
261 "result_length", len(c.ToolResult))
262
263 content.Parts = append(content.Parts, gemini.Part{
264 FunctionResponse: &gemini.FunctionResponse{
265 Name: funcName,
266 Response: response,
267 },
268 })
269 }
270 }
271
272 gemReq.Contents = append(gemReq.Contents, content)
273 }
274
275 // Handle tools/functions
276 if len(req.Tools) > 0 {
277 // Convert tool schemas
278 decls, err := convertToolSchemas(req.Tools)
279 if err != nil {
280 return nil, fmt.Errorf("failed to convert tool schemas: %w", err)
281 }
282 if len(decls) > 0 {
283 gemReq.Tools = []gemini.Tool{{FunctionDeclarations: decls}}
284 }
285 }
286
287 return gemReq, nil
288}
289
290// convertGeminiResponsesToContent converts a Gemini response to llm.Content
291func convertGeminiResponseToContent(res *gemini.Response) []llm.Content {
292 if res == nil || len(res.Candidates) == 0 || len(res.Candidates[0].Content.Parts) == 0 {
293 return []llm.Content{{
294 Type: llm.ContentTypeText,
295 Text: "",
296 }}
297 }
298
299 var contents []llm.Content
300
301 // Process each part in the first candidate's content
302 for i, part := range res.Candidates[0].Content.Parts {
303 // Log the part type for debugging
304 slog.DebugContext(context.Background(), "processing_gemini_part",
305 "index", i,
306 "has_text", part.Text != "",
307 "has_function_call", part.FunctionCall != nil,
308 "has_function_response", part.FunctionResponse != nil)
309
310 if part.Text != "" {
311 // Simple text response
312 contents = append(contents, llm.Content{
313 Type: llm.ContentTypeText,
314 Text: part.Text,
315 })
316 } else if part.FunctionCall != nil {
317 // Function call (tool use)
318 args, err := json.Marshal(part.FunctionCall.Args)
319 if err != nil {
320 // If we can't marshal, use empty args
321 slog.DebugContext(context.Background(), "gemini_failed_to_markshal_args",
322 "tool_name", part.FunctionCall.Name,
323 "args", string(args),
324 "err", err.Error(),
325 )
326 args = []byte("{}")
327 }
328
329 // Generate a unique ID for this tool use that includes the function name
330 // to make it easier to correlate with responses
331 toolID := fmt.Sprintf("gemini_tool_%s_%d", part.FunctionCall.Name, time.Now().UnixNano())
332
333 contents = append(contents, llm.Content{
334 ID: toolID,
335 Type: llm.ContentTypeToolUse,
336 ToolName: part.FunctionCall.Name,
337 ToolInput: json.RawMessage(args),
338 })
339
340 slog.DebugContext(context.Background(), "gemini_tool_call",
341 "tool_id", toolID,
342 "tool_name", part.FunctionCall.Name,
343 "args", string(args))
344 } else if part.FunctionResponse != nil {
345 // We shouldn't normally get function responses from the model, but just in case
346 respData, _ := json.Marshal(part.FunctionResponse.Response)
347 slog.DebugContext(context.Background(), "unexpected_function_response",
348 "name", part.FunctionResponse.Name,
349 "response", string(respData))
350 }
351 }
352
353 // If no content was added, add an empty text content
354 if len(contents) == 0 {
355 slog.DebugContext(context.Background(), "empty_gemini_response", "adding_empty_text", true)
356 contents = append(contents, llm.Content{
357 Type: llm.ContentTypeText,
358 Text: "",
359 })
360 }
361
362 return contents
363}
364
365// Gemini doesn't provide usage info directly, so we need to estimate it
366// ensureToolIDs makes sure all tool uses have proper IDs
367func ensureToolIDs(contents []llm.Content) {
368 for i, content := range contents {
369 if content.Type == llm.ContentTypeToolUse && content.ID == "" {
370 // Generate a stable ID using the tool name and timestamp
371 contents[i].ID = fmt.Sprintf("gemini_tool_%s_%d", content.ToolName, time.Now().UnixNano())
372 slog.DebugContext(context.Background(), "assigned_missing_tool_id",
373 "tool_name", content.ToolName,
374 "new_id", contents[i].ID)
375 }
376 }
377}
378
379func calculateUsage(req *gemini.Request, res *gemini.Response) llm.Usage {
380 // Very rough estimation of token counts
381 var inputTokens uint64
382 var outputTokens uint64
383
384 // Count system tokens
385 if req.SystemInstruction != nil {
386 for _, part := range req.SystemInstruction.Parts {
387 if part.Text != "" {
388 // Very rough estimation: 1 token per 4 characters
389 inputTokens += uint64(len(part.Text)) / 4
390 }
391 }
392 }
393
394 // Count input tokens
395 for _, content := range req.Contents {
396 for _, part := range content.Parts {
397 if part.Text != "" {
398 inputTokens += uint64(len(part.Text)) / 4
399 } else if part.FunctionCall != nil {
400 // Estimate function call tokens
401 argBytes, _ := json.Marshal(part.FunctionCall.Args)
402 inputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
403 } else if part.FunctionResponse != nil {
404 // Estimate function response tokens
405 resBytes, _ := json.Marshal(part.FunctionResponse.Response)
406 inputTokens += uint64(len(part.FunctionResponse.Name)+len(resBytes)) / 4
407 }
408 }
409 }
410
411 // Count output tokens
412 if res != nil && len(res.Candidates) > 0 {
413 for _, part := range res.Candidates[0].Content.Parts {
414 if part.Text != "" {
415 outputTokens += uint64(len(part.Text)) / 4
416 } else if part.FunctionCall != nil {
417 // Estimate function call tokens
418 argBytes, _ := json.Marshal(part.FunctionCall.Args)
419 outputTokens += uint64(len(part.FunctionCall.Name)+len(argBytes)) / 4
420 }
421 }
422 }
423
424 // For Gemini 2.5 Pro Preview pricing: $1.25 per 1M input tokens, $10 per 1M output tokens
425 // Convert to dollars
426 costUSD := float64(inputTokens)*1.25/1_000_000.0 + float64(outputTokens)*10/1_000_000.0
427
428 return llm.Usage{
429 InputTokens: inputTokens,
430 OutputTokens: outputTokens,
431 CostUSD: costUSD,
432 }
433}
434
435// Do sends a request to Gemini.
436func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
437 // Log the incoming request for debugging
438 slog.DebugContext(ctx, "gemini_request",
439 "message_count", len(ir.Messages),
440 "tool_count", len(ir.Tools),
441 "system_count", len(ir.System))
442
443 // Log tool-related information if any tools are present
444 if len(ir.Tools) > 0 {
445 var toolNames []string
446 for _, tool := range ir.Tools {
447 toolNames = append(toolNames, tool.Name)
448 }
449 slog.DebugContext(ctx, "gemini_tools", "tools", toolNames)
450 }
451
452 // Log details about the messages being sent
453 for i, msg := range ir.Messages {
454 contentTypes := make([]string, len(msg.Content))
455 for j, c := range msg.Content {
456 contentTypes[j] = c.Type.String()
457
458 // Log tool-related content with more details
459 if c.Type == llm.ContentTypeToolUse {
460 slog.DebugContext(ctx, "gemini_tool_use",
461 "message_idx", i,
462 "content_idx", j,
463 "tool_name", c.ToolName,
464 "tool_input", string(c.ToolInput))
465 } else if c.Type == llm.ContentTypeToolResult {
466 slog.DebugContext(ctx, "gemini_tool_result",
467 "message_idx", i,
468 "content_idx", j,
469 "tool_use_id", c.ToolUseID,
470 "tool_error", c.ToolError,
471 "result_length", len(c.ToolResult))
472 }
473 }
474 slog.DebugContext(ctx, "gemini_message",
475 "idx", i,
476 "role", msg.Role.String(),
477 "content_types", contentTypes)
478 }
479 // Build the Gemini request
480 gemReq, err := s.buildGeminiRequest(ir)
481 if err != nil {
482 return nil, fmt.Errorf("failed to build Gemini request: %w", err)
483 }
484
485 // Log the structured Gemini request for debugging
486 if reqJSON, err := json.MarshalIndent(gemReq, "", " "); err == nil {
487 slog.DebugContext(ctx, "gemini_request_json", "request", string(reqJSON))
488 }
489
490 // Create a Gemini model instance
491 model := gemini.Model{
David Crawshaw3659d872025-05-05 17:52:23 -0700492 Model: "models/" + cmp.Or(s.Model, DefaultModel),
493 Endpoint: s.URL,
494 APIKey: s.APIKey,
495 HTTPC: cmp.Or(s.HTTPC, http.DefaultClient),
David Crawshaw5a234062025-05-04 17:52:08 +0000496 }
497
498 // Send the request to Gemini with retry logic
499 startTime := time.Now()
500 endTime := startTime // Initialize endTime
501 var gemRes *gemini.Response
502
503 // Retry mechanism for handling server errors and rate limiting
504 backoff := []time.Duration{1 * time.Second, 3 * time.Second, 5 * time.Second, 10 * time.Second}
505 for attempts := 0; attempts <= len(backoff); attempts++ {
506 gemApiErr := error(nil)
507 gemRes, gemApiErr = model.GenerateContent(ctx, gemReq)
508 endTime = time.Now()
509
510 if gemApiErr == nil {
511 // Successful response
512 // Log the structured Gemini response
513 if resJSON, err := json.MarshalIndent(gemRes, "", " "); err == nil {
514 slog.DebugContext(ctx, "gemini_response_json", "response", string(resJSON))
515 }
516 break
517 }
518
519 if attempts == len(backoff) {
520 // We've exhausted all retry attempts
521 return nil, fmt.Errorf("gemini: API error after %d attempts: %w", attempts, gemApiErr)
522 }
523
524 // Check if the error is retryable (e.g., server error or rate limiting)
525 if strings.Contains(gemApiErr.Error(), "429") || strings.Contains(gemApiErr.Error(), "5") {
526 // Rate limited or server error - wait and retry
527 random := time.Duration(rand.Int63n(int64(time.Second)))
528 sleep := backoff[attempts] + random
529 slog.WarnContext(ctx, "gemini_request_retry", "error", gemApiErr.Error(), "attempt", attempts+1, "sleep", sleep)
530 time.Sleep(sleep)
531 continue
532 }
533
534 // Non-retryable error
535 return nil, fmt.Errorf("gemini: API error: %w", gemApiErr)
536 }
537
538 content := convertGeminiResponseToContent(gemRes)
539
540 ensureToolIDs(content)
541
542 usage := calculateUsage(gemReq, gemRes)
543
544 stopReason := llm.StopReasonEndTurn
545 for _, part := range content {
546 if part.Type == llm.ContentTypeToolUse {
547 stopReason = llm.StopReasonToolUse
548 slog.DebugContext(ctx, "gemini_tool_use_detected",
549 "setting_stop_reason", "llm.StopReasonToolUse",
550 "tool_name", part.ToolName)
551 break
552 }
553 }
554
555 return &llm.Response{
556 Role: llm.MessageRoleAssistant,
557 Model: s.Model,
558 Content: content,
559 StopReason: stopReason,
560 Usage: usage,
561 StartTime: &startTime,
562 EndTime: &endTime,
563 }, nil
564}