llm: get costs from server
Calculating costs on the client has the advantage
that it works when not using skaband.
It requires that we maintain multiple sources of truth, though.
And it makes it very challenging to add serverside tools,
such as Anthropic's web tool.
This commit switches sketch to rely on the server for all costs.
If not using skaband, no costs will be calculated, which also
means that budget constraints won't work.
It's unfortunate, but at the moment it seems like the best path.
diff --git a/llm/gem/gem_test.go b/llm/gem/gem_test.go
index 7518d49..002b4d1 100644
--- a/llm/gem/gem_test.go
+++ b/llm/gem/gem_test.go
@@ -1,7 +1,11 @@
package gem
import (
+ "bytes"
+ "context"
"encoding/json"
+ "io"
+ "net/http"
"testing"
"sketch.dev/llm"
@@ -216,3 +220,147 @@
t.Fatalf("Expected command 'cat README.md', got '%s'", cmd)
}
}
+
+func TestGeminiHeaderCapture(t *testing.T) {
+ // Create a mock HTTP client that returns a response with headers
+ mockClient := &http.Client{
+ Transport: &mockRoundTripper{
+ response: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "Skaband-Cost-Microcents": []string{"123456"},
+ },
+ Body: io.NopCloser(bytes.NewBufferString(`{
+ "candidates": [{
+ "content": {
+ "parts": [{
+ "text": "Hello!"
+ }]
+ }
+ }]
+ }`)),
+ },
+ },
+ }
+
+ // Create a Gemini model with the mock client
+ model := gemini.Model{
+ Model: "models/gemini-test",
+ APIKey: "test-key",
+ HTTPC: mockClient,
+ Endpoint: "https://test.googleapis.com",
+ }
+
+ // Make a request
+ req := &gemini.Request{
+ Contents: []gemini.Content{
+ {
+ Parts: []gemini.Part{{Text: "Hello"}},
+ Role: "user",
+ },
+ },
+ }
+
+ ctx := context.Background()
+ res, err := model.GenerateContent(ctx, req)
+ if err != nil {
+ t.Fatalf("Failed to generate content: %v", err)
+ }
+
+ // Verify that headers were captured
+ headers := res.Header()
+ if headers == nil {
+ t.Fatalf("Expected headers to be captured, got nil")
+ }
+
+ // Check for the cost header
+ costHeader := headers.Get("Skaband-Cost-Microcents")
+ if costHeader != "123456" {
+ t.Fatalf("Expected cost header '123456', got '%s'", costHeader)
+ }
+
+ // Verify that llm.CostUSDFromResponse works with these headers
+ costUSD := llm.CostUSDFromResponse(headers)
+ expectedCost := 0.00123456 // 123456 microcents / 100,000,000
+ if costUSD != expectedCost {
+ t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, costUSD)
+ }
+}
+
+// mockRoundTripper is a mock HTTP transport for testing
+type mockRoundTripper struct {
+ response *http.Response
+}
+
+func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ return m.response, nil
+}
+
+func TestHeaderCostIntegration(t *testing.T) {
+ // Create a mock HTTP client that returns a response with cost headers
+ mockClient := &http.Client{
+ Transport: &mockRoundTripper{
+ response: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "Skaband-Cost-Microcents": []string{"50000"}, // 0.5 USD
+ },
+ Body: io.NopCloser(bytes.NewBufferString(`{
+ "candidates": [{
+ "content": {
+ "parts": [{
+ "text": "Test response"
+ }]
+ }
+ }]
+ }`)),
+ },
+ },
+ }
+
+ // Create a Gem service with the mock client
+ service := &Service{
+ Model: "gemini-test",
+ APIKey: "test-key",
+ HTTPC: mockClient,
+ URL: "https://test.googleapis.com",
+ }
+
+ // Create a request
+ ir := &llm.Request{
+ Messages: []llm.Message{
+ {
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{
+ {
+ Type: llm.ContentTypeText,
+ Text: "Hello",
+ },
+ },
+ },
+ },
+ }
+
+ // Make the request
+ ctx := context.Background()
+ res, err := service.Do(ctx, ir)
+ if err != nil {
+ t.Fatalf("Failed to make request: %v", err)
+ }
+
+ // Verify that the cost was captured from headers
+ expectedCost := 0.0005 // 50000 microcents / 100,000,000
+ if res.Usage.CostUSD != expectedCost {
+ t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, res.Usage.CostUSD)
+ }
+
+ // Verify token counts are still estimated
+ if res.Usage.InputTokens == 0 {
+ t.Fatalf("Expected input tokens to be estimated, got 0")
+ }
+ if res.Usage.OutputTokens == 0 {
+ t.Fatalf("Expected output tokens to be estimated, got 0")
+ }
+}