blob: af8a90459c8d562a89c90b41a0be2fa500105cb6 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package conversation
2
3import (
4 "cmp"
5 "context"
6 "net/http"
7 "os"
8 "strings"
9 "testing"
10
11 "sketch.dev/httprr"
12 "sketch.dev/llm/ant"
13)
14
15func TestBasicConvo(t *testing.T) {
16 ctx := context.Background()
17 rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
18 if err != nil {
19 t.Fatal(err)
20 }
21 rr.ScrubReq(func(req *http.Request) error {
22 req.Header.Del("x-api-key")
23 return nil
24 })
25
David Crawshaw3659d872025-05-05 17:52:23 -070026 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070027 srv := &ant.Service{
28 APIKey: apiKey,
29 HTTPC: rr.Client(),
30 }
31 convo := New(ctx, srv)
32
33 const name = "Cornelius"
34 res, err := convo.SendUserTextMessage("Hi, my name is " + name)
35 if err != nil {
36 t.Fatal(err)
37 }
38 for _, part := range res.Content {
39 t.Logf("%s", part.Text)
40 }
41 res, err = convo.SendUserTextMessage("What is my name?")
42 if err != nil {
43 t.Fatal(err)
44 }
45 got := ""
46 for _, part := range res.Content {
47 got += part.Text
48 }
49 if !strings.Contains(got, name) {
50 t.Errorf("model does not know the given name %s: %q", name, got)
51 }
52}
53
54// TestCancelToolUse tests the CancelToolUse function of the Convo struct
55func TestCancelToolUse(t *testing.T) {
56 tests := []struct {
57 name string
58 setupToolUse bool
59 toolUseID string
60 cancelErr error
61 expectError bool
62 expectCancel bool
63 }{
64 {
65 name: "Cancel existing tool use",
66 setupToolUse: true,
67 toolUseID: "tool123",
68 cancelErr: nil,
69 expectError: false,
70 expectCancel: true,
71 },
72 {
73 name: "Cancel existing tool use with error",
74 setupToolUse: true,
75 toolUseID: "tool456",
76 cancelErr: context.Canceled,
77 expectError: false,
78 expectCancel: true,
79 },
80 {
81 name: "Cancel non-existent tool use",
82 setupToolUse: false,
83 toolUseID: "tool789",
84 cancelErr: nil,
85 expectError: true,
86 expectCancel: false,
87 },
88 }
89
90 srv := &ant.Service{}
91 for _, tt := range tests {
92 t.Run(tt.name, func(t *testing.T) {
93 convo := New(context.Background(), srv)
94
95 var cancelCalled bool
96 var cancelledWithErr error
97
98 if tt.setupToolUse {
99 // Setup a mock cancel function to track calls
100 mockCancel := func(err error) {
101 cancelCalled = true
102 cancelledWithErr = err
103 }
104
105 convo.muToolUseCancel.Lock()
106 convo.toolUseCancel[tt.toolUseID] = mockCancel
107 convo.muToolUseCancel.Unlock()
108 }
109
110 err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
111
112 // Check if we got the expected error state
113 if (err != nil) != tt.expectError {
114 t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
115 }
116
117 // Check if the cancel function was called as expected
118 if cancelCalled != tt.expectCancel {
119 t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
120 }
121
122 // If we expected the cancel to be called, verify it was called with the right error
123 if tt.expectCancel && cancelledWithErr != tt.cancelErr {
124 t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
125 }
126
127 // Verify the toolUseID was removed from the map if it was initially added
128 if tt.setupToolUse {
129 convo.muToolUseCancel.Lock()
130 _, exists := convo.toolUseCancel[tt.toolUseID]
131 convo.muToolUseCancel.Unlock()
132
133 if exists {
134 t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
135 }
136 }
137 })
138 }
139}