blob: 67cc5db5f32f0b9bb3109e5e5c565ea0f7225de2 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package ant
2
3import (
4 "math"
5 "testing"
6)
7
8// TestCalculateCostFromTokens tests the calculateCostFromTokens function
9func TestCalculateCostFromTokens(t *testing.T) {
10 tests := []struct {
11 name string
12 model string
13 inputTokens uint64
14 outputTokens uint64
15 cacheReadInputTokens uint64
16 cacheCreationInputTokens uint64
17 want float64
18 }{
19 {
20 name: "Zero tokens",
21 model: Claude37Sonnet,
22 inputTokens: 0,
23 outputTokens: 0,
24 cacheReadInputTokens: 0,
25 cacheCreationInputTokens: 0,
26 want: 0,
27 },
28 {
29 name: "1000 input tokens, 500 output tokens",
30 model: Claude37Sonnet,
31 inputTokens: 1000,
32 outputTokens: 500,
33 cacheReadInputTokens: 0,
34 cacheCreationInputTokens: 0,
35 want: 0.0105,
36 },
37 {
38 name: "10000 input tokens, 5000 output tokens",
39 model: Claude37Sonnet,
40 inputTokens: 10000,
41 outputTokens: 5000,
42 cacheReadInputTokens: 0,
43 cacheCreationInputTokens: 0,
44 want: 0.105,
45 },
46 {
47 name: "With cache read tokens",
48 model: Claude37Sonnet,
49 inputTokens: 1000,
50 outputTokens: 500,
51 cacheReadInputTokens: 2000,
52 cacheCreationInputTokens: 0,
53 want: 0.0111,
54 },
55 {
56 name: "With cache creation tokens",
57 model: Claude37Sonnet,
58 inputTokens: 1000,
59 outputTokens: 500,
60 cacheReadInputTokens: 0,
61 cacheCreationInputTokens: 1500,
62 want: 0.016125,
63 },
64 {
65 name: "With all token types",
66 model: Claude37Sonnet,
67 inputTokens: 1000,
68 outputTokens: 500,
69 cacheReadInputTokens: 2000,
70 cacheCreationInputTokens: 1500,
71 want: 0.016725,
72 },
73 }
74
75 for _, tt := range tests {
76 t.Run(tt.name, func(t *testing.T) {
77 usage := usage{
78 InputTokens: tt.inputTokens,
79 OutputTokens: tt.outputTokens,
80 CacheReadInputTokens: tt.cacheReadInputTokens,
81 CacheCreationInputTokens: tt.cacheCreationInputTokens,
82 }
83 mr := response{
84 Model: tt.model,
85 Usage: usage,
86 }
87 totalCost := mr.TotalDollars()
88 if math.Abs(totalCost-tt.want) > 0.0001 {
89 t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
90 }
91 })
92 }
93}