Initial commit
diff --git a/ant/ant.go b/ant/ant.go
new file mode 100644
index 0000000..a3eb79a
--- /dev/null
+++ b/ant/ant.go
@@ -0,0 +1,976 @@
+package ant
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "maps"
+ "math/rand/v2"
+ "net/http"
+ "slices"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/richardlehane/crock32"
+ "sketch.dev/skribe"
+)
+
+const (
+ DefaultModel = Claude37Sonnet
+ // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
+ // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
+ DefaultMaxTokens = 8192
+ DefaultURL = "https://api.anthropic.com/v1/messages"
+)
+
+const (
+ Claude35Sonnet = "claude-3-5-sonnet-20241022"
+ Claude35Haiku = "claude-3-5-haiku-20241022"
+ Claude37Sonnet = "claude-3-7-sonnet-20250219"
+)
+
+const (
+ MessageRoleUser = "user"
+ MessageRoleAssistant = "assistant"
+
+ ContentTypeText = "text"
+ ContentTypeThinking = "thinking"
+ ContentTypeRedactedThinking = "redacted_thinking"
+ ContentTypeToolUse = "tool_use"
+ ContentTypeToolResult = "tool_result"
+
+ StopReasonStopSequence = "stop_sequence"
+ StopReasonMaxTokens = "max_tokens"
+ StopReasonEndTurn = "end_turn"
+ StopReasonToolUse = "tool_use"
+)
+
+type Listener interface {
+ // TODO: Content is leaking an anthropic API; should we avoid it?
+ // TODO: Where should we include start/end time and usage?
+ OnToolResult(ctx context.Context, convo *Convo, toolName string, toolInput json.RawMessage, content Content, result *string, err error)
+ OnResponse(ctx context.Context, convo *Convo, msg *MessageResponse)
+ OnRequest(ctx context.Context, convo *Convo, msg *Message)
+}
+
+type NoopListener struct{}
+
+func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, toolName string, toolInput json.RawMessage, content Content, result *string, err error) {
+}
+func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, msg *MessageResponse) {}
+func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, msg *Message) {}
+
+type Content struct {
+ // TODO: image support?
+ // https://docs.anthropic.com/en/api/messages
+ ID string `json:"id,omitempty"`
+ Type string `json:"type,omitempty"`
+ Text string `json:"text,omitempty"`
+
+ // for thinking
+ Thinking string `json:"thinking,omitempty"`
+ Data string `json:"data,omitempty"` // for redacted_thinking
+ Signature string `json:"signature,omitempty"` // for thinking
+
+ // for tool_use
+ ToolName string `json:"name,omitempty"`
+ ToolInput json.RawMessage `json:"input,omitempty"`
+
+ // for tool_result
+ ToolUseID string `json:"tool_use_id,omitempty"`
+ ToolError bool `json:"is_error,omitempty"`
+ ToolResult string `json:"content,omitempty"`
+
+ // timing information for tool_result; not sent to Claude
+ StartTime *time.Time `json:"-"`
+ EndTime *time.Time `json:"-"`
+
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+func StringContent(s string) Content {
+ return Content{Type: ContentTypeText, Text: s}
+}
+
+// Message represents a message in the conversation.
+type Message struct {
+ Role string `json:"role"`
+ Content []Content `json:"content"`
+ ToolUse *ToolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
+}
+
+// ToolUse represents a tool use in the message content.
+type ToolUse struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+}
+
+// Tool represents a tool available to Claude.
+type Tool struct {
+ Name string `json:"name"`
+ // Type is used by the text editor tool; see
+ // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
+ Type string `json:"type,omitempty"`
+ Description string `json:"description,omitempty"`
+ InputSchema json.RawMessage `json:"input_schema,omitempty"`
+
+ // The Run function is automatically called when the tool is used.
+ // Run functions may be called concurrently with each other and themselves.
+ // The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
+ // The outputs from Run will be sent back to Claude.
+ // If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
+ // ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
+ Run func(ctx context.Context, input json.RawMessage) (string, error) `json:"-"`
+}
+
+var ErrDoNotRespond = errors.New("do not respond")
+
+// Usage represents the billing and rate-limit usage.
+type Usage struct {
+ InputTokens uint64 `json:"input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CostUSD float64 `json:"cost_usd"`
+}
+
+func (u *Usage) Add(other Usage) {
+ u.InputTokens += other.InputTokens
+ u.CacheCreationInputTokens += other.CacheCreationInputTokens
+ u.CacheReadInputTokens += other.CacheReadInputTokens
+ u.OutputTokens += other.OutputTokens
+ u.CostUSD += other.CostUSD
+}
+
+func (u *Usage) String() string {
+ return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
+}
+
+func (u *Usage) IsZero() bool {
+ return *u == Usage{}
+}
+
+func (u *Usage) Attr() slog.Attr {
+ return slog.Group("usage",
+ slog.Uint64("input_tokens", u.InputTokens),
+ slog.Uint64("output_tokens", u.OutputTokens),
+ slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
+ slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
+ )
+}
+
+type ErrorResponse struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+// MessageResponse represents the response from the message API.
+type MessageResponse struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Role string `json:"role"`
+ Model string `json:"model"`
+ Content []Content `json:"content"`
+ StopReason string `json:"stop_reason"`
+ StopSequence *string `json:"stop_sequence,omitempty"`
+ Usage Usage `json:"usage"`
+ StartTime *time.Time `json:"start_time,omitempty"`
+ EndTime *time.Time `json:"end_time,omitempty"`
+}
+
+func (m *MessageResponse) ToMessage() Message {
+ return Message{
+ Role: m.Role,
+ Content: m.Content,
+ }
+}
+
+func (m *MessageResponse) StopSequenceString() string {
+ if m.StopSequence == nil {
+ return ""
+ }
+ return *m.StopSequence
+}
+
+const (
+ ToolChoiceTypeAuto = "auto" // default
+ ToolChoiceTypeAny = "any" // any tool, but must use one
+ ToolChoiceTypeNone = "none" // no tools allowed
+ ToolChoiceTypeTool = "tool" // must use the tool specified in the Name field
+)
+
+type ToolChoice struct {
+ Type string `json:"type"`
+ Name string `json:"name,omitempty"`
+}
+
+// https://docs.anthropic.com/en/api/messages#body-system
+type SystemContent struct {
+ Text string `json:"text,omitempty"`
+ Type string `json:"type,omitempty"`
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+// MessageRequest represents the request payload for creating a message.
+type MessageRequest struct {
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
+ MaxTokens int `json:"max_tokens"`
+ Tools []*Tool `json:"tools,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ System []SystemContent `json:"system,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ StopSequences []string `json:"stop_sequences,omitempty"`
+
+ TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
+}
+
+const dumpText = false // debugging toggle to see raw communications with Claude
+
+// createMessage sends a request to the Anthropic message API to create a message.
+func createMessage(ctx context.Context, httpc *http.Client, url, apiKey string, request *MessageRequest) (*MessageResponse, error) {
+ var payload []byte
+ var err error
+ if dumpText || testing.Testing() {
+ payload, err = json.MarshalIndent(request, "", " ")
+ } else {
+ payload, err = json.Marshal(request)
+ payload = append(payload, '\n')
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ if false {
+ fmt.Printf("claude request payload:\n%s\n", payload)
+ }
+
+ backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
+ largerMaxTokens := false
+ var partialUsage Usage
+
+ // retry loop
+ for attempts := 0; ; attempts++ {
+ if dumpText {
+ fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
+ }
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("X-API-Key", apiKey)
+ req.Header.Set("Anthropic-Version", "2023-06-01")
+
+ features := []string{}
+
+ if request.TokenEfficientToolUse {
+ features = append(features, "token-efficient-tool-use-2025-02-19")
+ }
+ if largerMaxTokens {
+ features = append(features, "output-128k-2025-02-19")
+ request.MaxTokens = 128 * 1024
+ }
+ if len(features) > 0 {
+ req.Header.Set("anthropic-beta", strings.Join(features, ","))
+ }
+
+ resp, err := httpc.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ buf, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+
+ switch {
+ case resp.StatusCode == http.StatusOK:
+ if dumpText {
+ fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
+ }
+ var response MessageResponse
+ err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
+ if err != nil {
+ return nil, err
+ }
+ if response.StopReason == StopReasonMaxTokens && !largerMaxTokens {
+ fmt.Printf("Retrying Anthropic API call with larger max tokens size.")
+ // Retry with more output tokens.
+ largerMaxTokens = true
+ response.Usage.CostUSD = response.TotalDollars()
+ partialUsage = response.Usage
+ continue
+ }
+
+ // Calculate and set the cost_usd field
+ if largerMaxTokens {
+ response.Usage.Add(partialUsage)
+ }
+ response.Usage.CostUSD = response.TotalDollars()
+
+ return &response, nil
+ case resp.StatusCode >= 500 && resp.StatusCode < 600:
+ // overloaded or unhappy, in one form or another
+ sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
+ time.Sleep(sleep)
+ case resp.StatusCode == 429:
+ // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
+ // and then add some additional time for backoff.
+ sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
+ // case resp.StatusCode == 400:
+ // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
+ default:
+ return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
+ }
+ }
+}
+
+// A Convo is a managed conversation with Claude.
+// It automatically manages the state of the conversation,
+// including appending messages send/received,
+// calling tools and sending their results,
+// tracking usage, etc.
+//
+// Exported fields must not be altered concurrently with calling any method on Convo.
+// Typical usage is to configure a Convo once before using it.
+type Convo struct {
+ // ID is a unique ID for the conversation
+ ID string
+ // Ctx is the context for the entire conversation.
+ Ctx context.Context
+ // HTTPC is the HTTP client for the conversation.
+ HTTPC *http.Client
+ // URL is the remote messages URL to dial.
+ URL string
+ // APIKey is the API key for the conversation.
+ APIKey string
+ // Model is the model for the conversation.
+ Model string
+ // MaxTokens is the max tokens for each response in the conversation.
+ MaxTokens int
+ // Tools are the tools available during the conversation.
+ Tools []*Tool
+ // SystemPrompt is the system prompt for the conversation.
+ SystemPrompt string
+ // PromptCaching indicates whether to use Anthropic's prompt caching.
+ // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
+ // for the documentation. At request send time, we set the cache_control field on the
+ // last message. We also cache the system prompt.
+ // Default: true.
+ PromptCaching bool
+ // ToolUseOnly indicates whether Claude may only use tools during this conversation.
+ // TODO: add more fine-grained control over tool use?
+ ToolUseOnly bool
+ // Parent is the parent conversation, if any.
+ // It is non-nil for "subagent" calls.
+ // It is set automatically when calling SubConvo,
+ // and usually should not be set manually.
+ Parent *Convo
+ // Budget is the budget for this conversation (and all sub-conversations).
+ // The Conversation DOES NOT automatically enforce the budget.
+ // It is up to the caller to call OverBudget() as appropriate.
+ Budget Budget
+
+ // messages tracks the messages so far in the conversation.
+ messages []Message
+
+ // Listener receives messages being sent.
+ Listener Listener
+
+ muToolUseCancel *sync.Mutex
+ toolUseCancel map[string]context.CancelCauseFunc
+
+ // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
+ mu *sync.Mutex
+ // usage tracks usage for this conversation and all sub-conversations.
+ usage *CumulativeUsage
+}
+
+// newConvoID generates a new 8-byte random id.
+// The uniqueness/collision requirements here are very low.
+// They are not global identifiers,
+// just enough to distinguish different convos in a single session.
+func newConvoID() string {
+ u1 := rand.Uint32()
+ s := crock32.Encode(uint64(u1))
+ if len(s) < 7 {
+ s += strings.Repeat("0", 7-len(s))
+ }
+ return s[:3] + "-" + s[3:]
+}
+
+// NewConvo creates a new conversation with Claude with sensible defaults.
+// ctx is the context for the entire conversation.
+func NewConvo(ctx context.Context, apiKey string) *Convo {
+ id := newConvoID()
+ return &Convo{
+ Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
+ HTTPC: http.DefaultClient,
+ URL: DefaultURL,
+ APIKey: apiKey,
+ Model: DefaultModel,
+ MaxTokens: DefaultMaxTokens,
+ PromptCaching: true,
+ usage: newUsage(),
+ Listener: &NoopListener{},
+ ID: id,
+ muToolUseCancel: &sync.Mutex{},
+ toolUseCancel: map[string]context.CancelCauseFunc{},
+ mu: &sync.Mutex{},
+ }
+}
+
+// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
+// (This propagates context for cancellation, HTTP client, API key, etc.)
+// The sub-conversation shares no messages with the parent conversation.
+// It does not inherit tools from the parent conversation.
+func (c *Convo) SubConvo() *Convo {
+ id := newConvoID()
+ return &Convo{
+ Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
+ HTTPC: c.HTTPC,
+ APIKey: c.APIKey,
+ Model: c.Model,
+ MaxTokens: c.MaxTokens,
+ PromptCaching: c.PromptCaching,
+ Parent: c,
+ // For convenience, sub-convo usage shares tool uses map with parent,
+ // all other fields separate, propagated in AddResponse
+ usage: newUsageWithSharedToolUses(c.usage),
+ mu: c.mu,
+ Listener: c.Listener,
+ ID: id,
+ // Do not copy Budget. Each budget is independent,
+ // and OverBudget checks whether any ancestor is over budget.
+ }
+}
+
+// Depth reports how many "sub-conversations" deep this conversation is.
+// That it, it walks up parents until it finds a root.
+func (c *Convo) Depth() int {
+ x := c
+ var depth int
+ for x.Parent != nil {
+ x = x.Parent
+ depth++
+ }
+ return depth
+}
+
+// SendUserTextMessage sends a text message to Claude in this conversation.
+// otherContents contains additional contents to send with the message, usually tool results.
+func (c *Convo) SendUserTextMessage(s string, otherContents ...Content) (*MessageResponse, error) {
+ contents := slices.Clone(otherContents)
+ if s != "" {
+ contents = append(contents, Content{Type: ContentTypeText, Text: s})
+ }
+ msg := Message{
+ Role: MessageRoleUser,
+ Content: contents,
+ }
+ return c.SendMessage(msg)
+}
+
+func (c *Convo) messageRequest(msg Message) *MessageRequest {
+ system := []SystemContent{}
+ if c.SystemPrompt != "" {
+ var d SystemContent
+ d = SystemContent{Type: "text", Text: c.SystemPrompt}
+ if c.PromptCaching {
+ d.CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
+ }
+ system = []SystemContent{d}
+ }
+
+ // Claude is happy to return an empty response in response to our Done() call,
+ // and, if so, you'll see something like:
+ // API request failed with status 400 Bad Request
+ // {"type":"error","error": {"type":"invalid_request_error",
+ // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
+ // So, we filter out those empty messages.
+ var nonEmptyMessages []Message
+ for _, m := range c.messages {
+ if len(m.Content) > 0 {
+ nonEmptyMessages = append(nonEmptyMessages, m)
+ }
+ }
+
+ mr := &MessageRequest{
+ Model: c.Model,
+ Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
+ System: system,
+ Tools: c.Tools,
+ MaxTokens: c.MaxTokens,
+ }
+ if c.ToolUseOnly {
+ mr.ToolChoice = &ToolChoice{Type: ToolChoiceTypeAny}
+ }
+ return mr
+}
+
+func (c *Convo) findTool(name string) (*Tool, error) {
+ for _, tool := range c.Tools {
+ if tool.Name == name {
+ return tool, nil
+ }
+ }
+ return nil, fmt.Errorf("tool %q not found", name)
+}
+
+// insertMissingToolResults adds error results for tool uses that were requested
+// but not included in the message, which can happen in error paths like "out of budget."
+// We only insert these if there were no tool responses at all, since an incorrect
+// number of tool results would be a programmer error. Mutates inputs.
+func (c *Convo) insertMissingToolResults(mr *MessageRequest, msg *Message) {
+ if len(mr.Messages) < 2 {
+ return
+ }
+ prev := mr.Messages[len(mr.Messages)-2]
+ var toolUsePrev int
+ for _, c := range prev.Content {
+ if c.Type == ContentTypeToolUse {
+ toolUsePrev++
+ }
+ }
+ if toolUsePrev == 0 {
+ return
+ }
+ var toolUseCurrent int
+ for _, c := range msg.Content {
+ if c.Type == ContentTypeToolResult {
+ toolUseCurrent++
+ }
+ }
+ if toolUseCurrent != 0 {
+ return
+ }
+ var prefix []Content
+ for _, part := range prev.Content {
+ if part.Type != ContentTypeToolUse {
+ continue
+ }
+ content := Content{
+ Type: ContentTypeToolResult,
+ ToolUseID: part.ID,
+ ToolError: true,
+ ToolResult: "not executed; retry possible",
+ }
+ prefix = append(prefix, content)
+ msg.Content = append(prefix, msg.Content...)
+ mr.Messages[len(mr.Messages)-1].Content = msg.Content
+ }
+ slog.DebugContext(c.Ctx, "inserted missing tool results")
+}
+
+// SendMessage sends a message to Claude.
+// The conversation records (internally) all messages succesfully sent and received.
+func (c *Convo) SendMessage(msg Message) (*MessageResponse, error) {
+ mr := c.messageRequest(msg)
+ var lastMessage *Message
+ if c.PromptCaching {
+ lastMessage = &mr.Messages[len(mr.Messages)-1]
+ if len(lastMessage.Content) > 0 {
+ lastMessage.Content[len(lastMessage.Content)-1].CacheControl = json.RawMessage(`{"type":"ephemeral"}`)
+ }
+ }
+ defer func() {
+ if lastMessage == nil {
+ return
+ }
+ if len(lastMessage.Content) > 0 {
+ lastMessage.Content[len(lastMessage.Content)-1].CacheControl = []byte{}
+ }
+ }()
+ c.insertMissingToolResults(mr, &msg)
+ c.Listener.OnRequest(c.Ctx, c, &msg)
+
+ startTime := time.Now()
+ resp, err := createMessage(c.Ctx, c.HTTPC, c.URL, c.APIKey, mr)
+ if resp != nil {
+ resp.StartTime = &startTime
+ endTime := time.Now()
+ resp.EndTime = &endTime
+ }
+
+ if err != nil {
+ return nil, err
+ }
+ c.messages = append(c.messages, msg, resp.ToMessage())
+ // Propagate usage to all ancestors (including us).
+ for x := c; x != nil; x = x.Parent {
+ x.usage.AddResponse(resp)
+ }
+ c.Listener.OnResponse(c.Ctx, c, resp)
+ return resp, err
+}
+
+type toolCallInfoKeyType string
+
+var toolCallInfoKey toolCallInfoKeyType
+
+type ToolCallInfo struct {
+ ToolUseID string
+ Convo *Convo
+}
+
+func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
+ v := ctx.Value(toolCallInfoKey)
+ i, _ := v.(ToolCallInfo)
+ return i
+}
+
+func (c *Convo) ToolResultCancelContents(resp *MessageResponse) ([]Content, error) {
+ if resp.StopReason != StopReasonToolUse {
+ return nil, nil
+ }
+ var toolResults []Content
+
+ for _, part := range resp.Content {
+ if part.Type != ContentTypeToolUse {
+ continue
+ }
+ c.incrementToolUse(part.ToolName)
+
+ content := Content{
+ Type: ContentTypeToolResult,
+ ToolUseID: part.ID,
+ }
+
+ content.ToolError = true
+ content.ToolResult = "user canceled this too_use"
+ toolResults = append(toolResults, content)
+ }
+ return toolResults, nil
+}
+
+func (c *Convo) CancelToolUse(toolUseID string, err error) error {
+ c.muToolUseCancel.Lock()
+ defer c.muToolUseCancel.Unlock()
+ cancel, ok := c.toolUseCancel[toolUseID]
+ if !ok {
+ return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
+ }
+ delete(c.toolUseCancel, toolUseID)
+ cancel(err)
+ return nil
+}
+
+func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
+ c.muToolUseCancel.Lock()
+ defer c.muToolUseCancel.Unlock()
+ ctx, cancel := context.WithCancelCause(ctx)
+ c.toolUseCancel[toolUseID] = cancel
+ return ctx, func() { c.CancelToolUse(toolUseID, nil) }
+}
+
+// ToolResultContents runs all tool uses requested by the response and returns their results.
+// Cancelling ctx will cancel any running tool calls.
+func (c *Convo) ToolResultContents(ctx context.Context, resp *MessageResponse) ([]Content, error) {
+ if resp.StopReason != StopReasonToolUse {
+ return nil, nil
+ }
+ // Extract all tool calls from the response, call the tools, and gather the results.
+ var wg sync.WaitGroup
+ toolResultC := make(chan Content, len(resp.Content))
+ for _, part := range resp.Content {
+ if part.Type != ContentTypeToolUse {
+ continue
+ }
+ c.incrementToolUse(part.ToolName)
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ // Record start time
+ startTime := time.Now()
+
+ content := Content{
+ Type: ContentTypeToolResult,
+ ToolUseID: part.ID,
+ StartTime: &startTime,
+ }
+ sendErr := func(err error) {
+ // Record end time
+ endTime := time.Now()
+ content.EndTime = &endTime
+
+ content.ToolError = true
+ content.ToolResult = err.Error()
+ c.Listener.OnToolResult(ctx, c, part.ToolName, part.ToolInput, content, nil, err)
+ toolResultC <- content
+ }
+ sendRes := func(res string) {
+ // Record end time
+ endTime := time.Now()
+ content.EndTime = &endTime
+
+ content.ToolResult = res
+ c.Listener.OnToolResult(ctx, c, part.ToolName, part.ToolInput, content, &res, nil)
+ toolResultC <- content
+ }
+
+ tool, err := c.findTool(part.ToolName)
+ if err != nil {
+ sendErr(err)
+ return
+ }
+ // Create a new context for just this tool_use call, and register its
+ // cancel function so that it can be canceled individually.
+ toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
+ defer cancel()
+ // TODO: move this into newToolUseContext?
+ toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
+ toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
+ if errors.Is(err, ErrDoNotRespond) {
+ return
+ }
+ if toolUseCtx.Err() != nil {
+ sendErr(context.Cause(toolUseCtx))
+ return
+ }
+
+ if err != nil {
+ sendErr(err)
+ return
+ }
+ sendRes(toolResult)
+ }()
+ }
+ wg.Wait()
+ close(toolResultC)
+ var toolResults []Content
+ for toolResult := range toolResultC {
+ toolResults = append(toolResults, toolResult)
+ }
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+ return toolResults, nil
+}
+
+func (c *Convo) incrementToolUse(name string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.usage.ToolUses[name]++
+}
+
+// ContentsAttr returns contents as a slog.Attr.
+// It is meant for logging.
+func ContentsAttr(contents []Content) slog.Attr {
+ var contentAttrs []any // slog.Attr
+ for _, content := range contents {
+ var attrs []any // slog.Attr
+ switch content.Type {
+ case ContentTypeText:
+ attrs = append(attrs, slog.String("text", content.Text))
+ case ContentTypeToolUse:
+ attrs = append(attrs, slog.String("tool_name", content.ToolName))
+ attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
+ case ContentTypeToolResult:
+ attrs = append(attrs, slog.String("tool_result", content.ToolResult))
+ attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
+ case ContentTypeThinking:
+ attrs = append(attrs, slog.String("thinking", content.Text))
+ default:
+ attrs = append(attrs, slog.String("unknown_content_type", content.Type))
+ attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
+ }
+ contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
+ }
+ return slog.Group("contents", contentAttrs...)
+}
+
+// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
+// It panics if the schema is invalid.
+func MustSchema(schema string) json.RawMessage {
+ // TODO: validate schema, for now just make sure it's valid JSON
+ schema = strings.TrimSpace(schema)
+ bytes := []byte(schema)
+ if !json.Valid(bytes) {
+ panic("invalid JSON schema: " + schema)
+ }
+ return json.RawMessage(bytes)
+}
+
+// cents per million tokens
+// (not dollars because i'm twitchy about using floats for money)
+type centsPer1MTokens struct {
+ Input uint64
+ Output uint64
+ CacheRead uint64
+ CacheCreation uint64
+}
+
+// https://www.anthropic.com/pricing#anthropic-api
+var modelCost = map[string]centsPer1MTokens{
+ Claude37Sonnet: {
+ Input: 300, // $3
+ Output: 1500, // $15
+ CacheRead: 30, // $0.30
+ CacheCreation: 375, // $3.75
+ },
+ Claude35Haiku: {
+ Input: 80, // $0.80
+ Output: 400, // $4.00
+ CacheRead: 8, // $0.08
+ CacheCreation: 100, // $1.00
+ },
+ Claude35Sonnet: {
+ Input: 300, // $3
+ Output: 1500, // $15
+ CacheRead: 30, // $0.30
+ CacheCreation: 375, // $3.75
+ },
+}
+
+// TotalDollars returns the total cost to obtain this response, in dollars.
+func (mr *MessageResponse) TotalDollars() float64 {
+ cpm, ok := modelCost[mr.Model]
+ if !ok {
+ panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
+ }
+ use := mr.Usage
+ megaCents := use.InputTokens*cpm.Input +
+ use.OutputTokens*cpm.Output +
+ use.CacheReadInputTokens*cpm.CacheRead +
+ use.CacheCreationInputTokens*cpm.CacheCreation
+ cents := float64(megaCents) / 1_000_000.0
+ return cents / 100.0
+}
+
+func newUsage() *CumulativeUsage {
+ return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
+}
+
+func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
+ return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
+}
+
+// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
+type CumulativeUsage struct {
+ StartTime time.Time `json:"start_time"`
+ Responses uint64 `json:"messages"` // count of responses
+ InputTokens uint64 `json:"input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ TotalCostUSD float64 `json:"total_cost_usd"`
+ ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
+}
+
+func (u *CumulativeUsage) Clone() CumulativeUsage {
+ v := *u
+ v.ToolUses = maps.Clone(u.ToolUses)
+ return v
+}
+
+func (c *Convo) CumulativeUsage() CumulativeUsage {
+ if c == nil {
+ return CumulativeUsage{}
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.usage.Clone()
+}
+
+func (u *CumulativeUsage) WallTime() time.Duration {
+ return time.Since(u.StartTime)
+}
+
+func (u *CumulativeUsage) DollarsPerHour() float64 {
+ hours := u.WallTime().Hours()
+ if hours == 0 {
+ return 0
+ }
+ return u.TotalCostUSD / hours
+}
+
+func (u *CumulativeUsage) AddResponse(resp *MessageResponse) {
+ usage := resp.Usage
+ u.Responses++
+ u.InputTokens += usage.InputTokens
+ u.OutputTokens += usage.OutputTokens
+ u.CacheReadInputTokens += usage.CacheReadInputTokens
+ u.CacheCreationInputTokens += usage.CacheCreationInputTokens
+ u.TotalCostUSD += resp.TotalDollars()
+}
+
+// Attr returns the cumulative usage as a slog.Attr with key "usage".
+func (u CumulativeUsage) Attr() slog.Attr {
+ elapsed := time.Since(u.StartTime)
+ return slog.Group("usage",
+ slog.Duration("wall_time", elapsed),
+ slog.Uint64("responses", u.Responses),
+ slog.Uint64("input_tokens", u.InputTokens),
+ slog.Uint64("output_tokens", u.OutputTokens),
+ slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
+ slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
+ slog.Float64("total_cost_usd", u.TotalCostUSD),
+ slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
+ slog.Any("tool_uses", maps.Clone(u.ToolUses)),
+ )
+}
+
+// A Budget represents the maximum amount of resources that may be spent on a conversation.
+// Note that the default (zero) budget is unlimited.
+type Budget struct {
+ MaxResponses uint64 // if > 0, max number of iterations (=responses)
+ MaxDollars float64 // if > 0, max dollars that may be spent
+ MaxWallTime time.Duration // if > 0, max wall time that may be spent
+}
+
+// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
+// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
+func (c *Convo) OverBudget() error {
+ for x := c; x != nil; x = x.Parent {
+ if err := x.overBudget(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// ResetBudget sets the budget to the passed in budget and
+// adjusts it by what's been used so far.
+func (c *Convo) ResetBudget(budget Budget) {
+ c.Budget = budget
+ if c.Budget.MaxDollars > 0 {
+ c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
+ }
+ if c.Budget.MaxResponses > 0 {
+ c.Budget.MaxResponses += c.CumulativeUsage().Responses
+ }
+ if c.Budget.MaxWallTime > 0 {
+ c.Budget.MaxWallTime += c.usage.WallTime()
+ }
+}
+
+func (c *Convo) overBudget() error {
+ usage := c.CumulativeUsage()
+ // TODO: stop before we exceed the budget instead of after?
+ // Top priority is money, then time, then response count.
+ var err error
+ cont := "Continuing to chat will reset the budget."
+ if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
+ err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
+ }
+ if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
+ err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
+ }
+ if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
+ err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
+ }
+ return err
+}
diff --git a/ant/ant_test.go b/ant/ant_test.go
new file mode 100644
index 0000000..6556e9e
--- /dev/null
+++ b/ant/ant_test.go
@@ -0,0 +1,220 @@
+package ant
+
+import (
+ "context"
+ "math"
+ "net/http"
+ "os"
+ "strings"
+ "testing"
+
+ "sketch.dev/httprr"
+)
+
+func TestBasicConvo(t *testing.T) {
+ ctx := context.Background()
+ rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rr.ScrubReq(func(req *http.Request) error {
+ req.Header.Del("x-api-key")
+ return nil
+ })
+
+ convo := NewConvo(ctx, os.Getenv("ANTHROPIC_API_KEY"))
+ convo.HTTPC = rr.Client()
+
+ const name = "Cornelius"
+ res, err := convo.SendUserTextMessage("Hi, my name is " + name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, part := range res.Content {
+ t.Logf("%s", part.Text)
+ }
+ res, err = convo.SendUserTextMessage("What is my name?")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := ""
+ for _, part := range res.Content {
+ got += part.Text
+ }
+ if !strings.Contains(got, name) {
+ t.Errorf("model does not know the given name %s: %q", name, got)
+ }
+}
+
+// TestCalculateCostFromTokens tests the calculateCostFromTokens function
+func TestCalculateCostFromTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ model string
+ inputTokens uint64
+ outputTokens uint64
+ cacheReadInputTokens uint64
+ cacheCreationInputTokens uint64
+ want float64
+ }{
+ {
+ name: "Zero tokens",
+ model: Claude37Sonnet,
+ inputTokens: 0,
+ outputTokens: 0,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.0105,
+ },
+ {
+ name: "10000 input tokens, 5000 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 10000,
+ outputTokens: 5000,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.105,
+ },
+ {
+ name: "With cache read tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 0,
+ want: 0.0111,
+ },
+ {
+ name: "With cache creation tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 1500,
+ want: 0.016125,
+ },
+ {
+ name: "With all token types",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 1500,
+ want: 0.016725,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ usage := Usage{
+ InputTokens: tt.inputTokens,
+ OutputTokens: tt.outputTokens,
+ CacheReadInputTokens: tt.cacheReadInputTokens,
+ CacheCreationInputTokens: tt.cacheCreationInputTokens,
+ }
+ mr := MessageResponse{
+ Model: tt.model,
+ Usage: usage,
+ }
+ totalCost := mr.TotalDollars()
+ if math.Abs(totalCost-tt.want) > 0.0001 {
+ t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
+ }
+ })
+ }
+}
+
+// TestCancelToolUse tests the CancelToolUse function of the Convo struct
+func TestCancelToolUse(t *testing.T) {
+ tests := []struct {
+ name string
+ setupToolUse bool
+ toolUseID string
+ cancelErr error
+ expectError bool
+ expectCancel bool
+ }{
+ {
+ name: "Cancel existing tool use",
+ setupToolUse: true,
+ toolUseID: "tool123",
+ cancelErr: nil,
+ expectError: false,
+ expectCancel: true,
+ },
+ {
+ name: "Cancel existing tool use with error",
+ setupToolUse: true,
+ toolUseID: "tool456",
+ cancelErr: context.Canceled,
+ expectError: false,
+ expectCancel: true,
+ },
+ {
+ name: "Cancel non-existent tool use",
+ setupToolUse: false,
+ toolUseID: "tool789",
+ cancelErr: nil,
+ expectError: true,
+ expectCancel: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ convo := NewConvo(context.Background(), "")
+
+ var cancelCalled bool
+ var cancelledWithErr error
+
+ if tt.setupToolUse {
+ // Setup a mock cancel function to track calls
+ mockCancel := func(err error) {
+ cancelCalled = true
+ cancelledWithErr = err
+ }
+
+ convo.muToolUseCancel.Lock()
+ convo.toolUseCancel[tt.toolUseID] = mockCancel
+ convo.muToolUseCancel.Unlock()
+ }
+
+ err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
+
+ // Check if we got the expected error state
+ if (err != nil) != tt.expectError {
+ t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
+ }
+
+ // Check if the cancel function was called as expected
+ if cancelCalled != tt.expectCancel {
+ t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
+ }
+
+ // If we expected the cancel to be called, verify it was called with the right error
+ if tt.expectCancel && cancelledWithErr != tt.cancelErr {
+ t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
+ }
+
+ // Verify the toolUseID was removed from the map if it was initially added
+ if tt.setupToolUse {
+ convo.muToolUseCancel.Lock()
+ _, exists := convo.toolUseCancel[tt.toolUseID]
+ convo.muToolUseCancel.Unlock()
+
+ if exists {
+ t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
+ }
+ }
+ })
+ }
+}
diff --git a/ant/testdata/basic_convo.httprr b/ant/testdata/basic_convo.httprr
new file mode 100644
index 0000000..663de8d
--- /dev/null
+++ b/ant/testdata/basic_convo.httprr
@@ -0,0 +1,116 @@
+httprr trace v1
+457 1329
+POST https://api.anthropic.com/v1/messages HTTP/1.1
+Host: api.anthropic.com
+User-Agent: Go-http-client/1.1
+Content-Length: 261
+Anthropic-Version: 2023-06-01
+Content-Type: application/json
+
+{
+ "model": "claude-3-7-sonnet-20250219",
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Hi, my name is Cornelius",
+ "cache_control": {
+ "type": "ephemeral"
+ }
+ }
+ ]
+ }
+ ],
+ "max_tokens": 8192
+}HTTP/2.0 200 OK
+Anthropic-Organization-Id: 3c473a21-7208-450a-a9f8-80aebda45c1b
+Anthropic-Ratelimit-Input-Tokens-Limit: 200000
+Anthropic-Ratelimit-Input-Tokens-Remaining: 200000
+Anthropic-Ratelimit-Input-Tokens-Reset: 2025-03-11T17:45:06Z
+Anthropic-Ratelimit-Output-Tokens-Limit: 80000
+Anthropic-Ratelimit-Output-Tokens-Remaining: 79000
+Anthropic-Ratelimit-Output-Tokens-Reset: 2025-03-11T17:45:07Z
+Anthropic-Ratelimit-Requests-Limit: 4000
+Anthropic-Ratelimit-Requests-Remaining: 3999
+Anthropic-Ratelimit-Requests-Reset: 2025-03-11T17:45:05Z
+Anthropic-Ratelimit-Tokens-Limit: 280000
+Anthropic-Ratelimit-Tokens-Remaining: 279000
+Anthropic-Ratelimit-Tokens-Reset: 2025-03-11T17:45:06Z
+Cf-Cache-Status: DYNAMIC
+Cf-Ray: 91ecdd10fdc3f97f-SJC
+Content-Type: application/json
+Date: Tue, 11 Mar 2025 17:45:07 GMT
+Request-Id: req_01LBtxMdNzxDcDVPGJSh7giv
+Server: cloudflare
+Via: 1.1 google
+X-Robots-Tag: none
+
+{"id":"msg_01S1uUyUsTaKPBKuDUGGX8J2","type":"message","role":"assistant","model":"claude-3-7-sonnet-20250219","content":[{"type":"text","text":"Hello, Cornelius! It's nice to meet you. How are you doing today? Is there something I can help you with?"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":15,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":31}}779 1286
+POST https://api.anthropic.com/v1/messages HTTP/1.1
+Host: api.anthropic.com
+User-Agent: Go-http-client/1.1
+Content-Length: 583
+Anthropic-Version: 2023-06-01
+Content-Type: application/json
+
+{
+ "model": "claude-3-7-sonnet-20250219",
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Hi, my name is Cornelius"
+ }
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "text": "Hello, Cornelius! It's nice to meet you. How are you doing today? Is there something I can help you with?"
+ }
+ ]
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What is my name?",
+ "cache_control": {
+ "type": "ephemeral"
+ }
+ }
+ ]
+ }
+ ],
+ "max_tokens": 8192
+}HTTP/2.0 200 OK
+Anthropic-Organization-Id: 3c473a21-7208-450a-a9f8-80aebda45c1b
+Anthropic-Ratelimit-Input-Tokens-Limit: 200000
+Anthropic-Ratelimit-Input-Tokens-Remaining: 200000
+Anthropic-Ratelimit-Input-Tokens-Reset: 2025-03-11T17:45:07Z
+Anthropic-Ratelimit-Output-Tokens-Limit: 80000
+Anthropic-Ratelimit-Output-Tokens-Remaining: 80000
+Anthropic-Ratelimit-Output-Tokens-Reset: 2025-03-11T17:45:07Z
+Anthropic-Ratelimit-Requests-Limit: 4000
+Anthropic-Ratelimit-Requests-Remaining: 3999
+Anthropic-Ratelimit-Requests-Reset: 2025-03-11T17:45:07Z
+Anthropic-Ratelimit-Tokens-Limit: 280000
+Anthropic-Ratelimit-Tokens-Remaining: 280000
+Anthropic-Ratelimit-Tokens-Reset: 2025-03-11T17:45:07Z
+Cf-Cache-Status: DYNAMIC
+Cf-Ray: 91ecdd1ae9a6f97f-SJC
+Content-Type: application/json
+Date: Tue, 11 Mar 2025 17:45:07 GMT
+Request-Id: req_01MBf3RWXNfQgwhVRwwkBYSn
+Server: cloudflare
+Via: 1.1 google
+X-Robots-Tag: none
+
+{"id":"msg_01FGz6DeWeDpspJG8cuxyVE9","type":"message","role":"assistant","model":"claude-3-7-sonnet-20250219","content":[{"type":"text","text":"Your name is Cornelius, as you mentioned in your introduction."}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":54,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":17}}
\ No newline at end of file