blob: 6556e9e9ff87d64940c605ccb2f1eeb42ef09eef [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package ant
2
3import (
4 "context"
5 "math"
6 "net/http"
7 "os"
8 "strings"
9 "testing"
10
11 "sketch.dev/httprr"
12)
13
14func TestBasicConvo(t *testing.T) {
15 ctx := context.Background()
16 rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
17 if err != nil {
18 t.Fatal(err)
19 }
20 rr.ScrubReq(func(req *http.Request) error {
21 req.Header.Del("x-api-key")
22 return nil
23 })
24
25 convo := NewConvo(ctx, os.Getenv("ANTHROPIC_API_KEY"))
26 convo.HTTPC = rr.Client()
27
28 const name = "Cornelius"
29 res, err := convo.SendUserTextMessage("Hi, my name is " + name)
30 if err != nil {
31 t.Fatal(err)
32 }
33 for _, part := range res.Content {
34 t.Logf("%s", part.Text)
35 }
36 res, err = convo.SendUserTextMessage("What is my name?")
37 if err != nil {
38 t.Fatal(err)
39 }
40 got := ""
41 for _, part := range res.Content {
42 got += part.Text
43 }
44 if !strings.Contains(got, name) {
45 t.Errorf("model does not know the given name %s: %q", name, got)
46 }
47}
48
49// TestCalculateCostFromTokens tests the calculateCostFromTokens function
50func TestCalculateCostFromTokens(t *testing.T) {
51 tests := []struct {
52 name string
53 model string
54 inputTokens uint64
55 outputTokens uint64
56 cacheReadInputTokens uint64
57 cacheCreationInputTokens uint64
58 want float64
59 }{
60 {
61 name: "Zero tokens",
62 model: Claude37Sonnet,
63 inputTokens: 0,
64 outputTokens: 0,
65 cacheReadInputTokens: 0,
66 cacheCreationInputTokens: 0,
67 want: 0,
68 },
69 {
70 name: "1000 input tokens, 500 output tokens",
71 model: Claude37Sonnet,
72 inputTokens: 1000,
73 outputTokens: 500,
74 cacheReadInputTokens: 0,
75 cacheCreationInputTokens: 0,
76 want: 0.0105,
77 },
78 {
79 name: "10000 input tokens, 5000 output tokens",
80 model: Claude37Sonnet,
81 inputTokens: 10000,
82 outputTokens: 5000,
83 cacheReadInputTokens: 0,
84 cacheCreationInputTokens: 0,
85 want: 0.105,
86 },
87 {
88 name: "With cache read tokens",
89 model: Claude37Sonnet,
90 inputTokens: 1000,
91 outputTokens: 500,
92 cacheReadInputTokens: 2000,
93 cacheCreationInputTokens: 0,
94 want: 0.0111,
95 },
96 {
97 name: "With cache creation tokens",
98 model: Claude37Sonnet,
99 inputTokens: 1000,
100 outputTokens: 500,
101 cacheReadInputTokens: 0,
102 cacheCreationInputTokens: 1500,
103 want: 0.016125,
104 },
105 {
106 name: "With all token types",
107 model: Claude37Sonnet,
108 inputTokens: 1000,
109 outputTokens: 500,
110 cacheReadInputTokens: 2000,
111 cacheCreationInputTokens: 1500,
112 want: 0.016725,
113 },
114 }
115
116 for _, tt := range tests {
117 t.Run(tt.name, func(t *testing.T) {
118 usage := Usage{
119 InputTokens: tt.inputTokens,
120 OutputTokens: tt.outputTokens,
121 CacheReadInputTokens: tt.cacheReadInputTokens,
122 CacheCreationInputTokens: tt.cacheCreationInputTokens,
123 }
124 mr := MessageResponse{
125 Model: tt.model,
126 Usage: usage,
127 }
128 totalCost := mr.TotalDollars()
129 if math.Abs(totalCost-tt.want) > 0.0001 {
130 t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
131 }
132 })
133 }
134}
135
136// TestCancelToolUse tests the CancelToolUse function of the Convo struct
137func TestCancelToolUse(t *testing.T) {
138 tests := []struct {
139 name string
140 setupToolUse bool
141 toolUseID string
142 cancelErr error
143 expectError bool
144 expectCancel bool
145 }{
146 {
147 name: "Cancel existing tool use",
148 setupToolUse: true,
149 toolUseID: "tool123",
150 cancelErr: nil,
151 expectError: false,
152 expectCancel: true,
153 },
154 {
155 name: "Cancel existing tool use with error",
156 setupToolUse: true,
157 toolUseID: "tool456",
158 cancelErr: context.Canceled,
159 expectError: false,
160 expectCancel: true,
161 },
162 {
163 name: "Cancel non-existent tool use",
164 setupToolUse: false,
165 toolUseID: "tool789",
166 cancelErr: nil,
167 expectError: true,
168 expectCancel: false,
169 },
170 }
171
172 for _, tt := range tests {
173 t.Run(tt.name, func(t *testing.T) {
174 convo := NewConvo(context.Background(), "")
175
176 var cancelCalled bool
177 var cancelledWithErr error
178
179 if tt.setupToolUse {
180 // Setup a mock cancel function to track calls
181 mockCancel := func(err error) {
182 cancelCalled = true
183 cancelledWithErr = err
184 }
185
186 convo.muToolUseCancel.Lock()
187 convo.toolUseCancel[tt.toolUseID] = mockCancel
188 convo.muToolUseCancel.Unlock()
189 }
190
191 err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
192
193 // Check if we got the expected error state
194 if (err != nil) != tt.expectError {
195 t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
196 }
197
198 // Check if the cancel function was called as expected
199 if cancelCalled != tt.expectCancel {
200 t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
201 }
202
203 // If we expected the cancel to be called, verify it was called with the right error
204 if tt.expectCancel && cancelledWithErr != tt.cancelErr {
205 t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
206 }
207
208 // Verify the toolUseID was removed from the map if it was initially added
209 if tt.setupToolUse {
210 convo.muToolUseCancel.Lock()
211 _, exists := convo.toolUseCancel[tt.toolUseID]
212 convo.muToolUseCancel.Unlock()
213
214 if exists {
215 t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
216 }
217 }
218 })
219 }
220}