blob: 694b977c767a8ab63b245224593f06e0e1c0171c [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package conversation
2
3import (
4 "cmp"
5 "context"
6 "net/http"
7 "os"
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +00008 "slices"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07009 "strings"
10 "testing"
11
12 "sketch.dev/httprr"
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +000013 "sketch.dev/llm"
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070014 "sketch.dev/llm/ant"
15)
16
17func TestBasicConvo(t *testing.T) {
18 ctx := context.Background()
19 rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
20 if err != nil {
21 t.Fatal(err)
22 }
23 rr.ScrubReq(func(req *http.Request) error {
24 req.Header.Del("x-api-key")
25 return nil
26 })
27
David Crawshaw3659d872025-05-05 17:52:23 -070028 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070029 srv := &ant.Service{
30 APIKey: apiKey,
31 HTTPC: rr.Client(),
32 }
philip.zeyliger882e7ea2025-06-20 14:31:16 +000033 convo := New(ctx, srv, nil)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070034
35 const name = "Cornelius"
36 res, err := convo.SendUserTextMessage("Hi, my name is " + name)
37 if err != nil {
38 t.Fatal(err)
39 }
40 for _, part := range res.Content {
41 t.Logf("%s", part.Text)
42 }
43 res, err = convo.SendUserTextMessage("What is my name?")
44 if err != nil {
45 t.Fatal(err)
46 }
47 got := ""
48 for _, part := range res.Content {
49 got += part.Text
50 }
51 if !strings.Contains(got, name) {
52 t.Errorf("model does not know the given name %s: %q", name, got)
53 }
54}
55
56// TestCancelToolUse tests the CancelToolUse function of the Convo struct
57func TestCancelToolUse(t *testing.T) {
58 tests := []struct {
59 name string
60 setupToolUse bool
61 toolUseID string
62 cancelErr error
63 expectError bool
64 expectCancel bool
65 }{
66 {
67 name: "Cancel existing tool use",
68 setupToolUse: true,
69 toolUseID: "tool123",
70 cancelErr: nil,
71 expectError: false,
72 expectCancel: true,
73 },
74 {
75 name: "Cancel existing tool use with error",
76 setupToolUse: true,
77 toolUseID: "tool456",
78 cancelErr: context.Canceled,
79 expectError: false,
80 expectCancel: true,
81 },
82 {
83 name: "Cancel non-existent tool use",
84 setupToolUse: false,
85 toolUseID: "tool789",
86 cancelErr: nil,
87 expectError: true,
88 expectCancel: false,
89 },
90 }
91
92 srv := &ant.Service{}
93 for _, tt := range tests {
94 t.Run(tt.name, func(t *testing.T) {
philip.zeyliger882e7ea2025-06-20 14:31:16 +000095 convo := New(context.Background(), srv, nil)
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070096
97 var cancelCalled bool
98 var cancelledWithErr error
99
100 if tt.setupToolUse {
101 // Setup a mock cancel function to track calls
102 mockCancel := func(err error) {
103 cancelCalled = true
104 cancelledWithErr = err
105 }
106
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000107 convo.toolUseCancelMu.Lock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700108 convo.toolUseCancel[tt.toolUseID] = mockCancel
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000109 convo.toolUseCancelMu.Unlock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700110 }
111
112 err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
113
114 // Check if we got the expected error state
115 if (err != nil) != tt.expectError {
116 t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
117 }
118
119 // Check if the cancel function was called as expected
120 if cancelCalled != tt.expectCancel {
121 t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
122 }
123
124 // If we expected the cancel to be called, verify it was called with the right error
125 if tt.expectCancel && cancelledWithErr != tt.cancelErr {
126 t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
127 }
128
129 // Verify the toolUseID was removed from the map if it was initially added
130 if tt.setupToolUse {
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000131 convo.toolUseCancelMu.Lock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700132 _, exists := convo.toolUseCancel[tt.toolUseID]
Josh Bleecher Snyder495c1fa2025-05-29 00:37:22 +0000133 convo.toolUseCancelMu.Unlock()
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700134
135 if exists {
136 t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
137 }
138 }
139 })
140 }
141}
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +0000142
143// TestInsertMissingToolResults tests the insertMissingToolResults function
144// to ensure it doesn't create duplicate tool results when multiple tool uses are missing results.
145func TestInsertMissingToolResults(t *testing.T) {
146 tests := []struct {
147 name string
148 messages []llm.Message
149 currentMsg llm.Message
150 expectedCount int
151 expectedToolIDs []string
152 }{
153 {
154 name: "Single missing tool result",
155 messages: []llm.Message{
156 {
157 Role: llm.MessageRoleAssistant,
158 Content: []llm.Content{
159 {
160 Type: llm.ContentTypeToolUse,
161 ID: "tool1",
162 },
163 },
164 },
165 },
166 currentMsg: llm.Message{
167 Role: llm.MessageRoleUser,
168 Content: []llm.Content{},
169 },
170 expectedCount: 1,
171 expectedToolIDs: []string{"tool1"},
172 },
173 {
174 name: "Multiple missing tool results",
175 messages: []llm.Message{
176 {
177 Role: llm.MessageRoleAssistant,
178 Content: []llm.Content{
179 {
180 Type: llm.ContentTypeToolUse,
181 ID: "tool1",
182 },
183 {
184 Type: llm.ContentTypeToolUse,
185 ID: "tool2",
186 },
187 {
188 Type: llm.ContentTypeToolUse,
189 ID: "tool3",
190 },
191 },
192 },
193 },
194 currentMsg: llm.Message{
195 Role: llm.MessageRoleUser,
196 Content: []llm.Content{},
197 },
198 expectedCount: 3,
199 expectedToolIDs: []string{"tool1", "tool2", "tool3"},
200 },
201 {
202 name: "No missing tool results when results already present",
203 messages: []llm.Message{
204 {
205 Role: llm.MessageRoleAssistant,
206 Content: []llm.Content{
207 {
208 Type: llm.ContentTypeToolUse,
209 ID: "tool1",
210 },
211 },
212 },
213 },
214 currentMsg: llm.Message{
215 Role: llm.MessageRoleUser,
216 Content: []llm.Content{
217 {
218 Type: llm.ContentTypeToolResult,
219 ToolUseID: "tool1",
220 },
221 },
222 },
223 expectedCount: 1, // Only the existing one
224 expectedToolIDs: []string{"tool1"},
225 },
226 {
227 name: "No tool uses in previous message",
228 messages: []llm.Message{
229 {
230 Role: llm.MessageRoleAssistant,
231 Content: []llm.Content{
232 {
233 Type: llm.ContentTypeText,
234 Text: "Just some text",
235 },
236 },
237 },
238 },
239 currentMsg: llm.Message{
240 Role: llm.MessageRoleUser,
241 Content: []llm.Content{},
242 },
243 expectedCount: 0,
244 expectedToolIDs: []string{},
245 },
246 }
247
248 for _, tt := range tests {
249 t.Run(tt.name, func(t *testing.T) {
250 srv := &ant.Service{}
philip.zeyliger882e7ea2025-06-20 14:31:16 +0000251 convo := New(context.Background(), srv, nil)
Josh Bleecher Snydera3e28fb2025-05-30 15:53:23 +0000252
253 // Create request with messages
254 req := &llm.Request{
255 Messages: append(tt.messages, tt.currentMsg),
256 }
257
258 // Call insertMissingToolResults
259 msg := tt.currentMsg
260 convo.insertMissingToolResults(req, &msg)
261
262 // Count tool results in the message
263 toolResultCount := 0
264 toolIDs := []string{}
265 for _, content := range msg.Content {
266 if content.Type == llm.ContentTypeToolResult {
267 toolResultCount++
268 toolIDs = append(toolIDs, content.ToolUseID)
269 }
270 }
271
272 // Verify count
273 if toolResultCount != tt.expectedCount {
274 t.Errorf("Expected %d tool results, got %d", tt.expectedCount, toolResultCount)
275 }
276
277 // Verify no duplicates by checking unique tool IDs
278 seenIDs := make(map[string]int)
279 for _, id := range toolIDs {
280 seenIDs[id]++
281 }
282
283 // Check for duplicates
284 for id, count := range seenIDs {
285 if count > 1 {
286 t.Errorf("Duplicate tool result for ID %s: found %d times", id, count)
287 }
288 }
289
290 // Verify all expected tool IDs are present
291 for _, expectedID := range tt.expectedToolIDs {
292 if !slices.Contains(toolIDs, expectedID) {
293 t.Errorf("Expected tool ID %s not found in results", expectedID)
294 }
295 }
296 })
297 }
298}