| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 1 | package conversation |
| 2 | |
| 3 | import ( |
| 4 | "cmp" |
| 5 | "context" |
| 6 | "net/http" |
| 7 | "os" |
| Josh Bleecher Snyder | a3e28fb | 2025-05-30 15:53:23 +0000 | [diff] [blame] | 8 | "slices" |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 9 | "strings" |
| 10 | "testing" |
| 11 | |
| 12 | "sketch.dev/httprr" |
| Josh Bleecher Snyder | a3e28fb | 2025-05-30 15:53:23 +0000 | [diff] [blame] | 13 | "sketch.dev/llm" |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 14 | "sketch.dev/llm/ant" |
| 15 | ) |
| 16 | |
| 17 | func TestBasicConvo(t *testing.T) { |
| 18 | ctx := context.Background() |
| 19 | rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport) |
| 20 | if err != nil { |
| 21 | t.Fatal(err) |
| 22 | } |
| 23 | rr.ScrubReq(func(req *http.Request) error { |
| 24 | req.Header.Del("x-api-key") |
| 25 | return nil |
| 26 | }) |
| 27 | |
| David Crawshaw | 3659d87 | 2025-05-05 17:52:23 -0700 | [diff] [blame] | 28 | apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY")) |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 29 | srv := &ant.Service{ |
| 30 | APIKey: apiKey, |
| 31 | HTTPC: rr.Client(), |
| 32 | } |
| philip.zeyliger | 882e7ea | 2025-06-20 14:31:16 +0000 | [diff] [blame] | 33 | convo := New(ctx, srv, nil) |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 34 | |
| 35 | const name = "Cornelius" |
| 36 | res, err := convo.SendUserTextMessage("Hi, my name is " + name) |
| 37 | if err != nil { |
| 38 | t.Fatal(err) |
| 39 | } |
| 40 | for _, part := range res.Content { |
| 41 | t.Logf("%s", part.Text) |
| 42 | } |
| 43 | res, err = convo.SendUserTextMessage("What is my name?") |
| 44 | if err != nil { |
| 45 | t.Fatal(err) |
| 46 | } |
| 47 | got := "" |
| 48 | for _, part := range res.Content { |
| 49 | got += part.Text |
| 50 | } |
| 51 | if !strings.Contains(got, name) { |
| 52 | t.Errorf("model does not know the given name %s: %q", name, got) |
| 53 | } |
| 54 | } |
| 55 | |
| 56 | // TestCancelToolUse tests the CancelToolUse function of the Convo struct |
| 57 | func TestCancelToolUse(t *testing.T) { |
| 58 | tests := []struct { |
| 59 | name string |
| 60 | setupToolUse bool |
| 61 | toolUseID string |
| 62 | cancelErr error |
| 63 | expectError bool |
| 64 | expectCancel bool |
| 65 | }{ |
| 66 | { |
| 67 | name: "Cancel existing tool use", |
| 68 | setupToolUse: true, |
| 69 | toolUseID: "tool123", |
| 70 | cancelErr: nil, |
| 71 | expectError: false, |
| 72 | expectCancel: true, |
| 73 | }, |
| 74 | { |
| 75 | name: "Cancel existing tool use with error", |
| 76 | setupToolUse: true, |
| 77 | toolUseID: "tool456", |
| 78 | cancelErr: context.Canceled, |
| 79 | expectError: false, |
| 80 | expectCancel: true, |
| 81 | }, |
| 82 | { |
| 83 | name: "Cancel non-existent tool use", |
| 84 | setupToolUse: false, |
| 85 | toolUseID: "tool789", |
| 86 | cancelErr: nil, |
| 87 | expectError: true, |
| 88 | expectCancel: false, |
| 89 | }, |
| 90 | } |
| 91 | |
| 92 | srv := &ant.Service{} |
| 93 | for _, tt := range tests { |
| 94 | t.Run(tt.name, func(t *testing.T) { |
| philip.zeyliger | 882e7ea | 2025-06-20 14:31:16 +0000 | [diff] [blame] | 95 | convo := New(context.Background(), srv, nil) |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 96 | |
| 97 | var cancelCalled bool |
| 98 | var cancelledWithErr error |
| 99 | |
| 100 | if tt.setupToolUse { |
| 101 | // Setup a mock cancel function to track calls |
| 102 | mockCancel := func(err error) { |
| 103 | cancelCalled = true |
| 104 | cancelledWithErr = err |
| 105 | } |
| 106 | |
| Josh Bleecher Snyder | 495c1fa | 2025-05-29 00:37:22 +0000 | [diff] [blame] | 107 | convo.toolUseCancelMu.Lock() |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 108 | convo.toolUseCancel[tt.toolUseID] = mockCancel |
| Josh Bleecher Snyder | 495c1fa | 2025-05-29 00:37:22 +0000 | [diff] [blame] | 109 | convo.toolUseCancelMu.Unlock() |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 110 | } |
| 111 | |
| 112 | err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr) |
| 113 | |
| 114 | // Check if we got the expected error state |
| 115 | if (err != nil) != tt.expectError { |
| 116 | t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError) |
| 117 | } |
| 118 | |
| 119 | // Check if the cancel function was called as expected |
| 120 | if cancelCalled != tt.expectCancel { |
| 121 | t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel) |
| 122 | } |
| 123 | |
| 124 | // If we expected the cancel to be called, verify it was called with the right error |
| 125 | if tt.expectCancel && cancelledWithErr != tt.cancelErr { |
| 126 | t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr) |
| 127 | } |
| 128 | |
| 129 | // Verify the toolUseID was removed from the map if it was initially added |
| 130 | if tt.setupToolUse { |
| Josh Bleecher Snyder | 495c1fa | 2025-05-29 00:37:22 +0000 | [diff] [blame] | 131 | convo.toolUseCancelMu.Lock() |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 132 | _, exists := convo.toolUseCancel[tt.toolUseID] |
| Josh Bleecher Snyder | 495c1fa | 2025-05-29 00:37:22 +0000 | [diff] [blame] | 133 | convo.toolUseCancelMu.Unlock() |
| Josh Bleecher Snyder | 4f84ab7 | 2025-04-22 16:40:54 -0700 | [diff] [blame] | 134 | |
| 135 | if exists { |
| 136 | t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID) |
| 137 | } |
| 138 | } |
| 139 | }) |
| 140 | } |
| 141 | } |
| Josh Bleecher Snyder | a3e28fb | 2025-05-30 15:53:23 +0000 | [diff] [blame] | 142 | |
| 143 | // TestInsertMissingToolResults tests the insertMissingToolResults function |
| 144 | // to ensure it doesn't create duplicate tool results when multiple tool uses are missing results. |
| 145 | func TestInsertMissingToolResults(t *testing.T) { |
| 146 | tests := []struct { |
| 147 | name string |
| 148 | messages []llm.Message |
| 149 | currentMsg llm.Message |
| 150 | expectedCount int |
| 151 | expectedToolIDs []string |
| 152 | }{ |
| 153 | { |
| 154 | name: "Single missing tool result", |
| 155 | messages: []llm.Message{ |
| 156 | { |
| 157 | Role: llm.MessageRoleAssistant, |
| 158 | Content: []llm.Content{ |
| 159 | { |
| 160 | Type: llm.ContentTypeToolUse, |
| 161 | ID: "tool1", |
| 162 | }, |
| 163 | }, |
| 164 | }, |
| 165 | }, |
| 166 | currentMsg: llm.Message{ |
| 167 | Role: llm.MessageRoleUser, |
| 168 | Content: []llm.Content{}, |
| 169 | }, |
| 170 | expectedCount: 1, |
| 171 | expectedToolIDs: []string{"tool1"}, |
| 172 | }, |
| 173 | { |
| 174 | name: "Multiple missing tool results", |
| 175 | messages: []llm.Message{ |
| 176 | { |
| 177 | Role: llm.MessageRoleAssistant, |
| 178 | Content: []llm.Content{ |
| 179 | { |
| 180 | Type: llm.ContentTypeToolUse, |
| 181 | ID: "tool1", |
| 182 | }, |
| 183 | { |
| 184 | Type: llm.ContentTypeToolUse, |
| 185 | ID: "tool2", |
| 186 | }, |
| 187 | { |
| 188 | Type: llm.ContentTypeToolUse, |
| 189 | ID: "tool3", |
| 190 | }, |
| 191 | }, |
| 192 | }, |
| 193 | }, |
| 194 | currentMsg: llm.Message{ |
| 195 | Role: llm.MessageRoleUser, |
| 196 | Content: []llm.Content{}, |
| 197 | }, |
| 198 | expectedCount: 3, |
| 199 | expectedToolIDs: []string{"tool1", "tool2", "tool3"}, |
| 200 | }, |
| 201 | { |
| 202 | name: "No missing tool results when results already present", |
| 203 | messages: []llm.Message{ |
| 204 | { |
| 205 | Role: llm.MessageRoleAssistant, |
| 206 | Content: []llm.Content{ |
| 207 | { |
| 208 | Type: llm.ContentTypeToolUse, |
| 209 | ID: "tool1", |
| 210 | }, |
| 211 | }, |
| 212 | }, |
| 213 | }, |
| 214 | currentMsg: llm.Message{ |
| 215 | Role: llm.MessageRoleUser, |
| 216 | Content: []llm.Content{ |
| 217 | { |
| 218 | Type: llm.ContentTypeToolResult, |
| 219 | ToolUseID: "tool1", |
| 220 | }, |
| 221 | }, |
| 222 | }, |
| 223 | expectedCount: 1, // Only the existing one |
| 224 | expectedToolIDs: []string{"tool1"}, |
| 225 | }, |
| 226 | { |
| 227 | name: "No tool uses in previous message", |
| 228 | messages: []llm.Message{ |
| 229 | { |
| 230 | Role: llm.MessageRoleAssistant, |
| 231 | Content: []llm.Content{ |
| 232 | { |
| 233 | Type: llm.ContentTypeText, |
| 234 | Text: "Just some text", |
| 235 | }, |
| 236 | }, |
| 237 | }, |
| 238 | }, |
| 239 | currentMsg: llm.Message{ |
| 240 | Role: llm.MessageRoleUser, |
| 241 | Content: []llm.Content{}, |
| 242 | }, |
| 243 | expectedCount: 0, |
| 244 | expectedToolIDs: []string{}, |
| 245 | }, |
| 246 | } |
| 247 | |
| 248 | for _, tt := range tests { |
| 249 | t.Run(tt.name, func(t *testing.T) { |
| 250 | srv := &ant.Service{} |
| philip.zeyliger | 882e7ea | 2025-06-20 14:31:16 +0000 | [diff] [blame] | 251 | convo := New(context.Background(), srv, nil) |
| Josh Bleecher Snyder | a3e28fb | 2025-05-30 15:53:23 +0000 | [diff] [blame] | 252 | |
| 253 | // Create request with messages |
| 254 | req := &llm.Request{ |
| 255 | Messages: append(tt.messages, tt.currentMsg), |
| 256 | } |
| 257 | |
| 258 | // Call insertMissingToolResults |
| 259 | msg := tt.currentMsg |
| 260 | convo.insertMissingToolResults(req, &msg) |
| 261 | |
| 262 | // Count tool results in the message |
| 263 | toolResultCount := 0 |
| 264 | toolIDs := []string{} |
| 265 | for _, content := range msg.Content { |
| 266 | if content.Type == llm.ContentTypeToolResult { |
| 267 | toolResultCount++ |
| 268 | toolIDs = append(toolIDs, content.ToolUseID) |
| 269 | } |
| 270 | } |
| 271 | |
| 272 | // Verify count |
| 273 | if toolResultCount != tt.expectedCount { |
| 274 | t.Errorf("Expected %d tool results, got %d", tt.expectedCount, toolResultCount) |
| 275 | } |
| 276 | |
| 277 | // Verify no duplicates by checking unique tool IDs |
| 278 | seenIDs := make(map[string]int) |
| 279 | for _, id := range toolIDs { |
| 280 | seenIDs[id]++ |
| 281 | } |
| 282 | |
| 283 | // Check for duplicates |
| 284 | for id, count := range seenIDs { |
| 285 | if count > 1 { |
| 286 | t.Errorf("Duplicate tool result for ID %s: found %d times", id, count) |
| 287 | } |
| 288 | } |
| 289 | |
| 290 | // Verify all expected tool IDs are present |
| 291 | for _, expectedID := range tt.expectedToolIDs { |
| 292 | if !slices.Contains(toolIDs, expectedID) { |
| 293 | t.Errorf("Expected tool ID %s not found in results", expectedID) |
| 294 | } |
| 295 | } |
| 296 | }) |
| 297 | } |
| 298 | } |