blob: fcce0cddd22f94bf652a45e0d81b9a27dac0accf [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package ant
2
3import (
Josh Bleecher Snyder4d5e9972025-05-01 15:56:37 -07004 "cmp"
Earl Lee2e463fb2025-04-17 11:22:22 -07005 "context"
6 "math"
7 "net/http"
8 "os"
9 "strings"
10 "testing"
11
12 "sketch.dev/httprr"
13)
14
15func TestBasicConvo(t *testing.T) {
16 ctx := context.Background()
17 rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
18 if err != nil {
19 t.Fatal(err)
20 }
21 rr.ScrubReq(func(req *http.Request) error {
22 req.Header.Del("x-api-key")
23 return nil
24 })
25
Josh Bleecher Snyder4d5e9972025-05-01 15:56:37 -070026 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_ANTHROPIC_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
27 convo := NewConvo(ctx, apiKey)
Earl Lee2e463fb2025-04-17 11:22:22 -070028 convo.HTTPC = rr.Client()
29
30 const name = "Cornelius"
31 res, err := convo.SendUserTextMessage("Hi, my name is " + name)
32 if err != nil {
33 t.Fatal(err)
34 }
35 for _, part := range res.Content {
36 t.Logf("%s", part.Text)
37 }
38 res, err = convo.SendUserTextMessage("What is my name?")
39 if err != nil {
40 t.Fatal(err)
41 }
42 got := ""
43 for _, part := range res.Content {
44 got += part.Text
45 }
46 if !strings.Contains(got, name) {
47 t.Errorf("model does not know the given name %s: %q", name, got)
48 }
49}
50
51// TestCalculateCostFromTokens tests the calculateCostFromTokens function
52func TestCalculateCostFromTokens(t *testing.T) {
53 tests := []struct {
54 name string
55 model string
56 inputTokens uint64
57 outputTokens uint64
58 cacheReadInputTokens uint64
59 cacheCreationInputTokens uint64
60 want float64
61 }{
62 {
63 name: "Zero tokens",
64 model: Claude37Sonnet,
65 inputTokens: 0,
66 outputTokens: 0,
67 cacheReadInputTokens: 0,
68 cacheCreationInputTokens: 0,
69 want: 0,
70 },
71 {
72 name: "1000 input tokens, 500 output tokens",
73 model: Claude37Sonnet,
74 inputTokens: 1000,
75 outputTokens: 500,
76 cacheReadInputTokens: 0,
77 cacheCreationInputTokens: 0,
78 want: 0.0105,
79 },
80 {
81 name: "10000 input tokens, 5000 output tokens",
82 model: Claude37Sonnet,
83 inputTokens: 10000,
84 outputTokens: 5000,
85 cacheReadInputTokens: 0,
86 cacheCreationInputTokens: 0,
87 want: 0.105,
88 },
89 {
90 name: "With cache read tokens",
91 model: Claude37Sonnet,
92 inputTokens: 1000,
93 outputTokens: 500,
94 cacheReadInputTokens: 2000,
95 cacheCreationInputTokens: 0,
96 want: 0.0111,
97 },
98 {
99 name: "With cache creation tokens",
100 model: Claude37Sonnet,
101 inputTokens: 1000,
102 outputTokens: 500,
103 cacheReadInputTokens: 0,
104 cacheCreationInputTokens: 1500,
105 want: 0.016125,
106 },
107 {
108 name: "With all token types",
109 model: Claude37Sonnet,
110 inputTokens: 1000,
111 outputTokens: 500,
112 cacheReadInputTokens: 2000,
113 cacheCreationInputTokens: 1500,
114 want: 0.016725,
115 },
116 }
117
118 for _, tt := range tests {
119 t.Run(tt.name, func(t *testing.T) {
120 usage := Usage{
121 InputTokens: tt.inputTokens,
122 OutputTokens: tt.outputTokens,
123 CacheReadInputTokens: tt.cacheReadInputTokens,
124 CacheCreationInputTokens: tt.cacheCreationInputTokens,
125 }
126 mr := MessageResponse{
127 Model: tt.model,
128 Usage: usage,
129 }
130 totalCost := mr.TotalDollars()
131 if math.Abs(totalCost-tt.want) > 0.0001 {
132 t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
133 }
134 })
135 }
136}
137
138// TestCancelToolUse tests the CancelToolUse function of the Convo struct
139func TestCancelToolUse(t *testing.T) {
140 tests := []struct {
141 name string
142 setupToolUse bool
143 toolUseID string
144 cancelErr error
145 expectError bool
146 expectCancel bool
147 }{
148 {
149 name: "Cancel existing tool use",
150 setupToolUse: true,
151 toolUseID: "tool123",
152 cancelErr: nil,
153 expectError: false,
154 expectCancel: true,
155 },
156 {
157 name: "Cancel existing tool use with error",
158 setupToolUse: true,
159 toolUseID: "tool456",
160 cancelErr: context.Canceled,
161 expectError: false,
162 expectCancel: true,
163 },
164 {
165 name: "Cancel non-existent tool use",
166 setupToolUse: false,
167 toolUseID: "tool789",
168 cancelErr: nil,
169 expectError: true,
170 expectCancel: false,
171 },
172 }
173
174 for _, tt := range tests {
175 t.Run(tt.name, func(t *testing.T) {
176 convo := NewConvo(context.Background(), "")
177
178 var cancelCalled bool
179 var cancelledWithErr error
180
181 if tt.setupToolUse {
182 // Setup a mock cancel function to track calls
183 mockCancel := func(err error) {
184 cancelCalled = true
185 cancelledWithErr = err
186 }
187
188 convo.muToolUseCancel.Lock()
189 convo.toolUseCancel[tt.toolUseID] = mockCancel
190 convo.muToolUseCancel.Unlock()
191 }
192
193 err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
194
195 // Check if we got the expected error state
196 if (err != nil) != tt.expectError {
197 t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
198 }
199
200 // Check if the cancel function was called as expected
201 if cancelCalled != tt.expectCancel {
202 t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
203 }
204
205 // If we expected the cancel to be called, verify it was called with the right error
206 if tt.expectCancel && cancelledWithErr != tt.cancelErr {
207 t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
208 }
209
210 // Verify the toolUseID was removed from the map if it was initially added
211 if tt.setupToolUse {
212 convo.muToolUseCancel.Lock()
213 _, exists := convo.toolUseCancel[tt.toolUseID]
214 convo.muToolUseCancel.Unlock()
215
216 if exists {
217 t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
218 }
219 }
220 })
221 }
222}