Initial commit
diff --git a/ant/ant_test.go b/ant/ant_test.go
new file mode 100644
index 0000000..6556e9e
--- /dev/null
+++ b/ant/ant_test.go
@@ -0,0 +1,220 @@
+package ant
+
+import (
+ "context"
+ "math"
+ "net/http"
+ "os"
+ "strings"
+ "testing"
+
+ "sketch.dev/httprr"
+)
+
+func TestBasicConvo(t *testing.T) {
+ ctx := context.Background()
+ rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rr.ScrubReq(func(req *http.Request) error {
+ req.Header.Del("x-api-key")
+ return nil
+ })
+
+ convo := NewConvo(ctx, os.Getenv("ANTHROPIC_API_KEY"))
+ convo.HTTPC = rr.Client()
+
+ const name = "Cornelius"
+ res, err := convo.SendUserTextMessage("Hi, my name is " + name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, part := range res.Content {
+ t.Logf("%s", part.Text)
+ }
+ res, err = convo.SendUserTextMessage("What is my name?")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := ""
+ for _, part := range res.Content {
+ got += part.Text
+ }
+ if !strings.Contains(got, name) {
+ t.Errorf("model does not know the given name %s: %q", name, got)
+ }
+}
+
+// TestCalculateCostFromTokens tests the calculateCostFromTokens function
+func TestCalculateCostFromTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ model string
+ inputTokens uint64
+ outputTokens uint64
+ cacheReadInputTokens uint64
+ cacheCreationInputTokens uint64
+ want float64
+ }{
+ {
+ name: "Zero tokens",
+ model: Claude37Sonnet,
+ inputTokens: 0,
+ outputTokens: 0,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.0105,
+ },
+ {
+ name: "10000 input tokens, 5000 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 10000,
+ outputTokens: 5000,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.105,
+ },
+ {
+ name: "With cache read tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 0,
+ want: 0.0111,
+ },
+ {
+ name: "With cache creation tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 1500,
+ want: 0.016125,
+ },
+ {
+ name: "With all token types",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 1500,
+ want: 0.016725,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ usage := Usage{
+ InputTokens: tt.inputTokens,
+ OutputTokens: tt.outputTokens,
+ CacheReadInputTokens: tt.cacheReadInputTokens,
+ CacheCreationInputTokens: tt.cacheCreationInputTokens,
+ }
+ mr := MessageResponse{
+ Model: tt.model,
+ Usage: usage,
+ }
+ totalCost := mr.TotalDollars()
+ if math.Abs(totalCost-tt.want) > 0.0001 {
+ t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
+ }
+ })
+ }
+}
+
+// TestCancelToolUse tests the CancelToolUse function of the Convo struct
+func TestCancelToolUse(t *testing.T) {
+ tests := []struct {
+ name string
+ setupToolUse bool
+ toolUseID string
+ cancelErr error
+ expectError bool
+ expectCancel bool
+ }{
+ {
+ name: "Cancel existing tool use",
+ setupToolUse: true,
+ toolUseID: "tool123",
+ cancelErr: nil,
+ expectError: false,
+ expectCancel: true,
+ },
+ {
+ name: "Cancel existing tool use with error",
+ setupToolUse: true,
+ toolUseID: "tool456",
+ cancelErr: context.Canceled,
+ expectError: false,
+ expectCancel: true,
+ },
+ {
+ name: "Cancel non-existent tool use",
+ setupToolUse: false,
+ toolUseID: "tool789",
+ cancelErr: nil,
+ expectError: true,
+ expectCancel: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ convo := NewConvo(context.Background(), "")
+
+ var cancelCalled bool
+ var cancelledWithErr error
+
+ if tt.setupToolUse {
+ // Setup a mock cancel function to track calls
+ mockCancel := func(err error) {
+ cancelCalled = true
+ cancelledWithErr = err
+ }
+
+ convo.muToolUseCancel.Lock()
+ convo.toolUseCancel[tt.toolUseID] = mockCancel
+ convo.muToolUseCancel.Unlock()
+ }
+
+ err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
+
+ // Check if we got the expected error state
+ if (err != nil) != tt.expectError {
+ t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
+ }
+
+ // Check if the cancel function was called as expected
+ if cancelCalled != tt.expectCancel {
+ t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
+ }
+
+ // If we expected the cancel to be called, verify it was called with the right error
+ if tt.expectCancel && cancelledWithErr != tt.cancelErr {
+ t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
+ }
+
+ // Verify the toolUseID was removed from the map if it was initially added
+ if tt.setupToolUse {
+ convo.muToolUseCancel.Lock()
+ _, exists := convo.toolUseCancel[tt.toolUseID]
+ convo.muToolUseCancel.Unlock()
+
+ if exists {
+ t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
+ }
+ }
+ })
+ }
+}