blob: 6556e9e9ff87d64940c605ccb2f1eeb42ef09eef [file] [log] [blame]
package ant
import (
"context"
"math"
"net/http"
"os"
"strings"
"testing"
"sketch.dev/httprr"
)
func TestBasicConvo(t *testing.T) {
ctx := context.Background()
rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
if err != nil {
t.Fatal(err)
}
rr.ScrubReq(func(req *http.Request) error {
req.Header.Del("x-api-key")
return nil
})
convo := NewConvo(ctx, os.Getenv("ANTHROPIC_API_KEY"))
convo.HTTPC = rr.Client()
const name = "Cornelius"
res, err := convo.SendUserTextMessage("Hi, my name is " + name)
if err != nil {
t.Fatal(err)
}
for _, part := range res.Content {
t.Logf("%s", part.Text)
}
res, err = convo.SendUserTextMessage("What is my name?")
if err != nil {
t.Fatal(err)
}
got := ""
for _, part := range res.Content {
got += part.Text
}
if !strings.Contains(got, name) {
t.Errorf("model does not know the given name %s: %q", name, got)
}
}
// TestCalculateCostFromTokens tests the calculateCostFromTokens function
func TestCalculateCostFromTokens(t *testing.T) {
tests := []struct {
name string
model string
inputTokens uint64
outputTokens uint64
cacheReadInputTokens uint64
cacheCreationInputTokens uint64
want float64
}{
{
name: "Zero tokens",
model: Claude37Sonnet,
inputTokens: 0,
outputTokens: 0,
cacheReadInputTokens: 0,
cacheCreationInputTokens: 0,
want: 0,
},
{
name: "1000 input tokens, 500 output tokens",
model: Claude37Sonnet,
inputTokens: 1000,
outputTokens: 500,
cacheReadInputTokens: 0,
cacheCreationInputTokens: 0,
want: 0.0105,
},
{
name: "10000 input tokens, 5000 output tokens",
model: Claude37Sonnet,
inputTokens: 10000,
outputTokens: 5000,
cacheReadInputTokens: 0,
cacheCreationInputTokens: 0,
want: 0.105,
},
{
name: "With cache read tokens",
model: Claude37Sonnet,
inputTokens: 1000,
outputTokens: 500,
cacheReadInputTokens: 2000,
cacheCreationInputTokens: 0,
want: 0.0111,
},
{
name: "With cache creation tokens",
model: Claude37Sonnet,
inputTokens: 1000,
outputTokens: 500,
cacheReadInputTokens: 0,
cacheCreationInputTokens: 1500,
want: 0.016125,
},
{
name: "With all token types",
model: Claude37Sonnet,
inputTokens: 1000,
outputTokens: 500,
cacheReadInputTokens: 2000,
cacheCreationInputTokens: 1500,
want: 0.016725,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
usage := Usage{
InputTokens: tt.inputTokens,
OutputTokens: tt.outputTokens,
CacheReadInputTokens: tt.cacheReadInputTokens,
CacheCreationInputTokens: tt.cacheCreationInputTokens,
}
mr := MessageResponse{
Model: tt.model,
Usage: usage,
}
totalCost := mr.TotalDollars()
if math.Abs(totalCost-tt.want) > 0.0001 {
t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
}
})
}
}
// TestCancelToolUse tests the CancelToolUse function of the Convo struct
func TestCancelToolUse(t *testing.T) {
tests := []struct {
name string
setupToolUse bool
toolUseID string
cancelErr error
expectError bool
expectCancel bool
}{
{
name: "Cancel existing tool use",
setupToolUse: true,
toolUseID: "tool123",
cancelErr: nil,
expectError: false,
expectCancel: true,
},
{
name: "Cancel existing tool use with error",
setupToolUse: true,
toolUseID: "tool456",
cancelErr: context.Canceled,
expectError: false,
expectCancel: true,
},
{
name: "Cancel non-existent tool use",
setupToolUse: false,
toolUseID: "tool789",
cancelErr: nil,
expectError: true,
expectCancel: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
convo := NewConvo(context.Background(), "")
var cancelCalled bool
var cancelledWithErr error
if tt.setupToolUse {
// Setup a mock cancel function to track calls
mockCancel := func(err error) {
cancelCalled = true
cancelledWithErr = err
}
convo.muToolUseCancel.Lock()
convo.toolUseCancel[tt.toolUseID] = mockCancel
convo.muToolUseCancel.Unlock()
}
err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
// Check if we got the expected error state
if (err != nil) != tt.expectError {
t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
}
// Check if the cancel function was called as expected
if cancelCalled != tt.expectCancel {
t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
}
// If we expected the cancel to be called, verify it was called with the right error
if tt.expectCancel && cancelledWithErr != tt.cancelErr {
t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
}
// Verify the toolUseID was removed from the map if it was initially added
if tt.setupToolUse {
convo.muToolUseCancel.Lock()
_, exists := convo.toolUseCancel[tt.toolUseID]
convo.muToolUseCancel.Unlock()
if exists {
t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
}
}
})
}
}