all: support openai-compatible models
The support is rather minimal at this point:
Only hard-coded models, only -unsafe, only -skabandaddr="".
The "shared" LLM package is strongly Claude-flavored.
We can fix all of this and more over time, if we are inspired to.
(Maybe we'll switch to https://github.com/maruel/genai?)
The goal for now is to get the rough structure in place.
I've rebased and rebuilt this more times than I care to remember.
diff --git a/llm/ant/ant.go b/llm/ant/ant.go
new file mode 100644
index 0000000..dce17f1
--- /dev/null
+++ b/llm/ant/ant.go
@@ -0,0 +1,480 @@
+package ant
+
+import (
+ "bytes"
+ "cmp"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "math/rand/v2"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ "sketch.dev/llm"
+)
+
+const (
+ DefaultModel = Claude37Sonnet
+ // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
+ // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
+ DefaultMaxTokens = 8192
+ DefaultURL = "https://api.anthropic.com/v1/messages"
+)
+
+const (
+ Claude35Sonnet = "claude-3-5-sonnet-20241022"
+ Claude35Haiku = "claude-3-5-haiku-20241022"
+ Claude37Sonnet = "claude-3-7-sonnet-20250219"
+)
+
+// Service provides Claude completions.
+// Fields should not be altered concurrently with calling any method on Service.
+type Service struct {
+ HTTPC *http.Client // defaults to http.DefaultClient if nil
+ URL string // defaults to DefaultURL if empty
+ APIKey string // must be non-empty
+ Model string // defaults to DefaultModel if empty
+ MaxTokens int // defaults to DefaultMaxTokens if zero
+}
+
+var _ llm.Service = (*Service)(nil)
+
+type content struct {
+ // TODO: image support?
+ // https://docs.anthropic.com/en/api/messages
+ ID string `json:"id,omitempty"`
+ Type string `json:"type,omitempty"`
+ Text string `json:"text,omitempty"`
+
+ // for thinking
+ Thinking string `json:"thinking,omitempty"`
+ Data string `json:"data,omitempty"` // for redacted_thinking
+ Signature string `json:"signature,omitempty"` // for thinking
+
+ // for tool_use
+ ToolName string `json:"name,omitempty"`
+ ToolInput json.RawMessage `json:"input,omitempty"`
+
+ // for tool_result
+ ToolUseID string `json:"tool_use_id,omitempty"`
+ ToolError bool `json:"is_error,omitempty"`
+ ToolResult string `json:"content,omitempty"`
+
+ // timing information for tool_result; not sent to Claude
+ StartTime *time.Time `json:"-"`
+ EndTime *time.Time `json:"-"`
+
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+// message represents a message in the conversation.
+type message struct {
+ Role string `json:"role"`
+ Content []content `json:"content"`
+ ToolUse *toolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
+}
+
+// toolUse represents a tool use in the message content.
+type toolUse struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+}
+
+// tool represents a tool available to Claude.
+type tool struct {
+ Name string `json:"name"`
+ // Type is used by the text editor tool; see
+ // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
+ Type string `json:"type,omitempty"`
+ Description string `json:"description,omitempty"`
+ InputSchema json.RawMessage `json:"input_schema,omitempty"`
+}
+
+// usage represents the billing and rate-limit usage.
+type usage struct {
+ InputTokens uint64 `json:"input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CostUSD float64 `json:"cost_usd"`
+}
+
+func (u *usage) Add(other usage) {
+ u.InputTokens += other.InputTokens
+ u.CacheCreationInputTokens += other.CacheCreationInputTokens
+ u.CacheReadInputTokens += other.CacheReadInputTokens
+ u.OutputTokens += other.OutputTokens
+ u.CostUSD += other.CostUSD
+}
+
+type errorResponse struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+// response represents the response from the message API.
+type response struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Role string `json:"role"`
+ Model string `json:"model"`
+ Content []content `json:"content"`
+ StopReason string `json:"stop_reason"`
+ StopSequence *string `json:"stop_sequence,omitempty"`
+ Usage usage `json:"usage"`
+}
+
+type toolChoice struct {
+ Type string `json:"type"`
+ Name string `json:"name,omitempty"`
+}
+
+// https://docs.anthropic.com/en/api/messages#body-system
+type systemContent struct {
+ Text string `json:"text,omitempty"`
+ Type string `json:"type,omitempty"`
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+// request represents the request payload for creating a message.
+type request struct {
+ Model string `json:"model"`
+ Messages []message `json:"messages"`
+ ToolChoice *toolChoice `json:"tool_choice,omitempty"`
+ MaxTokens int `json:"max_tokens"`
+ Tools []*tool `json:"tools,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ System []systemContent `json:"system,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ StopSequences []string `json:"stop_sequences,omitempty"`
+
+ TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
+}
+
+const dumpText = false // debugging toggle to see raw communications with Claude
+
+func mapped[Slice ~[]E, E, T any](s Slice, f func(E) T) []T {
+ out := make([]T, len(s))
+ for i, v := range s {
+ out[i] = f(v)
+ }
+ return out
+}
+
+func inverted[K, V cmp.Ordered](m map[K]V) map[V]K {
+ inv := make(map[V]K)
+ for k, v := range m {
+ if _, ok := inv[v]; ok {
+ panic(fmt.Errorf("inverted map has multiple keys for value %v", v))
+ }
+ inv[v] = k
+ }
+ return inv
+}
+
+var (
+ fromLLMRole = map[llm.MessageRole]string{
+ llm.MessageRoleAssistant: "assistant",
+ llm.MessageRoleUser: "user",
+ }
+ toLLMRole = inverted(fromLLMRole)
+
+ fromLLMContentType = map[llm.ContentType]string{
+ llm.ContentTypeText: "text",
+ llm.ContentTypeThinking: "thinking",
+ llm.ContentTypeRedactedThinking: "redacted_thinking",
+ llm.ContentTypeToolUse: "tool_use",
+ llm.ContentTypeToolResult: "tool_result",
+ }
+ toLLMContentType = inverted(fromLLMContentType)
+
+ fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
+ llm.ToolChoiceTypeAuto: "auto",
+ llm.ToolChoiceTypeAny: "any",
+ llm.ToolChoiceTypeNone: "none",
+ llm.ToolChoiceTypeTool: "tool",
+ }
+
+ toLLMStopReason = map[string]llm.StopReason{
+ "stop_sequence": llm.StopReasonStopSequence,
+ "max_tokens": llm.StopReasonMaxTokens,
+ "end_turn": llm.StopReasonEndTurn,
+ "tool_use": llm.StopReasonToolUse,
+ }
+)
+
+func fromLLMCache(c bool) json.RawMessage {
+ if !c {
+ return nil
+ }
+ return json.RawMessage(`{"type":"ephemeral"}`)
+}
+
+func fromLLMContent(c llm.Content) content {
+ return content{
+ ID: c.ID,
+ Type: fromLLMContentType[c.Type],
+ Text: c.Text,
+ Thinking: c.Thinking,
+ Data: c.Data,
+ Signature: c.Signature,
+ ToolName: c.ToolName,
+ ToolInput: c.ToolInput,
+ ToolUseID: c.ToolUseID,
+ ToolError: c.ToolError,
+ ToolResult: c.ToolResult,
+ CacheControl: fromLLMCache(c.Cache),
+ }
+}
+
+func fromLLMToolUse(tu *llm.ToolUse) *toolUse {
+ if tu == nil {
+ return nil
+ }
+ return &toolUse{
+ ID: tu.ID,
+ Name: tu.Name,
+ }
+}
+
+func fromLLMMessage(msg llm.Message) message {
+ return message{
+ Role: fromLLMRole[msg.Role],
+ Content: mapped(msg.Content, fromLLMContent),
+ ToolUse: fromLLMToolUse(msg.ToolUse),
+ }
+}
+
+func fromLLMToolChoice(tc *llm.ToolChoice) *toolChoice {
+ if tc == nil {
+ return nil
+ }
+ return &toolChoice{
+ Type: fromLLMToolChoiceType[tc.Type],
+ Name: tc.Name,
+ }
+}
+
+func fromLLMTool(t *llm.Tool) *tool {
+ return &tool{
+ Name: t.Name,
+ Type: t.Type,
+ Description: t.Description,
+ InputSchema: t.InputSchema,
+ }
+}
+
+func fromLLMSystem(s llm.SystemContent) systemContent {
+ return systemContent{
+ Text: s.Text,
+ Type: s.Type,
+ CacheControl: fromLLMCache(s.Cache),
+ }
+}
+
+func (s *Service) fromLLMRequest(r *llm.Request) *request {
+ return &request{
+ Model: cmp.Or(s.Model, DefaultModel),
+ Messages: mapped(r.Messages, fromLLMMessage),
+ MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
+ ToolChoice: fromLLMToolChoice(r.ToolChoice),
+ Tools: mapped(r.Tools, fromLLMTool),
+ System: mapped(r.System, fromLLMSystem),
+ }
+}
+
+func toLLMUsage(u usage) llm.Usage {
+ return llm.Usage{
+ InputTokens: u.InputTokens,
+ CacheCreationInputTokens: u.CacheCreationInputTokens,
+ CacheReadInputTokens: u.CacheReadInputTokens,
+ OutputTokens: u.OutputTokens,
+ CostUSD: u.CostUSD,
+ }
+}
+
+func toLLMContent(c content) llm.Content {
+ return llm.Content{
+ ID: c.ID,
+ Type: toLLMContentType[c.Type],
+ Text: c.Text,
+ Thinking: c.Thinking,
+ Data: c.Data,
+ Signature: c.Signature,
+ ToolName: c.ToolName,
+ ToolInput: c.ToolInput,
+ ToolUseID: c.ToolUseID,
+ ToolError: c.ToolError,
+ ToolResult: c.ToolResult,
+ }
+}
+
+func toLLMResponse(r *response) *llm.Response {
+ return &llm.Response{
+ ID: r.ID,
+ Type: r.Type,
+ Role: toLLMRole[r.Role],
+ Model: r.Model,
+ Content: mapped(r.Content, toLLMContent),
+ StopReason: toLLMStopReason[r.StopReason],
+ StopSequence: r.StopSequence,
+ Usage: toLLMUsage(r.Usage),
+ }
+}
+
+// Do sends a request to Anthropic.
+func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
+ request := s.fromLLMRequest(ir)
+
+ var payload []byte
+ var err error
+ if dumpText || testing.Testing() {
+ payload, err = json.MarshalIndent(request, "", " ")
+ } else {
+ payload, err = json.Marshal(request)
+ payload = append(payload, '\n')
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ if false {
+ fmt.Printf("claude request payload:\n%s\n", payload)
+ }
+
+ backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
+ largerMaxTokens := false
+ var partialUsage usage
+
+ url := cmp.Or(s.URL, DefaultURL)
+ httpc := cmp.Or(s.HTTPC, http.DefaultClient)
+
+ // retry loop
+ for attempts := 0; ; attempts++ {
+ if dumpText {
+ fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
+ }
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("X-API-Key", s.APIKey)
+ req.Header.Set("Anthropic-Version", "2023-06-01")
+
+ var features []string
+ if request.TokenEfficientToolUse {
+ features = append(features, "token-efficient-tool-use-2025-02-19")
+ }
+ if largerMaxTokens {
+ features = append(features, "output-128k-2025-02-19")
+ request.MaxTokens = 128 * 1024
+ }
+ if len(features) > 0 {
+ req.Header.Set("anthropic-beta", strings.Join(features, ","))
+ }
+
+ resp, err := httpc.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ buf, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+
+ switch {
+ case resp.StatusCode == http.StatusOK:
+ if dumpText {
+ fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
+ }
+ var response response
+ err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
+ if err != nil {
+ return nil, err
+ }
+ if response.StopReason == "max_tokens" && !largerMaxTokens {
+ fmt.Printf("Retrying Anthropic API call with larger max tokens size.")
+ // Retry with more output tokens.
+ largerMaxTokens = true
+ response.Usage.CostUSD = response.TotalDollars()
+ partialUsage = response.Usage
+ continue
+ }
+
+ // Calculate and set the cost_usd field
+ if largerMaxTokens {
+ response.Usage.Add(partialUsage)
+ }
+ response.Usage.CostUSD = response.TotalDollars()
+
+ return toLLMResponse(&response), nil
+ case resp.StatusCode >= 500 && resp.StatusCode < 600:
+ // overloaded or unhappy, in one form or another
+ sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
+ time.Sleep(sleep)
+ case resp.StatusCode == 429:
+ // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
+ // and then add some additional time for backoff.
+ sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
+ time.Sleep(sleep)
+ // case resp.StatusCode == 400:
+ // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
+ default:
+ return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
+ }
+ }
+}
+
+// cents per million tokens
+// (not dollars because i'm twitchy about using floats for money)
+type centsPer1MTokens struct {
+ Input uint64
+ Output uint64
+ CacheRead uint64
+ CacheCreation uint64
+}
+
+// https://www.anthropic.com/pricing#anthropic-api
+var modelCost = map[string]centsPer1MTokens{
+ Claude37Sonnet: {
+ Input: 300, // $3
+ Output: 1500, // $15
+ CacheRead: 30, // $0.30
+ CacheCreation: 375, // $3.75
+ },
+ Claude35Haiku: {
+ Input: 80, // $0.80
+ Output: 400, // $4.00
+ CacheRead: 8, // $0.08
+ CacheCreation: 100, // $1.00
+ },
+ Claude35Sonnet: {
+ Input: 300, // $3
+ Output: 1500, // $15
+ CacheRead: 30, // $0.30
+ CacheCreation: 375, // $3.75
+ },
+}
+
+// TotalDollars returns the total cost to obtain this response, in dollars.
+func (mr *response) TotalDollars() float64 {
+ cpm, ok := modelCost[mr.Model]
+ if !ok {
+ panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
+ }
+ use := mr.Usage
+ megaCents := use.InputTokens*cpm.Input +
+ use.OutputTokens*cpm.Output +
+ use.CacheReadInputTokens*cpm.CacheRead +
+ use.CacheCreationInputTokens*cpm.CacheCreation
+ cents := float64(megaCents) / 1_000_000.0
+ return cents / 100.0
+}
diff --git a/llm/ant/ant_test.go b/llm/ant/ant_test.go
new file mode 100644
index 0000000..67cc5db
--- /dev/null
+++ b/llm/ant/ant_test.go
@@ -0,0 +1,93 @@
+package ant
+
+import (
+ "math"
+ "testing"
+)
+
+// TestCalculateCostFromTokens tests the calculateCostFromTokens function
+func TestCalculateCostFromTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ model string
+ inputTokens uint64
+ outputTokens uint64
+ cacheReadInputTokens uint64
+ cacheCreationInputTokens uint64
+ want float64
+ }{
+ {
+ name: "Zero tokens",
+ model: Claude37Sonnet,
+ inputTokens: 0,
+ outputTokens: 0,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.0105,
+ },
+ {
+ name: "10000 input tokens, 5000 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 10000,
+ outputTokens: 5000,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.105,
+ },
+ {
+ name: "With cache read tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 0,
+ want: 0.0111,
+ },
+ {
+ name: "With cache creation tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 1500,
+ want: 0.016125,
+ },
+ {
+ name: "With all token types",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 1500,
+ want: 0.016725,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ usage := usage{
+ InputTokens: tt.inputTokens,
+ OutputTokens: tt.outputTokens,
+ CacheReadInputTokens: tt.cacheReadInputTokens,
+ CacheCreationInputTokens: tt.cacheCreationInputTokens,
+ }
+ mr := response{
+ Model: tt.model,
+ Usage: usage,
+ }
+ totalCost := mr.TotalDollars()
+ if math.Abs(totalCost-tt.want) > 0.0001 {
+ t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
+ }
+ })
+ }
+}