all: support openai-compatible models
The support is rather minimal at this point:
Only hard-coded models, only -unsafe, only -skabandaddr="".
The "shared" LLM package is strongly Claude-flavored.
We can fix all of this and more over time, if we are inspired to.
(Maybe we'll switch to https://github.com/maruel/genai?)
The goal for now is to get the rough structure in place.
I've rebased and rebuilt this more times than I care to remember.
diff --git a/llm/ant/ant.go b/llm/ant/ant.go
new file mode 100644
index 0000000..dce17f1
--- /dev/null
+++ b/llm/ant/ant.go
@@ -0,0 +1,480 @@
+package ant
+
+import (
+ "bytes"
+ "cmp"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "math/rand/v2"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ "sketch.dev/llm"
+)
+
+const (
+ DefaultModel = Claude37Sonnet
+ // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
+ // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
+ DefaultMaxTokens = 8192
+ DefaultURL = "https://api.anthropic.com/v1/messages"
+)
+
+const (
+ Claude35Sonnet = "claude-3-5-sonnet-20241022"
+ Claude35Haiku = "claude-3-5-haiku-20241022"
+ Claude37Sonnet = "claude-3-7-sonnet-20250219"
+)
+
+// Service provides Claude completions.
+// Fields should not be altered concurrently with calling any method on Service.
+type Service struct {
+ HTTPC *http.Client // defaults to http.DefaultClient if nil
+ URL string // defaults to DefaultURL if empty
+ APIKey string // must be non-empty
+ Model string // defaults to DefaultModel if empty
+ MaxTokens int // defaults to DefaultMaxTokens if zero
+}
+
+var _ llm.Service = (*Service)(nil)
+
+type content struct {
+ // TODO: image support?
+ // https://docs.anthropic.com/en/api/messages
+ ID string `json:"id,omitempty"`
+ Type string `json:"type,omitempty"`
+ Text string `json:"text,omitempty"`
+
+ // for thinking
+ Thinking string `json:"thinking,omitempty"`
+ Data string `json:"data,omitempty"` // for redacted_thinking
+ Signature string `json:"signature,omitempty"` // for thinking
+
+ // for tool_use
+ ToolName string `json:"name,omitempty"`
+ ToolInput json.RawMessage `json:"input,omitempty"`
+
+ // for tool_result
+ ToolUseID string `json:"tool_use_id,omitempty"`
+ ToolError bool `json:"is_error,omitempty"`
+ ToolResult string `json:"content,omitempty"`
+
+ // timing information for tool_result; not sent to Claude
+ StartTime *time.Time `json:"-"`
+ EndTime *time.Time `json:"-"`
+
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+// message represents a message in the conversation.
+type message struct {
+ Role string `json:"role"`
+ Content []content `json:"content"`
+ ToolUse *toolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
+}
+
+// toolUse represents a tool use in the message content.
+type toolUse struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+}
+
+// tool represents a tool available to Claude.
+type tool struct {
+ Name string `json:"name"`
+ // Type is used by the text editor tool; see
+ // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
+ Type string `json:"type,omitempty"`
+ Description string `json:"description,omitempty"`
+ InputSchema json.RawMessage `json:"input_schema,omitempty"`
+}
+
+// usage represents the billing and rate-limit usage.
+type usage struct {
+ InputTokens uint64 `json:"input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CostUSD float64 `json:"cost_usd"`
+}
+
+func (u *usage) Add(other usage) {
+ u.InputTokens += other.InputTokens
+ u.CacheCreationInputTokens += other.CacheCreationInputTokens
+ u.CacheReadInputTokens += other.CacheReadInputTokens
+ u.OutputTokens += other.OutputTokens
+ u.CostUSD += other.CostUSD
+}
+
+type errorResponse struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+// response represents the response from the message API.
+type response struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Role string `json:"role"`
+ Model string `json:"model"`
+ Content []content `json:"content"`
+ StopReason string `json:"stop_reason"`
+ StopSequence *string `json:"stop_sequence,omitempty"`
+ Usage usage `json:"usage"`
+}
+
+type toolChoice struct {
+ Type string `json:"type"`
+ Name string `json:"name,omitempty"`
+}
+
+// https://docs.anthropic.com/en/api/messages#body-system
+type systemContent struct {
+ Text string `json:"text,omitempty"`
+ Type string `json:"type,omitempty"`
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+// request represents the request payload for creating a message.
+type request struct {
+ Model string `json:"model"`
+ Messages []message `json:"messages"`
+ ToolChoice *toolChoice `json:"tool_choice,omitempty"`
+ MaxTokens int `json:"max_tokens"`
+ Tools []*tool `json:"tools,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ System []systemContent `json:"system,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ StopSequences []string `json:"stop_sequences,omitempty"`
+
+ TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
+}
+
+const dumpText = false // debugging toggle to see raw communications with Claude
+
+func mapped[Slice ~[]E, E, T any](s Slice, f func(E) T) []T {
+ out := make([]T, len(s))
+ for i, v := range s {
+ out[i] = f(v)
+ }
+ return out
+}
+
+func inverted[K, V cmp.Ordered](m map[K]V) map[V]K {
+ inv := make(map[V]K)
+ for k, v := range m {
+ if _, ok := inv[v]; ok {
+ panic(fmt.Errorf("inverted map has multiple keys for value %v", v))
+ }
+ inv[v] = k
+ }
+ return inv
+}
+
+var (
+ fromLLMRole = map[llm.MessageRole]string{
+ llm.MessageRoleAssistant: "assistant",
+ llm.MessageRoleUser: "user",
+ }
+ toLLMRole = inverted(fromLLMRole)
+
+ fromLLMContentType = map[llm.ContentType]string{
+ llm.ContentTypeText: "text",
+ llm.ContentTypeThinking: "thinking",
+ llm.ContentTypeRedactedThinking: "redacted_thinking",
+ llm.ContentTypeToolUse: "tool_use",
+ llm.ContentTypeToolResult: "tool_result",
+ }
+ toLLMContentType = inverted(fromLLMContentType)
+
+ fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
+ llm.ToolChoiceTypeAuto: "auto",
+ llm.ToolChoiceTypeAny: "any",
+ llm.ToolChoiceTypeNone: "none",
+ llm.ToolChoiceTypeTool: "tool",
+ }
+
+ toLLMStopReason = map[string]llm.StopReason{
+ "stop_sequence": llm.StopReasonStopSequence,
+ "max_tokens": llm.StopReasonMaxTokens,
+ "end_turn": llm.StopReasonEndTurn,
+ "tool_use": llm.StopReasonToolUse,
+ }
+)
+
+func fromLLMCache(c bool) json.RawMessage {
+ if !c {
+ return nil
+ }
+ return json.RawMessage(`{"type":"ephemeral"}`)
+}
+
+func fromLLMContent(c llm.Content) content {
+ return content{
+ ID: c.ID,
+ Type: fromLLMContentType[c.Type],
+ Text: c.Text,
+ Thinking: c.Thinking,
+ Data: c.Data,
+ Signature: c.Signature,
+ ToolName: c.ToolName,
+ ToolInput: c.ToolInput,
+ ToolUseID: c.ToolUseID,
+ ToolError: c.ToolError,
+ ToolResult: c.ToolResult,
+ CacheControl: fromLLMCache(c.Cache),
+ }
+}
+
+func fromLLMToolUse(tu *llm.ToolUse) *toolUse {
+ if tu == nil {
+ return nil
+ }
+ return &toolUse{
+ ID: tu.ID,
+ Name: tu.Name,
+ }
+}
+
+func fromLLMMessage(msg llm.Message) message {
+ return message{
+ Role: fromLLMRole[msg.Role],
+ Content: mapped(msg.Content, fromLLMContent),
+ ToolUse: fromLLMToolUse(msg.ToolUse),
+ }
+}
+
+func fromLLMToolChoice(tc *llm.ToolChoice) *toolChoice {
+ if tc == nil {
+ return nil
+ }
+ return &toolChoice{
+ Type: fromLLMToolChoiceType[tc.Type],
+ Name: tc.Name,
+ }
+}
+
+func fromLLMTool(t *llm.Tool) *tool {
+ return &tool{
+ Name: t.Name,
+ Type: t.Type,
+ Description: t.Description,
+ InputSchema: t.InputSchema,
+ }
+}
+
+func fromLLMSystem(s llm.SystemContent) systemContent {
+ return systemContent{
+ Text: s.Text,
+ Type: s.Type,
+ CacheControl: fromLLMCache(s.Cache),
+ }
+}
+
+func (s *Service) fromLLMRequest(r *llm.Request) *request {
+ return &request{
+ Model: cmp.Or(s.Model, DefaultModel),
+ Messages: mapped(r.Messages, fromLLMMessage),
+ MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
+ ToolChoice: fromLLMToolChoice(r.ToolChoice),
+ Tools: mapped(r.Tools, fromLLMTool),
+ System: mapped(r.System, fromLLMSystem),
+ }
+}
+
+func toLLMUsage(u usage) llm.Usage {
+ return llm.Usage{
+ InputTokens: u.InputTokens,
+ CacheCreationInputTokens: u.CacheCreationInputTokens,
+ CacheReadInputTokens: u.CacheReadInputTokens,
+ OutputTokens: u.OutputTokens,
+ CostUSD: u.CostUSD,
+ }
+}
+
+func toLLMContent(c content) llm.Content {
+ return llm.Content{
+ ID: c.ID,
+ Type: toLLMContentType[c.Type],
+ Text: c.Text,
+ Thinking: c.Thinking,
+ Data: c.Data,
+ Signature: c.Signature,
+ ToolName: c.ToolName,
+ ToolInput: c.ToolInput,
+ ToolUseID: c.ToolUseID,
+ ToolError: c.ToolError,
+ ToolResult: c.ToolResult,
+ }
+}
+
+func toLLMResponse(r *response) *llm.Response {
+ return &llm.Response{
+ ID: r.ID,
+ Type: r.Type,
+ Role: toLLMRole[r.Role],
+ Model: r.Model,
+ Content: mapped(r.Content, toLLMContent),
+ StopReason: toLLMStopReason[r.StopReason],
+ StopSequence: r.StopSequence,
+ Usage: toLLMUsage(r.Usage),
+ }
+}
+
+// Do sends a request to Anthropic.
+func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
+ request := s.fromLLMRequest(ir)
+
+ var payload []byte
+ var err error
+ if dumpText || testing.Testing() {
+ payload, err = json.MarshalIndent(request, "", " ")
+ } else {
+ payload, err = json.Marshal(request)
+ payload = append(payload, '\n')
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ if false {
+ fmt.Printf("claude request payload:\n%s\n", payload)
+ }
+
+ backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
+ largerMaxTokens := false
+ var partialUsage usage
+
+ url := cmp.Or(s.URL, DefaultURL)
+ httpc := cmp.Or(s.HTTPC, http.DefaultClient)
+
+ // retry loop
+ for attempts := 0; ; attempts++ {
+ if dumpText {
+ fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
+ }
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("X-API-Key", s.APIKey)
+ req.Header.Set("Anthropic-Version", "2023-06-01")
+
+ var features []string
+ if request.TokenEfficientToolUse {
+ features = append(features, "token-efficient-tool-use-2025-02-19")
+ }
+ if largerMaxTokens {
+ features = append(features, "output-128k-2025-02-19")
+ request.MaxTokens = 128 * 1024
+ }
+ if len(features) > 0 {
+ req.Header.Set("anthropic-beta", strings.Join(features, ","))
+ }
+
+ resp, err := httpc.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ buf, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+
+ switch {
+ case resp.StatusCode == http.StatusOK:
+ if dumpText {
+ fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
+ }
+ var response response
+ err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
+ if err != nil {
+ return nil, err
+ }
+ if response.StopReason == "max_tokens" && !largerMaxTokens {
+ fmt.Printf("Retrying Anthropic API call with larger max tokens size.")
+ // Retry with more output tokens.
+ largerMaxTokens = true
+ response.Usage.CostUSD = response.TotalDollars()
+ partialUsage = response.Usage
+ continue
+ }
+
+ // Calculate and set the cost_usd field
+ if largerMaxTokens {
+ response.Usage.Add(partialUsage)
+ }
+ response.Usage.CostUSD = response.TotalDollars()
+
+ return toLLMResponse(&response), nil
+ case resp.StatusCode >= 500 && resp.StatusCode < 600:
+ // overloaded or unhappy, in one form or another
+ sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
+ time.Sleep(sleep)
+ case resp.StatusCode == 429:
+ // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
+ // and then add some additional time for backoff.
+ sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
+ time.Sleep(sleep)
+ // case resp.StatusCode == 400:
+ // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
+ default:
+ return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
+ }
+ }
+}
+
+// cents per million tokens
+// (not dollars because i'm twitchy about using floats for money)
+type centsPer1MTokens struct {
+ Input uint64
+ Output uint64
+ CacheRead uint64
+ CacheCreation uint64
+}
+
+// https://www.anthropic.com/pricing#anthropic-api
+var modelCost = map[string]centsPer1MTokens{
+ Claude37Sonnet: {
+ Input: 300, // $3
+ Output: 1500, // $15
+ CacheRead: 30, // $0.30
+ CacheCreation: 375, // $3.75
+ },
+ Claude35Haiku: {
+ Input: 80, // $0.80
+ Output: 400, // $4.00
+ CacheRead: 8, // $0.08
+ CacheCreation: 100, // $1.00
+ },
+ Claude35Sonnet: {
+ Input: 300, // $3
+ Output: 1500, // $15
+ CacheRead: 30, // $0.30
+ CacheCreation: 375, // $3.75
+ },
+}
+
+// TotalDollars returns the total cost to obtain this response, in dollars.
+func (mr *response) TotalDollars() float64 {
+ cpm, ok := modelCost[mr.Model]
+ if !ok {
+ panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
+ }
+ use := mr.Usage
+ megaCents := use.InputTokens*cpm.Input +
+ use.OutputTokens*cpm.Output +
+ use.CacheReadInputTokens*cpm.CacheRead +
+ use.CacheCreationInputTokens*cpm.CacheCreation
+ cents := float64(megaCents) / 1_000_000.0
+ return cents / 100.0
+}
diff --git a/llm/ant/ant_test.go b/llm/ant/ant_test.go
new file mode 100644
index 0000000..67cc5db
--- /dev/null
+++ b/llm/ant/ant_test.go
@@ -0,0 +1,93 @@
+package ant
+
+import (
+ "math"
+ "testing"
+)
+
+// TestCalculateCostFromTokens tests the calculateCostFromTokens function
+func TestCalculateCostFromTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ model string
+ inputTokens uint64
+ outputTokens uint64
+ cacheReadInputTokens uint64
+ cacheCreationInputTokens uint64
+ want float64
+ }{
+ {
+ name: "Zero tokens",
+ model: Claude37Sonnet,
+ inputTokens: 0,
+ outputTokens: 0,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.0105,
+ },
+ {
+ name: "10000 input tokens, 5000 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 10000,
+ outputTokens: 5000,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.105,
+ },
+ {
+ name: "With cache read tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 0,
+ want: 0.0111,
+ },
+ {
+ name: "With cache creation tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 1500,
+ want: 0.016125,
+ },
+ {
+ name: "With all token types",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 1500,
+ want: 0.016725,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ usage := usage{
+ InputTokens: tt.inputTokens,
+ OutputTokens: tt.outputTokens,
+ CacheReadInputTokens: tt.cacheReadInputTokens,
+ CacheCreationInputTokens: tt.cacheCreationInputTokens,
+ }
+ mr := response{
+ Model: tt.model,
+ Usage: usage,
+ }
+ totalCost := mr.TotalDollars()
+ if math.Abs(totalCost-tt.want) > 0.0001 {
+ t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
+ }
+ })
+ }
+}
diff --git a/llm/conversation/convo.go b/llm/conversation/convo.go
new file mode 100644
index 0000000..5a12256
--- /dev/null
+++ b/llm/conversation/convo.go
@@ -0,0 +1,617 @@
+package conversation
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "maps"
+ "math/rand/v2"
+ "slices"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/oklog/ulid/v2"
+ "github.com/richardlehane/crock32"
+ "sketch.dev/llm"
+ "sketch.dev/skribe"
+)
+
+type Listener interface {
+ // TODO: Content is leaking an anthropic API; should we avoid it?
+ // TODO: Where should we include start/end time and usage?
+ OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content)
+ OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error)
+ OnRequest(ctx context.Context, convo *Convo, requestID string, msg *llm.Message)
+ OnResponse(ctx context.Context, convo *Convo, requestID string, msg *llm.Response)
+}
+
+type NoopListener struct{}
+
+func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content) {
+}
+
+func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
+}
+
+func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *llm.Response) {
+}
+func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {}
+
+var ErrDoNotRespond = errors.New("do not respond")
+
+// A Convo is a managed conversation with Claude.
+// It automatically manages the state of the conversation,
+// including appending messages send/received,
+// calling tools and sending their results,
+// tracking usage, etc.
+//
+// Exported fields must not be altered concurrently with calling any method on Convo.
+// Typical usage is to configure a Convo once before using it.
+type Convo struct {
+ // ID is a unique ID for the conversation
+ ID string
+ // Ctx is the context for the entire conversation.
+ Ctx context.Context
+ // Service is the LLM service to use.
+ Service llm.Service
+ // Tools are the tools available during the conversation.
+ Tools []*llm.Tool
+ // SystemPrompt is the system prompt for the conversation.
+ SystemPrompt string
+ // PromptCaching indicates whether to use Anthropic's prompt caching.
+ // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
+ // for the documentation. At request send time, we set the cache_control field on the
+ // last message. We also cache the system prompt.
+ // Default: true.
+ PromptCaching bool
+ // ToolUseOnly indicates whether Claude may only use tools during this conversation.
+ // TODO: add more fine-grained control over tool use?
+ ToolUseOnly bool
+ // Parent is the parent conversation, if any.
+ // It is non-nil for "subagent" calls.
+ // It is set automatically when calling SubConvo,
+ // and usually should not be set manually.
+ Parent *Convo
+ // Budget is the budget for this conversation (and all sub-conversations).
+ // The Conversation DOES NOT automatically enforce the budget.
+ // It is up to the caller to call OverBudget() as appropriate.
+ Budget Budget
+
+ // messages tracks the messages so far in the conversation.
+ messages []llm.Message
+
+ // Listener receives messages being sent.
+ Listener Listener
+
+ muToolUseCancel *sync.Mutex
+ toolUseCancel map[string]context.CancelCauseFunc
+
+ // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
+ mu *sync.Mutex
+ // usage tracks usage for this conversation and all sub-conversations.
+ usage *CumulativeUsage
+}
+
+// newConvoID generates a new 8-byte random id.
+// The uniqueness/collision requirements here are very low.
+// They are not global identifiers,
+// just enough to distinguish different convos in a single session.
+func newConvoID() string {
+ u1 := rand.Uint32()
+ s := crock32.Encode(uint64(u1))
+ if len(s) < 7 {
+ s += strings.Repeat("0", 7-len(s))
+ }
+ return s[:3] + "-" + s[3:]
+}
+
+// New creates a new conversation with Claude with sensible defaults.
+// ctx is the context for the entire conversation.
+func New(ctx context.Context, srv llm.Service) *Convo {
+ id := newConvoID()
+ return &Convo{
+ Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
+ Service: srv,
+ PromptCaching: true,
+ usage: newUsage(),
+ Listener: &NoopListener{},
+ ID: id,
+ muToolUseCancel: &sync.Mutex{},
+ toolUseCancel: map[string]context.CancelCauseFunc{},
+ mu: &sync.Mutex{},
+ }
+}
+
+// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
+// (This propagates context for cancellation, HTTP client, API key, etc.)
+// The sub-conversation shares no messages with the parent conversation.
+// It does not inherit tools from the parent conversation.
+func (c *Convo) SubConvo() *Convo {
+ id := newConvoID()
+ return &Convo{
+ Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
+ Service: c.Service,
+ PromptCaching: c.PromptCaching,
+ Parent: c,
+ // For convenience, sub-convo usage shares tool uses map with parent,
+ // all other fields separate, propagated in AddResponse
+ usage: newUsageWithSharedToolUses(c.usage),
+ mu: c.mu,
+ Listener: c.Listener,
+ ID: id,
+ // Do not copy Budget. Each budget is independent,
+ // and OverBudget checks whether any ancestor is over budget.
+ }
+}
+
+func (c *Convo) SubConvoWithHistory() *Convo {
+ id := newConvoID()
+ return &Convo{
+ Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
+ Service: c.Service,
+ PromptCaching: c.PromptCaching,
+ Parent: c,
+ // For convenience, sub-convo usage shares tool uses map with parent,
+ // all other fields separate, propagated in AddResponse
+ usage: newUsageWithSharedToolUses(c.usage),
+ mu: c.mu,
+ Listener: c.Listener,
+ ID: id,
+ // Do not copy Budget. Each budget is independent,
+ // and OverBudget checks whether any ancestor is over budget.
+ messages: slices.Clone(c.messages),
+ }
+}
+
+// Depth reports how many "sub-conversations" deep this conversation is.
+// That it, it walks up parents until it finds a root.
+func (c *Convo) Depth() int {
+ x := c
+ var depth int
+ for x.Parent != nil {
+ x = x.Parent
+ depth++
+ }
+ return depth
+}
+
+// SendUserTextMessage sends a text message to the LLM in this conversation.
+// otherContents contains additional contents to send with the message, usually tool results.
+func (c *Convo) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
+ contents := slices.Clone(otherContents)
+ if s != "" {
+ contents = append(contents, llm.Content{Type: llm.ContentTypeText, Text: s})
+ }
+ msg := llm.Message{
+ Role: llm.MessageRoleUser,
+ Content: contents,
+ }
+ return c.SendMessage(msg)
+}
+
+func (c *Convo) messageRequest(msg llm.Message) *llm.Request {
+ system := []llm.SystemContent{}
+ if c.SystemPrompt != "" {
+ var d llm.SystemContent
+ d = llm.SystemContent{Type: "text", Text: c.SystemPrompt}
+ if c.PromptCaching {
+ d.Cache = true
+ }
+ system = []llm.SystemContent{d}
+ }
+
+ // Claude is happy to return an empty response in response to our Done() call,
+ // and, if so, you'll see something like:
+ // API request failed with status 400 Bad Request
+ // {"type":"error","error": {"type":"invalid_request_error",
+ // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
+ // So, we filter out those empty messages.
+ var nonEmptyMessages []llm.Message
+ for _, m := range c.messages {
+ if len(m.Content) > 0 {
+ nonEmptyMessages = append(nonEmptyMessages, m)
+ }
+ }
+
+ mr := &llm.Request{
+ Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
+ System: system,
+ Tools: c.Tools,
+ }
+ if c.ToolUseOnly {
+ mr.ToolChoice = &llm.ToolChoice{Type: llm.ToolChoiceTypeAny}
+ }
+ return mr
+}
+
+func (c *Convo) findTool(name string) (*llm.Tool, error) {
+ for _, tool := range c.Tools {
+ if tool.Name == name {
+ return tool, nil
+ }
+ }
+ return nil, fmt.Errorf("tool %q not found", name)
+}
+
+// insertMissingToolResults adds error results for tool uses that were requested
+// but not included in the message, which can happen in error paths like "out of budget."
+// We only insert these if there were no tool responses at all, since an incorrect
+// number of tool results would be a programmer error. Mutates inputs.
+func (c *Convo) insertMissingToolResults(mr *llm.Request, msg *llm.Message) {
+ if len(mr.Messages) < 2 {
+ return
+ }
+ prev := mr.Messages[len(mr.Messages)-2]
+ var toolUsePrev int
+ for _, c := range prev.Content {
+ if c.Type == llm.ContentTypeToolUse {
+ toolUsePrev++
+ }
+ }
+ if toolUsePrev == 0 {
+ return
+ }
+ var toolUseCurrent int
+ for _, c := range msg.Content {
+ if c.Type == llm.ContentTypeToolResult {
+ toolUseCurrent++
+ }
+ }
+ if toolUseCurrent != 0 {
+ return
+ }
+ var prefix []llm.Content
+ for _, part := range prev.Content {
+ if part.Type != llm.ContentTypeToolUse {
+ continue
+ }
+ content := llm.Content{
+ Type: llm.ContentTypeToolResult,
+ ToolUseID: part.ID,
+ ToolError: true,
+ ToolResult: "not executed; retry possible",
+ }
+ prefix = append(prefix, content)
+ msg.Content = append(prefix, msg.Content...)
+ mr.Messages[len(mr.Messages)-1].Content = msg.Content
+ }
+ slog.DebugContext(c.Ctx, "inserted missing tool results")
+}
+
+// SendMessage sends a message to Claude.
+// The conversation records (internally) all messages succesfully sent and received.
+func (c *Convo) SendMessage(msg llm.Message) (*llm.Response, error) {
+ id := ulid.Make().String()
+ mr := c.messageRequest(msg)
+ var lastMessage *llm.Message
+ if c.PromptCaching {
+ lastMessage = &mr.Messages[len(mr.Messages)-1]
+ if len(lastMessage.Content) > 0 {
+ lastMessage.Content[len(lastMessage.Content)-1].Cache = true
+ }
+ }
+ defer func() {
+ if lastMessage == nil {
+ return
+ }
+ if len(lastMessage.Content) > 0 {
+ lastMessage.Content[len(lastMessage.Content)-1].Cache = false
+ }
+ }()
+ c.insertMissingToolResults(mr, &msg)
+ c.Listener.OnRequest(c.Ctx, c, id, &msg)
+
+ startTime := time.Now()
+ resp, err := c.Service.Do(c.Ctx, mr)
+ if resp != nil {
+ resp.StartTime = &startTime
+ endTime := time.Now()
+ resp.EndTime = &endTime
+ }
+
+ if err != nil {
+ c.Listener.OnResponse(c.Ctx, c, id, nil)
+ return nil, err
+ }
+ c.messages = append(c.messages, msg, resp.ToMessage())
+ // Propagate usage to all ancestors (including us).
+ for x := c; x != nil; x = x.Parent {
+ x.usage.Add(resp.Usage)
+ }
+ c.Listener.OnResponse(c.Ctx, c, id, resp)
+ return resp, err
+}
+
+type toolCallInfoKeyType string
+
+var toolCallInfoKey toolCallInfoKeyType
+
+type ToolCallInfo struct {
+ ToolUseID string
+ Convo *Convo
+}
+
+func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
+ v := ctx.Value(toolCallInfoKey)
+ i, _ := v.(ToolCallInfo)
+ return i
+}
+
+func (c *Convo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
+ if resp.StopReason != llm.StopReasonToolUse {
+ return nil, nil
+ }
+ var toolResults []llm.Content
+
+ for _, part := range resp.Content {
+ if part.Type != llm.ContentTypeToolUse {
+ continue
+ }
+ c.incrementToolUse(part.ToolName)
+
+ content := llm.Content{
+ Type: llm.ContentTypeToolResult,
+ ToolUseID: part.ID,
+ }
+
+ content.ToolError = true
+ content.ToolResult = "user canceled this too_use"
+ toolResults = append(toolResults, content)
+ }
+ return toolResults, nil
+}
+
+// GetID returns the conversation ID
+func (c *Convo) GetID() string {
+ return c.ID
+}
+
+func (c *Convo) CancelToolUse(toolUseID string, err error) error {
+ c.muToolUseCancel.Lock()
+ defer c.muToolUseCancel.Unlock()
+ cancel, ok := c.toolUseCancel[toolUseID]
+ if !ok {
+ return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
+ }
+ delete(c.toolUseCancel, toolUseID)
+ cancel(err)
+ return nil
+}
+
+func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
+ c.muToolUseCancel.Lock()
+ defer c.muToolUseCancel.Unlock()
+ ctx, cancel := context.WithCancelCause(ctx)
+ c.toolUseCancel[toolUseID] = cancel
+ return ctx, func() { c.CancelToolUse(toolUseID, nil) }
+}
+
+// ToolResultContents runs all tool uses requested by the response and returns their results.
+// Cancelling ctx will cancel any running tool calls.
+func (c *Convo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
+ if resp.StopReason != llm.StopReasonToolUse {
+ return nil, nil
+ }
+ // Extract all tool calls from the response, call the tools, and gather the results.
+ var wg sync.WaitGroup
+ toolResultC := make(chan llm.Content, len(resp.Content))
+ for _, part := range resp.Content {
+ if part.Type != llm.ContentTypeToolUse {
+ continue
+ }
+ c.incrementToolUse(part.ToolName)
+ startTime := time.Now()
+
+ c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, llm.Content{
+ Type: llm.ContentTypeToolUse,
+ ToolUseID: part.ID,
+ ToolUseStartTime: &startTime,
+ })
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ content := llm.Content{
+ Type: llm.ContentTypeToolResult,
+ ToolUseID: part.ID,
+ ToolUseStartTime: &startTime,
+ }
+ sendErr := func(err error) {
+ // Record end time
+ endTime := time.Now()
+ content.ToolUseEndTime = &endTime
+
+ content.ToolError = true
+ content.ToolResult = err.Error()
+ c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
+ toolResultC <- content
+ }
+ sendRes := func(res string) {
+ // Record end time
+ endTime := time.Now()
+ content.ToolUseEndTime = &endTime
+
+ content.ToolResult = res
+ c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &res, nil)
+ toolResultC <- content
+ }
+
+ tool, err := c.findTool(part.ToolName)
+ if err != nil {
+ sendErr(err)
+ return
+ }
+ // Create a new context for just this tool_use call, and register its
+ // cancel function so that it can be canceled individually.
+ toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
+ defer cancel()
+ // TODO: move this into newToolUseContext?
+ toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
+ toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
+ if errors.Is(err, ErrDoNotRespond) {
+ return
+ }
+ if toolUseCtx.Err() != nil {
+ sendErr(context.Cause(toolUseCtx))
+ return
+ }
+
+ if err != nil {
+ sendErr(err)
+ return
+ }
+ sendRes(toolResult)
+ }()
+ }
+ wg.Wait()
+ close(toolResultC)
+ var toolResults []llm.Content
+ for toolResult := range toolResultC {
+ toolResults = append(toolResults, toolResult)
+ }
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+ return toolResults, nil
+}
+
+func (c *Convo) incrementToolUse(name string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.usage.ToolUses[name]++
+}
+
+// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
+type CumulativeUsage struct {
+ StartTime time.Time `json:"start_time"`
+ Responses uint64 `json:"messages"` // count of responses
+ InputTokens uint64 `json:"input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ TotalCostUSD float64 `json:"total_cost_usd"`
+ ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
+}
+
+func newUsage() *CumulativeUsage {
+ return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
+}
+
+func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
+ return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
+}
+
+func (u *CumulativeUsage) Clone() CumulativeUsage {
+ v := *u
+ v.ToolUses = maps.Clone(u.ToolUses)
+ return v
+}
+
+func (c *Convo) CumulativeUsage() CumulativeUsage {
+ if c == nil {
+ return CumulativeUsage{}
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.usage.Clone()
+}
+
+func (u *CumulativeUsage) WallTime() time.Duration {
+ return time.Since(u.StartTime)
+}
+
+func (u *CumulativeUsage) DollarsPerHour() float64 {
+ hours := u.WallTime().Hours()
+ // Prevent division by very small numbers that could cause issues
+ if hours < 1e-6 {
+ return 0
+ }
+ return u.TotalCostUSD / hours
+}
+
+func (u *CumulativeUsage) Add(usage llm.Usage) {
+ u.Responses++
+ u.InputTokens += usage.InputTokens
+ u.OutputTokens += usage.OutputTokens
+ u.CacheReadInputTokens += usage.CacheReadInputTokens
+ u.CacheCreationInputTokens += usage.CacheCreationInputTokens
+ u.TotalCostUSD += usage.CostUSD
+}
+
+// TotalInputTokens returns the grand total cumulative input tokens in u.
+func (u *CumulativeUsage) TotalInputTokens() uint64 {
+ return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
+}
+
+// Attr returns the cumulative usage as a slog.Attr with key "usage".
+func (u CumulativeUsage) Attr() slog.Attr {
+ elapsed := time.Since(u.StartTime)
+ return slog.Group("usage",
+ slog.Duration("wall_time", elapsed),
+ slog.Uint64("responses", u.Responses),
+ slog.Uint64("input_tokens", u.InputTokens),
+ slog.Uint64("output_tokens", u.OutputTokens),
+ slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
+ slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
+ slog.Float64("total_cost_usd", u.TotalCostUSD),
+ slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
+ slog.Any("tool_uses", maps.Clone(u.ToolUses)),
+ )
+}
+
+// A Budget represents the maximum amount of resources that may be spent on a conversation.
+// Note that the default (zero) budget is unlimited.
+type Budget struct {
+ MaxResponses uint64 // if > 0, max number of iterations (=responses)
+ MaxDollars float64 // if > 0, max dollars that may be spent
+ MaxWallTime time.Duration // if > 0, max wall time that may be spent
+}
+
+// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
+// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
+func (c *Convo) OverBudget() error {
+ for x := c; x != nil; x = x.Parent {
+ if err := x.overBudget(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// ResetBudget sets the budget to the passed in budget and
+// adjusts it by what's been used so far.
+func (c *Convo) ResetBudget(budget Budget) {
+ c.Budget = budget
+ if c.Budget.MaxDollars > 0 {
+ c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
+ }
+ if c.Budget.MaxResponses > 0 {
+ c.Budget.MaxResponses += c.CumulativeUsage().Responses
+ }
+ if c.Budget.MaxWallTime > 0 {
+ c.Budget.MaxWallTime += c.usage.WallTime()
+ }
+}
+
+func (c *Convo) overBudget() error {
+ usage := c.CumulativeUsage()
+ // TODO: stop before we exceed the budget instead of after?
+ // Top priority is money, then time, then response count.
+ var err error
+ cont := "Continuing to chat will reset the budget."
+ if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
+ err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
+ }
+ if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
+ err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
+ }
+ if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
+ err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
+ }
+ return err
+}
diff --git a/llm/conversation/convo_test.go b/llm/conversation/convo_test.go
new file mode 100644
index 0000000..3fb1750
--- /dev/null
+++ b/llm/conversation/convo_test.go
@@ -0,0 +1,139 @@
+package conversation
+
+import (
+ "cmp"
+ "context"
+ "net/http"
+ "os"
+ "strings"
+ "testing"
+
+ "sketch.dev/httprr"
+ "sketch.dev/llm/ant"
+)
+
+func TestBasicConvo(t *testing.T) {
+ ctx := context.Background()
+ rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rr.ScrubReq(func(req *http.Request) error {
+ req.Header.Del("x-api-key")
+ return nil
+ })
+
+ apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_ANTHROPIC_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
+ srv := &ant.Service{
+ APIKey: apiKey,
+ HTTPC: rr.Client(),
+ }
+ convo := New(ctx, srv)
+
+ const name = "Cornelius"
+ res, err := convo.SendUserTextMessage("Hi, my name is " + name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, part := range res.Content {
+ t.Logf("%s", part.Text)
+ }
+ res, err = convo.SendUserTextMessage("What is my name?")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := ""
+ for _, part := range res.Content {
+ got += part.Text
+ }
+ if !strings.Contains(got, name) {
+ t.Errorf("model does not know the given name %s: %q", name, got)
+ }
+}
+
+// TestCancelToolUse tests the CancelToolUse function of the Convo struct
+func TestCancelToolUse(t *testing.T) {
+ tests := []struct {
+ name string
+ setupToolUse bool
+ toolUseID string
+ cancelErr error
+ expectError bool
+ expectCancel bool
+ }{
+ {
+ name: "Cancel existing tool use",
+ setupToolUse: true,
+ toolUseID: "tool123",
+ cancelErr: nil,
+ expectError: false,
+ expectCancel: true,
+ },
+ {
+ name: "Cancel existing tool use with error",
+ setupToolUse: true,
+ toolUseID: "tool456",
+ cancelErr: context.Canceled,
+ expectError: false,
+ expectCancel: true,
+ },
+ {
+ name: "Cancel non-existent tool use",
+ setupToolUse: false,
+ toolUseID: "tool789",
+ cancelErr: nil,
+ expectError: true,
+ expectCancel: false,
+ },
+ }
+
+ srv := &ant.Service{}
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ convo := New(context.Background(), srv)
+
+ var cancelCalled bool
+ var cancelledWithErr error
+
+ if tt.setupToolUse {
+ // Setup a mock cancel function to track calls
+ mockCancel := func(err error) {
+ cancelCalled = true
+ cancelledWithErr = err
+ }
+
+ convo.muToolUseCancel.Lock()
+ convo.toolUseCancel[tt.toolUseID] = mockCancel
+ convo.muToolUseCancel.Unlock()
+ }
+
+ err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
+
+ // Check if we got the expected error state
+ if (err != nil) != tt.expectError {
+ t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
+ }
+
+ // Check if the cancel function was called as expected
+ if cancelCalled != tt.expectCancel {
+ t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
+ }
+
+ // If we expected the cancel to be called, verify it was called with the right error
+ if tt.expectCancel && cancelledWithErr != tt.cancelErr {
+ t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
+ }
+
+ // Verify the toolUseID was removed from the map if it was initially added
+ if tt.setupToolUse {
+ convo.muToolUseCancel.Lock()
+ _, exists := convo.toolUseCancel[tt.toolUseID]
+ convo.muToolUseCancel.Unlock()
+
+ if exists {
+ t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
+ }
+ }
+ })
+ }
+}
diff --git a/llm/conversation/testdata/basic_convo.httprr b/llm/conversation/testdata/basic_convo.httprr
new file mode 100644
index 0000000..663de8d
--- /dev/null
+++ b/llm/conversation/testdata/basic_convo.httprr
@@ -0,0 +1,116 @@
+httprr trace v1
+457 1329
+POST https://api.anthropic.com/v1/messages HTTP/1.1
+Host: api.anthropic.com
+User-Agent: Go-http-client/1.1
+Content-Length: 261
+Anthropic-Version: 2023-06-01
+Content-Type: application/json
+
+{
+ "model": "claude-3-7-sonnet-20250219",
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Hi, my name is Cornelius",
+ "cache_control": {
+ "type": "ephemeral"
+ }
+ }
+ ]
+ }
+ ],
+ "max_tokens": 8192
+}HTTP/2.0 200 OK
+Anthropic-Organization-Id: 3c473a21-7208-450a-a9f8-80aebda45c1b
+Anthropic-Ratelimit-Input-Tokens-Limit: 200000
+Anthropic-Ratelimit-Input-Tokens-Remaining: 200000
+Anthropic-Ratelimit-Input-Tokens-Reset: 2025-03-11T17:45:06Z
+Anthropic-Ratelimit-Output-Tokens-Limit: 80000
+Anthropic-Ratelimit-Output-Tokens-Remaining: 79000
+Anthropic-Ratelimit-Output-Tokens-Reset: 2025-03-11T17:45:07Z
+Anthropic-Ratelimit-Requests-Limit: 4000
+Anthropic-Ratelimit-Requests-Remaining: 3999
+Anthropic-Ratelimit-Requests-Reset: 2025-03-11T17:45:05Z
+Anthropic-Ratelimit-Tokens-Limit: 280000
+Anthropic-Ratelimit-Tokens-Remaining: 279000
+Anthropic-Ratelimit-Tokens-Reset: 2025-03-11T17:45:06Z
+Cf-Cache-Status: DYNAMIC
+Cf-Ray: 91ecdd10fdc3f97f-SJC
+Content-Type: application/json
+Date: Tue, 11 Mar 2025 17:45:07 GMT
+Request-Id: req_01LBtxMdNzxDcDVPGJSh7giv
+Server: cloudflare
+Via: 1.1 google
+X-Robots-Tag: none
+
+{"id":"msg_01S1uUyUsTaKPBKuDUGGX8J2","type":"message","role":"assistant","model":"claude-3-7-sonnet-20250219","content":[{"type":"text","text":"Hello, Cornelius! It's nice to meet you. How are you doing today? Is there something I can help you with?"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":15,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":31}}779 1286
+POST https://api.anthropic.com/v1/messages HTTP/1.1
+Host: api.anthropic.com
+User-Agent: Go-http-client/1.1
+Content-Length: 583
+Anthropic-Version: 2023-06-01
+Content-Type: application/json
+
+{
+ "model": "claude-3-7-sonnet-20250219",
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Hi, my name is Cornelius"
+ }
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "text": "Hello, Cornelius! It's nice to meet you. How are you doing today? Is there something I can help you with?"
+ }
+ ]
+ },
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What is my name?",
+ "cache_control": {
+ "type": "ephemeral"
+ }
+ }
+ ]
+ }
+ ],
+ "max_tokens": 8192
+}HTTP/2.0 200 OK
+Anthropic-Organization-Id: 3c473a21-7208-450a-a9f8-80aebda45c1b
+Anthropic-Ratelimit-Input-Tokens-Limit: 200000
+Anthropic-Ratelimit-Input-Tokens-Remaining: 200000
+Anthropic-Ratelimit-Input-Tokens-Reset: 2025-03-11T17:45:07Z
+Anthropic-Ratelimit-Output-Tokens-Limit: 80000
+Anthropic-Ratelimit-Output-Tokens-Remaining: 80000
+Anthropic-Ratelimit-Output-Tokens-Reset: 2025-03-11T17:45:07Z
+Anthropic-Ratelimit-Requests-Limit: 4000
+Anthropic-Ratelimit-Requests-Remaining: 3999
+Anthropic-Ratelimit-Requests-Reset: 2025-03-11T17:45:07Z
+Anthropic-Ratelimit-Tokens-Limit: 280000
+Anthropic-Ratelimit-Tokens-Remaining: 280000
+Anthropic-Ratelimit-Tokens-Reset: 2025-03-11T17:45:07Z
+Cf-Cache-Status: DYNAMIC
+Cf-Ray: 91ecdd1ae9a6f97f-SJC
+Content-Type: application/json
+Date: Tue, 11 Mar 2025 17:45:07 GMT
+Request-Id: req_01MBf3RWXNfQgwhVRwwkBYSn
+Server: cloudflare
+Via: 1.1 google
+X-Robots-Tag: none
+
+{"id":"msg_01FGz6DeWeDpspJG8cuxyVE9","type":"message","role":"assistant","model":"claude-3-7-sonnet-20250219","content":[{"type":"text","text":"Your name is Cornelius, as you mentioned in your introduction."}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":54,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":17}}
\ No newline at end of file
diff --git a/llm/llm.go b/llm/llm.go
new file mode 100644
index 0000000..3ba6ed4
--- /dev/null
+++ b/llm/llm.go
@@ -0,0 +1,229 @@
+// Package llm provides a unified interface for interacting with LLMs.
+package llm
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "strings"
+ "time"
+)
+
+type Service interface {
+ // Do sends a request to an LLM.
+ Do(context.Context, *Request) (*Response, error)
+}
+
+// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
+// It panics if the schema is invalid.
+func MustSchema(schema string) json.RawMessage {
+ // TODO: validate schema, for now just make sure it's valid JSON
+ schema = strings.TrimSpace(schema)
+ bytes := []byte(schema)
+ if !json.Valid(bytes) {
+ panic("invalid JSON schema: " + schema)
+ }
+ return json.RawMessage(bytes)
+}
+
+type Request struct {
+ Messages []Message
+ ToolChoice *ToolChoice
+ Tools []*Tool
+ System []SystemContent
+}
+
+// Message represents a message in the conversation.
+type Message struct {
+ Role MessageRole
+ Content []Content
+ ToolUse *ToolUse // use to control whether/which tool to use
+}
+
+// ToolUse represents a tool use in the message content.
+type ToolUse struct {
+ ID string
+ Name string
+}
+
+type ToolChoice struct {
+ Type ToolChoiceType
+ Name string
+}
+
+type SystemContent struct {
+ Text string
+ Type string
+ Cache bool
+}
+
+// Tool represents a tool available to an LLM.
+type Tool struct {
+ Name string
+ // Type is used by the text editor tool; see
+ // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
+ Type string
+ Description string
+ InputSchema json.RawMessage
+
+ // The Run function is automatically called when the tool is used.
+ // Run functions may be called concurrently with each other and themselves.
+ // The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
+ // The outputs from Run will be sent back to Claude.
+ // If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
+ // ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
+ Run func(ctx context.Context, input json.RawMessage) (string, error) `json:"-"`
+}
+
+type Content struct {
+ ID string
+ Type ContentType
+ Text string
+
+ // for thinking
+ Thinking string
+ Data string
+ Signature string
+
+ // for tool_use
+ ToolName string
+ ToolInput json.RawMessage
+
+ // for tool_result
+ ToolUseID string
+ ToolError bool
+ ToolResult string
+
+ // timing information for tool_result; added externally; not sent to the LLM
+ ToolUseStartTime *time.Time
+ ToolUseEndTime *time.Time
+
+ Cache bool
+}
+
+func StringContent(s string) Content {
+ return Content{Type: ContentTypeText, Text: s}
+}
+
+// ContentsAttr returns contents as a slog.Attr.
+// It is meant for logging.
+func ContentsAttr(contents []Content) slog.Attr {
+ var contentAttrs []any // slog.Attr
+ for _, content := range contents {
+ var attrs []any // slog.Attr
+ switch content.Type {
+ case ContentTypeText:
+ attrs = append(attrs, slog.String("text", content.Text))
+ case ContentTypeToolUse:
+ attrs = append(attrs, slog.String("tool_name", content.ToolName))
+ attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
+ case ContentTypeToolResult:
+ attrs = append(attrs, slog.String("tool_result", content.ToolResult))
+ attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
+ case ContentTypeThinking:
+ attrs = append(attrs, slog.String("thinking", content.Text))
+ default:
+ attrs = append(attrs, slog.String("unknown_content_type", content.Type.String()))
+ attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
+ }
+ contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
+ }
+ return slog.Group("contents", contentAttrs...)
+}
+
+type (
+ MessageRole int
+ ContentType int
+ ToolChoiceType int
+ StopReason int
+)
+
+//go:generate go tool golang.org/x/tools/cmd/stringer -type=MessageRole,ContentType,ToolChoiceType,StopReason -output=llm_string.go
+
+const (
+ MessageRoleUser MessageRole = iota
+ MessageRoleAssistant
+
+ ContentTypeText ContentType = iota
+ ContentTypeThinking
+ ContentTypeRedactedThinking
+ ContentTypeToolUse
+ ContentTypeToolResult
+
+ ToolChoiceTypeAuto ToolChoiceType = iota // default
+ ToolChoiceTypeAny // any tool, but must use one
+ ToolChoiceTypeNone // no tools allowed
+ ToolChoiceTypeTool // must use the tool specified in the Name field
+
+ StopReasonStopSequence StopReason = iota
+ StopReasonMaxTokens
+ StopReasonEndTurn
+ StopReasonToolUse
+)
+
+type Response struct {
+ ID string
+ Type string
+ Role MessageRole
+ Model string
+ Content []Content
+ StopReason StopReason
+ StopSequence *string
+ Usage Usage
+ StartTime *time.Time
+ EndTime *time.Time
+}
+
+func (m *Response) ToMessage() Message {
+ return Message{
+ Role: m.Role,
+ Content: m.Content,
+ }
+}
+
+// Usage represents the billing and rate-limit usage.
+// Most LLM structs do not have JSON tags, to avoid accidental direct use in specific providers.
+// However, the front-end uses this struct, and it relies on its JSON serialization.
+// Do NOT use this struct directly when implementing an llm.Service.
+type Usage struct {
+ InputTokens uint64 `json:"input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CostUSD float64 `json:"cost_usd"`
+}
+
+func (u *Usage) Add(other Usage) {
+ u.InputTokens += other.InputTokens
+ u.CacheCreationInputTokens += other.CacheCreationInputTokens
+ u.CacheReadInputTokens += other.CacheReadInputTokens
+ u.OutputTokens += other.OutputTokens
+ u.CostUSD += other.CostUSD
+}
+
+func (u *Usage) String() string {
+ return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
+}
+
+func (u *Usage) IsZero() bool {
+ return *u == Usage{}
+}
+
+func (u *Usage) Attr() slog.Attr {
+ return slog.Group("usage",
+ slog.Uint64("input_tokens", u.InputTokens),
+ slog.Uint64("output_tokens", u.OutputTokens),
+ slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
+ slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
+ slog.Float64("cost_usd", u.CostUSD),
+ )
+}
+
+// UserStringMessage creates a user message with a single text content item.
+func UserStringMessage(text string) Message {
+ return Message{
+ Role: MessageRoleUser,
+ Content: []Content{StringContent(text)},
+ }
+}
diff --git a/llm/llm_string.go b/llm/llm_string.go
new file mode 100644
index 0000000..1c3189e
--- /dev/null
+++ b/llm/llm_string.go
@@ -0,0 +1,88 @@
+// Code generated by "stringer -type=MessageRole,ContentType,ToolChoiceType,StopReason -output=llm_string.go"; DO NOT EDIT.
+
+package llm
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[MessageRoleUser-0]
+ _ = x[MessageRoleAssistant-1]
+}
+
+const _MessageRole_name = "MessageRoleUserMessageRoleAssistant"
+
+var _MessageRole_index = [...]uint8{0, 15, 35}
+
+func (i MessageRole) String() string {
+ if i < 0 || i >= MessageRole(len(_MessageRole_index)-1) {
+ return "MessageRole(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _MessageRole_name[_MessageRole_index[i]:_MessageRole_index[i+1]]
+}
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[ContentTypeText-2]
+ _ = x[ContentTypeThinking-3]
+ _ = x[ContentTypeRedactedThinking-4]
+ _ = x[ContentTypeToolUse-5]
+ _ = x[ContentTypeToolResult-6]
+}
+
+const _ContentType_name = "ContentTypeTextContentTypeThinkingContentTypeRedactedThinkingContentTypeToolUseContentTypeToolResult"
+
+var _ContentType_index = [...]uint8{0, 15, 34, 61, 79, 100}
+
+func (i ContentType) String() string {
+ i -= 2
+ if i < 0 || i >= ContentType(len(_ContentType_index)-1) {
+ return "ContentType(" + strconv.FormatInt(int64(i+2), 10) + ")"
+ }
+ return _ContentType_name[_ContentType_index[i]:_ContentType_index[i+1]]
+}
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[ToolChoiceTypeAuto-7]
+ _ = x[ToolChoiceTypeAny-8]
+ _ = x[ToolChoiceTypeNone-9]
+ _ = x[ToolChoiceTypeTool-10]
+}
+
+const _ToolChoiceType_name = "ToolChoiceTypeAutoToolChoiceTypeAnyToolChoiceTypeNoneToolChoiceTypeTool"
+
+var _ToolChoiceType_index = [...]uint8{0, 18, 35, 53, 71}
+
+func (i ToolChoiceType) String() string {
+ i -= 7
+ if i < 0 || i >= ToolChoiceType(len(_ToolChoiceType_index)-1) {
+ return "ToolChoiceType(" + strconv.FormatInt(int64(i+7), 10) + ")"
+ }
+ return _ToolChoiceType_name[_ToolChoiceType_index[i]:_ToolChoiceType_index[i+1]]
+}
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[StopReasonStopSequence-11]
+ _ = x[StopReasonMaxTokens-12]
+ _ = x[StopReasonEndTurn-13]
+ _ = x[StopReasonToolUse-14]
+}
+
+const _StopReason_name = "StopReasonStopSequenceStopReasonMaxTokensStopReasonEndTurnStopReasonToolUse"
+
+var _StopReason_index = [...]uint8{0, 22, 41, 58, 75}
+
+func (i StopReason) String() string {
+ i -= 11
+ if i < 0 || i >= StopReason(len(_StopReason_index)-1) {
+ return "StopReason(" + strconv.FormatInt(int64(i+11), 10) + ")"
+ }
+ return _StopReason_name[_StopReason_index[i]:_StopReason_index[i+1]]
+}
diff --git a/llm/oai/oai.go b/llm/oai/oai.go
new file mode 100644
index 0000000..3e772ab
--- /dev/null
+++ b/llm/oai/oai.go
@@ -0,0 +1,592 @@
+package oai
+
+import (
+ "cmp"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "math/rand/v2"
+ "net/http"
+ "time"
+
+ "github.com/sashabaranov/go-openai"
+ "sketch.dev/llm"
+)
+
+const (
+ DefaultMaxTokens = 8192
+
+ OpenAIURL = "https://api.openai.com/v1"
+ FireworksURL = "https://api.fireworks.ai/inference/v1"
+ LlamaCPPURL = "http://localhost:8080/v1"
+ TogetherURL = "https://api.together.xyz/v1"
+ GeminiURL = "https://generativelanguage.googleapis.com/v1beta/openai/"
+
+ // Environment variable names for API keys
+ OpenAIAPIKeyEnv = "OPENAI_API_KEY"
+ FireworksAPIKeyEnv = "FIREWORKS_API_KEY"
+ TogetherAPIKeyEnv = "TOGETHER_API_KEY"
+ GeminiAPIKeyEnv = "GEMINI_API_KEY"
+)
+
+type Model struct {
+ UserName string // provided by the user to identify this model (e.g. "gpt4.1")
+ ModelName string // provided to the service provide to specify which model to use (e.g. "gpt-4.1-2025-04-14")
+ URL string
+ Cost ModelCost
+ APIKeyEnv string // environment variable name for the API key
+}
+
+type ModelCost struct {
+ Input uint64 // in cents per million tokens
+ CachedInput uint64 // in cents per million tokens
+ Output uint64 // in cents per million tokens
+}
+
+var (
+ DefaultModel = GPT41
+
+ GPT41 = Model{
+ UserName: "gpt4.1",
+ ModelName: "gpt-4.1-2025-04-14",
+ URL: OpenAIURL,
+ Cost: ModelCost{Input: 200, CachedInput: 50, Output: 800},
+ APIKeyEnv: OpenAIAPIKeyEnv,
+ }
+
+ Gemini25Flash = Model{
+ UserName: "gemini-flash-2.5",
+ ModelName: "gemini-2.5-flash-preview-04-17",
+ URL: GeminiURL,
+ Cost: ModelCost{Input: 15, Output: 60},
+ APIKeyEnv: GeminiAPIKeyEnv,
+ }
+
+ Gemini25Pro = Model{
+ UserName: "gemini-pro-2.5",
+ ModelName: "gemini-2.5-pro-preview-03-25",
+ URL: GeminiURL,
+ // GRRRR. Really??
+ // Input is: $1.25, prompts <= 200k tokens, $2.50, prompts > 200k tokens
+ // Output is: $10.00, prompts <= 200k tokens, $15.00, prompts > 200k
+ // Caching is: $0.31, prompts <= 200k tokens, $0.625, prompts > 200k, $4.50 / 1,000,000 tokens per hour
+ // Whatever that means. Are we caching? I have no idea.
+ // How do you always manage to be the annoying one, Google?
+ // I'm not complicating things just for you.
+ Cost: ModelCost{Input: 125, Output: 1000},
+ APIKeyEnv: GeminiAPIKeyEnv,
+ }
+
+ TogetherDeepseekV3 = Model{
+ UserName: "together-deepseek-v3",
+ ModelName: "deepseek-ai/DeepSeek-V3",
+ URL: TogetherURL,
+ Cost: ModelCost{Input: 125, Output: 125},
+ APIKeyEnv: TogetherAPIKeyEnv,
+ }
+
+ TogetherLlama4Maverick = Model{
+ UserName: "together-llama4-maverick",
+ ModelName: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
+ URL: TogetherURL,
+ Cost: ModelCost{Input: 27, Output: 85},
+ APIKeyEnv: TogetherAPIKeyEnv,
+ }
+
+ TogetherLlama3_3_70B = Model{
+ UserName: "together-llama3-70b",
+ ModelName: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
+ URL: TogetherURL,
+ Cost: ModelCost{Input: 88, Output: 88},
+ APIKeyEnv: TogetherAPIKeyEnv,
+ }
+
+ TogetherMistralSmall = Model{
+ UserName: "together-mistral-small",
+ ModelName: "mistralai/Mistral-Small-24B-Instruct-2501",
+ URL: TogetherURL,
+ Cost: ModelCost{Input: 80, Output: 80},
+ APIKeyEnv: TogetherAPIKeyEnv,
+ }
+
+ LlamaCPP = Model{
+ UserName: "llama.cpp",
+ ModelName: "llama.cpp local model",
+ URL: LlamaCPPURL,
+ // zero cost
+ Cost: ModelCost{},
+ }
+
+ FireworksDeepseekV3 = Model{
+ UserName: "fireworks-deepseek-v3",
+ ModelName: "accounts/fireworks/models/deepseek-v3-0324",
+ URL: FireworksURL,
+ Cost: ModelCost{Input: 90, Output: 90}, // not entirely sure about this, they don't list pricing anywhere convenient
+ APIKeyEnv: FireworksAPIKeyEnv,
+ }
+)
+
+// Service provides chat completions.
+// Fields should not be altered concurrently with calling any method on Service.
+type Service struct {
+ HTTPC *http.Client // defaults to http.DefaultClient if nil
+ APIKey string // optional, if not set will try to load from env var
+ Model Model // defaults to DefaultModel if zero value
+ MaxTokens int // defaults to DefaultMaxTokens if zero
+ Org string // optional - organization ID
+}
+
+var _ llm.Service = (*Service)(nil)
+
+// ModelsRegistry is a registry of all known models with their user-friendly names.
+var ModelsRegistry = []Model{
+ GPT41,
+ Gemini25Flash,
+ Gemini25Pro,
+ TogetherDeepseekV3,
+ TogetherLlama4Maverick,
+ TogetherLlama3_3_70B,
+ TogetherMistralSmall,
+ LlamaCPP,
+ FireworksDeepseekV3,
+}
+
+// ListModels returns a list of all available models with their user-friendly names.
+func ListModels() []string {
+ var names []string
+ for _, model := range ModelsRegistry {
+ if model.UserName != "" {
+ names = append(names, model.UserName)
+ }
+ }
+ return names
+}
+
+// ModelByUserName returns a model by its user-friendly name.
+// Returns nil if no model with the given name is found.
+func ModelByUserName(name string) *Model {
+ for _, model := range ModelsRegistry {
+ if model.UserName == name {
+ return &model
+ }
+ }
+ return nil
+}
+
+var (
+ fromLLMRole = map[llm.MessageRole]string{
+ llm.MessageRoleAssistant: "assistant",
+ llm.MessageRoleUser: "user",
+ }
+ fromLLMContentType = map[llm.ContentType]string{
+ llm.ContentTypeText: "text",
+ llm.ContentTypeToolUse: "function", // OpenAI uses function instead of tool_call
+ llm.ContentTypeToolResult: "tool_result",
+ llm.ContentTypeThinking: "text", // Map thinking to text since OpenAI doesn't have thinking
+ llm.ContentTypeRedactedThinking: "text", // Map redacted_thinking to text
+ }
+ fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
+ llm.ToolChoiceTypeAuto: "auto",
+ llm.ToolChoiceTypeAny: "any",
+ llm.ToolChoiceTypeNone: "none",
+ llm.ToolChoiceTypeTool: "function", // OpenAI uses "function" instead of "tool"
+ }
+ toLLMRole = map[string]llm.MessageRole{
+ "assistant": llm.MessageRoleAssistant,
+ "user": llm.MessageRoleUser,
+ }
+ toLLMStopReason = map[string]llm.StopReason{
+ "stop": llm.StopReasonStopSequence,
+ "length": llm.StopReasonMaxTokens,
+ "tool_calls": llm.StopReasonToolUse,
+ "function_call": llm.StopReasonToolUse, // Map both to ToolUse
+ "content_filter": llm.StopReasonStopSequence, // No direct equivalent
+ }
+)
+
+// fromLLMContent converts llm.Content to the format expected by OpenAI.
+func fromLLMContent(c llm.Content) (string, []openai.ToolCall) {
+ switch c.Type {
+ case llm.ContentTypeText:
+ return c.Text, nil
+ case llm.ContentTypeToolUse:
+ // For OpenAI, tool use is sent as a null content with tool_calls in the message
+ return "", []openai.ToolCall{
+ {
+ Type: openai.ToolTypeFunction,
+ ID: c.ID, // Use the content ID if provided
+ Function: openai.FunctionCall{
+ Name: c.ToolName,
+ Arguments: string(c.ToolInput),
+ },
+ },
+ }
+ case llm.ContentTypeToolResult:
+ // Tool results in OpenAI are sent as a separate message with tool_call_id
+ return c.ToolResult, nil
+ default:
+ // For thinking or other types, convert to text
+ return c.Text, nil
+ }
+}
+
+// fromLLMMessage converts llm.Message to OpenAI ChatCompletionMessage format
+func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage {
+ // For OpenAI, we need to handle tool results differently than regular messages
+ // Each tool result becomes its own message with role="tool"
+
+ var messages []openai.ChatCompletionMessage
+
+ // Check if this is a regular message or contains tool results
+ var regularContent []llm.Content
+ var toolResults []llm.Content
+
+ for _, c := range msg.Content {
+ if c.Type == llm.ContentTypeToolResult {
+ toolResults = append(toolResults, c)
+ } else {
+ regularContent = append(regularContent, c)
+ }
+ }
+
+ // Process tool results as separate messages, but first
+ for _, tr := range toolResults {
+ m := openai.ChatCompletionMessage{
+ Role: "tool",
+ Content: cmp.Or(tr.ToolResult, " "), // TODO: remove omitempty upstream
+ ToolCallID: tr.ToolUseID,
+ }
+ messages = append(messages, m)
+ }
+ // Process regular content second
+ if len(regularContent) > 0 {
+ m := openai.ChatCompletionMessage{
+ Role: fromLLMRole[msg.Role],
+ }
+
+ // For assistant messages that contain tool calls
+ var toolCalls []openai.ToolCall
+ var textContent string
+
+ for _, c := range regularContent {
+ content, tools := fromLLMContent(c)
+ if len(tools) > 0 {
+ toolCalls = append(toolCalls, tools...)
+ } else if content != "" {
+ if textContent != "" {
+ textContent += "\n"
+ }
+ textContent += content
+ }
+ }
+
+ m.Content = textContent
+ m.ToolCalls = toolCalls
+
+ messages = append(messages, m)
+ }
+
+ return messages
+}
+
+// fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI.
+func fromLLMToolChoice(tc *llm.ToolChoice) any {
+ if tc == nil {
+ return nil
+ }
+
+ if tc.Type == llm.ToolChoiceTypeTool && tc.Name != "" {
+ return openai.ToolChoice{
+ Type: openai.ToolTypeFunction,
+ Function: openai.ToolFunction{
+ Name: tc.Name,
+ },
+ }
+ }
+
+ // For non-specific tool choice, just use the string
+ return fromLLMToolChoiceType[tc.Type]
+}
+
+// fromLLMTool converts llm.Tool to the format expected by OpenAI.
+func fromLLMTool(t *llm.Tool) openai.Tool {
+ return openai.Tool{
+ Type: openai.ToolTypeFunction,
+ Function: &openai.FunctionDefinition{
+ Name: t.Name,
+ Description: t.Description,
+ Parameters: t.InputSchema,
+ },
+ }
+}
+
+// fromLLMSystem converts llm.SystemContent to an OpenAI system message.
+func fromLLMSystem(systemContent []llm.SystemContent) []openai.ChatCompletionMessage {
+ if len(systemContent) == 0 {
+ return nil
+ }
+
+ // Combine all system content into a single system message
+ var systemText string
+ for i, content := range systemContent {
+ if i > 0 && systemText != "" && content.Text != "" {
+ systemText += "\n"
+ }
+ systemText += content.Text
+ }
+
+ if systemText == "" {
+ return nil
+ }
+
+ return []openai.ChatCompletionMessage{
+ {
+ Role: "system",
+ Content: systemText,
+ },
+ }
+}
+
+// toRawLLMContent converts a raw content string from OpenAI to llm.Content.
+func toRawLLMContent(content string) llm.Content {
+ return llm.Content{
+ Type: llm.ContentTypeText,
+ Text: content,
+ }
+}
+
+// toToolCallLLMContent converts a tool call from OpenAI to llm.Content.
+func toToolCallLLMContent(toolCall openai.ToolCall) llm.Content {
+ // Generate a content ID if needed
+ id := toolCall.ID
+ if id == "" {
+ // Create a deterministic ID based on the function name if no ID is provided
+ id = "tc_" + toolCall.Function.Name
+ }
+
+ return llm.Content{
+ ID: id,
+ Type: llm.ContentTypeToolUse,
+ ToolName: toolCall.Function.Name,
+ ToolInput: json.RawMessage(toolCall.Function.Arguments),
+ }
+}
+
+// toToolResultLLMContent converts a tool result message from OpenAI to llm.Content.
+func toToolResultLLMContent(msg openai.ChatCompletionMessage) llm.Content {
+ return llm.Content{
+ Type: llm.ContentTypeToolResult,
+ ToolUseID: msg.ToolCallID,
+ ToolResult: msg.Content,
+ ToolError: false, // OpenAI doesn't specify errors explicitly
+ }
+}
+
+// toLLMContents converts message content from OpenAI to []llm.Content.
+func toLLMContents(msg openai.ChatCompletionMessage) []llm.Content {
+ var contents []llm.Content
+
+ // If this is a tool response, handle it separately
+ if msg.Role == "tool" && msg.ToolCallID != "" {
+ return []llm.Content{toToolResultLLMContent(msg)}
+ }
+
+ // If there's text content, add it
+ if msg.Content != "" {
+ contents = append(contents, toRawLLMContent(msg.Content))
+ }
+
+ // If there are tool calls, add them
+ for _, tc := range msg.ToolCalls {
+ contents = append(contents, toToolCallLLMContent(tc))
+ }
+
+ // If empty, add an empty text content
+ if len(contents) == 0 {
+ contents = append(contents, llm.Content{
+ Type: llm.ContentTypeText,
+ Text: "",
+ })
+ }
+
+ return contents
+}
+
+// toLLMUsage converts usage information from OpenAI to llm.Usage.
+func (s *Service) toLLMUsage(model string, au openai.Usage) llm.Usage {
+ // fmt.Printf("raw usage: %+v / %v / %v\n", au, au.PromptTokensDetails, au.CompletionTokensDetails)
+ in := uint64(au.PromptTokens)
+ var inc uint64
+ if au.PromptTokensDetails != nil {
+ inc = uint64(au.PromptTokensDetails.CachedTokens)
+ }
+ out := uint64(au.CompletionTokens)
+ u := llm.Usage{
+ InputTokens: in,
+ CacheReadInputTokens: inc,
+ CacheCreationInputTokens: in,
+ OutputTokens: out,
+ }
+ u.CostUSD = s.calculateCostFromTokens(u)
+ return u
+}
+
+// toLLMResponse converts the OpenAI response to llm.Response.
+func (s *Service) toLLMResponse(r *openai.ChatCompletionResponse) *llm.Response {
+ // fmt.Printf("Raw response\n")
+ // enc := json.NewEncoder(os.Stdout)
+ // enc.SetIndent("", " ")
+ // enc.Encode(r)
+ // fmt.Printf("\n")
+
+ if len(r.Choices) == 0 {
+ return &llm.Response{
+ ID: r.ID,
+ Model: r.Model,
+ Role: llm.MessageRoleAssistant,
+ Usage: s.toLLMUsage(r.Model, r.Usage),
+ }
+ }
+
+ // Process the primary choice
+ choice := r.Choices[0]
+
+ return &llm.Response{
+ ID: r.ID,
+ Model: r.Model,
+ Role: toRoleFromString(choice.Message.Role),
+ Content: toLLMContents(choice.Message),
+ StopReason: toStopReason(string(choice.FinishReason)),
+ Usage: s.toLLMUsage(r.Model, r.Usage),
+ }
+}
+
+// toRoleFromString converts a role string to llm.MessageRole.
+func toRoleFromString(role string) llm.MessageRole {
+ if role == "tool" || role == "system" || role == "function" {
+ return llm.MessageRoleAssistant // Map special roles to assistant for consistency
+ }
+ if mr, ok := toLLMRole[role]; ok {
+ return mr
+ }
+ return llm.MessageRoleUser // Default to user if unknown
+}
+
+// toStopReason converts a finish reason string to llm.StopReason.
+func toStopReason(reason string) llm.StopReason {
+ if sr, ok := toLLMStopReason[reason]; ok {
+ return sr
+ }
+ return llm.StopReasonStopSequence // Default
+}
+
+// calculateCostFromTokens calculates the cost in dollars for the given model and token counts.
+func (s *Service) calculateCostFromTokens(u llm.Usage) float64 {
+ cost := s.Model.Cost
+
+ // TODO: check this for correctness, i am skeptical
+ // Calculate cost in cents
+ megaCents := u.CacheCreationInputTokens*cost.Input +
+ u.CacheReadInputTokens*cost.CachedInput +
+ u.OutputTokens*cost.Output
+
+ cents := float64(megaCents) / 1_000_000
+ // Convert to dollars
+ dollars := cents / 100.0
+ // fmt.Printf("in_new=%d, in_cached=%d, out=%d, cost=%.2f\n", u.CacheCreationInputTokens, u.CacheReadInputTokens, u.OutputTokens, dollars)
+ return dollars
+}
+
+// Do sends a request to OpenAI using the go-openai package.
+func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
+ // Configure the OpenAI client
+ httpc := cmp.Or(s.HTTPC, http.DefaultClient)
+ model := cmp.Or(s.Model, DefaultModel)
+
+ // TODO: do this one during Service setup? maybe with a constructor instead?
+ config := openai.DefaultConfig(s.APIKey)
+ if model.URL != "" {
+ config.BaseURL = model.URL
+ }
+ if s.Org != "" {
+ config.OrgID = s.Org
+ }
+ config.HTTPClient = httpc
+
+ client := openai.NewClientWithConfig(config)
+
+ // Start with system messages if provided
+ var allMessages []openai.ChatCompletionMessage
+ if len(ir.System) > 0 {
+ sysMessages := fromLLMSystem(ir.System)
+ allMessages = append(allMessages, sysMessages...)
+ }
+
+ // Add regular and tool messages
+ for _, msg := range ir.Messages {
+ msgs := fromLLMMessage(msg)
+ allMessages = append(allMessages, msgs...)
+ }
+
+ // Convert tools
+ var tools []openai.Tool
+ for _, t := range ir.Tools {
+ tools = append(tools, fromLLMTool(t))
+ }
+
+ // Create the OpenAI request
+ req := openai.ChatCompletionRequest{
+ Model: model.ModelName,
+ Messages: allMessages,
+ MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
+ Tools: tools,
+ ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
+ }
+ // fmt.Printf("Sending request to OpenAI\n")
+ // enc := json.NewEncoder(os.Stdout)
+ // enc.SetIndent("", " ")
+ // enc.Encode(req)
+ // fmt.Printf("\n")
+
+ // Retry mechanism
+ backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second}
+
+ // retry loop
+ for attempts := 0; ; attempts++ {
+ resp, err := client.CreateChatCompletion(ctx, req)
+
+ // Handle successful response
+ if err == nil {
+ return s.toLLMResponse(&resp), nil
+ }
+
+ // Handle errors
+ var apiErr *openai.APIError
+ if ok := errors.As(err, &apiErr); !ok {
+ // Not an OpenAI API error, return immediately
+ return nil, err
+ }
+
+ switch {
+ case apiErr.HTTPStatusCode >= 500:
+ // Server error, try again with backoff
+ sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode, "sleep", sleep)
+ time.Sleep(sleep)
+ continue
+
+ case apiErr.HTTPStatusCode == 429:
+ // Rate limited, back off longer
+ sleep := 20*time.Second + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error(), "sleep", sleep)
+ time.Sleep(sleep)
+ continue
+
+ default:
+ // Other error, return immediately
+ return nil, fmt.Errorf("OpenAI API error: %w", err)
+ }
+ }
+}
diff --git a/llm/oai/oai_test.go b/llm/oai/oai_test.go
new file mode 100644
index 0000000..7bea552
--- /dev/null
+++ b/llm/oai/oai_test.go
@@ -0,0 +1,96 @@
+package oai
+
+import (
+ "math"
+ "testing"
+
+ "sketch.dev/llm"
+)
+
+// TestCalculateCostFromTokens tests the calculateCostFromTokens method
+func TestCalculateCostFromTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ model Model
+ cacheCreationTokens uint64
+ cacheReadTokens uint64
+ outputTokens uint64
+ want float64
+ }{
+ {
+ name: "Zero tokens",
+ model: GPT41,
+ cacheCreationTokens: 0,
+ cacheReadTokens: 0,
+ outputTokens: 0,
+ want: 0,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens",
+ model: GPT41,
+ cacheCreationTokens: 1000,
+ cacheReadTokens: 0,
+ outputTokens: 500,
+ // GPT41: Input: 200 per million, Output: 800 per million
+ // (1000 * 200 + 500 * 800) / 1_000_000 / 100 = 0.006
+ want: 0.006,
+ },
+ {
+ name: "10000 input tokens, 5000 output tokens",
+ model: GPT41,
+ cacheCreationTokens: 10000,
+ cacheReadTokens: 0,
+ outputTokens: 5000,
+ // (10000 * 200 + 5000 * 800) / 1_000_000 / 100 = 0.06
+ want: 0.06,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens Gemini",
+ model: Gemini25Flash,
+ cacheCreationTokens: 1000,
+ cacheReadTokens: 0,
+ outputTokens: 500,
+ // Gemini25Flash: Input: 15 per million, Output: 60 per million
+ // (1000 * 15 + 500 * 60) / 1_000_000 / 100 = 0.00045
+ want: 0.00045,
+ },
+ {
+ name: "With cache read tokens",
+ model: GPT41,
+ cacheCreationTokens: 500,
+ cacheReadTokens: 500, // 500 tokens from cache
+ outputTokens: 500,
+ // (500 * 200 + 500 * 50 + 500 * 800) / 1_000_000 / 100 = 0.00525
+ want: 0.00525,
+ },
+ {
+ name: "With all token types",
+ model: GPT41,
+ cacheCreationTokens: 1000,
+ cacheReadTokens: 1000,
+ outputTokens: 1000,
+ // (1000 * 200 + 1000 * 50 + 1000 * 800) / 1_000_000 / 100 = 0.0105
+ want: 0.0105,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create a service with the test model
+ svc := &Service{Model: tt.model}
+
+ // Create a usage object
+ usage := llm.Usage{
+ CacheCreationInputTokens: tt.cacheCreationTokens,
+ CacheReadInputTokens: tt.cacheReadTokens,
+ OutputTokens: tt.outputTokens,
+ }
+
+ totalCost := svc.calculateCostFromTokens(usage)
+ if math.Abs(totalCost-tt.want) > 0.0001 {
+ t.Errorf("calculateCostFromTokens(%s, cache_creation=%d, cache_read=%d, output=%d) = %v, want %v",
+ tt.model.ModelName, tt.cacheCreationTokens, tt.cacheReadTokens, tt.outputTokens, totalCost, tt.want)
+ }
+ })
+ }
+}