blob: 7bea552b10251b2a6450e92982a520b315d92373 [file] [log] [blame]
package oai
import (
"math"
"testing"
"sketch.dev/llm"
)
// TestCalculateCostFromTokens tests the calculateCostFromTokens method
func TestCalculateCostFromTokens(t *testing.T) {
tests := []struct {
name string
model Model
cacheCreationTokens uint64
cacheReadTokens uint64
outputTokens uint64
want float64
}{
{
name: "Zero tokens",
model: GPT41,
cacheCreationTokens: 0,
cacheReadTokens: 0,
outputTokens: 0,
want: 0,
},
{
name: "1000 input tokens, 500 output tokens",
model: GPT41,
cacheCreationTokens: 1000,
cacheReadTokens: 0,
outputTokens: 500,
// GPT41: Input: 200 per million, Output: 800 per million
// (1000 * 200 + 500 * 800) / 1_000_000 / 100 = 0.006
want: 0.006,
},
{
name: "10000 input tokens, 5000 output tokens",
model: GPT41,
cacheCreationTokens: 10000,
cacheReadTokens: 0,
outputTokens: 5000,
// (10000 * 200 + 5000 * 800) / 1_000_000 / 100 = 0.06
want: 0.06,
},
{
name: "1000 input tokens, 500 output tokens Gemini",
model: Gemini25Flash,
cacheCreationTokens: 1000,
cacheReadTokens: 0,
outputTokens: 500,
// Gemini25Flash: Input: 15 per million, Output: 60 per million
// (1000 * 15 + 500 * 60) / 1_000_000 / 100 = 0.00045
want: 0.00045,
},
{
name: "With cache read tokens",
model: GPT41,
cacheCreationTokens: 500,
cacheReadTokens: 500, // 500 tokens from cache
outputTokens: 500,
// (500 * 200 + 500 * 50 + 500 * 800) / 1_000_000 / 100 = 0.00525
want: 0.00525,
},
{
name: "With all token types",
model: GPT41,
cacheCreationTokens: 1000,
cacheReadTokens: 1000,
outputTokens: 1000,
// (1000 * 200 + 1000 * 50 + 1000 * 800) / 1_000_000 / 100 = 0.0105
want: 0.0105,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a service with the test model
svc := &Service{Model: tt.model}
// Create a usage object
usage := llm.Usage{
CacheCreationInputTokens: tt.cacheCreationTokens,
CacheReadInputTokens: tt.cacheReadTokens,
OutputTokens: tt.outputTokens,
}
totalCost := svc.calculateCostFromTokens(usage)
if math.Abs(totalCost-tt.want) > 0.0001 {
t.Errorf("calculateCostFromTokens(%s, cache_creation=%d, cache_read=%d, output=%d) = %v, want %v",
tt.model.ModelName, tt.cacheCreationTokens, tt.cacheReadTokens, tt.outputTokens, totalCost, tt.want)
}
})
}
}