llm: get costs from server
Calculating costs on the client has the advantage
that it works when not using skaband.
It requires that we maintain multiple sources of truth, though.
And it makes it very challenging to add serverside tools,
such as Anthropic's web tool.
This commit switches sketch to rely on the server for all costs.
If not using skaband, no costs will be calculated, which also
means that budget constraints won't work.
It's unfortunate, but at the moment it seems like the best path.
diff --git a/llm/ant/ant.go b/llm/ant/ant.go
index 92b55f9..0072b69 100644
--- a/llm/ant/ant.go
+++ b/llm/ant/ant.go
@@ -166,11 +166,6 @@
u.CostUSD += other.CostUSD
}
-type errorResponse struct {
- Type string `json:"type"`
- Message string `json:"message"`
-}
-
// response represents the response from the message API.
type response struct {
ID string `json:"id"`
@@ -513,7 +508,7 @@
slog.InfoContext(ctx, "anthropic_retrying_with_larger_tokens", "message", "Retrying Anthropic API call with larger max tokens size")
// Retry with more output tokens.
largerMaxTokens = true
- response.Usage.CostUSD = response.TotalDollars()
+ response.Usage.CostUSD = llm.CostUSDFromResponse(resp.Header)
partialUsage = response.Usage
continue
}
@@ -522,7 +517,7 @@
if largerMaxTokens {
response.Usage.Add(partialUsage)
}
- response.Usage.CostUSD = response.TotalDollars()
+ response.Usage.CostUSD = llm.CostUSDFromResponse(resp.Header)
return toLLMResponse(&response), nil
case resp.StatusCode >= 500 && resp.StatusCode < 600:
@@ -547,61 +542,3 @@
}
}
}
-
-// cents per million tokens
-// (not dollars because i'm twitchy about using floats for money)
-type centsPer1MTokens struct {
- Input uint64
- Output uint64
- CacheRead uint64
- CacheCreation uint64
-}
-
-// https://www.anthropic.com/pricing#anthropic-api
-var modelCost = map[string]centsPer1MTokens{
- Claude37Sonnet: {
- Input: 300, // $3
- Output: 1500, // $15
- CacheRead: 30, // $0.30
- CacheCreation: 375, // $3.75
- },
- Claude35Haiku: {
- Input: 80, // $0.80
- Output: 400, // $4.00
- CacheRead: 8, // $0.08
- CacheCreation: 100, // $1.00
- },
- Claude35Sonnet: {
- Input: 300, // $3
- Output: 1500, // $15
- CacheRead: 30, // $0.30
- CacheCreation: 375, // $3.75
- },
- Claude4Sonnet: {
- Input: 300, // $3
- Output: 1500, // $15
- CacheRead: 30, // $0.30
- CacheCreation: 375, // $3.75
- },
- Claude4Opus: {
- Input: 1500, // $15
- Output: 7500, // $75
- CacheRead: 150, // $1.50
- CacheCreation: 1875, // $18.75
- },
-}
-
-// TotalDollars returns the total cost to obtain this response, in dollars.
-func (mr *response) TotalDollars() float64 {
- cpm, ok := modelCost[mr.Model]
- if !ok {
- panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
- }
- use := mr.Usage
- megaCents := use.InputTokens*cpm.Input +
- use.OutputTokens*cpm.Output +
- use.CacheReadInputTokens*cpm.CacheRead +
- use.CacheCreationInputTokens*cpm.CacheCreation
- cents := float64(megaCents) / 1_000_000.0
- return cents / 100.0
-}
diff --git a/llm/ant/ant_test.go b/llm/ant/ant_test.go
deleted file mode 100644
index 67cc5db..0000000
--- a/llm/ant/ant_test.go
+++ /dev/null
@@ -1,93 +0,0 @@
-package ant
-
-import (
- "math"
- "testing"
-)
-
-// TestCalculateCostFromTokens tests the calculateCostFromTokens function
-func TestCalculateCostFromTokens(t *testing.T) {
- tests := []struct {
- name string
- model string
- inputTokens uint64
- outputTokens uint64
- cacheReadInputTokens uint64
- cacheCreationInputTokens uint64
- want float64
- }{
- {
- name: "Zero tokens",
- model: Claude37Sonnet,
- inputTokens: 0,
- outputTokens: 0,
- cacheReadInputTokens: 0,
- cacheCreationInputTokens: 0,
- want: 0,
- },
- {
- name: "1000 input tokens, 500 output tokens",
- model: Claude37Sonnet,
- inputTokens: 1000,
- outputTokens: 500,
- cacheReadInputTokens: 0,
- cacheCreationInputTokens: 0,
- want: 0.0105,
- },
- {
- name: "10000 input tokens, 5000 output tokens",
- model: Claude37Sonnet,
- inputTokens: 10000,
- outputTokens: 5000,
- cacheReadInputTokens: 0,
- cacheCreationInputTokens: 0,
- want: 0.105,
- },
- {
- name: "With cache read tokens",
- model: Claude37Sonnet,
- inputTokens: 1000,
- outputTokens: 500,
- cacheReadInputTokens: 2000,
- cacheCreationInputTokens: 0,
- want: 0.0111,
- },
- {
- name: "With cache creation tokens",
- model: Claude37Sonnet,
- inputTokens: 1000,
- outputTokens: 500,
- cacheReadInputTokens: 0,
- cacheCreationInputTokens: 1500,
- want: 0.016125,
- },
- {
- name: "With all token types",
- model: Claude37Sonnet,
- inputTokens: 1000,
- outputTokens: 500,
- cacheReadInputTokens: 2000,
- cacheCreationInputTokens: 1500,
- want: 0.016725,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- usage := usage{
- InputTokens: tt.inputTokens,
- OutputTokens: tt.outputTokens,
- CacheReadInputTokens: tt.cacheReadInputTokens,
- CacheCreationInputTokens: tt.cacheCreationInputTokens,
- }
- mr := response{
- Model: tt.model,
- Usage: usage,
- }
- totalCost := mr.TotalDollars()
- if math.Abs(totalCost-tt.want) > 0.0001 {
- t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
- }
- })
- }
-}
diff --git a/llm/gem/gem.go b/llm/gem/gem.go
index 178df68..1621f2d 100644
--- a/llm/gem/gem.go
+++ b/llm/gem/gem.go
@@ -431,14 +431,9 @@
}
}
- // For Gemini 2.5 Pro Preview pricing: $1.25 per 1M input tokens, $10 per 1M output tokens
- // Convert to dollars
- costUSD := float64(inputTokens)*1.25/1_000_000.0 + float64(outputTokens)*10/1_000_000.0
-
return llm.Usage{
InputTokens: inputTokens,
OutputTokens: outputTokens,
- CostUSD: costUSD,
}
}
@@ -573,6 +568,7 @@
ensureToolIDs(content)
usage := calculateUsage(gemReq, gemRes)
+ usage.CostUSD = llm.CostUSDFromResponse(gemRes.Header())
stopReason := llm.StopReasonEndTurn
for _, part := range content {
diff --git a/llm/gem/gem_test.go b/llm/gem/gem_test.go
index 7518d49..002b4d1 100644
--- a/llm/gem/gem_test.go
+++ b/llm/gem/gem_test.go
@@ -1,7 +1,11 @@
package gem
import (
+ "bytes"
+ "context"
"encoding/json"
+ "io"
+ "net/http"
"testing"
"sketch.dev/llm"
@@ -216,3 +220,147 @@
t.Fatalf("Expected command 'cat README.md', got '%s'", cmd)
}
}
+
+func TestGeminiHeaderCapture(t *testing.T) {
+ // Create a mock HTTP client that returns a response with headers
+ mockClient := &http.Client{
+ Transport: &mockRoundTripper{
+ response: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "Skaband-Cost-Microcents": []string{"123456"},
+ },
+ Body: io.NopCloser(bytes.NewBufferString(`{
+ "candidates": [{
+ "content": {
+ "parts": [{
+ "text": "Hello!"
+ }]
+ }
+ }]
+ }`)),
+ },
+ },
+ }
+
+ // Create a Gemini model with the mock client
+ model := gemini.Model{
+ Model: "models/gemini-test",
+ APIKey: "test-key",
+ HTTPC: mockClient,
+ Endpoint: "https://test.googleapis.com",
+ }
+
+ // Make a request
+ req := &gemini.Request{
+ Contents: []gemini.Content{
+ {
+ Parts: []gemini.Part{{Text: "Hello"}},
+ Role: "user",
+ },
+ },
+ }
+
+ ctx := context.Background()
+ res, err := model.GenerateContent(ctx, req)
+ if err != nil {
+ t.Fatalf("Failed to generate content: %v", err)
+ }
+
+ // Verify that headers were captured
+ headers := res.Header()
+ if headers == nil {
+ t.Fatalf("Expected headers to be captured, got nil")
+ }
+
+ // Check for the cost header
+ costHeader := headers.Get("Skaband-Cost-Microcents")
+ if costHeader != "123456" {
+ t.Fatalf("Expected cost header '123456', got '%s'", costHeader)
+ }
+
+ // Verify that llm.CostUSDFromResponse works with these headers
+ costUSD := llm.CostUSDFromResponse(headers)
+ expectedCost := 0.00123456 // 123456 microcents / 100,000,000
+ if costUSD != expectedCost {
+ t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, costUSD)
+ }
+}
+
+// mockRoundTripper is a mock HTTP transport for testing
+type mockRoundTripper struct {
+ response *http.Response
+}
+
+func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ return m.response, nil
+}
+
+func TestHeaderCostIntegration(t *testing.T) {
+ // Create a mock HTTP client that returns a response with cost headers
+ mockClient := &http.Client{
+ Transport: &mockRoundTripper{
+ response: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "Skaband-Cost-Microcents": []string{"50000"}, // 0.5 USD
+ },
+ Body: io.NopCloser(bytes.NewBufferString(`{
+ "candidates": [{
+ "content": {
+ "parts": [{
+ "text": "Test response"
+ }]
+ }
+ }]
+ }`)),
+ },
+ },
+ }
+
+ // Create a Gem service with the mock client
+ service := &Service{
+ Model: "gemini-test",
+ APIKey: "test-key",
+ HTTPC: mockClient,
+ URL: "https://test.googleapis.com",
+ }
+
+ // Create a request
+ ir := &llm.Request{
+ Messages: []llm.Message{
+ {
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{
+ {
+ Type: llm.ContentTypeText,
+ Text: "Hello",
+ },
+ },
+ },
+ },
+ }
+
+ // Make the request
+ ctx := context.Background()
+ res, err := service.Do(ctx, ir)
+ if err != nil {
+ t.Fatalf("Failed to make request: %v", err)
+ }
+
+ // Verify that the cost was captured from headers
+ expectedCost := 0.0005 // 50000 microcents / 100,000,000
+ if res.Usage.CostUSD != expectedCost {
+ t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, res.Usage.CostUSD)
+ }
+
+ // Verify token counts are still estimated
+ if res.Usage.InputTokens == 0 {
+ t.Fatalf("Expected input tokens to be estimated, got 0")
+ }
+ if res.Usage.OutputTokens == 0 {
+ t.Fatalf("Expected output tokens to be estimated, got 0")
+ }
+}
diff --git a/llm/gem/gemini/gemini.go b/llm/gem/gemini/gemini.go
index ab4788c..26aca0a 100644
--- a/llm/gem/gemini/gemini.go
+++ b/llm/gem/gemini/gemini.go
@@ -21,7 +21,13 @@
// https://ai.google.dev/api/generate-content#response-body
type Response struct {
- Candidates []Candidate `json:"candidates"`
+ Candidates []Candidate `json:"candidates"`
+ headers http.Header // captured HTTP response headers
+}
+
+// Header returns the HTTP response headers.
+func (r *Response) Header() http.Header {
+ return r.headers
}
type Candidate struct {
@@ -162,6 +168,7 @@
if err := json.Unmarshal(body, &res); err != nil {
return nil, fmt.Errorf("GenerateContent: unmarshaling response: %w, %s", err, string(body))
}
+ res.headers = httpResp.Header
return &res, nil
}
diff --git a/llm/llm.go b/llm/llm.go
index 2aea24e..9718954 100644
--- a/llm/llm.go
+++ b/llm/llm.go
@@ -6,6 +6,8 @@
"encoding/json"
"fmt"
"log/slog"
+ "net/http"
+ "strconv"
"strings"
"time"
)
@@ -194,6 +196,19 @@
}
}
+func CostUSDFromResponse(headers http.Header) float64 {
+ h := headers.Get("Skaband-Cost-Microcents")
+ if h == "" {
+ return 0
+ }
+ uc, err := strconv.ParseUint(h, 10, 64)
+ if err != nil {
+ slog.Warn("failed to parse cost header", "header", h)
+ return 0
+ }
+ return float64(uc) / 100_000_000
+}
+
// Usage represents the billing and rate-limit usage.
// Most LLM structs do not have JSON tags, to avoid accidental direct use in specific providers.
// However, the front-end uses this struct, and it relies on its JSON serialization.
diff --git a/llm/oai/oai.go b/llm/oai/oai.go
index 654cea4..aa6151f 100644
--- a/llm/oai/oai.go
+++ b/llm/oai/oai.go
@@ -38,17 +38,10 @@
UserName string // provided by the user to identify this model (e.g. "gpt4.1")
ModelName string // provided to the service provide to specify which model to use (e.g. "gpt-4.1-2025-04-14")
URL string
- Cost ModelCost
APIKeyEnv string // environment variable name for the API key
IsReasoningModel bool // whether this model is a reasoning model (e.g. O3, O4-mini)
}
-type ModelCost struct {
- Input uint64 // in cents per million tokens
- CachedInput uint64 // in cents per million tokens
- Output uint64 // in cents per million tokens
-}
-
var (
DefaultModel = GPT41
@@ -56,7 +49,6 @@
UserName: "gpt4.1",
ModelName: "gpt-4.1-2025-04-14",
URL: OpenAIURL,
- Cost: ModelCost{Input: 200, CachedInput: 50, Output: 800},
APIKeyEnv: OpenAIAPIKeyEnv,
}
@@ -64,7 +56,6 @@
UserName: "gpt4o",
ModelName: "gpt-4o-2024-08-06",
URL: OpenAIURL,
- Cost: ModelCost{Input: 250, CachedInput: 125, Output: 1000},
APIKeyEnv: OpenAIAPIKeyEnv,
}
@@ -72,7 +63,6 @@
UserName: "gpt4o-mini",
ModelName: "gpt-4o-mini-2024-07-18",
URL: OpenAIURL,
- Cost: ModelCost{Input: 15, CachedInput: 8, Output: 60}, // 8 is actually 7.5 GRRR round up for now oh well
APIKeyEnv: OpenAIAPIKeyEnv,
}
@@ -80,7 +70,6 @@
UserName: "gpt4.1-mini",
ModelName: "gpt-4.1-mini-2025-04-14",
URL: OpenAIURL,
- Cost: ModelCost{Input: 40, CachedInput: 10, Output: 160},
APIKeyEnv: OpenAIAPIKeyEnv,
}
@@ -88,7 +77,6 @@
UserName: "gpt4.1-nano",
ModelName: "gpt-4.1-nano-2025-04-14",
URL: OpenAIURL,
- Cost: ModelCost{Input: 10, CachedInput: 3, Output: 40}, // 3 is actually 2.5 GRRR round up for now oh well
APIKeyEnv: OpenAIAPIKeyEnv,
}
@@ -96,7 +84,6 @@
UserName: "o3",
ModelName: "o3-2025-04-16",
URL: OpenAIURL,
- Cost: ModelCost{Input: 1000, CachedInput: 250, Output: 4000},
APIKeyEnv: OpenAIAPIKeyEnv,
IsReasoningModel: true,
}
@@ -105,7 +92,6 @@
UserName: "o4-mini",
ModelName: "o4-mini-2025-04-16",
URL: OpenAIURL,
- Cost: ModelCost{Input: 110, CachedInput: 28, Output: 440}, // 28 is actually 27.5 GRRR round up for now oh well
APIKeyEnv: OpenAIAPIKeyEnv,
IsReasoningModel: true,
}
@@ -114,7 +100,6 @@
UserName: "gemini-flash-2.5",
ModelName: "gemini-2.5-flash-preview-04-17",
URL: GeminiURL,
- Cost: ModelCost{Input: 15, Output: 60},
APIKeyEnv: GeminiAPIKeyEnv,
}
@@ -129,7 +114,6 @@
// Whatever that means. Are we caching? I have no idea.
// How do you always manage to be the annoying one, Google?
// I'm not complicating things just for you.
- Cost: ModelCost{Input: 125, Output: 1000},
APIKeyEnv: GeminiAPIKeyEnv,
}
@@ -137,7 +121,6 @@
UserName: "together-deepseek-v3",
ModelName: "deepseek-ai/DeepSeek-V3",
URL: TogetherURL,
- Cost: ModelCost{Input: 125, Output: 125},
APIKeyEnv: TogetherAPIKeyEnv,
}
@@ -145,7 +128,6 @@
UserName: "together-deepseek-r1",
ModelName: "deepseek-ai/DeepSeek-R1",
URL: TogetherURL,
- Cost: ModelCost{Input: 300, Output: 700},
APIKeyEnv: TogetherAPIKeyEnv,
}
@@ -153,7 +135,6 @@
UserName: "together-llama4-maverick",
ModelName: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
URL: TogetherURL,
- Cost: ModelCost{Input: 27, Output: 85},
APIKeyEnv: TogetherAPIKeyEnv,
}
@@ -161,7 +142,6 @@
UserName: "fireworks-llama4-maverick",
ModelName: "accounts/fireworks/models/llama4-maverick-instruct-basic",
URL: FireworksURL,
- Cost: ModelCost{Input: 22, Output: 88},
APIKeyEnv: FireworksAPIKeyEnv,
}
@@ -169,7 +149,6 @@
UserName: "together-llama3-70b",
ModelName: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
URL: TogetherURL,
- Cost: ModelCost{Input: 88, Output: 88},
APIKeyEnv: TogetherAPIKeyEnv,
}
@@ -177,7 +156,6 @@
UserName: "together-mistral-small",
ModelName: "mistralai/Mistral-Small-24B-Instruct-2501",
URL: TogetherURL,
- Cost: ModelCost{Input: 80, Output: 80},
APIKeyEnv: TogetherAPIKeyEnv,
}
@@ -185,7 +163,6 @@
UserName: "together-qwen3",
ModelName: "Qwen/Qwen3-235B-A22B-fp8-tput",
URL: TogetherURL,
- Cost: ModelCost{Input: 20, Output: 60},
APIKeyEnv: TogetherAPIKeyEnv,
}
@@ -193,7 +170,6 @@
UserName: "together-gemma2",
ModelName: "google/gemma-2-27b-it",
URL: TogetherURL,
- Cost: ModelCost{Input: 80, Output: 80},
APIKeyEnv: TogetherAPIKeyEnv,
}
@@ -201,15 +177,12 @@
UserName: "llama.cpp",
ModelName: "llama.cpp local model",
URL: LlamaCPPURL,
- // zero cost
- Cost: ModelCost{},
}
FireworksDeepseekV3 = Model{
UserName: "fireworks-deepseek-v3",
ModelName: "accounts/fireworks/models/deepseek-v3-0324",
URL: FireworksURL,
- Cost: ModelCost{Input: 90, Output: 90}, // not entirely sure about this, they don't list pricing anywhere convenient
APIKeyEnv: FireworksAPIKeyEnv,
}
@@ -217,7 +190,6 @@
UserName: "mistral-medium-3",
ModelName: "mistral-medium-latest",
URL: MistralURL,
- Cost: ModelCost{Input: 40, Output: 200},
APIKeyEnv: MistralAPIKeyEnv,
}
@@ -225,7 +197,6 @@
UserName: "devstral-small",
ModelName: "devstral-small-latest",
URL: MistralURL,
- Cost: ModelCost{Input: 100, Output: 300},
APIKeyEnv: MistralAPIKeyEnv,
}
)
@@ -294,13 +265,6 @@
llm.MessageRoleAssistant: "assistant",
llm.MessageRoleUser: "user",
}
- fromLLMContentType = map[llm.ContentType]string{
- llm.ContentTypeText: "text",
- llm.ContentTypeToolUse: "function", // OpenAI uses function instead of tool_call
- llm.ContentTypeToolResult: "tool_result",
- llm.ContentTypeThinking: "text", // Map thinking to text since OpenAI doesn't have thinking
- llm.ContentTypeRedactedThinking: "text", // Map redacted_thinking to text
- }
fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
llm.ToolChoiceTypeAuto: "auto",
llm.ToolChoiceTypeAny: "any",
@@ -552,7 +516,7 @@
}
// toLLMUsage converts usage information from OpenAI to llm.Usage.
-func (s *Service) toLLMUsage(au openai.Usage) llm.Usage {
+func (s *Service) toLLMUsage(au openai.Usage, headers http.Header) llm.Usage {
// fmt.Printf("raw usage: %+v / %v / %v\n", au, au.PromptTokensDetails, au.CompletionTokensDetails)
in := uint64(au.PromptTokens)
var inc uint64
@@ -566,7 +530,7 @@
CacheCreationInputTokens: in,
OutputTokens: out,
}
- u.CostUSD = s.calculateCostFromTokens(u)
+ u.CostUSD = llm.CostUSDFromResponse(headers)
return u
}
@@ -583,7 +547,7 @@
ID: r.ID,
Model: r.Model,
Role: llm.MessageRoleAssistant,
- Usage: s.toLLMUsage(r.Usage),
+ Usage: s.toLLMUsage(r.Usage, r.Header()),
}
}
@@ -596,7 +560,7 @@
Role: toRoleFromString(choice.Message.Role),
Content: toLLMContents(choice.Message),
StopReason: toStopReason(string(choice.FinishReason)),
- Usage: s.toLLMUsage(r.Usage),
+ Usage: s.toLLMUsage(r.Usage, r.Header()),
}
}
@@ -619,23 +583,6 @@
return llm.StopReasonStopSequence // Default
}
-// calculateCostFromTokens calculates the cost in dollars for the given model and token counts.
-func (s *Service) calculateCostFromTokens(u llm.Usage) float64 {
- cost := s.Model.Cost
-
- // TODO: check this for correctness, i am skeptical
- // Calculate cost in cents
- megaCents := u.CacheCreationInputTokens*cost.Input +
- u.CacheReadInputTokens*cost.CachedInput +
- u.OutputTokens*cost.Output
-
- cents := float64(megaCents) / 1_000_000
- // Convert to dollars
- dollars := cents / 100.0
- // fmt.Printf("in_new=%d, in_cached=%d, out=%d, cost=%.2f\n", u.CacheCreationInputTokens, u.CacheReadInputTokens, u.OutputTokens, dollars)
- return dollars
-}
-
// TokenContextWindow returns the maximum token context window size for this service
func (s *Service) TokenContextWindow() int {
model := cmp.Or(s.Model, DefaultModel)
diff --git a/llm/oai/oai_test.go b/llm/oai/oai_test.go
deleted file mode 100644
index 7bea552..0000000
--- a/llm/oai/oai_test.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package oai
-
-import (
- "math"
- "testing"
-
- "sketch.dev/llm"
-)
-
-// TestCalculateCostFromTokens tests the calculateCostFromTokens method
-func TestCalculateCostFromTokens(t *testing.T) {
- tests := []struct {
- name string
- model Model
- cacheCreationTokens uint64
- cacheReadTokens uint64
- outputTokens uint64
- want float64
- }{
- {
- name: "Zero tokens",
- model: GPT41,
- cacheCreationTokens: 0,
- cacheReadTokens: 0,
- outputTokens: 0,
- want: 0,
- },
- {
- name: "1000 input tokens, 500 output tokens",
- model: GPT41,
- cacheCreationTokens: 1000,
- cacheReadTokens: 0,
- outputTokens: 500,
- // GPT41: Input: 200 per million, Output: 800 per million
- // (1000 * 200 + 500 * 800) / 1_000_000 / 100 = 0.006
- want: 0.006,
- },
- {
- name: "10000 input tokens, 5000 output tokens",
- model: GPT41,
- cacheCreationTokens: 10000,
- cacheReadTokens: 0,
- outputTokens: 5000,
- // (10000 * 200 + 5000 * 800) / 1_000_000 / 100 = 0.06
- want: 0.06,
- },
- {
- name: "1000 input tokens, 500 output tokens Gemini",
- model: Gemini25Flash,
- cacheCreationTokens: 1000,
- cacheReadTokens: 0,
- outputTokens: 500,
- // Gemini25Flash: Input: 15 per million, Output: 60 per million
- // (1000 * 15 + 500 * 60) / 1_000_000 / 100 = 0.00045
- want: 0.00045,
- },
- {
- name: "With cache read tokens",
- model: GPT41,
- cacheCreationTokens: 500,
- cacheReadTokens: 500, // 500 tokens from cache
- outputTokens: 500,
- // (500 * 200 + 500 * 50 + 500 * 800) / 1_000_000 / 100 = 0.00525
- want: 0.00525,
- },
- {
- name: "With all token types",
- model: GPT41,
- cacheCreationTokens: 1000,
- cacheReadTokens: 1000,
- outputTokens: 1000,
- // (1000 * 200 + 1000 * 50 + 1000 * 800) / 1_000_000 / 100 = 0.0105
- want: 0.0105,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create a service with the test model
- svc := &Service{Model: tt.model}
-
- // Create a usage object
- usage := llm.Usage{
- CacheCreationInputTokens: tt.cacheCreationTokens,
- CacheReadInputTokens: tt.cacheReadTokens,
- OutputTokens: tt.outputTokens,
- }
-
- totalCost := svc.calculateCostFromTokens(usage)
- if math.Abs(totalCost-tt.want) > 0.0001 {
- t.Errorf("calculateCostFromTokens(%s, cache_creation=%d, cache_read=%d, output=%d) = %v, want %v",
- tt.model.ModelName, tt.cacheCreationTokens, tt.cacheReadTokens, tt.outputTokens, totalCost, tt.want)
- }
- })
- }
-}