blob: 7518d4960264863e9350162519de55e197cc5b44 [file] [log] [blame]
David Crawshaw5a234062025-05-04 17:52:08 +00001package gem
2
3import (
4 "encoding/json"
5 "testing"
6
7 "sketch.dev/llm"
8 "sketch.dev/llm/gem/gemini"
9)
10
11func TestBuildGeminiRequest(t *testing.T) {
12 // Create a service
13 service := &Service{
14 Model: DefaultModel,
15 APIKey: "test-api-key",
16 }
17
18 // Create a simple request
19 req := &llm.Request{
20 Messages: []llm.Message{
21 {
22 Role: llm.MessageRoleUser,
23 Content: []llm.Content{
24 {
25 Type: llm.ContentTypeText,
26 Text: "Hello, world!",
27 },
28 },
29 },
30 },
31 System: []llm.SystemContent{
32 {
33 Text: "You are a helpful assistant.",
34 },
35 },
36 }
37
38 // Build the Gemini request
39 gemReq, err := service.buildGeminiRequest(req)
40 if err != nil {
41 t.Fatalf("Failed to build Gemini request: %v", err)
42 }
43
44 // Verify the system instruction
45 if gemReq.SystemInstruction == nil {
46 t.Fatalf("Expected system instruction, got nil")
47 }
48 if len(gemReq.SystemInstruction.Parts) != 1 {
49 t.Fatalf("Expected 1 system part, got %d", len(gemReq.SystemInstruction.Parts))
50 }
51 if gemReq.SystemInstruction.Parts[0].Text != "You are a helpful assistant." {
52 t.Fatalf("Expected system text 'You are a helpful assistant.', got '%s'", gemReq.SystemInstruction.Parts[0].Text)
53 }
54
55 // Verify the contents
56 if len(gemReq.Contents) != 1 {
57 t.Fatalf("Expected 1 content, got %d", len(gemReq.Contents))
58 }
59 if len(gemReq.Contents[0].Parts) != 1 {
60 t.Fatalf("Expected 1 part, got %d", len(gemReq.Contents[0].Parts))
61 }
62 if gemReq.Contents[0].Parts[0].Text != "Hello, world!" {
63 t.Fatalf("Expected text 'Hello, world!', got '%s'", gemReq.Contents[0].Parts[0].Text)
64 }
65 // Verify the role is set correctly
66 if gemReq.Contents[0].Role != "user" {
67 t.Fatalf("Expected role 'user', got '%s'", gemReq.Contents[0].Role)
68 }
69}
70
71func TestConvertToolSchemas(t *testing.T) {
72 // Create a simple tool with a JSON schema
73 schema := `{
74 "type": "object",
75 "properties": {
76 "name": {
77 "type": "string",
78 "description": "The name of the person"
79 },
80 "age": {
81 "type": "integer",
82 "description": "The age of the person"
83 }
84 },
85 "required": ["name"]
86 }`
87
88 tools := []*llm.Tool{
89 {
90 Name: "get_person",
91 Description: "Get information about a person",
92 InputSchema: json.RawMessage(schema),
93 },
94 }
95
96 // Convert the tools
97 decls, err := convertToolSchemas(tools)
98 if err != nil {
99 t.Fatalf("Failed to convert tool schemas: %v", err)
100 }
101
102 // Verify the result
103 if len(decls) != 1 {
104 t.Fatalf("Expected 1 declaration, got %d", len(decls))
105 }
106 if decls[0].Name != "get_person" {
107 t.Fatalf("Expected name 'get_person', got '%s'", decls[0].Name)
108 }
109 if decls[0].Description != "Get information about a person" {
110 t.Fatalf("Expected description 'Get information about a person', got '%s'", decls[0].Description)
111 }
112
113 // Verify the schema properties
114 if decls[0].Parameters.Type != 6 { // DataTypeOBJECT
115 t.Fatalf("Expected type OBJECT (6), got %d", decls[0].Parameters.Type)
116 }
117 if len(decls[0].Parameters.Properties) != 2 {
118 t.Fatalf("Expected 2 properties, got %d", len(decls[0].Parameters.Properties))
119 }
120 if decls[0].Parameters.Properties["name"].Type != 1 { // DataTypeSTRING
121 t.Fatalf("Expected name type STRING (1), got %d", decls[0].Parameters.Properties["name"].Type)
122 }
123 if decls[0].Parameters.Properties["age"].Type != 3 { // DataTypeINTEGER
124 t.Fatalf("Expected age type INTEGER (3), got %d", decls[0].Parameters.Properties["age"].Type)
125 }
126 if len(decls[0].Parameters.Required) != 1 || decls[0].Parameters.Required[0] != "name" {
127 t.Fatalf("Expected required field 'name', got %v", decls[0].Parameters.Required)
128 }
129}
130
131func TestService_Do_MockResponse(t *testing.T) {
132 // This is a mock test that doesn't make actual API calls
133 // Create a mock HTTP client that returns a predefined response
134
135 // Create a Service with a mock client
136 service := &Service{
137 Model: DefaultModel,
138 APIKey: "test-api-key",
139 // We would use a mock HTTP client here in a real test
140 }
141
142 // Create a sample request
143 ir := &llm.Request{
144 Messages: []llm.Message{
145 {
146 Role: llm.MessageRoleUser,
147 Content: []llm.Content{
148 {
149 Type: llm.ContentTypeText,
150 Text: "Hello",
151 },
152 },
153 },
154 },
155 }
156
157 // In a real test, we would execute service.Do with a mock client
158 // and verify the response structure
159
160 // For now, we'll just test that buildGeminiRequest works correctly
161 _, err := service.buildGeminiRequest(ir)
162 if err != nil {
163 t.Fatalf("Failed to build request: %v", err)
164 }
165}
166
167func TestConvertResponseWithToolCall(t *testing.T) {
168 // Create a mock Gemini response with a function call
169 gemRes := &gemini.Response{
170 Candidates: []gemini.Candidate{
171 {
172 Content: gemini.Content{
173 Parts: []gemini.Part{
174 {
175 FunctionCall: &gemini.FunctionCall{
176 Name: "bash",
177 Args: map[string]any{
178 "command": "cat README.md",
179 },
180 },
181 },
182 },
183 },
184 },
185 },
186 }
187
188 // Convert the response
189 content := convertGeminiResponseToContent(gemRes)
190
191 // Verify that content has a tool use
192 if len(content) != 1 {
193 t.Fatalf("Expected 1 content item, got %d", len(content))
194 }
195
196 if content[0].Type != llm.ContentTypeToolUse {
197 t.Fatalf("Expected content type ToolUse, got %s", content[0].Type)
198 }
199
200 if content[0].ToolName != "bash" {
201 t.Fatalf("Expected tool name 'bash', got '%s'", content[0].ToolName)
202 }
203
204 // Verify the tool input
205 var args map[string]any
206 if err := json.Unmarshal(content[0].ToolInput, &args); err != nil {
207 t.Fatalf("Failed to unmarshal tool input: %v", err)
208 }
209
210 cmd, ok := args["command"]
211 if !ok {
212 t.Fatalf("Expected 'command' argument, not found")
213 }
214
215 if cmd != "cat README.md" {
216 t.Fatalf("Expected command 'cat README.md', got '%s'", cmd)
217 }
218}