blob: 7bea552b10251b2a6450e92982a520b315d92373 [file] [log] [blame]
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07001package oai
2
3import (
4 "math"
5 "testing"
6
7 "sketch.dev/llm"
8)
9
10// TestCalculateCostFromTokens tests the calculateCostFromTokens method
11func TestCalculateCostFromTokens(t *testing.T) {
12 tests := []struct {
13 name string
14 model Model
15 cacheCreationTokens uint64
16 cacheReadTokens uint64
17 outputTokens uint64
18 want float64
19 }{
20 {
21 name: "Zero tokens",
22 model: GPT41,
23 cacheCreationTokens: 0,
24 cacheReadTokens: 0,
25 outputTokens: 0,
26 want: 0,
27 },
28 {
29 name: "1000 input tokens, 500 output tokens",
30 model: GPT41,
31 cacheCreationTokens: 1000,
32 cacheReadTokens: 0,
33 outputTokens: 500,
34 // GPT41: Input: 200 per million, Output: 800 per million
35 // (1000 * 200 + 500 * 800) / 1_000_000 / 100 = 0.006
36 want: 0.006,
37 },
38 {
39 name: "10000 input tokens, 5000 output tokens",
40 model: GPT41,
41 cacheCreationTokens: 10000,
42 cacheReadTokens: 0,
43 outputTokens: 5000,
44 // (10000 * 200 + 5000 * 800) / 1_000_000 / 100 = 0.06
45 want: 0.06,
46 },
47 {
48 name: "1000 input tokens, 500 output tokens Gemini",
49 model: Gemini25Flash,
50 cacheCreationTokens: 1000,
51 cacheReadTokens: 0,
52 outputTokens: 500,
53 // Gemini25Flash: Input: 15 per million, Output: 60 per million
54 // (1000 * 15 + 500 * 60) / 1_000_000 / 100 = 0.00045
55 want: 0.00045,
56 },
57 {
58 name: "With cache read tokens",
59 model: GPT41,
60 cacheCreationTokens: 500,
61 cacheReadTokens: 500, // 500 tokens from cache
62 outputTokens: 500,
63 // (500 * 200 + 500 * 50 + 500 * 800) / 1_000_000 / 100 = 0.00525
64 want: 0.00525,
65 },
66 {
67 name: "With all token types",
68 model: GPT41,
69 cacheCreationTokens: 1000,
70 cacheReadTokens: 1000,
71 outputTokens: 1000,
72 // (1000 * 200 + 1000 * 50 + 1000 * 800) / 1_000_000 / 100 = 0.0105
73 want: 0.0105,
74 },
75 }
76
77 for _, tt := range tests {
78 t.Run(tt.name, func(t *testing.T) {
79 // Create a service with the test model
80 svc := &Service{Model: tt.model}
81
82 // Create a usage object
83 usage := llm.Usage{
84 CacheCreationInputTokens: tt.cacheCreationTokens,
85 CacheReadInputTokens: tt.cacheReadTokens,
86 OutputTokens: tt.outputTokens,
87 }
88
89 totalCost := svc.calculateCostFromTokens(usage)
90 if math.Abs(totalCost-tt.want) > 0.0001 {
91 t.Errorf("calculateCostFromTokens(%s, cache_creation=%d, cache_read=%d, output=%d) = %v, want %v",
92 tt.model.ModelName, tt.cacheCreationTokens, tt.cacheReadTokens, tt.outputTokens, totalCost, tt.want)
93 }
94 })
95 }
96}