all: support openai-compatible models
The support is rather minimal at this point:
Only hard-coded models, only -unsafe, only -skabandaddr="".
The "shared" LLM package is strongly Claude-flavored.
We can fix all of this and more over time, if we are inspired to.
(Maybe we'll switch to https://github.com/maruel/genai?)
The goal for now is to get the rough structure in place.
I've rebased and rebuilt this more times than I care to remember.
diff --git a/ant/ant_test.go b/ant/ant_test.go
deleted file mode 100644
index fcce0cd..0000000
--- a/ant/ant_test.go
+++ /dev/null
@@ -1,222 +0,0 @@
-package ant
-
-import (
- "cmp"
- "context"
- "math"
- "net/http"
- "os"
- "strings"
- "testing"
-
- "sketch.dev/httprr"
-)
-
-func TestBasicConvo(t *testing.T) {
- ctx := context.Background()
- rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
- if err != nil {
- t.Fatal(err)
- }
- rr.ScrubReq(func(req *http.Request) error {
- req.Header.Del("x-api-key")
- return nil
- })
-
- apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_ANTHROPIC_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
- convo := NewConvo(ctx, apiKey)
- convo.HTTPC = rr.Client()
-
- const name = "Cornelius"
- res, err := convo.SendUserTextMessage("Hi, my name is " + name)
- if err != nil {
- t.Fatal(err)
- }
- for _, part := range res.Content {
- t.Logf("%s", part.Text)
- }
- res, err = convo.SendUserTextMessage("What is my name?")
- if err != nil {
- t.Fatal(err)
- }
- got := ""
- for _, part := range res.Content {
- got += part.Text
- }
- if !strings.Contains(got, name) {
- t.Errorf("model does not know the given name %s: %q", name, got)
- }
-}
-
-// TestCalculateCostFromTokens tests the calculateCostFromTokens function
-func TestCalculateCostFromTokens(t *testing.T) {
- tests := []struct {
- name string
- model string
- inputTokens uint64
- outputTokens uint64
- cacheReadInputTokens uint64
- cacheCreationInputTokens uint64
- want float64
- }{
- {
- name: "Zero tokens",
- model: Claude37Sonnet,
- inputTokens: 0,
- outputTokens: 0,
- cacheReadInputTokens: 0,
- cacheCreationInputTokens: 0,
- want: 0,
- },
- {
- name: "1000 input tokens, 500 output tokens",
- model: Claude37Sonnet,
- inputTokens: 1000,
- outputTokens: 500,
- cacheReadInputTokens: 0,
- cacheCreationInputTokens: 0,
- want: 0.0105,
- },
- {
- name: "10000 input tokens, 5000 output tokens",
- model: Claude37Sonnet,
- inputTokens: 10000,
- outputTokens: 5000,
- cacheReadInputTokens: 0,
- cacheCreationInputTokens: 0,
- want: 0.105,
- },
- {
- name: "With cache read tokens",
- model: Claude37Sonnet,
- inputTokens: 1000,
- outputTokens: 500,
- cacheReadInputTokens: 2000,
- cacheCreationInputTokens: 0,
- want: 0.0111,
- },
- {
- name: "With cache creation tokens",
- model: Claude37Sonnet,
- inputTokens: 1000,
- outputTokens: 500,
- cacheReadInputTokens: 0,
- cacheCreationInputTokens: 1500,
- want: 0.016125,
- },
- {
- name: "With all token types",
- model: Claude37Sonnet,
- inputTokens: 1000,
- outputTokens: 500,
- cacheReadInputTokens: 2000,
- cacheCreationInputTokens: 1500,
- want: 0.016725,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- usage := Usage{
- InputTokens: tt.inputTokens,
- OutputTokens: tt.outputTokens,
- CacheReadInputTokens: tt.cacheReadInputTokens,
- CacheCreationInputTokens: tt.cacheCreationInputTokens,
- }
- mr := MessageResponse{
- Model: tt.model,
- Usage: usage,
- }
- totalCost := mr.TotalDollars()
- if math.Abs(totalCost-tt.want) > 0.0001 {
- t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
- }
- })
- }
-}
-
-// TestCancelToolUse tests the CancelToolUse function of the Convo struct
-func TestCancelToolUse(t *testing.T) {
- tests := []struct {
- name string
- setupToolUse bool
- toolUseID string
- cancelErr error
- expectError bool
- expectCancel bool
- }{
- {
- name: "Cancel existing tool use",
- setupToolUse: true,
- toolUseID: "tool123",
- cancelErr: nil,
- expectError: false,
- expectCancel: true,
- },
- {
- name: "Cancel existing tool use with error",
- setupToolUse: true,
- toolUseID: "tool456",
- cancelErr: context.Canceled,
- expectError: false,
- expectCancel: true,
- },
- {
- name: "Cancel non-existent tool use",
- setupToolUse: false,
- toolUseID: "tool789",
- cancelErr: nil,
- expectError: true,
- expectCancel: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- convo := NewConvo(context.Background(), "")
-
- var cancelCalled bool
- var cancelledWithErr error
-
- if tt.setupToolUse {
- // Setup a mock cancel function to track calls
- mockCancel := func(err error) {
- cancelCalled = true
- cancelledWithErr = err
- }
-
- convo.muToolUseCancel.Lock()
- convo.toolUseCancel[tt.toolUseID] = mockCancel
- convo.muToolUseCancel.Unlock()
- }
-
- err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
-
- // Check if we got the expected error state
- if (err != nil) != tt.expectError {
- t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
- }
-
- // Check if the cancel function was called as expected
- if cancelCalled != tt.expectCancel {
- t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
- }
-
- // If we expected the cancel to be called, verify it was called with the right error
- if tt.expectCancel && cancelledWithErr != tt.cancelErr {
- t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
- }
-
- // Verify the toolUseID was removed from the map if it was initially added
- if tt.setupToolUse {
- convo.muToolUseCancel.Lock()
- _, exists := convo.toolUseCancel[tt.toolUseID]
- convo.muToolUseCancel.Unlock()
-
- if exists {
- t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
- }
- }
- })
- }
-}
diff --git a/claudetool/bash.go b/claudetool/bash.go
index b3b8b03..882dddf 100644
--- a/claudetool/bash.go
+++ b/claudetool/bash.go
@@ -13,8 +13,8 @@
"syscall"
"time"
- "sketch.dev/ant"
"sketch.dev/claudetool/bashkit"
+ "sketch.dev/llm"
)
// PermissionCallback is a function type for checking if a command is allowed to run
@@ -27,15 +27,15 @@
}
// NewBashTool creates a new Bash tool with optional permission callback
-func NewBashTool(checkPermission PermissionCallback) *ant.Tool {
+func NewBashTool(checkPermission PermissionCallback) *llm.Tool {
tool := &BashTool{
CheckPermission: checkPermission,
}
- return &ant.Tool{
+ return &llm.Tool{
Name: bashName,
Description: strings.TrimSpace(bashDescription),
- InputSchema: ant.MustSchema(bashInputSchema),
+ InputSchema: llm.MustSchema(bashInputSchema),
Run: tool.Run,
}
}
diff --git a/claudetool/differential.go b/claudetool/differential.go
index a6b3413..d76209a 100644
--- a/claudetool/differential.go
+++ b/claudetool/differential.go
@@ -20,18 +20,18 @@
"time"
"golang.org/x/tools/go/packages"
- "sketch.dev/ant"
+ "sketch.dev/llm"
)
// This file does differential quality analysis of a commit relative to a base commit.
// Tool returns a tool spec for a CodeReview tool backed by r.
-func (r *CodeReviewer) Tool() *ant.Tool {
- spec := &ant.Tool{
+func (r *CodeReviewer) Tool() *llm.Tool {
+ spec := &llm.Tool{
Name: "codereview",
Description: `Run an automated code review.`,
// If you modify this, update the termui template for prettier rendering.
- InputSchema: ant.MustSchema(`{"type": "object", "properties": {}}`),
+ InputSchema: llm.MustSchema(`{"type": "object", "properties": {}}`),
Run: r.Run,
}
return spec
@@ -663,7 +663,7 @@
// testStatus represents the status of a test in a given commit
type testStatus int
-//go:generate go tool stringer -type=testStatus -trimprefix=testStatus
+//go:generate go tool golang.org/x/tools/cmd/stringer -type=testStatus -trimprefix=testStatus
const (
testStatusUnknown testStatus = iota
testStatusPass
diff --git a/claudetool/edit.go b/claudetool/edit.go
index df83139..50084b7 100644
--- a/claudetool/edit.go
+++ b/claudetool/edit.go
@@ -18,7 +18,7 @@
"path/filepath"
"strings"
- "sketch.dev/ant"
+ "sketch.dev/llm"
)
// Constants for the AnthropicEditTool
@@ -59,7 +59,7 @@
var fileHistory = make(map[string][]string)
// AnthropicEditTool is a tool for viewing, creating, and editing files
-var AnthropicEditTool = &ant.Tool{
+var AnthropicEditTool = &llm.Tool{
// Note that Type is model-dependent, and would be different for Claude 3.5, for example.
Type: "text_editor_20250124",
Name: editName,
diff --git a/claudetool/keyword.go b/claudetool/keyword.go
index a99e3cd..8c693be 100644
--- a/claudetool/keyword.go
+++ b/claudetool/keyword.go
@@ -9,16 +9,17 @@
"os/exec"
"strings"
- "sketch.dev/ant"
+ "sketch.dev/llm"
+ "sketch.dev/llm/conversation"
)
// The Keyword tool provides keyword search.
// TODO: use an embedding model + re-ranker or otherwise do something nicer than this kludge.
// TODO: if we can get this fast enough, do it on the fly while the user is typing their prompt.
-var Keyword = &ant.Tool{
+var Keyword = &llm.Tool{
Name: keywordName,
Description: keywordDescription,
- InputSchema: ant.MustSchema(keywordInputSchema),
+ InputSchema: llm.MustSchema(keywordInputSchema),
Run: keywordRun,
}
@@ -122,16 +123,16 @@
keep = keep[:len(keep)-1]
}
- info := ant.ToolCallInfoFromContext(ctx)
+ info := conversation.ToolCallInfoFromContext(ctx)
convo := info.Convo.SubConvo()
convo.SystemPrompt = strings.TrimSpace(keywordSystemPrompt)
- initialMessage := ant.Message{
- Role: ant.MessageRoleUser,
- Content: []ant.Content{
- ant.StringContent("<pwd>\n" + wd + "\n</pwd>"),
- ant.StringContent("<ripgrep_results>\n" + out + "\n</ripgrep_results>"),
- ant.StringContent("<query>\n" + input.Query + "\n</query>"),
+ initialMessage := llm.Message{
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{
+ llm.StringContent("<pwd>\n" + wd + "\n</pwd>"),
+ llm.StringContent("<ripgrep_results>\n" + out + "\n</ripgrep_results>"),
+ llm.StringContent("<query>\n" + input.Query + "\n</query>"),
},
}
diff --git a/claudetool/patch.go b/claudetool/patch.go
index 9254319..0886e66 100644
--- a/claudetool/patch.go
+++ b/claudetool/patch.go
@@ -13,16 +13,16 @@
"path/filepath"
"strings"
- "sketch.dev/ant"
"sketch.dev/claudetool/editbuf"
"sketch.dev/claudetool/patchkit"
+ "sketch.dev/llm"
)
// Patch is a tool for precise text modifications in files.
-var Patch = &ant.Tool{
+var Patch = &llm.Tool{
Name: PatchName,
Description: strings.TrimSpace(PatchDescription),
- InputSchema: ant.MustSchema(PatchInputSchema),
+ InputSchema: llm.MustSchema(PatchInputSchema),
Run: PatchRun,
}
diff --git a/claudetool/think.go b/claudetool/think.go
index 293cc0b..69aac3c 100644
--- a/claudetool/think.go
+++ b/claudetool/think.go
@@ -4,14 +4,14 @@
"context"
"encoding/json"
- "sketch.dev/ant"
+ "sketch.dev/llm"
)
// The Think tool provides space to think.
-var Think = &ant.Tool{
+var Think = &llm.Tool{
Name: thinkName,
Description: thinkDescription,
- InputSchema: ant.MustSchema(thinkInputSchema),
+ InputSchema: llm.MustSchema(thinkInputSchema),
Run: thinkRun,
}
diff --git a/cmd/go2ts/go2ts.go b/cmd/go2ts/go2ts.go
index f7d773f..0a175a9 100644
--- a/cmd/go2ts/go2ts.go
+++ b/cmd/go2ts/go2ts.go
@@ -12,7 +12,7 @@
"os"
"go.skia.org/infra/go/go2ts"
- "sketch.dev/ant"
+ "sketch.dev/llm"
"sketch.dev/loop"
"sketch.dev/loop/server"
)
@@ -54,7 +54,7 @@
loop.AgentMessage{},
loop.GitCommit{},
loop.ToolCall{},
- ant.Usage{},
+ llm.Usage{},
server.State{},
)
diff --git a/cmd/sketch/main.go b/cmd/sketch/main.go
index e99d354..183aef0 100644
--- a/cmd/sketch/main.go
+++ b/cmd/sketch/main.go
@@ -17,11 +17,15 @@
"strings"
"time"
+ "sketch.dev/llm"
+ "sketch.dev/llm/oai"
+
"github.com/richardlehane/crock32"
- "sketch.dev/ant"
"sketch.dev/browser"
"sketch.dev/dockerimg"
"sketch.dev/httprr"
+ "sketch.dev/llm/ant"
+ "sketch.dev/llm/conversation"
"sketch.dev/loop"
"sketch.dev/loop/server"
"sketch.dev/skabandclient"
@@ -40,10 +44,8 @@
// run is the main entry point that parses flags and dispatches to the appropriate
// execution path based on whether we're running in a container or not.
func run() error {
- // Parse command-line flags
flagArgs := parseCLIFlags()
- // Handle version flag early
if flagArgs.version {
bi, ok := debug.ReadBuildInfo()
if ok {
@@ -52,6 +54,26 @@
return nil
}
+ if flagArgs.listModels {
+ fmt.Println("Available models:")
+ fmt.Println("- claude (default, uses Anthropic service)")
+ for _, name := range oai.ListModels() {
+ note := ""
+ if name != "gpt4.1" {
+ note = " (not recommended)"
+ }
+ fmt.Printf("- %s%s\n", name, note)
+ }
+ return nil
+ }
+
+ // For now, only Claude is supported in container mode.
+ // TODO: finish support--thread through API keys, add server support
+ isClaude := flagArgs.modelName == "claude" || flagArgs.modelName == ""
+ if !isClaude && (!flagArgs.unsafe || flagArgs.skabandAddr != "") {
+ return fmt.Errorf("only -model=claude is supported in safe mode right now, use -unsafe -skaband-addr=''")
+ }
+
// Add a global "session_id" to all logs using this context.
// A "session" is a single full run of the agent.
ctx := skribe.ContextWithAttr(context.Background(), slog.String("session_id", flagArgs.sessionID))
@@ -120,6 +142,8 @@
maxDollars float64
oneShot bool
prompt string
+ modelName string
+ listModels bool
verbose bool
version bool
workingDir string
@@ -152,6 +176,8 @@
flag.Float64Var(&flags.maxDollars, "max-dollars", 5.0, "maximum dollars the agent should spend per turn, 0 to disable limit")
flag.BoolVar(&flags.oneShot, "one-shot", false, "exit after the first turn without termui")
flag.StringVar(&flags.prompt, "prompt", "", "prompt to send to sketch")
+ flag.StringVar(&flags.modelName, "model", "claude", "model to use (e.g. claude, gpt4.1)")
+ flag.BoolVar(&flags.listModels, "list-models", false, "list all available models and exit")
flag.BoolVar(&flags.verbose, "verbose", false, "enable verbose output")
flag.BoolVar(&flags.version, "version", false, "print the version and exit")
flag.StringVar(&flags.workingDir, "C", "", "when set, change to this directory before running")
@@ -318,19 +344,25 @@
client = rr.Client()
}
- // Get current working directory
wd, err := os.Getwd()
if err != nil {
return err
}
- // Create and configure the agent
+ llmService, err := selectLLMService(client, flags.modelName, antURL, apiKey)
+ if err != nil {
+ return fmt.Errorf("failed to initialize LLM service: %w", err)
+ }
+ budget := conversation.Budget{
+ MaxResponses: flags.maxIterations,
+ MaxWallTime: flags.maxWallTime,
+ MaxDollars: flags.maxDollars,
+ }
+
agentConfig := loop.AgentConfig{
Context: ctx,
- AntURL: antURL,
- APIKey: apiKey,
- HTTPC: client,
- Budget: ant.Budget{MaxResponses: flags.maxIterations, MaxWallTime: flags.maxWallTime, MaxDollars: flags.maxDollars},
+ Service: llmService,
+ Budget: budget,
GitUsername: flags.gitUsername,
GitEmail: flags.gitEmail,
SessionID: flags.sessionID,
@@ -507,3 +539,37 @@
}
return strings.TrimSpace(string(out))
}
+
+// selectLLMService creates an LLM service based on the specified model name.
+// If modelName is empty or "claude", it uses the Anthropic service.
+// Otherwise, it tries to use the OpenAI service with the specified model.
+// Returns an error if the model name is not recognized or if required configuration is missing.
+func selectLLMService(client *http.Client, modelName string, antURL, apiKey string) (llm.Service, error) {
+ if modelName == "" || modelName == "claude" {
+ if apiKey == "" {
+ return nil, fmt.Errorf("missing ANTHROPIC_API_KEY")
+ }
+ return &ant.Service{
+ HTTPC: client,
+ URL: antURL,
+ APIKey: apiKey,
+ }, nil
+ }
+
+ model := oai.ModelByUserName(modelName)
+ if model == nil {
+ return nil, fmt.Errorf("unknown model '%s', use -list-models to see available models", modelName)
+ }
+
+ // Verify we have an API key, if necessary.
+ apiKey = os.Getenv(model.APIKeyEnv)
+ if model.APIKeyEnv != "" && apiKey == "" {
+ return nil, fmt.Errorf("missing API key for %s model, set %s environment variable", model.UserName, model.APIKeyEnv)
+ }
+
+ return &oai.Service{
+ HTTPC: client,
+ Model: *model,
+ APIKey: apiKey,
+ }, nil
+}
diff --git a/dockerimg/createdockerfile.go b/dockerimg/createdockerfile.go
index 12e876a..75ece83 100644
--- a/dockerimg/createdockerfile.go
+++ b/dockerimg/createdockerfile.go
@@ -15,7 +15,8 @@
"strings"
"text/template"
- "sketch.dev/ant"
+ "sketch.dev/llm"
+ "sketch.dev/llm/conversation"
)
func hashInitFiles(initFiles map[string]string) string {
@@ -166,7 +167,7 @@
// It expects the relevant initFiles to have been provided.
// If the sketch binary is being executed in a sub-directory of the repository,
// the relative path is provided on subPathWorkingDir.
-func createDockerfile(ctx context.Context, httpc *http.Client, antURL, antAPIKey string, initFiles map[string]string, subPathWorkingDir string) (string, error) {
+func createDockerfile(ctx context.Context, srv llm.Service, initFiles map[string]string, subPathWorkingDir string) (string, error) {
if subPathWorkingDir == "." {
subPathWorkingDir = ""
} else if subPathWorkingDir != "" && subPathWorkingDir[0] != '/' {
@@ -188,18 +189,14 @@
toolCalled = true
return "OK", nil
}
- convo := ant.NewConvo(ctx, antAPIKey)
- if httpc != nil {
- convo.HTTPC = httpc
- }
- if antURL != "" {
- convo.URL = antURL
- }
- convo.Tools = []*ant.Tool{{
+
+ convo := conversation.New(ctx, srv)
+
+ convo.Tools = []*llm.Tool{{
Name: "dockerfile",
Description: "Helps define a Dockerfile that sets up a dev environment for this project.",
Run: runDockerfile,
- InputSchema: ant.MustSchema(`{
+ InputSchema: llm.MustSchema(`{
"type": "object",
"required": ["extra_cmds"],
"properties": {
@@ -223,10 +220,10 @@
// git diff dockerimg/testdata/*.dockerfile
//
// If the dockerfile changes are a strict improvement, commit all the changes.
- msg := ant.Message{
- Role: ant.MessageRoleUser,
- Content: []ant.Content{{
- Type: ant.ContentTypeText,
+ msg := llm.Message{
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{{
+ Type: llm.ContentTypeText,
Text: `
Call the dockerfile tool to create a Dockerfile.
The parameters to dockerfile fill out the From and ExtraCmds
@@ -250,15 +247,15 @@
}
for _, name := range slices.Sorted(maps.Keys(initFiles)) {
- msg.Content = append(msg.Content, ant.StringContent(fmt.Sprintf("Here is the contents %s:\n<file>\n%s\n</file>\n\n", name, initFiles[name])))
+ msg.Content = append(msg.Content, llm.StringContent(fmt.Sprintf("Here is the contents %s:\n<file>\n%s\n</file>\n\n", name, initFiles[name])))
}
- msg.Content = append(msg.Content, ant.StringContent("Now call the dockerfile tool."))
+ msg.Content = append(msg.Content, llm.StringContent("Now call the dockerfile tool."))
res, err := convo.SendMessage(msg)
if err != nil {
return "", err
}
- if res.StopReason != ant.StopReasonToolUse {
- return "", fmt.Errorf("expected stop reason %q, got %q", ant.StopReasonToolUse, res.StopReason)
+ if res.StopReason != llm.StopReasonToolUse {
+ return "", fmt.Errorf("expected stop reason %q, got %q", llm.StopReasonToolUse, res.StopReason)
}
if _, err := convo.ToolResultContents(context.TODO(), res); err != nil {
return "", err
diff --git a/dockerimg/dockerimg.go b/dockerimg/dockerimg.go
index 292c8b6..1486435 100644
--- a/dockerimg/dockerimg.go
+++ b/dockerimg/dockerimg.go
@@ -21,6 +21,7 @@
"time"
"sketch.dev/browser"
+ "sketch.dev/llm/ant"
"sketch.dev/loop/server"
"sketch.dev/skribe"
"sketch.dev/webui"
@@ -654,7 +655,12 @@
}
start := time.Now()
- dockerfile, err := createDockerfile(ctx, http.DefaultClient, antURL, antAPIKey, initFiles, subPathWorkingDir)
+ srv := &ant.Service{
+ URL: antURL,
+ APIKey: antAPIKey,
+ HTTPC: http.DefaultClient,
+ }
+ dockerfile, err := createDockerfile(ctx, srv, initFiles, subPathWorkingDir)
if err != nil {
return "", fmt.Errorf("create dockerfile: %w", err)
}
diff --git a/dockerimg/dockerimg_test.go b/dockerimg/dockerimg_test.go
index 9e39e9c..7e41742 100644
--- a/dockerimg/dockerimg_test.go
+++ b/dockerimg/dockerimg_test.go
@@ -13,6 +13,7 @@
gcmp "github.com/google/go-cmp/cmp"
"sketch.dev/httprr"
+ "sketch.dev/llm/ant"
)
var flagRewriteWant = flag.Bool("rewritewant", false, "rewrite the dockerfiles we want from the model")
@@ -89,7 +90,11 @@
t.Fatal(err)
}
apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_ANTHROPIC_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
- result, err := createDockerfile(ctx, rr.Client(), "", apiKey, initFiles, "")
+ srv := &ant.Service{
+ APIKey: apiKey,
+ HTTPC: rr.Client(),
+ }
+ result, err := createDockerfile(ctx, srv, initFiles, "")
if err != nil {
t.Fatal(err)
}
diff --git a/go.mod b/go.mod
index 85d199c..ccadaf1 100644
--- a/go.mod
+++ b/go.mod
@@ -25,6 +25,7 @@
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/oklog/ulid/v2 v2.1.0 // indirect
+ github.com/sashabaranov/go-openai v1.38.2 // indirect
golang.org/x/mod v0.24.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.24.0 // indirect
diff --git a/go.sum b/go.sum
index 9f65af8..5e1d734 100644
--- a/go.sum
+++ b/go.sum
@@ -42,6 +42,8 @@
go.skia.org/infra v0.0.0-20250421160028-59e18403fd4a/go.mod h1:itQeLiwIYtXPJJEqdxRpOlS77LNv/quHjkyy+SaXrkw=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
+github.com/sashabaranov/go-openai v1.38.2 h1:akrssjj+6DY3lWuDwHv6cBvJ8Z+FZDM9XEaaYFt0Auo=
+github.com/sashabaranov/go-openai v1.38.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
diff --git a/llm/ant/ant.go b/llm/ant/ant.go
new file mode 100644
index 0000000..dce17f1
--- /dev/null
+++ b/llm/ant/ant.go
@@ -0,0 +1,480 @@
+package ant
+
+import (
+ "bytes"
+ "cmp"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "math/rand/v2"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ "sketch.dev/llm"
+)
+
+const (
+ DefaultModel = Claude37Sonnet
+ // See https://docs.anthropic.com/en/docs/about-claude/models/all-models for
+ // current maximums. There's currently a flag to enable 128k output (output-128k-2025-02-19)
+ DefaultMaxTokens = 8192
+ DefaultURL = "https://api.anthropic.com/v1/messages"
+)
+
+const (
+ Claude35Sonnet = "claude-3-5-sonnet-20241022"
+ Claude35Haiku = "claude-3-5-haiku-20241022"
+ Claude37Sonnet = "claude-3-7-sonnet-20250219"
+)
+
+// Service provides Claude completions.
+// Fields should not be altered concurrently with calling any method on Service.
+type Service struct {
+ HTTPC *http.Client // defaults to http.DefaultClient if nil
+ URL string // defaults to DefaultURL if empty
+ APIKey string // must be non-empty
+ Model string // defaults to DefaultModel if empty
+ MaxTokens int // defaults to DefaultMaxTokens if zero
+}
+
+var _ llm.Service = (*Service)(nil)
+
+type content struct {
+ // TODO: image support?
+ // https://docs.anthropic.com/en/api/messages
+ ID string `json:"id,omitempty"`
+ Type string `json:"type,omitempty"`
+ Text string `json:"text,omitempty"`
+
+ // for thinking
+ Thinking string `json:"thinking,omitempty"`
+ Data string `json:"data,omitempty"` // for redacted_thinking
+ Signature string `json:"signature,omitempty"` // for thinking
+
+ // for tool_use
+ ToolName string `json:"name,omitempty"`
+ ToolInput json.RawMessage `json:"input,omitempty"`
+
+ // for tool_result
+ ToolUseID string `json:"tool_use_id,omitempty"`
+ ToolError bool `json:"is_error,omitempty"`
+ ToolResult string `json:"content,omitempty"`
+
+ // timing information for tool_result; not sent to Claude
+ StartTime *time.Time `json:"-"`
+ EndTime *time.Time `json:"-"`
+
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+// message represents a message in the conversation.
+type message struct {
+ Role string `json:"role"`
+ Content []content `json:"content"`
+ ToolUse *toolUse `json:"tool_use,omitempty"` // use to control whether/which tool to use
+}
+
+// toolUse represents a tool use in the message content.
+type toolUse struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+}
+
+// tool represents a tool available to Claude.
+type tool struct {
+ Name string `json:"name"`
+ // Type is used by the text editor tool; see
+ // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
+ Type string `json:"type,omitempty"`
+ Description string `json:"description,omitempty"`
+ InputSchema json.RawMessage `json:"input_schema,omitempty"`
+}
+
+// usage represents the billing and rate-limit usage.
+type usage struct {
+ InputTokens uint64 `json:"input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CostUSD float64 `json:"cost_usd"`
+}
+
+func (u *usage) Add(other usage) {
+ u.InputTokens += other.InputTokens
+ u.CacheCreationInputTokens += other.CacheCreationInputTokens
+ u.CacheReadInputTokens += other.CacheReadInputTokens
+ u.OutputTokens += other.OutputTokens
+ u.CostUSD += other.CostUSD
+}
+
+type errorResponse struct {
+ Type string `json:"type"`
+ Message string `json:"message"`
+}
+
+// response represents the response from the message API.
+type response struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Role string `json:"role"`
+ Model string `json:"model"`
+ Content []content `json:"content"`
+ StopReason string `json:"stop_reason"`
+ StopSequence *string `json:"stop_sequence,omitempty"`
+ Usage usage `json:"usage"`
+}
+
+type toolChoice struct {
+ Type string `json:"type"`
+ Name string `json:"name,omitempty"`
+}
+
+// https://docs.anthropic.com/en/api/messages#body-system
+type systemContent struct {
+ Text string `json:"text,omitempty"`
+ Type string `json:"type,omitempty"`
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+// request represents the request payload for creating a message.
+type request struct {
+ Model string `json:"model"`
+ Messages []message `json:"messages"`
+ ToolChoice *toolChoice `json:"tool_choice,omitempty"`
+ MaxTokens int `json:"max_tokens"`
+ Tools []*tool `json:"tools,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ System []systemContent `json:"system,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ StopSequences []string `json:"stop_sequences,omitempty"`
+
+ TokenEfficientToolUse bool `json:"-"` // DO NOT USE, broken on Anthropic's side as of 2025-02-28
+}
+
+const dumpText = false // debugging toggle to see raw communications with Claude
+
+func mapped[Slice ~[]E, E, T any](s Slice, f func(E) T) []T {
+ out := make([]T, len(s))
+ for i, v := range s {
+ out[i] = f(v)
+ }
+ return out
+}
+
+func inverted[K, V cmp.Ordered](m map[K]V) map[V]K {
+ inv := make(map[V]K)
+ for k, v := range m {
+ if _, ok := inv[v]; ok {
+ panic(fmt.Errorf("inverted map has multiple keys for value %v", v))
+ }
+ inv[v] = k
+ }
+ return inv
+}
+
+var (
+ fromLLMRole = map[llm.MessageRole]string{
+ llm.MessageRoleAssistant: "assistant",
+ llm.MessageRoleUser: "user",
+ }
+ toLLMRole = inverted(fromLLMRole)
+
+ fromLLMContentType = map[llm.ContentType]string{
+ llm.ContentTypeText: "text",
+ llm.ContentTypeThinking: "thinking",
+ llm.ContentTypeRedactedThinking: "redacted_thinking",
+ llm.ContentTypeToolUse: "tool_use",
+ llm.ContentTypeToolResult: "tool_result",
+ }
+ toLLMContentType = inverted(fromLLMContentType)
+
+ fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
+ llm.ToolChoiceTypeAuto: "auto",
+ llm.ToolChoiceTypeAny: "any",
+ llm.ToolChoiceTypeNone: "none",
+ llm.ToolChoiceTypeTool: "tool",
+ }
+
+ toLLMStopReason = map[string]llm.StopReason{
+ "stop_sequence": llm.StopReasonStopSequence,
+ "max_tokens": llm.StopReasonMaxTokens,
+ "end_turn": llm.StopReasonEndTurn,
+ "tool_use": llm.StopReasonToolUse,
+ }
+)
+
+func fromLLMCache(c bool) json.RawMessage {
+ if !c {
+ return nil
+ }
+ return json.RawMessage(`{"type":"ephemeral"}`)
+}
+
+func fromLLMContent(c llm.Content) content {
+ return content{
+ ID: c.ID,
+ Type: fromLLMContentType[c.Type],
+ Text: c.Text,
+ Thinking: c.Thinking,
+ Data: c.Data,
+ Signature: c.Signature,
+ ToolName: c.ToolName,
+ ToolInput: c.ToolInput,
+ ToolUseID: c.ToolUseID,
+ ToolError: c.ToolError,
+ ToolResult: c.ToolResult,
+ CacheControl: fromLLMCache(c.Cache),
+ }
+}
+
+func fromLLMToolUse(tu *llm.ToolUse) *toolUse {
+ if tu == nil {
+ return nil
+ }
+ return &toolUse{
+ ID: tu.ID,
+ Name: tu.Name,
+ }
+}
+
+func fromLLMMessage(msg llm.Message) message {
+ return message{
+ Role: fromLLMRole[msg.Role],
+ Content: mapped(msg.Content, fromLLMContent),
+ ToolUse: fromLLMToolUse(msg.ToolUse),
+ }
+}
+
+func fromLLMToolChoice(tc *llm.ToolChoice) *toolChoice {
+ if tc == nil {
+ return nil
+ }
+ return &toolChoice{
+ Type: fromLLMToolChoiceType[tc.Type],
+ Name: tc.Name,
+ }
+}
+
+func fromLLMTool(t *llm.Tool) *tool {
+ return &tool{
+ Name: t.Name,
+ Type: t.Type,
+ Description: t.Description,
+ InputSchema: t.InputSchema,
+ }
+}
+
+func fromLLMSystem(s llm.SystemContent) systemContent {
+ return systemContent{
+ Text: s.Text,
+ Type: s.Type,
+ CacheControl: fromLLMCache(s.Cache),
+ }
+}
+
+func (s *Service) fromLLMRequest(r *llm.Request) *request {
+ return &request{
+ Model: cmp.Or(s.Model, DefaultModel),
+ Messages: mapped(r.Messages, fromLLMMessage),
+ MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
+ ToolChoice: fromLLMToolChoice(r.ToolChoice),
+ Tools: mapped(r.Tools, fromLLMTool),
+ System: mapped(r.System, fromLLMSystem),
+ }
+}
+
+func toLLMUsage(u usage) llm.Usage {
+ return llm.Usage{
+ InputTokens: u.InputTokens,
+ CacheCreationInputTokens: u.CacheCreationInputTokens,
+ CacheReadInputTokens: u.CacheReadInputTokens,
+ OutputTokens: u.OutputTokens,
+ CostUSD: u.CostUSD,
+ }
+}
+
+func toLLMContent(c content) llm.Content {
+ return llm.Content{
+ ID: c.ID,
+ Type: toLLMContentType[c.Type],
+ Text: c.Text,
+ Thinking: c.Thinking,
+ Data: c.Data,
+ Signature: c.Signature,
+ ToolName: c.ToolName,
+ ToolInput: c.ToolInput,
+ ToolUseID: c.ToolUseID,
+ ToolError: c.ToolError,
+ ToolResult: c.ToolResult,
+ }
+}
+
+func toLLMResponse(r *response) *llm.Response {
+ return &llm.Response{
+ ID: r.ID,
+ Type: r.Type,
+ Role: toLLMRole[r.Role],
+ Model: r.Model,
+ Content: mapped(r.Content, toLLMContent),
+ StopReason: toLLMStopReason[r.StopReason],
+ StopSequence: r.StopSequence,
+ Usage: toLLMUsage(r.Usage),
+ }
+}
+
+// Do sends a request to Anthropic.
+func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
+ request := s.fromLLMRequest(ir)
+
+ var payload []byte
+ var err error
+ if dumpText || testing.Testing() {
+ payload, err = json.MarshalIndent(request, "", " ")
+ } else {
+ payload, err = json.Marshal(request)
+ payload = append(payload, '\n')
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ if false {
+ fmt.Printf("claude request payload:\n%s\n", payload)
+ }
+
+ backoff := []time.Duration{15 * time.Second, 30 * time.Second, time.Minute}
+ largerMaxTokens := false
+ var partialUsage usage
+
+ url := cmp.Or(s.URL, DefaultURL)
+ httpc := cmp.Or(s.HTTPC, http.DefaultClient)
+
+ // retry loop
+ for attempts := 0; ; attempts++ {
+ if dumpText {
+ fmt.Printf("RAW REQUEST:\n%s\n\n", payload)
+ }
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("X-API-Key", s.APIKey)
+ req.Header.Set("Anthropic-Version", "2023-06-01")
+
+ var features []string
+ if request.TokenEfficientToolUse {
+ features = append(features, "token-efficient-tool-use-2025-02-19")
+ }
+ if largerMaxTokens {
+ features = append(features, "output-128k-2025-02-19")
+ request.MaxTokens = 128 * 1024
+ }
+ if len(features) > 0 {
+ req.Header.Set("anthropic-beta", strings.Join(features, ","))
+ }
+
+ resp, err := httpc.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ buf, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+
+ switch {
+ case resp.StatusCode == http.StatusOK:
+ if dumpText {
+ fmt.Printf("RAW RESPONSE:\n%s\n\n", buf)
+ }
+ var response response
+ err = json.NewDecoder(bytes.NewReader(buf)).Decode(&response)
+ if err != nil {
+ return nil, err
+ }
+ if response.StopReason == "max_tokens" && !largerMaxTokens {
+ fmt.Printf("Retrying Anthropic API call with larger max tokens size.")
+ // Retry with more output tokens.
+ largerMaxTokens = true
+ response.Usage.CostUSD = response.TotalDollars()
+ partialUsage = response.Usage
+ continue
+ }
+
+ // Calculate and set the cost_usd field
+ if largerMaxTokens {
+ response.Usage.Add(partialUsage)
+ }
+ response.Usage.CostUSD = response.TotalDollars()
+
+ return toLLMResponse(&response), nil
+ case resp.StatusCode >= 500 && resp.StatusCode < 600:
+ // overloaded or unhappy, in one form or another
+ sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "anthropic_request_failed", "response", string(buf), "status_code", resp.StatusCode, "sleep", sleep)
+ time.Sleep(sleep)
+ case resp.StatusCode == 429:
+ // rate limited. wait 1 minute as a starting point, because that's the rate limiting window.
+ // and then add some additional time for backoff.
+ sleep := time.Minute + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "anthropic_request_rate_limited", "response", string(buf), "sleep", sleep)
+ time.Sleep(sleep)
+ // case resp.StatusCode == 400:
+ // TODO: parse ErrorResponse, make (*ErrorResponse) implement error
+ default:
+ return nil, fmt.Errorf("API request failed with status %s\n%s", resp.Status, buf)
+ }
+ }
+}
+
+// cents per million tokens
+// (not dollars because i'm twitchy about using floats for money)
+type centsPer1MTokens struct {
+ Input uint64
+ Output uint64
+ CacheRead uint64
+ CacheCreation uint64
+}
+
+// https://www.anthropic.com/pricing#anthropic-api
+var modelCost = map[string]centsPer1MTokens{
+ Claude37Sonnet: {
+ Input: 300, // $3
+ Output: 1500, // $15
+ CacheRead: 30, // $0.30
+ CacheCreation: 375, // $3.75
+ },
+ Claude35Haiku: {
+ Input: 80, // $0.80
+ Output: 400, // $4.00
+ CacheRead: 8, // $0.08
+ CacheCreation: 100, // $1.00
+ },
+ Claude35Sonnet: {
+ Input: 300, // $3
+ Output: 1500, // $15
+ CacheRead: 30, // $0.30
+ CacheCreation: 375, // $3.75
+ },
+}
+
+// TotalDollars returns the total cost to obtain this response, in dollars.
+func (mr *response) TotalDollars() float64 {
+ cpm, ok := modelCost[mr.Model]
+ if !ok {
+ panic(fmt.Sprintf("no pricing info for model: %s", mr.Model))
+ }
+ use := mr.Usage
+ megaCents := use.InputTokens*cpm.Input +
+ use.OutputTokens*cpm.Output +
+ use.CacheReadInputTokens*cpm.CacheRead +
+ use.CacheCreationInputTokens*cpm.CacheCreation
+ cents := float64(megaCents) / 1_000_000.0
+ return cents / 100.0
+}
diff --git a/llm/ant/ant_test.go b/llm/ant/ant_test.go
new file mode 100644
index 0000000..67cc5db
--- /dev/null
+++ b/llm/ant/ant_test.go
@@ -0,0 +1,93 @@
+package ant
+
+import (
+ "math"
+ "testing"
+)
+
+// TestCalculateCostFromTokens tests the calculateCostFromTokens function
+func TestCalculateCostFromTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ model string
+ inputTokens uint64
+ outputTokens uint64
+ cacheReadInputTokens uint64
+ cacheCreationInputTokens uint64
+ want float64
+ }{
+ {
+ name: "Zero tokens",
+ model: Claude37Sonnet,
+ inputTokens: 0,
+ outputTokens: 0,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.0105,
+ },
+ {
+ name: "10000 input tokens, 5000 output tokens",
+ model: Claude37Sonnet,
+ inputTokens: 10000,
+ outputTokens: 5000,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 0,
+ want: 0.105,
+ },
+ {
+ name: "With cache read tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 0,
+ want: 0.0111,
+ },
+ {
+ name: "With cache creation tokens",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 0,
+ cacheCreationInputTokens: 1500,
+ want: 0.016125,
+ },
+ {
+ name: "With all token types",
+ model: Claude37Sonnet,
+ inputTokens: 1000,
+ outputTokens: 500,
+ cacheReadInputTokens: 2000,
+ cacheCreationInputTokens: 1500,
+ want: 0.016725,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ usage := usage{
+ InputTokens: tt.inputTokens,
+ OutputTokens: tt.outputTokens,
+ CacheReadInputTokens: tt.cacheReadInputTokens,
+ CacheCreationInputTokens: tt.cacheCreationInputTokens,
+ }
+ mr := response{
+ Model: tt.model,
+ Usage: usage,
+ }
+ totalCost := mr.TotalDollars()
+ if math.Abs(totalCost-tt.want) > 0.0001 {
+ t.Errorf("totalCost = %v, want %v", totalCost, tt.want)
+ }
+ })
+ }
+}
diff --git a/llm/conversation/convo.go b/llm/conversation/convo.go
new file mode 100644
index 0000000..5a12256
--- /dev/null
+++ b/llm/conversation/convo.go
@@ -0,0 +1,617 @@
+package conversation
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "maps"
+ "math/rand/v2"
+ "slices"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/oklog/ulid/v2"
+ "github.com/richardlehane/crock32"
+ "sketch.dev/llm"
+ "sketch.dev/skribe"
+)
+
+type Listener interface {
+ // TODO: Content is leaking an anthropic API; should we avoid it?
+ // TODO: Where should we include start/end time and usage?
+ OnToolCall(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content)
+ OnToolResult(ctx context.Context, convo *Convo, toolCallID string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error)
+ OnRequest(ctx context.Context, convo *Convo, requestID string, msg *llm.Message)
+ OnResponse(ctx context.Context, convo *Convo, requestID string, msg *llm.Response)
+}
+
+type NoopListener struct{}
+
+func (n *NoopListener) OnToolCall(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content) {
+}
+
+func (n *NoopListener) OnToolResult(ctx context.Context, convo *Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
+}
+
+func (n *NoopListener) OnResponse(ctx context.Context, convo *Convo, id string, msg *llm.Response) {
+}
+func (n *NoopListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {}
+
+var ErrDoNotRespond = errors.New("do not respond")
+
+// A Convo is a managed conversation with Claude.
+// It automatically manages the state of the conversation,
+// including appending messages send/received,
+// calling tools and sending their results,
+// tracking usage, etc.
+//
+// Exported fields must not be altered concurrently with calling any method on Convo.
+// Typical usage is to configure a Convo once before using it.
+type Convo struct {
+ // ID is a unique ID for the conversation
+ ID string
+ // Ctx is the context for the entire conversation.
+ Ctx context.Context
+ // Service is the LLM service to use.
+ Service llm.Service
+ // Tools are the tools available during the conversation.
+ Tools []*llm.Tool
+ // SystemPrompt is the system prompt for the conversation.
+ SystemPrompt string
+ // PromptCaching indicates whether to use Anthropic's prompt caching.
+ // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation
+ // for the documentation. At request send time, we set the cache_control field on the
+ // last message. We also cache the system prompt.
+ // Default: true.
+ PromptCaching bool
+ // ToolUseOnly indicates whether Claude may only use tools during this conversation.
+ // TODO: add more fine-grained control over tool use?
+ ToolUseOnly bool
+ // Parent is the parent conversation, if any.
+ // It is non-nil for "subagent" calls.
+ // It is set automatically when calling SubConvo,
+ // and usually should not be set manually.
+ Parent *Convo
+ // Budget is the budget for this conversation (and all sub-conversations).
+ // The Conversation DOES NOT automatically enforce the budget.
+ // It is up to the caller to call OverBudget() as appropriate.
+ Budget Budget
+
+ // messages tracks the messages so far in the conversation.
+ messages []llm.Message
+
+ // Listener receives messages being sent.
+ Listener Listener
+
+ muToolUseCancel *sync.Mutex
+ toolUseCancel map[string]context.CancelCauseFunc
+
+ // Protects usage. This is used for subconversations (that share part of CumulativeUsage) as well.
+ mu *sync.Mutex
+ // usage tracks usage for this conversation and all sub-conversations.
+ usage *CumulativeUsage
+}
+
+// newConvoID generates a new 8-byte random id.
+// The uniqueness/collision requirements here are very low.
+// They are not global identifiers,
+// just enough to distinguish different convos in a single session.
+func newConvoID() string {
+ u1 := rand.Uint32()
+ s := crock32.Encode(uint64(u1))
+ if len(s) < 7 {
+ s += strings.Repeat("0", 7-len(s))
+ }
+ return s[:3] + "-" + s[3:]
+}
+
+// New creates a new conversation with Claude with sensible defaults.
+// ctx is the context for the entire conversation.
+func New(ctx context.Context, srv llm.Service) *Convo {
+ id := newConvoID()
+ return &Convo{
+ Ctx: skribe.ContextWithAttr(ctx, slog.String("convo_id", id)),
+ Service: srv,
+ PromptCaching: true,
+ usage: newUsage(),
+ Listener: &NoopListener{},
+ ID: id,
+ muToolUseCancel: &sync.Mutex{},
+ toolUseCancel: map[string]context.CancelCauseFunc{},
+ mu: &sync.Mutex{},
+ }
+}
+
+// SubConvo creates a sub-conversation with the same configuration as the parent conversation.
+// (This propagates context for cancellation, HTTP client, API key, etc.)
+// The sub-conversation shares no messages with the parent conversation.
+// It does not inherit tools from the parent conversation.
+func (c *Convo) SubConvo() *Convo {
+ id := newConvoID()
+ return &Convo{
+ Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
+ Service: c.Service,
+ PromptCaching: c.PromptCaching,
+ Parent: c,
+ // For convenience, sub-convo usage shares tool uses map with parent,
+ // all other fields separate, propagated in AddResponse
+ usage: newUsageWithSharedToolUses(c.usage),
+ mu: c.mu,
+ Listener: c.Listener,
+ ID: id,
+ // Do not copy Budget. Each budget is independent,
+ // and OverBudget checks whether any ancestor is over budget.
+ }
+}
+
+func (c *Convo) SubConvoWithHistory() *Convo {
+ id := newConvoID()
+ return &Convo{
+ Ctx: skribe.ContextWithAttr(c.Ctx, slog.String("convo_id", id), slog.String("parent_convo_id", c.ID)),
+ Service: c.Service,
+ PromptCaching: c.PromptCaching,
+ Parent: c,
+ // For convenience, sub-convo usage shares tool uses map with parent,
+ // all other fields separate, propagated in AddResponse
+ usage: newUsageWithSharedToolUses(c.usage),
+ mu: c.mu,
+ Listener: c.Listener,
+ ID: id,
+ // Do not copy Budget. Each budget is independent,
+ // and OverBudget checks whether any ancestor is over budget.
+ messages: slices.Clone(c.messages),
+ }
+}
+
+// Depth reports how many "sub-conversations" deep this conversation is.
+// That it, it walks up parents until it finds a root.
+func (c *Convo) Depth() int {
+ x := c
+ var depth int
+ for x.Parent != nil {
+ x = x.Parent
+ depth++
+ }
+ return depth
+}
+
+// SendUserTextMessage sends a text message to the LLM in this conversation.
+// otherContents contains additional contents to send with the message, usually tool results.
+func (c *Convo) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
+ contents := slices.Clone(otherContents)
+ if s != "" {
+ contents = append(contents, llm.Content{Type: llm.ContentTypeText, Text: s})
+ }
+ msg := llm.Message{
+ Role: llm.MessageRoleUser,
+ Content: contents,
+ }
+ return c.SendMessage(msg)
+}
+
+func (c *Convo) messageRequest(msg llm.Message) *llm.Request {
+ system := []llm.SystemContent{}
+ if c.SystemPrompt != "" {
+ var d llm.SystemContent
+ d = llm.SystemContent{Type: "text", Text: c.SystemPrompt}
+ if c.PromptCaching {
+ d.Cache = true
+ }
+ system = []llm.SystemContent{d}
+ }
+
+ // Claude is happy to return an empty response in response to our Done() call,
+ // and, if so, you'll see something like:
+ // API request failed with status 400 Bad Request
+ // {"type":"error","error": {"type":"invalid_request_error",
+ // "message":"messages.5: all messages must have non-empty content except for the optional final assistant message"}}
+ // So, we filter out those empty messages.
+ var nonEmptyMessages []llm.Message
+ for _, m := range c.messages {
+ if len(m.Content) > 0 {
+ nonEmptyMessages = append(nonEmptyMessages, m)
+ }
+ }
+
+ mr := &llm.Request{
+ Messages: append(nonEmptyMessages, msg), // not yet committed to keeping msg
+ System: system,
+ Tools: c.Tools,
+ }
+ if c.ToolUseOnly {
+ mr.ToolChoice = &llm.ToolChoice{Type: llm.ToolChoiceTypeAny}
+ }
+ return mr
+}
+
+func (c *Convo) findTool(name string) (*llm.Tool, error) {
+ for _, tool := range c.Tools {
+ if tool.Name == name {
+ return tool, nil
+ }
+ }
+ return nil, fmt.Errorf("tool %q not found", name)
+}
+
+// insertMissingToolResults adds error results for tool uses that were requested
+// but not included in the message, which can happen in error paths like "out of budget."
+// We only insert these if there were no tool responses at all, since an incorrect
+// number of tool results would be a programmer error. Mutates inputs.
+func (c *Convo) insertMissingToolResults(mr *llm.Request, msg *llm.Message) {
+ if len(mr.Messages) < 2 {
+ return
+ }
+ prev := mr.Messages[len(mr.Messages)-2]
+ var toolUsePrev int
+ for _, c := range prev.Content {
+ if c.Type == llm.ContentTypeToolUse {
+ toolUsePrev++
+ }
+ }
+ if toolUsePrev == 0 {
+ return
+ }
+ var toolUseCurrent int
+ for _, c := range msg.Content {
+ if c.Type == llm.ContentTypeToolResult {
+ toolUseCurrent++
+ }
+ }
+ if toolUseCurrent != 0 {
+ return
+ }
+ var prefix []llm.Content
+ for _, part := range prev.Content {
+ if part.Type != llm.ContentTypeToolUse {
+ continue
+ }
+ content := llm.Content{
+ Type: llm.ContentTypeToolResult,
+ ToolUseID: part.ID,
+ ToolError: true,
+ ToolResult: "not executed; retry possible",
+ }
+ prefix = append(prefix, content)
+ msg.Content = append(prefix, msg.Content...)
+ mr.Messages[len(mr.Messages)-1].Content = msg.Content
+ }
+ slog.DebugContext(c.Ctx, "inserted missing tool results")
+}
+
+// SendMessage sends a message to Claude.
+// The conversation records (internally) all messages succesfully sent and received.
+func (c *Convo) SendMessage(msg llm.Message) (*llm.Response, error) {
+ id := ulid.Make().String()
+ mr := c.messageRequest(msg)
+ var lastMessage *llm.Message
+ if c.PromptCaching {
+ lastMessage = &mr.Messages[len(mr.Messages)-1]
+ if len(lastMessage.Content) > 0 {
+ lastMessage.Content[len(lastMessage.Content)-1].Cache = true
+ }
+ }
+ defer func() {
+ if lastMessage == nil {
+ return
+ }
+ if len(lastMessage.Content) > 0 {
+ lastMessage.Content[len(lastMessage.Content)-1].Cache = false
+ }
+ }()
+ c.insertMissingToolResults(mr, &msg)
+ c.Listener.OnRequest(c.Ctx, c, id, &msg)
+
+ startTime := time.Now()
+ resp, err := c.Service.Do(c.Ctx, mr)
+ if resp != nil {
+ resp.StartTime = &startTime
+ endTime := time.Now()
+ resp.EndTime = &endTime
+ }
+
+ if err != nil {
+ c.Listener.OnResponse(c.Ctx, c, id, nil)
+ return nil, err
+ }
+ c.messages = append(c.messages, msg, resp.ToMessage())
+ // Propagate usage to all ancestors (including us).
+ for x := c; x != nil; x = x.Parent {
+ x.usage.Add(resp.Usage)
+ }
+ c.Listener.OnResponse(c.Ctx, c, id, resp)
+ return resp, err
+}
+
+type toolCallInfoKeyType string
+
+var toolCallInfoKey toolCallInfoKeyType
+
+type ToolCallInfo struct {
+ ToolUseID string
+ Convo *Convo
+}
+
+func ToolCallInfoFromContext(ctx context.Context) ToolCallInfo {
+ v := ctx.Value(toolCallInfoKey)
+ i, _ := v.(ToolCallInfo)
+ return i
+}
+
+func (c *Convo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
+ if resp.StopReason != llm.StopReasonToolUse {
+ return nil, nil
+ }
+ var toolResults []llm.Content
+
+ for _, part := range resp.Content {
+ if part.Type != llm.ContentTypeToolUse {
+ continue
+ }
+ c.incrementToolUse(part.ToolName)
+
+ content := llm.Content{
+ Type: llm.ContentTypeToolResult,
+ ToolUseID: part.ID,
+ }
+
+ content.ToolError = true
+ content.ToolResult = "user canceled this too_use"
+ toolResults = append(toolResults, content)
+ }
+ return toolResults, nil
+}
+
+// GetID returns the conversation ID
+func (c *Convo) GetID() string {
+ return c.ID
+}
+
+func (c *Convo) CancelToolUse(toolUseID string, err error) error {
+ c.muToolUseCancel.Lock()
+ defer c.muToolUseCancel.Unlock()
+ cancel, ok := c.toolUseCancel[toolUseID]
+ if !ok {
+ return fmt.Errorf("cannot cancel %s: no cancel function registered for this tool_use_id. All I have is %+v", toolUseID, c.toolUseCancel)
+ }
+ delete(c.toolUseCancel, toolUseID)
+ cancel(err)
+ return nil
+}
+
+func (c *Convo) newToolUseContext(ctx context.Context, toolUseID string) (context.Context, context.CancelFunc) {
+ c.muToolUseCancel.Lock()
+ defer c.muToolUseCancel.Unlock()
+ ctx, cancel := context.WithCancelCause(ctx)
+ c.toolUseCancel[toolUseID] = cancel
+ return ctx, func() { c.CancelToolUse(toolUseID, nil) }
+}
+
+// ToolResultContents runs all tool uses requested by the response and returns their results.
+// Cancelling ctx will cancel any running tool calls.
+func (c *Convo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
+ if resp.StopReason != llm.StopReasonToolUse {
+ return nil, nil
+ }
+ // Extract all tool calls from the response, call the tools, and gather the results.
+ var wg sync.WaitGroup
+ toolResultC := make(chan llm.Content, len(resp.Content))
+ for _, part := range resp.Content {
+ if part.Type != llm.ContentTypeToolUse {
+ continue
+ }
+ c.incrementToolUse(part.ToolName)
+ startTime := time.Now()
+
+ c.Listener.OnToolCall(ctx, c, part.ID, part.ToolName, part.ToolInput, llm.Content{
+ Type: llm.ContentTypeToolUse,
+ ToolUseID: part.ID,
+ ToolUseStartTime: &startTime,
+ })
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ content := llm.Content{
+ Type: llm.ContentTypeToolResult,
+ ToolUseID: part.ID,
+ ToolUseStartTime: &startTime,
+ }
+ sendErr := func(err error) {
+ // Record end time
+ endTime := time.Now()
+ content.ToolUseEndTime = &endTime
+
+ content.ToolError = true
+ content.ToolResult = err.Error()
+ c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, nil, err)
+ toolResultC <- content
+ }
+ sendRes := func(res string) {
+ // Record end time
+ endTime := time.Now()
+ content.ToolUseEndTime = &endTime
+
+ content.ToolResult = res
+ c.Listener.OnToolResult(ctx, c, part.ID, part.ToolName, part.ToolInput, content, &res, nil)
+ toolResultC <- content
+ }
+
+ tool, err := c.findTool(part.ToolName)
+ if err != nil {
+ sendErr(err)
+ return
+ }
+ // Create a new context for just this tool_use call, and register its
+ // cancel function so that it can be canceled individually.
+ toolUseCtx, cancel := c.newToolUseContext(ctx, part.ID)
+ defer cancel()
+ // TODO: move this into newToolUseContext?
+ toolUseCtx = context.WithValue(toolUseCtx, toolCallInfoKey, ToolCallInfo{ToolUseID: part.ID, Convo: c})
+ toolResult, err := tool.Run(toolUseCtx, part.ToolInput)
+ if errors.Is(err, ErrDoNotRespond) {
+ return
+ }
+ if toolUseCtx.Err() != nil {
+ sendErr(context.Cause(toolUseCtx))
+ return
+ }
+
+ if err != nil {
+ sendErr(err)
+ return
+ }
+ sendRes(toolResult)
+ }()
+ }
+ wg.Wait()
+ close(toolResultC)
+ var toolResults []llm.Content
+ for toolResult := range toolResultC {
+ toolResults = append(toolResults, toolResult)
+ }
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+ return toolResults, nil
+}
+
+func (c *Convo) incrementToolUse(name string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.usage.ToolUses[name]++
+}
+
+// CumulativeUsage represents cumulative usage across a Convo, including all sub-conversations.
+type CumulativeUsage struct {
+ StartTime time.Time `json:"start_time"`
+ Responses uint64 `json:"messages"` // count of responses
+ InputTokens uint64 `json:"input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ TotalCostUSD float64 `json:"total_cost_usd"`
+ ToolUses map[string]int `json:"tool_uses"` // tool name -> number of uses
+}
+
+func newUsage() *CumulativeUsage {
+ return &CumulativeUsage{ToolUses: make(map[string]int), StartTime: time.Now()}
+}
+
+func newUsageWithSharedToolUses(parent *CumulativeUsage) *CumulativeUsage {
+ return &CumulativeUsage{ToolUses: parent.ToolUses, StartTime: time.Now()}
+}
+
+func (u *CumulativeUsage) Clone() CumulativeUsage {
+ v := *u
+ v.ToolUses = maps.Clone(u.ToolUses)
+ return v
+}
+
+func (c *Convo) CumulativeUsage() CumulativeUsage {
+ if c == nil {
+ return CumulativeUsage{}
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.usage.Clone()
+}
+
+func (u *CumulativeUsage) WallTime() time.Duration {
+ return time.Since(u.StartTime)
+}
+
+func (u *CumulativeUsage) DollarsPerHour() float64 {
+ hours := u.WallTime().Hours()
+ // Prevent division by very small numbers that could cause issues
+ if hours < 1e-6 {
+ return 0
+ }
+ return u.TotalCostUSD / hours
+}
+
+func (u *CumulativeUsage) Add(usage llm.Usage) {
+ u.Responses++
+ u.InputTokens += usage.InputTokens
+ u.OutputTokens += usage.OutputTokens
+ u.CacheReadInputTokens += usage.CacheReadInputTokens
+ u.CacheCreationInputTokens += usage.CacheCreationInputTokens
+ u.TotalCostUSD += usage.CostUSD
+}
+
+// TotalInputTokens returns the grand total cumulative input tokens in u.
+func (u *CumulativeUsage) TotalInputTokens() uint64 {
+ return u.InputTokens + u.CacheReadInputTokens + u.CacheCreationInputTokens
+}
+
+// Attr returns the cumulative usage as a slog.Attr with key "usage".
+func (u CumulativeUsage) Attr() slog.Attr {
+ elapsed := time.Since(u.StartTime)
+ return slog.Group("usage",
+ slog.Duration("wall_time", elapsed),
+ slog.Uint64("responses", u.Responses),
+ slog.Uint64("input_tokens", u.InputTokens),
+ slog.Uint64("output_tokens", u.OutputTokens),
+ slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
+ slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
+ slog.Float64("total_cost_usd", u.TotalCostUSD),
+ slog.Float64("dollars_per_hour", u.TotalCostUSD/elapsed.Hours()),
+ slog.Any("tool_uses", maps.Clone(u.ToolUses)),
+ )
+}
+
+// A Budget represents the maximum amount of resources that may be spent on a conversation.
+// Note that the default (zero) budget is unlimited.
+type Budget struct {
+ MaxResponses uint64 // if > 0, max number of iterations (=responses)
+ MaxDollars float64 // if > 0, max dollars that may be spent
+ MaxWallTime time.Duration // if > 0, max wall time that may be spent
+}
+
+// OverBudget returns an error if the convo (or any of its parents) has exceeded its budget.
+// TODO: document parent vs sub budgets, multiple errors, etc, once we know the desired behavior.
+func (c *Convo) OverBudget() error {
+ for x := c; x != nil; x = x.Parent {
+ if err := x.overBudget(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// ResetBudget sets the budget to the passed in budget and
+// adjusts it by what's been used so far.
+func (c *Convo) ResetBudget(budget Budget) {
+ c.Budget = budget
+ if c.Budget.MaxDollars > 0 {
+ c.Budget.MaxDollars += c.CumulativeUsage().TotalCostUSD
+ }
+ if c.Budget.MaxResponses > 0 {
+ c.Budget.MaxResponses += c.CumulativeUsage().Responses
+ }
+ if c.Budget.MaxWallTime > 0 {
+ c.Budget.MaxWallTime += c.usage.WallTime()
+ }
+}
+
+func (c *Convo) overBudget() error {
+ usage := c.CumulativeUsage()
+ // TODO: stop before we exceed the budget instead of after?
+ // Top priority is money, then time, then response count.
+ var err error
+ cont := "Continuing to chat will reset the budget."
+ if c.Budget.MaxDollars > 0 && usage.TotalCostUSD >= c.Budget.MaxDollars {
+ err = errors.Join(err, fmt.Errorf("$%.2f spent, budget is $%.2f. %s", usage.TotalCostUSD, c.Budget.MaxDollars, cont))
+ }
+ if c.Budget.MaxWallTime > 0 && usage.WallTime() >= c.Budget.MaxWallTime {
+ err = errors.Join(err, fmt.Errorf("%v elapsed, budget is %v. %s", usage.WallTime().Truncate(time.Second), c.Budget.MaxWallTime.Truncate(time.Second), cont))
+ }
+ if c.Budget.MaxResponses > 0 && usage.Responses >= c.Budget.MaxResponses {
+ err = errors.Join(err, fmt.Errorf("%d responses received, budget is %d. %s", usage.Responses, c.Budget.MaxResponses, cont))
+ }
+ return err
+}
diff --git a/llm/conversation/convo_test.go b/llm/conversation/convo_test.go
new file mode 100644
index 0000000..3fb1750
--- /dev/null
+++ b/llm/conversation/convo_test.go
@@ -0,0 +1,139 @@
+package conversation
+
+import (
+ "cmp"
+ "context"
+ "net/http"
+ "os"
+ "strings"
+ "testing"
+
+ "sketch.dev/httprr"
+ "sketch.dev/llm/ant"
+)
+
+func TestBasicConvo(t *testing.T) {
+ ctx := context.Background()
+ rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rr.ScrubReq(func(req *http.Request) error {
+ req.Header.Del("x-api-key")
+ return nil
+ })
+
+ apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_ANTHROPIC_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
+ srv := &ant.Service{
+ APIKey: apiKey,
+ HTTPC: rr.Client(),
+ }
+ convo := New(ctx, srv)
+
+ const name = "Cornelius"
+ res, err := convo.SendUserTextMessage("Hi, my name is " + name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, part := range res.Content {
+ t.Logf("%s", part.Text)
+ }
+ res, err = convo.SendUserTextMessage("What is my name?")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := ""
+ for _, part := range res.Content {
+ got += part.Text
+ }
+ if !strings.Contains(got, name) {
+ t.Errorf("model does not know the given name %s: %q", name, got)
+ }
+}
+
+// TestCancelToolUse tests the CancelToolUse function of the Convo struct
+func TestCancelToolUse(t *testing.T) {
+ tests := []struct {
+ name string
+ setupToolUse bool
+ toolUseID string
+ cancelErr error
+ expectError bool
+ expectCancel bool
+ }{
+ {
+ name: "Cancel existing tool use",
+ setupToolUse: true,
+ toolUseID: "tool123",
+ cancelErr: nil,
+ expectError: false,
+ expectCancel: true,
+ },
+ {
+ name: "Cancel existing tool use with error",
+ setupToolUse: true,
+ toolUseID: "tool456",
+ cancelErr: context.Canceled,
+ expectError: false,
+ expectCancel: true,
+ },
+ {
+ name: "Cancel non-existent tool use",
+ setupToolUse: false,
+ toolUseID: "tool789",
+ cancelErr: nil,
+ expectError: true,
+ expectCancel: false,
+ },
+ }
+
+ srv := &ant.Service{}
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ convo := New(context.Background(), srv)
+
+ var cancelCalled bool
+ var cancelledWithErr error
+
+ if tt.setupToolUse {
+ // Setup a mock cancel function to track calls
+ mockCancel := func(err error) {
+ cancelCalled = true
+ cancelledWithErr = err
+ }
+
+ convo.muToolUseCancel.Lock()
+ convo.toolUseCancel[tt.toolUseID] = mockCancel
+ convo.muToolUseCancel.Unlock()
+ }
+
+ err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
+
+ // Check if we got the expected error state
+ if (err != nil) != tt.expectError {
+ t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
+ }
+
+ // Check if the cancel function was called as expected
+ if cancelCalled != tt.expectCancel {
+ t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
+ }
+
+ // If we expected the cancel to be called, verify it was called with the right error
+ if tt.expectCancel && cancelledWithErr != tt.cancelErr {
+ t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
+ }
+
+ // Verify the toolUseID was removed from the map if it was initially added
+ if tt.setupToolUse {
+ convo.muToolUseCancel.Lock()
+ _, exists := convo.toolUseCancel[tt.toolUseID]
+ convo.muToolUseCancel.Unlock()
+
+ if exists {
+ t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
+ }
+ }
+ })
+ }
+}
diff --git a/ant/testdata/basic_convo.httprr b/llm/conversation/testdata/basic_convo.httprr
similarity index 100%
rename from ant/testdata/basic_convo.httprr
rename to llm/conversation/testdata/basic_convo.httprr
diff --git a/llm/llm.go b/llm/llm.go
new file mode 100644
index 0000000..3ba6ed4
--- /dev/null
+++ b/llm/llm.go
@@ -0,0 +1,229 @@
+// Package llm provides a unified interface for interacting with LLMs.
+package llm
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "strings"
+ "time"
+)
+
+type Service interface {
+ // Do sends a request to an LLM.
+ Do(context.Context, *Request) (*Response, error)
+}
+
+// MustSchema validates that schema is a valid JSON schema and returns it as a json.RawMessage.
+// It panics if the schema is invalid.
+func MustSchema(schema string) json.RawMessage {
+ // TODO: validate schema, for now just make sure it's valid JSON
+ schema = strings.TrimSpace(schema)
+ bytes := []byte(schema)
+ if !json.Valid(bytes) {
+ panic("invalid JSON schema: " + schema)
+ }
+ return json.RawMessage(bytes)
+}
+
+type Request struct {
+ Messages []Message
+ ToolChoice *ToolChoice
+ Tools []*Tool
+ System []SystemContent
+}
+
+// Message represents a message in the conversation.
+type Message struct {
+ Role MessageRole
+ Content []Content
+ ToolUse *ToolUse // use to control whether/which tool to use
+}
+
+// ToolUse represents a tool use in the message content.
+type ToolUse struct {
+ ID string
+ Name string
+}
+
+type ToolChoice struct {
+ Type ToolChoiceType
+ Name string
+}
+
+type SystemContent struct {
+ Text string
+ Type string
+ Cache bool
+}
+
+// Tool represents a tool available to an LLM.
+type Tool struct {
+ Name string
+ // Type is used by the text editor tool; see
+ // https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool
+ Type string
+ Description string
+ InputSchema json.RawMessage
+
+ // The Run function is automatically called when the tool is used.
+ // Run functions may be called concurrently with each other and themselves.
+ // The input to Run function is the input to the tool, as provided by Claude, in compliance with the input schema.
+ // The outputs from Run will be sent back to Claude.
+ // If you do not want to respond to the tool call request from Claude, return ErrDoNotRespond.
+ // ctx contains extra (rarely used) tool call information; retrieve it with ToolCallInfoFromContext.
+ Run func(ctx context.Context, input json.RawMessage) (string, error) `json:"-"`
+}
+
+type Content struct {
+ ID string
+ Type ContentType
+ Text string
+
+ // for thinking
+ Thinking string
+ Data string
+ Signature string
+
+ // for tool_use
+ ToolName string
+ ToolInput json.RawMessage
+
+ // for tool_result
+ ToolUseID string
+ ToolError bool
+ ToolResult string
+
+ // timing information for tool_result; added externally; not sent to the LLM
+ ToolUseStartTime *time.Time
+ ToolUseEndTime *time.Time
+
+ Cache bool
+}
+
+func StringContent(s string) Content {
+ return Content{Type: ContentTypeText, Text: s}
+}
+
+// ContentsAttr returns contents as a slog.Attr.
+// It is meant for logging.
+func ContentsAttr(contents []Content) slog.Attr {
+ var contentAttrs []any // slog.Attr
+ for _, content := range contents {
+ var attrs []any // slog.Attr
+ switch content.Type {
+ case ContentTypeText:
+ attrs = append(attrs, slog.String("text", content.Text))
+ case ContentTypeToolUse:
+ attrs = append(attrs, slog.String("tool_name", content.ToolName))
+ attrs = append(attrs, slog.String("tool_input", string(content.ToolInput)))
+ case ContentTypeToolResult:
+ attrs = append(attrs, slog.String("tool_result", content.ToolResult))
+ attrs = append(attrs, slog.Bool("tool_error", content.ToolError))
+ case ContentTypeThinking:
+ attrs = append(attrs, slog.String("thinking", content.Text))
+ default:
+ attrs = append(attrs, slog.String("unknown_content_type", content.Type.String()))
+ attrs = append(attrs, slog.Any("text", content)) // just log it all raw, better to have too much than not enough
+ }
+ contentAttrs = append(contentAttrs, slog.Group(content.ID, attrs...))
+ }
+ return slog.Group("contents", contentAttrs...)
+}
+
+type (
+ MessageRole int
+ ContentType int
+ ToolChoiceType int
+ StopReason int
+)
+
+//go:generate go tool golang.org/x/tools/cmd/stringer -type=MessageRole,ContentType,ToolChoiceType,StopReason -output=llm_string.go
+
+const (
+ MessageRoleUser MessageRole = iota
+ MessageRoleAssistant
+
+ ContentTypeText ContentType = iota
+ ContentTypeThinking
+ ContentTypeRedactedThinking
+ ContentTypeToolUse
+ ContentTypeToolResult
+
+ ToolChoiceTypeAuto ToolChoiceType = iota // default
+ ToolChoiceTypeAny // any tool, but must use one
+ ToolChoiceTypeNone // no tools allowed
+ ToolChoiceTypeTool // must use the tool specified in the Name field
+
+ StopReasonStopSequence StopReason = iota
+ StopReasonMaxTokens
+ StopReasonEndTurn
+ StopReasonToolUse
+)
+
+type Response struct {
+ ID string
+ Type string
+ Role MessageRole
+ Model string
+ Content []Content
+ StopReason StopReason
+ StopSequence *string
+ Usage Usage
+ StartTime *time.Time
+ EndTime *time.Time
+}
+
+func (m *Response) ToMessage() Message {
+ return Message{
+ Role: m.Role,
+ Content: m.Content,
+ }
+}
+
+// Usage represents the billing and rate-limit usage.
+// Most LLM structs do not have JSON tags, to avoid accidental direct use in specific providers.
+// However, the front-end uses this struct, and it relies on its JSON serialization.
+// Do NOT use this struct directly when implementing an llm.Service.
+type Usage struct {
+ InputTokens uint64 `json:"input_tokens"`
+ CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
+ CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
+ OutputTokens uint64 `json:"output_tokens"`
+ CostUSD float64 `json:"cost_usd"`
+}
+
+func (u *Usage) Add(other Usage) {
+ u.InputTokens += other.InputTokens
+ u.CacheCreationInputTokens += other.CacheCreationInputTokens
+ u.CacheReadInputTokens += other.CacheReadInputTokens
+ u.OutputTokens += other.OutputTokens
+ u.CostUSD += other.CostUSD
+}
+
+func (u *Usage) String() string {
+ return fmt.Sprintf("in: %d, out: %d", u.InputTokens, u.OutputTokens)
+}
+
+func (u *Usage) IsZero() bool {
+ return *u == Usage{}
+}
+
+func (u *Usage) Attr() slog.Attr {
+ return slog.Group("usage",
+ slog.Uint64("input_tokens", u.InputTokens),
+ slog.Uint64("output_tokens", u.OutputTokens),
+ slog.Uint64("cache_creation_input_tokens", u.CacheCreationInputTokens),
+ slog.Uint64("cache_read_input_tokens", u.CacheReadInputTokens),
+ slog.Float64("cost_usd", u.CostUSD),
+ )
+}
+
+// UserStringMessage creates a user message with a single text content item.
+func UserStringMessage(text string) Message {
+ return Message{
+ Role: MessageRoleUser,
+ Content: []Content{StringContent(text)},
+ }
+}
diff --git a/llm/llm_string.go b/llm/llm_string.go
new file mode 100644
index 0000000..1c3189e
--- /dev/null
+++ b/llm/llm_string.go
@@ -0,0 +1,88 @@
+// Code generated by "stringer -type=MessageRole,ContentType,ToolChoiceType,StopReason -output=llm_string.go"; DO NOT EDIT.
+
+package llm
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[MessageRoleUser-0]
+ _ = x[MessageRoleAssistant-1]
+}
+
+const _MessageRole_name = "MessageRoleUserMessageRoleAssistant"
+
+var _MessageRole_index = [...]uint8{0, 15, 35}
+
+func (i MessageRole) String() string {
+ if i < 0 || i >= MessageRole(len(_MessageRole_index)-1) {
+ return "MessageRole(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _MessageRole_name[_MessageRole_index[i]:_MessageRole_index[i+1]]
+}
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[ContentTypeText-2]
+ _ = x[ContentTypeThinking-3]
+ _ = x[ContentTypeRedactedThinking-4]
+ _ = x[ContentTypeToolUse-5]
+ _ = x[ContentTypeToolResult-6]
+}
+
+const _ContentType_name = "ContentTypeTextContentTypeThinkingContentTypeRedactedThinkingContentTypeToolUseContentTypeToolResult"
+
+var _ContentType_index = [...]uint8{0, 15, 34, 61, 79, 100}
+
+func (i ContentType) String() string {
+ i -= 2
+ if i < 0 || i >= ContentType(len(_ContentType_index)-1) {
+ return "ContentType(" + strconv.FormatInt(int64(i+2), 10) + ")"
+ }
+ return _ContentType_name[_ContentType_index[i]:_ContentType_index[i+1]]
+}
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[ToolChoiceTypeAuto-7]
+ _ = x[ToolChoiceTypeAny-8]
+ _ = x[ToolChoiceTypeNone-9]
+ _ = x[ToolChoiceTypeTool-10]
+}
+
+const _ToolChoiceType_name = "ToolChoiceTypeAutoToolChoiceTypeAnyToolChoiceTypeNoneToolChoiceTypeTool"
+
+var _ToolChoiceType_index = [...]uint8{0, 18, 35, 53, 71}
+
+func (i ToolChoiceType) String() string {
+ i -= 7
+ if i < 0 || i >= ToolChoiceType(len(_ToolChoiceType_index)-1) {
+ return "ToolChoiceType(" + strconv.FormatInt(int64(i+7), 10) + ")"
+ }
+ return _ToolChoiceType_name[_ToolChoiceType_index[i]:_ToolChoiceType_index[i+1]]
+}
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[StopReasonStopSequence-11]
+ _ = x[StopReasonMaxTokens-12]
+ _ = x[StopReasonEndTurn-13]
+ _ = x[StopReasonToolUse-14]
+}
+
+const _StopReason_name = "StopReasonStopSequenceStopReasonMaxTokensStopReasonEndTurnStopReasonToolUse"
+
+var _StopReason_index = [...]uint8{0, 22, 41, 58, 75}
+
+func (i StopReason) String() string {
+ i -= 11
+ if i < 0 || i >= StopReason(len(_StopReason_index)-1) {
+ return "StopReason(" + strconv.FormatInt(int64(i+11), 10) + ")"
+ }
+ return _StopReason_name[_StopReason_index[i]:_StopReason_index[i+1]]
+}
diff --git a/llm/oai/oai.go b/llm/oai/oai.go
new file mode 100644
index 0000000..3e772ab
--- /dev/null
+++ b/llm/oai/oai.go
@@ -0,0 +1,592 @@
+package oai
+
+import (
+ "cmp"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "math/rand/v2"
+ "net/http"
+ "time"
+
+ "github.com/sashabaranov/go-openai"
+ "sketch.dev/llm"
+)
+
+const (
+ DefaultMaxTokens = 8192
+
+ OpenAIURL = "https://api.openai.com/v1"
+ FireworksURL = "https://api.fireworks.ai/inference/v1"
+ LlamaCPPURL = "http://localhost:8080/v1"
+ TogetherURL = "https://api.together.xyz/v1"
+ GeminiURL = "https://generativelanguage.googleapis.com/v1beta/openai/"
+
+ // Environment variable names for API keys
+ OpenAIAPIKeyEnv = "OPENAI_API_KEY"
+ FireworksAPIKeyEnv = "FIREWORKS_API_KEY"
+ TogetherAPIKeyEnv = "TOGETHER_API_KEY"
+ GeminiAPIKeyEnv = "GEMINI_API_KEY"
+)
+
+type Model struct {
+ UserName string // provided by the user to identify this model (e.g. "gpt4.1")
+ ModelName string // provided to the service provide to specify which model to use (e.g. "gpt-4.1-2025-04-14")
+ URL string
+ Cost ModelCost
+ APIKeyEnv string // environment variable name for the API key
+}
+
+type ModelCost struct {
+ Input uint64 // in cents per million tokens
+ CachedInput uint64 // in cents per million tokens
+ Output uint64 // in cents per million tokens
+}
+
+var (
+ DefaultModel = GPT41
+
+ GPT41 = Model{
+ UserName: "gpt4.1",
+ ModelName: "gpt-4.1-2025-04-14",
+ URL: OpenAIURL,
+ Cost: ModelCost{Input: 200, CachedInput: 50, Output: 800},
+ APIKeyEnv: OpenAIAPIKeyEnv,
+ }
+
+ Gemini25Flash = Model{
+ UserName: "gemini-flash-2.5",
+ ModelName: "gemini-2.5-flash-preview-04-17",
+ URL: GeminiURL,
+ Cost: ModelCost{Input: 15, Output: 60},
+ APIKeyEnv: GeminiAPIKeyEnv,
+ }
+
+ Gemini25Pro = Model{
+ UserName: "gemini-pro-2.5",
+ ModelName: "gemini-2.5-pro-preview-03-25",
+ URL: GeminiURL,
+ // GRRRR. Really??
+ // Input is: $1.25, prompts <= 200k tokens, $2.50, prompts > 200k tokens
+ // Output is: $10.00, prompts <= 200k tokens, $15.00, prompts > 200k
+ // Caching is: $0.31, prompts <= 200k tokens, $0.625, prompts > 200k, $4.50 / 1,000,000 tokens per hour
+ // Whatever that means. Are we caching? I have no idea.
+ // How do you always manage to be the annoying one, Google?
+ // I'm not complicating things just for you.
+ Cost: ModelCost{Input: 125, Output: 1000},
+ APIKeyEnv: GeminiAPIKeyEnv,
+ }
+
+ TogetherDeepseekV3 = Model{
+ UserName: "together-deepseek-v3",
+ ModelName: "deepseek-ai/DeepSeek-V3",
+ URL: TogetherURL,
+ Cost: ModelCost{Input: 125, Output: 125},
+ APIKeyEnv: TogetherAPIKeyEnv,
+ }
+
+ TogetherLlama4Maverick = Model{
+ UserName: "together-llama4-maverick",
+ ModelName: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
+ URL: TogetherURL,
+ Cost: ModelCost{Input: 27, Output: 85},
+ APIKeyEnv: TogetherAPIKeyEnv,
+ }
+
+ TogetherLlama3_3_70B = Model{
+ UserName: "together-llama3-70b",
+ ModelName: "meta-llama/Llama-3.3-70B-Instruct-Turbo",
+ URL: TogetherURL,
+ Cost: ModelCost{Input: 88, Output: 88},
+ APIKeyEnv: TogetherAPIKeyEnv,
+ }
+
+ TogetherMistralSmall = Model{
+ UserName: "together-mistral-small",
+ ModelName: "mistralai/Mistral-Small-24B-Instruct-2501",
+ URL: TogetherURL,
+ Cost: ModelCost{Input: 80, Output: 80},
+ APIKeyEnv: TogetherAPIKeyEnv,
+ }
+
+ LlamaCPP = Model{
+ UserName: "llama.cpp",
+ ModelName: "llama.cpp local model",
+ URL: LlamaCPPURL,
+ // zero cost
+ Cost: ModelCost{},
+ }
+
+ FireworksDeepseekV3 = Model{
+ UserName: "fireworks-deepseek-v3",
+ ModelName: "accounts/fireworks/models/deepseek-v3-0324",
+ URL: FireworksURL,
+ Cost: ModelCost{Input: 90, Output: 90}, // not entirely sure about this, they don't list pricing anywhere convenient
+ APIKeyEnv: FireworksAPIKeyEnv,
+ }
+)
+
+// Service provides chat completions.
+// Fields should not be altered concurrently with calling any method on Service.
+type Service struct {
+ HTTPC *http.Client // defaults to http.DefaultClient if nil
+ APIKey string // optional, if not set will try to load from env var
+ Model Model // defaults to DefaultModel if zero value
+ MaxTokens int // defaults to DefaultMaxTokens if zero
+ Org string // optional - organization ID
+}
+
+var _ llm.Service = (*Service)(nil)
+
+// ModelsRegistry is a registry of all known models with their user-friendly names.
+var ModelsRegistry = []Model{
+ GPT41,
+ Gemini25Flash,
+ Gemini25Pro,
+ TogetherDeepseekV3,
+ TogetherLlama4Maverick,
+ TogetherLlama3_3_70B,
+ TogetherMistralSmall,
+ LlamaCPP,
+ FireworksDeepseekV3,
+}
+
+// ListModels returns a list of all available models with their user-friendly names.
+func ListModels() []string {
+ var names []string
+ for _, model := range ModelsRegistry {
+ if model.UserName != "" {
+ names = append(names, model.UserName)
+ }
+ }
+ return names
+}
+
+// ModelByUserName returns a model by its user-friendly name.
+// Returns nil if no model with the given name is found.
+func ModelByUserName(name string) *Model {
+ for _, model := range ModelsRegistry {
+ if model.UserName == name {
+ return &model
+ }
+ }
+ return nil
+}
+
+var (
+ fromLLMRole = map[llm.MessageRole]string{
+ llm.MessageRoleAssistant: "assistant",
+ llm.MessageRoleUser: "user",
+ }
+ fromLLMContentType = map[llm.ContentType]string{
+ llm.ContentTypeText: "text",
+ llm.ContentTypeToolUse: "function", // OpenAI uses function instead of tool_call
+ llm.ContentTypeToolResult: "tool_result",
+ llm.ContentTypeThinking: "text", // Map thinking to text since OpenAI doesn't have thinking
+ llm.ContentTypeRedactedThinking: "text", // Map redacted_thinking to text
+ }
+ fromLLMToolChoiceType = map[llm.ToolChoiceType]string{
+ llm.ToolChoiceTypeAuto: "auto",
+ llm.ToolChoiceTypeAny: "any",
+ llm.ToolChoiceTypeNone: "none",
+ llm.ToolChoiceTypeTool: "function", // OpenAI uses "function" instead of "tool"
+ }
+ toLLMRole = map[string]llm.MessageRole{
+ "assistant": llm.MessageRoleAssistant,
+ "user": llm.MessageRoleUser,
+ }
+ toLLMStopReason = map[string]llm.StopReason{
+ "stop": llm.StopReasonStopSequence,
+ "length": llm.StopReasonMaxTokens,
+ "tool_calls": llm.StopReasonToolUse,
+ "function_call": llm.StopReasonToolUse, // Map both to ToolUse
+ "content_filter": llm.StopReasonStopSequence, // No direct equivalent
+ }
+)
+
+// fromLLMContent converts llm.Content to the format expected by OpenAI.
+func fromLLMContent(c llm.Content) (string, []openai.ToolCall) {
+ switch c.Type {
+ case llm.ContentTypeText:
+ return c.Text, nil
+ case llm.ContentTypeToolUse:
+ // For OpenAI, tool use is sent as a null content with tool_calls in the message
+ return "", []openai.ToolCall{
+ {
+ Type: openai.ToolTypeFunction,
+ ID: c.ID, // Use the content ID if provided
+ Function: openai.FunctionCall{
+ Name: c.ToolName,
+ Arguments: string(c.ToolInput),
+ },
+ },
+ }
+ case llm.ContentTypeToolResult:
+ // Tool results in OpenAI are sent as a separate message with tool_call_id
+ return c.ToolResult, nil
+ default:
+ // For thinking or other types, convert to text
+ return c.Text, nil
+ }
+}
+
+// fromLLMMessage converts llm.Message to OpenAI ChatCompletionMessage format
+func fromLLMMessage(msg llm.Message) []openai.ChatCompletionMessage {
+ // For OpenAI, we need to handle tool results differently than regular messages
+ // Each tool result becomes its own message with role="tool"
+
+ var messages []openai.ChatCompletionMessage
+
+ // Check if this is a regular message or contains tool results
+ var regularContent []llm.Content
+ var toolResults []llm.Content
+
+ for _, c := range msg.Content {
+ if c.Type == llm.ContentTypeToolResult {
+ toolResults = append(toolResults, c)
+ } else {
+ regularContent = append(regularContent, c)
+ }
+ }
+
+ // Process tool results as separate messages, but first
+ for _, tr := range toolResults {
+ m := openai.ChatCompletionMessage{
+ Role: "tool",
+ Content: cmp.Or(tr.ToolResult, " "), // TODO: remove omitempty upstream
+ ToolCallID: tr.ToolUseID,
+ }
+ messages = append(messages, m)
+ }
+ // Process regular content second
+ if len(regularContent) > 0 {
+ m := openai.ChatCompletionMessage{
+ Role: fromLLMRole[msg.Role],
+ }
+
+ // For assistant messages that contain tool calls
+ var toolCalls []openai.ToolCall
+ var textContent string
+
+ for _, c := range regularContent {
+ content, tools := fromLLMContent(c)
+ if len(tools) > 0 {
+ toolCalls = append(toolCalls, tools...)
+ } else if content != "" {
+ if textContent != "" {
+ textContent += "\n"
+ }
+ textContent += content
+ }
+ }
+
+ m.Content = textContent
+ m.ToolCalls = toolCalls
+
+ messages = append(messages, m)
+ }
+
+ return messages
+}
+
+// fromLLMToolChoice converts llm.ToolChoice to the format expected by OpenAI.
+func fromLLMToolChoice(tc *llm.ToolChoice) any {
+ if tc == nil {
+ return nil
+ }
+
+ if tc.Type == llm.ToolChoiceTypeTool && tc.Name != "" {
+ return openai.ToolChoice{
+ Type: openai.ToolTypeFunction,
+ Function: openai.ToolFunction{
+ Name: tc.Name,
+ },
+ }
+ }
+
+ // For non-specific tool choice, just use the string
+ return fromLLMToolChoiceType[tc.Type]
+}
+
+// fromLLMTool converts llm.Tool to the format expected by OpenAI.
+func fromLLMTool(t *llm.Tool) openai.Tool {
+ return openai.Tool{
+ Type: openai.ToolTypeFunction,
+ Function: &openai.FunctionDefinition{
+ Name: t.Name,
+ Description: t.Description,
+ Parameters: t.InputSchema,
+ },
+ }
+}
+
+// fromLLMSystem converts llm.SystemContent to an OpenAI system message.
+func fromLLMSystem(systemContent []llm.SystemContent) []openai.ChatCompletionMessage {
+ if len(systemContent) == 0 {
+ return nil
+ }
+
+ // Combine all system content into a single system message
+ var systemText string
+ for i, content := range systemContent {
+ if i > 0 && systemText != "" && content.Text != "" {
+ systemText += "\n"
+ }
+ systemText += content.Text
+ }
+
+ if systemText == "" {
+ return nil
+ }
+
+ return []openai.ChatCompletionMessage{
+ {
+ Role: "system",
+ Content: systemText,
+ },
+ }
+}
+
+// toRawLLMContent converts a raw content string from OpenAI to llm.Content.
+func toRawLLMContent(content string) llm.Content {
+ return llm.Content{
+ Type: llm.ContentTypeText,
+ Text: content,
+ }
+}
+
+// toToolCallLLMContent converts a tool call from OpenAI to llm.Content.
+func toToolCallLLMContent(toolCall openai.ToolCall) llm.Content {
+ // Generate a content ID if needed
+ id := toolCall.ID
+ if id == "" {
+ // Create a deterministic ID based on the function name if no ID is provided
+ id = "tc_" + toolCall.Function.Name
+ }
+
+ return llm.Content{
+ ID: id,
+ Type: llm.ContentTypeToolUse,
+ ToolName: toolCall.Function.Name,
+ ToolInput: json.RawMessage(toolCall.Function.Arguments),
+ }
+}
+
+// toToolResultLLMContent converts a tool result message from OpenAI to llm.Content.
+func toToolResultLLMContent(msg openai.ChatCompletionMessage) llm.Content {
+ return llm.Content{
+ Type: llm.ContentTypeToolResult,
+ ToolUseID: msg.ToolCallID,
+ ToolResult: msg.Content,
+ ToolError: false, // OpenAI doesn't specify errors explicitly
+ }
+}
+
+// toLLMContents converts message content from OpenAI to []llm.Content.
+func toLLMContents(msg openai.ChatCompletionMessage) []llm.Content {
+ var contents []llm.Content
+
+ // If this is a tool response, handle it separately
+ if msg.Role == "tool" && msg.ToolCallID != "" {
+ return []llm.Content{toToolResultLLMContent(msg)}
+ }
+
+ // If there's text content, add it
+ if msg.Content != "" {
+ contents = append(contents, toRawLLMContent(msg.Content))
+ }
+
+ // If there are tool calls, add them
+ for _, tc := range msg.ToolCalls {
+ contents = append(contents, toToolCallLLMContent(tc))
+ }
+
+ // If empty, add an empty text content
+ if len(contents) == 0 {
+ contents = append(contents, llm.Content{
+ Type: llm.ContentTypeText,
+ Text: "",
+ })
+ }
+
+ return contents
+}
+
+// toLLMUsage converts usage information from OpenAI to llm.Usage.
+func (s *Service) toLLMUsage(model string, au openai.Usage) llm.Usage {
+ // fmt.Printf("raw usage: %+v / %v / %v\n", au, au.PromptTokensDetails, au.CompletionTokensDetails)
+ in := uint64(au.PromptTokens)
+ var inc uint64
+ if au.PromptTokensDetails != nil {
+ inc = uint64(au.PromptTokensDetails.CachedTokens)
+ }
+ out := uint64(au.CompletionTokens)
+ u := llm.Usage{
+ InputTokens: in,
+ CacheReadInputTokens: inc,
+ CacheCreationInputTokens: in,
+ OutputTokens: out,
+ }
+ u.CostUSD = s.calculateCostFromTokens(u)
+ return u
+}
+
+// toLLMResponse converts the OpenAI response to llm.Response.
+func (s *Service) toLLMResponse(r *openai.ChatCompletionResponse) *llm.Response {
+ // fmt.Printf("Raw response\n")
+ // enc := json.NewEncoder(os.Stdout)
+ // enc.SetIndent("", " ")
+ // enc.Encode(r)
+ // fmt.Printf("\n")
+
+ if len(r.Choices) == 0 {
+ return &llm.Response{
+ ID: r.ID,
+ Model: r.Model,
+ Role: llm.MessageRoleAssistant,
+ Usage: s.toLLMUsage(r.Model, r.Usage),
+ }
+ }
+
+ // Process the primary choice
+ choice := r.Choices[0]
+
+ return &llm.Response{
+ ID: r.ID,
+ Model: r.Model,
+ Role: toRoleFromString(choice.Message.Role),
+ Content: toLLMContents(choice.Message),
+ StopReason: toStopReason(string(choice.FinishReason)),
+ Usage: s.toLLMUsage(r.Model, r.Usage),
+ }
+}
+
+// toRoleFromString converts a role string to llm.MessageRole.
+func toRoleFromString(role string) llm.MessageRole {
+ if role == "tool" || role == "system" || role == "function" {
+ return llm.MessageRoleAssistant // Map special roles to assistant for consistency
+ }
+ if mr, ok := toLLMRole[role]; ok {
+ return mr
+ }
+ return llm.MessageRoleUser // Default to user if unknown
+}
+
+// toStopReason converts a finish reason string to llm.StopReason.
+func toStopReason(reason string) llm.StopReason {
+ if sr, ok := toLLMStopReason[reason]; ok {
+ return sr
+ }
+ return llm.StopReasonStopSequence // Default
+}
+
+// calculateCostFromTokens calculates the cost in dollars for the given model and token counts.
+func (s *Service) calculateCostFromTokens(u llm.Usage) float64 {
+ cost := s.Model.Cost
+
+ // TODO: check this for correctness, i am skeptical
+ // Calculate cost in cents
+ megaCents := u.CacheCreationInputTokens*cost.Input +
+ u.CacheReadInputTokens*cost.CachedInput +
+ u.OutputTokens*cost.Output
+
+ cents := float64(megaCents) / 1_000_000
+ // Convert to dollars
+ dollars := cents / 100.0
+ // fmt.Printf("in_new=%d, in_cached=%d, out=%d, cost=%.2f\n", u.CacheCreationInputTokens, u.CacheReadInputTokens, u.OutputTokens, dollars)
+ return dollars
+}
+
+// Do sends a request to OpenAI using the go-openai package.
+func (s *Service) Do(ctx context.Context, ir *llm.Request) (*llm.Response, error) {
+ // Configure the OpenAI client
+ httpc := cmp.Or(s.HTTPC, http.DefaultClient)
+ model := cmp.Or(s.Model, DefaultModel)
+
+ // TODO: do this one during Service setup? maybe with a constructor instead?
+ config := openai.DefaultConfig(s.APIKey)
+ if model.URL != "" {
+ config.BaseURL = model.URL
+ }
+ if s.Org != "" {
+ config.OrgID = s.Org
+ }
+ config.HTTPClient = httpc
+
+ client := openai.NewClientWithConfig(config)
+
+ // Start with system messages if provided
+ var allMessages []openai.ChatCompletionMessage
+ if len(ir.System) > 0 {
+ sysMessages := fromLLMSystem(ir.System)
+ allMessages = append(allMessages, sysMessages...)
+ }
+
+ // Add regular and tool messages
+ for _, msg := range ir.Messages {
+ msgs := fromLLMMessage(msg)
+ allMessages = append(allMessages, msgs...)
+ }
+
+ // Convert tools
+ var tools []openai.Tool
+ for _, t := range ir.Tools {
+ tools = append(tools, fromLLMTool(t))
+ }
+
+ // Create the OpenAI request
+ req := openai.ChatCompletionRequest{
+ Model: model.ModelName,
+ Messages: allMessages,
+ MaxTokens: cmp.Or(s.MaxTokens, DefaultMaxTokens),
+ Tools: tools,
+ ToolChoice: fromLLMToolChoice(ir.ToolChoice), // TODO: make fromLLMToolChoice return an error when a perfect translation is not possible
+ }
+ // fmt.Printf("Sending request to OpenAI\n")
+ // enc := json.NewEncoder(os.Stdout)
+ // enc.SetIndent("", " ")
+ // enc.Encode(req)
+ // fmt.Printf("\n")
+
+ // Retry mechanism
+ backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second}
+
+ // retry loop
+ for attempts := 0; ; attempts++ {
+ resp, err := client.CreateChatCompletion(ctx, req)
+
+ // Handle successful response
+ if err == nil {
+ return s.toLLMResponse(&resp), nil
+ }
+
+ // Handle errors
+ var apiErr *openai.APIError
+ if ok := errors.As(err, &apiErr); !ok {
+ // Not an OpenAI API error, return immediately
+ return nil, err
+ }
+
+ switch {
+ case apiErr.HTTPStatusCode >= 500:
+ // Server error, try again with backoff
+ sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode, "sleep", sleep)
+ time.Sleep(sleep)
+ continue
+
+ case apiErr.HTTPStatusCode == 429:
+ // Rate limited, back off longer
+ sleep := 20*time.Second + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error(), "sleep", sleep)
+ time.Sleep(sleep)
+ continue
+
+ default:
+ // Other error, return immediately
+ return nil, fmt.Errorf("OpenAI API error: %w", err)
+ }
+ }
+}
diff --git a/llm/oai/oai_test.go b/llm/oai/oai_test.go
new file mode 100644
index 0000000..7bea552
--- /dev/null
+++ b/llm/oai/oai_test.go
@@ -0,0 +1,96 @@
+package oai
+
+import (
+ "math"
+ "testing"
+
+ "sketch.dev/llm"
+)
+
+// TestCalculateCostFromTokens tests the calculateCostFromTokens method
+func TestCalculateCostFromTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ model Model
+ cacheCreationTokens uint64
+ cacheReadTokens uint64
+ outputTokens uint64
+ want float64
+ }{
+ {
+ name: "Zero tokens",
+ model: GPT41,
+ cacheCreationTokens: 0,
+ cacheReadTokens: 0,
+ outputTokens: 0,
+ want: 0,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens",
+ model: GPT41,
+ cacheCreationTokens: 1000,
+ cacheReadTokens: 0,
+ outputTokens: 500,
+ // GPT41: Input: 200 per million, Output: 800 per million
+ // (1000 * 200 + 500 * 800) / 1_000_000 / 100 = 0.006
+ want: 0.006,
+ },
+ {
+ name: "10000 input tokens, 5000 output tokens",
+ model: GPT41,
+ cacheCreationTokens: 10000,
+ cacheReadTokens: 0,
+ outputTokens: 5000,
+ // (10000 * 200 + 5000 * 800) / 1_000_000 / 100 = 0.06
+ want: 0.06,
+ },
+ {
+ name: "1000 input tokens, 500 output tokens Gemini",
+ model: Gemini25Flash,
+ cacheCreationTokens: 1000,
+ cacheReadTokens: 0,
+ outputTokens: 500,
+ // Gemini25Flash: Input: 15 per million, Output: 60 per million
+ // (1000 * 15 + 500 * 60) / 1_000_000 / 100 = 0.00045
+ want: 0.00045,
+ },
+ {
+ name: "With cache read tokens",
+ model: GPT41,
+ cacheCreationTokens: 500,
+ cacheReadTokens: 500, // 500 tokens from cache
+ outputTokens: 500,
+ // (500 * 200 + 500 * 50 + 500 * 800) / 1_000_000 / 100 = 0.00525
+ want: 0.00525,
+ },
+ {
+ name: "With all token types",
+ model: GPT41,
+ cacheCreationTokens: 1000,
+ cacheReadTokens: 1000,
+ outputTokens: 1000,
+ // (1000 * 200 + 1000 * 50 + 1000 * 800) / 1_000_000 / 100 = 0.0105
+ want: 0.0105,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Create a service with the test model
+ svc := &Service{Model: tt.model}
+
+ // Create a usage object
+ usage := llm.Usage{
+ CacheCreationInputTokens: tt.cacheCreationTokens,
+ CacheReadInputTokens: tt.cacheReadTokens,
+ OutputTokens: tt.outputTokens,
+ }
+
+ totalCost := svc.calculateCostFromTokens(usage)
+ if math.Abs(totalCost-tt.want) > 0.0001 {
+ t.Errorf("calculateCostFromTokens(%s, cache_creation=%d, cache_read=%d, output=%d) = %v, want %v",
+ tt.model.ModelName, tt.cacheCreationTokens, tt.cacheReadTokens, tt.outputTokens, totalCost, tt.want)
+ }
+ })
+ }
+}
diff --git a/loop/agent.go b/loop/agent.go
index 3076385..960bf5a 100644
--- a/loop/agent.go
+++ b/loop/agent.go
@@ -17,10 +17,11 @@
"sync"
"time"
- "sketch.dev/ant"
"sketch.dev/browser"
"sketch.dev/claudetool"
"sketch.dev/claudetool/bashkit"
+ "sketch.dev/llm"
+ "sketch.dev/llm/conversation"
)
const (
@@ -64,8 +65,8 @@
// Returns the current number of messages in the history
MessageCount() int
- TotalUsage() ant.CumulativeUsage
- OriginalBudget() ant.Budget
+ TotalUsage() conversation.CumulativeUsage
+ OriginalBudget() conversation.Budget
WorkingDir() string
@@ -150,7 +151,7 @@
Timestamp time.Time `json:"timestamp"`
ConversationID string `json:"conversation_id"`
ParentConversationID *string `json:"parent_conversation_id,omitempty"`
- Usage *ant.Usage `json:"usage,omitempty"`
+ Usage *llm.Usage `json:"usage,omitempty"`
// Message timing information
StartTime *time.Time `json:"start_time,omitempty"`
@@ -164,7 +165,7 @@
}
// SetConvo sets m.ConversationID and m.ParentConversationID based on convo.
-func (m *AgentMessage) SetConvo(convo *ant.Convo) {
+func (m *AgentMessage) SetConvo(convo *conversation.Convo) {
if convo == nil {
m.ConversationID = ""
m.ParentConversationID = nil
@@ -262,16 +263,16 @@
// ConvoInterface defines the interface for conversation interactions
type ConvoInterface interface {
- CumulativeUsage() ant.CumulativeUsage
- ResetBudget(ant.Budget)
+ CumulativeUsage() conversation.CumulativeUsage
+ ResetBudget(conversation.Budget)
OverBudget() error
- SendMessage(message ant.Message) (*ant.MessageResponse, error)
- SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
+ SendMessage(message llm.Message) (*llm.Response, error)
+ SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error)
GetID() string
- ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
- ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error)
+ ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error)
+ ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error)
CancelToolUse(toolUseID string, cause error) error
- SubConvoWithHistory() *ant.Convo
+ SubConvoWithHistory() *conversation.Convo
}
type Agent struct {
@@ -287,7 +288,7 @@
outsideHTTP string // base address of the outside webserver (only when under docker)
ready chan struct{} // closed when the agent is initialized (only when under docker)
startedAt time.Time
- originalBudget ant.Budget
+ originalBudget conversation.Budget
title string
branchName string
codereview *claudetool.CodeReviewer
@@ -531,7 +532,7 @@
}
// OnToolCall implements ant.Listener and tracks the start of a tool call.
-func (a *Agent) OnToolCall(ctx context.Context, convo *ant.Convo, id string, toolName string, toolInput json.RawMessage, content ant.Content) {
+func (a *Agent) OnToolCall(ctx context.Context, convo *conversation.Convo, id string, toolName string, toolInput json.RawMessage, content llm.Content) {
// Track the tool call
a.mu.Lock()
a.outstandingToolCalls[id] = toolName
@@ -539,7 +540,7 @@
}
// OnToolResult implements ant.Listener.
-func (a *Agent) OnToolResult(ctx context.Context, convo *ant.Convo, toolID string, toolName string, toolInput json.RawMessage, content ant.Content, result *string, err error) {
+func (a *Agent) OnToolResult(ctx context.Context, convo *conversation.Convo, toolID string, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
// Remove the tool call from outstanding calls
a.mu.Lock()
delete(a.outstandingToolCalls, toolID)
@@ -553,13 +554,13 @@
ToolName: toolName,
ToolInput: string(toolInput),
ToolCallId: content.ToolUseID,
- StartTime: content.StartTime,
- EndTime: content.EndTime,
+ StartTime: content.ToolUseStartTime,
+ EndTime: content.ToolUseEndTime,
}
// Calculate the elapsed time if both start and end times are set
- if content.StartTime != nil && content.EndTime != nil {
- elapsed := content.EndTime.Sub(*content.StartTime)
+ if content.ToolUseStartTime != nil && content.ToolUseEndTime != nil {
+ elapsed := content.ToolUseEndTime.Sub(*content.ToolUseStartTime)
m.Elapsed = &elapsed
}
@@ -568,18 +569,18 @@
}
// OnRequest implements ant.Listener.
-func (a *Agent) OnRequest(ctx context.Context, convo *ant.Convo, id string, msg *ant.Message) {
+func (a *Agent) OnRequest(ctx context.Context, convo *conversation.Convo, id string, msg *llm.Message) {
a.mu.Lock()
defer a.mu.Unlock()
a.outstandingLLMCalls[id] = struct{}{}
// We already get tool results from the above. We send user messages to the outbox in the agent loop.
}
-// OnResponse implements ant.Listener. Responses contain messages from the LLM
+// OnResponse implements conversation.Listener. Responses contain messages from the LLM
// that need to be displayed (as well as tool calls that we send along when
// they're done). (It would be reasonable to also mention tool calls when they're
// started, but we don't do that yet.)
-func (a *Agent) OnResponse(ctx context.Context, convo *ant.Convo, id string, resp *ant.MessageResponse) {
+func (a *Agent) OnResponse(ctx context.Context, convo *conversation.Convo, id string, resp *llm.Response) {
// Remove the LLM call from outstanding calls
a.mu.Lock()
delete(a.outstandingLLMCalls, id)
@@ -597,7 +598,7 @@
}
endOfTurn := false
- if resp.StopReason != ant.StopReasonToolUse && convo.Parent == nil {
+ if resp.StopReason != llm.StopReasonToolUse && convo.Parent == nil {
endOfTurn = true
}
m := AgentMessage{
@@ -610,10 +611,10 @@
}
// Extract any tool calls from the response
- if resp.StopReason == ant.StopReasonToolUse {
+ if resp.StopReason == llm.StopReasonToolUse {
var toolCalls []ToolCall
for _, part := range resp.Content {
- if part.Type == ant.ContentTypeToolUse {
+ if part.Type == llm.ContentTypeToolUse {
toolCalls = append(toolCalls, ToolCall{
Name: part.ToolName,
Input: string(part.ToolInput),
@@ -653,17 +654,15 @@
return slices.Clone(a.history[start:end])
}
-func (a *Agent) OriginalBudget() ant.Budget {
+func (a *Agent) OriginalBudget() conversation.Budget {
return a.originalBudget
}
// AgentConfig contains configuration for creating a new Agent.
type AgentConfig struct {
Context context.Context
- AntURL string
- APIKey string
- HTTPC *http.Client
- Budget ant.Budget
+ Service llm.Service
+ Budget conversation.Budget
GitUsername string
GitEmail string
SessionID string
@@ -778,15 +777,9 @@
// initConvo initializes the conversation.
// It must not be called until all agent fields are initialized,
// particularly workingDir and git.
-func (a *Agent) initConvo() *ant.Convo {
+func (a *Agent) initConvo() *conversation.Convo {
ctx := a.config.Context
- convo := ant.NewConvo(ctx, a.config.APIKey)
- if a.config.HTTPC != nil {
- convo.HTTPC = a.config.HTTPC
- }
- if a.config.AntURL != "" {
- convo.URL = a.config.AntURL
- }
+ convo := conversation.New(ctx, a.config.Service)
convo.PromptCaching = true
convo.Budget = a.config.Budget
@@ -832,7 +825,7 @@
// Register all tools with the conversation
// When adding, removing, or modifying tools here, double-check that the termui tool display
// template in termui/termui.go has pretty-printing support for all tools.
- convo.Tools = []*ant.Tool{
+ convo.Tools = []*llm.Tool{
bashTool, claudetool.Keyword,
claudetool.Think, a.titleTool(), makeDoneTool(a.codereview, a.config.GitUsername, a.config.GitEmail),
a.codereview.Tool(),
@@ -863,8 +856,8 @@
return false
}
-func (a *Agent) titleTool() *ant.Tool {
- title := &ant.Tool{
+func (a *Agent) titleTool() *llm.Tool {
+ title := &llm.Tool{
Name: "title",
Description: `Sets the conversation title and creates a git branch for tracking work. MANDATORY: You must use this tool before making any git commits.`,
InputSchema: json.RawMessage(`{
@@ -990,20 +983,20 @@
}
}
-func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]ant.Content, error) {
- var m []ant.Content
+func (a *Agent) GatherMessages(ctx context.Context, block bool) ([]llm.Content, error) {
+ var m []llm.Content
if block {
select {
case <-ctx.Done():
return m, ctx.Err()
case msg := <-a.inbox:
- m = append(m, ant.StringContent(msg))
+ m = append(m, llm.StringContent(msg))
}
}
for {
select {
case msg := <-a.inbox:
- m = append(m, ant.StringContent(msg))
+ m = append(m, llm.StringContent(msg))
default:
return m, nil
}
@@ -1052,7 +1045,7 @@
}
// If the model is not requesting to use a tool, we're done
- if resp.StopReason != ant.StopReasonToolUse {
+ if resp.StopReason != llm.StopReasonToolUse {
a.stateMachine.Transition(ctx, StateEndOfTurn, "LLM completed response, ending turn")
break
}
@@ -1078,7 +1071,7 @@
}
// processUserMessage waits for user messages and sends them to the model
-func (a *Agent) processUserMessage(ctx context.Context) (*ant.MessageResponse, error) {
+func (a *Agent) processUserMessage(ctx context.Context) (*llm.Response, error) {
// Wait for at least one message from the user
msgs, err := a.GatherMessages(ctx, true)
if err != nil { // e.g. the context was canceled while blocking in GatherMessages
@@ -1086,8 +1079,8 @@
return nil, err
}
- userMessage := ant.Message{
- Role: ant.MessageRoleUser,
+ userMessage := llm.Message{
+ Role: llm.MessageRoleUser,
Content: msgs,
}
@@ -1109,8 +1102,8 @@
}
// handleToolExecution processes a tool use request from the model
-func (a *Agent) handleToolExecution(ctx context.Context, resp *ant.MessageResponse) (bool, *ant.MessageResponse) {
- var results []ant.Content
+func (a *Agent) handleToolExecution(ctx context.Context, resp *llm.Response) (bool, *llm.Response) {
+ var results []llm.Content
cancelled := false
// Transition to checking for cancellation state
@@ -1200,7 +1193,7 @@
}
// continueTurnWithToolResults continues the conversation with tool results
-func (a *Agent) continueTurnWithToolResults(ctx context.Context, results []ant.Content, autoqualityMessages []string, cancelled bool) (bool, *ant.MessageResponse) {
+func (a *Agent) continueTurnWithToolResults(ctx context.Context, results []llm.Content, autoqualityMessages []string, cancelled bool) (bool, *llm.Response) {
// Get any messages the user sent while tools were executing
a.stateMachine.Transition(ctx, StateGatheringAdditionalMessages, "Gathering additional user messages")
msgs, err := a.GatherMessages(ctx, false)
@@ -1211,19 +1204,19 @@
// Inject any auto-generated messages from quality checks
for _, msg := range autoqualityMessages {
- msgs = append(msgs, ant.StringContent(msg))
+ msgs = append(msgs, llm.StringContent(msg))
}
// Handle cancellation by appending a message about it
if cancelled {
- msgs = append(msgs, ant.StringContent(cancelToolUseMessage))
+ msgs = append(msgs, llm.StringContent(cancelToolUseMessage))
// EndOfTurn is false here so that the client of this agent keeps processing
// further messages; the conversation is not over.
a.pushToOutbox(ctx, AgentMessage{Type: ErrorMessageType, Content: userCancelMessage, EndOfTurn: false})
} else if err := a.convo.OverBudget(); err != nil {
// Handle budget issues by appending a message about it
budgetMsg := "We've exceeded our budget. Please ask the user to confirm before continuing by ending the turn."
- msgs = append(msgs, ant.StringContent(budgetMsg))
+ msgs = append(msgs, llm.StringContent(budgetMsg))
a.pushToOutbox(ctx, budgetMessage(fmt.Errorf("warning: %w (ask to keep trying, if you'd like)", err)))
}
@@ -1232,8 +1225,8 @@
// Send the combined message to continue the conversation
a.stateMachine.Transition(ctx, StateSendingToolResults, "Sending tool results back to LLM")
- resp, err := a.convo.SendMessage(ant.Message{
- Role: ant.MessageRoleUser,
+ resp, err := a.convo.SendMessage(llm.Message{
+ Role: llm.MessageRoleUser,
Content: results,
})
if err != nil {
@@ -1264,11 +1257,11 @@
return nil
}
-func collectTextContent(msg *ant.MessageResponse) string {
+func collectTextContent(msg *llm.Response) string {
// Collect all text content
var allText strings.Builder
for _, content := range msg.Content {
- if content.Type == ant.ContentTypeText && content.Text != "" {
+ if content.Type == llm.ContentTypeText && content.Text != "" {
if allText.Len() > 0 {
allText.WriteString("\n\n")
}
@@ -1278,7 +1271,7 @@
return allText.String()
}
-func (a *Agent) TotalUsage() ant.CumulativeUsage {
+func (a *Agent) TotalUsage() conversation.CumulativeUsage {
a.mu.Lock()
defer a.mu.Unlock()
return a.convo.CumulativeUsage()
@@ -1604,7 +1597,7 @@
Reply with ONLY the reprompt text.
`
- userMessage := ant.UserStringMessage(msg)
+ userMessage := llm.UserStringMessage(msg)
// By doing this in a subconversation, the agent doesn't call tools (because
// there aren't any), and there's not a concurrency risk with on-going other
// outstanding conversations.
diff --git a/loop/agent_test.go b/loop/agent_test.go
index 0924b39..56708e3 100644
--- a/loop/agent_test.go
+++ b/loop/agent_test.go
@@ -11,8 +11,10 @@
"testing"
"time"
- "sketch.dev/ant"
"sketch.dev/httprr"
+ "sketch.dev/llm"
+ "sketch.dev/llm/ant"
+ "sketch.dev/llm/conversation"
)
// TestAgentLoop tests that the Agent loop functionality works correctly.
@@ -58,7 +60,7 @@
if err := os.Chdir("/"); err != nil {
t.Fatal(err)
}
- budget := ant.Budget{MaxResponses: 100}
+ budget := conversation.Budget{MaxResponses: 100}
wd, err := os.Getwd()
if err != nil {
t.Fatal(err)
@@ -66,9 +68,11 @@
apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_ANTHROPIC_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
cfg := AgentConfig{
- Context: ctx,
- APIKey: apiKey,
- HTTPC: client,
+ Context: ctx,
+ Service: &ant.Service{
+ APIKey: apiKey,
+ HTTPC: client,
+ },
Budget: budget,
GitUsername: "Test Agent",
GitEmail: "totallyhuman@sketch.dev",
@@ -206,7 +210,7 @@
func TestAgentProcessTurnWithNilResponse(t *testing.T) {
// Create a mock conversation that will return nil and error
mockConvo := &MockConvoInterface{
- sendMessageFunc: func(message ant.Message) (*ant.MessageResponse, error) {
+ sendMessageFunc: func(message llm.Message) (*llm.Response, error) {
return nil, fmt.Errorf("test error: simulating nil response")
},
}
@@ -250,40 +254,40 @@
// MockConvoInterface implements the ConvoInterface for testing
type MockConvoInterface struct {
- sendMessageFunc func(message ant.Message) (*ant.MessageResponse, error)
- sendUserTextMessageFunc func(s string, otherContents ...ant.Content) (*ant.MessageResponse, error)
- toolResultContentsFunc func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
- toolResultCancelContentsFunc func(resp *ant.MessageResponse) ([]ant.Content, error)
+ sendMessageFunc func(message llm.Message) (*llm.Response, error)
+ sendUserTextMessageFunc func(s string, otherContents ...llm.Content) (*llm.Response, error)
+ toolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, error)
+ toolResultCancelContentsFunc func(resp *llm.Response) ([]llm.Content, error)
cancelToolUseFunc func(toolUseID string, cause error) error
- cumulativeUsageFunc func() ant.CumulativeUsage
- resetBudgetFunc func(ant.Budget)
+ cumulativeUsageFunc func() conversation.CumulativeUsage
+ resetBudgetFunc func(conversation.Budget)
overBudgetFunc func() error
getIDFunc func() string
- subConvoWithHistoryFunc func() *ant.Convo
+ subConvoWithHistoryFunc func() *conversation.Convo
}
-func (m *MockConvoInterface) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
+func (m *MockConvoInterface) SendMessage(message llm.Message) (*llm.Response, error) {
if m.sendMessageFunc != nil {
return m.sendMessageFunc(message)
}
return nil, nil
}
-func (m *MockConvoInterface) SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
+func (m *MockConvoInterface) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
if m.sendUserTextMessageFunc != nil {
return m.sendUserTextMessageFunc(s, otherContents...)
}
return nil, nil
}
-func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
+func (m *MockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
if m.toolResultContentsFunc != nil {
return m.toolResultContentsFunc(ctx, resp)
}
return nil, nil
}
-func (m *MockConvoInterface) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
+func (m *MockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
if m.toolResultCancelContentsFunc != nil {
return m.toolResultCancelContentsFunc(resp)
}
@@ -297,14 +301,14 @@
return nil
}
-func (m *MockConvoInterface) CumulativeUsage() ant.CumulativeUsage {
+func (m *MockConvoInterface) CumulativeUsage() conversation.CumulativeUsage {
if m.cumulativeUsageFunc != nil {
return m.cumulativeUsageFunc()
}
- return ant.CumulativeUsage{}
+ return conversation.CumulativeUsage{}
}
-func (m *MockConvoInterface) ResetBudget(budget ant.Budget) {
+func (m *MockConvoInterface) ResetBudget(budget conversation.Budget) {
if m.resetBudgetFunc != nil {
m.resetBudgetFunc(budget)
}
@@ -324,7 +328,7 @@
return "mock-convo-id"
}
-func (m *MockConvoInterface) SubConvoWithHistory() *ant.Convo {
+func (m *MockConvoInterface) SubConvoWithHistory() *conversation.Convo {
if m.subConvoWithHistoryFunc != nil {
return m.subConvoWithHistoryFunc()
}
@@ -337,7 +341,7 @@
func TestAgentProcessTurnWithNilResponseNilError(t *testing.T) {
// Create a mock conversation that will return nil response and nil error
mockConvo := &MockConvoInterface{
- sendMessageFunc: func(message ant.Message) (*ant.MessageResponse, error) {
+ sendMessageFunc: func(message llm.Message) (*llm.Response, error) {
return nil, nil // This is unusual but now handled gracefully
},
}
@@ -464,48 +468,48 @@
// mockConvoInterface is a mock implementation of ConvoInterface for testing
type mockConvoInterface struct {
- SendMessageFunc func(message ant.Message) (*ant.MessageResponse, error)
- ToolResultContentsFunc func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error)
+ SendMessageFunc func(message llm.Message) (*llm.Response, error)
+ ToolResultContentsFunc func(ctx context.Context, resp *llm.Response) ([]llm.Content, error)
}
func (c *mockConvoInterface) GetID() string {
return "mockConvoInterface-id"
}
-func (c *mockConvoInterface) SubConvoWithHistory() *ant.Convo {
+func (c *mockConvoInterface) SubConvoWithHistory() *conversation.Convo {
return nil
}
-func (m *mockConvoInterface) CumulativeUsage() ant.CumulativeUsage {
- return ant.CumulativeUsage{}
+func (m *mockConvoInterface) CumulativeUsage() conversation.CumulativeUsage {
+ return conversation.CumulativeUsage{}
}
-func (m *mockConvoInterface) ResetBudget(ant.Budget) {}
+func (m *mockConvoInterface) ResetBudget(conversation.Budget) {}
func (m *mockConvoInterface) OverBudget() error {
return nil
}
-func (m *mockConvoInterface) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
+func (m *mockConvoInterface) SendMessage(message llm.Message) (*llm.Response, error) {
if m.SendMessageFunc != nil {
return m.SendMessageFunc(message)
}
- return &ant.MessageResponse{StopReason: ant.StopReasonEndTurn}, nil
+ return &llm.Response{StopReason: llm.StopReasonEndTurn}, nil
}
-func (m *mockConvoInterface) SendUserTextMessage(s string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
- return m.SendMessage(ant.UserStringMessage(s))
+func (m *mockConvoInterface) SendUserTextMessage(s string, otherContents ...llm.Content) (*llm.Response, error) {
+ return m.SendMessage(llm.UserStringMessage(s))
}
-func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
+func (m *mockConvoInterface) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
if m.ToolResultContentsFunc != nil {
return m.ToolResultContentsFunc(ctx, resp)
}
- return []ant.Content{}, nil
+ return []llm.Content{}, nil
}
-func (m *mockConvoInterface) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
- return []ant.Content{ant.StringContent("Tool use cancelled")}, nil
+func (m *mockConvoInterface) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
+ return []llm.Content{llm.StringContent("Tool use cancelled")}, nil
}
func (m *mockConvoInterface) CancelToolUse(toolUseID string, cause error) error {
@@ -542,11 +546,11 @@
agent.inbox <- "Test message"
// Setup the mock to simulate a model response with end of turn
- mockConvo.SendMessageFunc = func(message ant.Message) (*ant.MessageResponse, error) {
- return &ant.MessageResponse{
- StopReason: ant.StopReasonEndTurn,
- Content: []ant.Content{
- ant.StringContent("This is a test response"),
+ mockConvo.SendMessageFunc = func(message llm.Message) (*llm.Response, error) {
+ return &llm.Response{
+ StopReason: llm.StopReasonEndTurn,
+ Content: []llm.Content{
+ llm.StringContent("This is a test response"),
},
}, nil
}
@@ -615,29 +619,29 @@
// First response requests a tool
firstResponseDone := false
- mockConvo.SendMessageFunc = func(message ant.Message) (*ant.MessageResponse, error) {
+ mockConvo.SendMessageFunc = func(message llm.Message) (*llm.Response, error) {
if !firstResponseDone {
firstResponseDone = true
- return &ant.MessageResponse{
- StopReason: ant.StopReasonToolUse,
- Content: []ant.Content{
- ant.StringContent("I'll use a tool"),
- {Type: ant.ContentTypeToolUse, ToolName: "test_tool", ToolInput: []byte("{}"), ID: "test_id"},
+ return &llm.Response{
+ StopReason: llm.StopReasonToolUse,
+ Content: []llm.Content{
+ llm.StringContent("I'll use a tool"),
+ {Type: llm.ContentTypeToolUse, ToolName: "test_tool", ToolInput: []byte("{}"), ID: "test_id"},
},
}, nil
}
// Second response ends the turn
- return &ant.MessageResponse{
- StopReason: ant.StopReasonEndTurn,
- Content: []ant.Content{
- ant.StringContent("Finished using the tool"),
+ return &llm.Response{
+ StopReason: llm.StopReasonEndTurn,
+ Content: []llm.Content{
+ llm.StringContent("Finished using the tool"),
},
}, nil
}
// Tool result content handler
- mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
- return []ant.Content{ant.StringContent("Tool executed successfully")}, nil
+ mockConvo.ToolResultContentsFunc = func(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
+ return []llm.Content{llm.StringContent("Tool executed successfully")}, nil
}
// Track state transitions
diff --git a/loop/donetool.go b/loop/donetool.go
index e4b0542..63604d8 100644
--- a/loop/donetool.go
+++ b/loop/donetool.go
@@ -5,8 +5,8 @@
"encoding/json"
"fmt"
- "sketch.dev/ant"
"sketch.dev/claudetool"
+ "sketch.dev/llm"
)
// makeDoneTool creates a tool that provides a checklist to the agent. There
@@ -14,8 +14,8 @@
// not as reliable as it could be. Historically, we've found that Claude ignores
// the tool results here, so we don't tell the tool to say "hey, really check this"
// at the moment, though we've tried.
-func makeDoneTool(codereview *claudetool.CodeReviewer, gitUsername, gitEmail string) *ant.Tool {
- return &ant.Tool{
+func makeDoneTool(codereview *claudetool.CodeReviewer, gitUsername, gitEmail string) *llm.Tool {
+ return &llm.Tool{
Name: "done",
Description: `Use this tool when you have achieved the user's goal. The parameters form a checklist which you should evaluate.`,
InputSchema: json.RawMessage(doneChecklistJSONSchema(gitUsername, gitEmail)),
diff --git a/loop/mocks.go b/loop/mocks.go
index 7e05070..811ab2c 100644
--- a/loop/mocks.go
+++ b/loop/mocks.go
@@ -6,10 +6,11 @@
"sync"
"testing"
- "sketch.dev/ant"
+ "sketch.dev/llm"
+ "sketch.dev/llm/conversation"
)
-// MockConvo is a custom mock for ant.Convo interface
+// MockConvo is a custom mock for conversation.Convo interface
type MockConvo struct {
mu sync.Mutex
t *testing.T
@@ -21,23 +22,23 @@
}
type mockCall struct {
- args []interface{}
- result []interface{}
+ args []any
+ result []any
}
type mockExpectation struct {
until chan any
- args []interface{}
- result []interface{}
+ args []any
+ result []any
}
// Return sets up return values for an expectation
-func (e *mockExpectation) Return(values ...interface{}) {
+func (e *mockExpectation) Return(values ...any) {
e.result = values
}
// Return sets up return values for an expectation
-func (e *mockExpectation) BlockAndReturn(until chan any, values ...interface{}) {
+func (e *mockExpectation) BlockAndReturn(until chan any, values ...any) {
e.until = until
e.result = values
}
@@ -53,7 +54,7 @@
}
// ExpectCall sets up an expectation for a method call
-func (m *MockConvo) ExpectCall(method string, args ...interface{}) *mockExpectation {
+func (m *MockConvo) ExpectCall(method string, args ...any) *mockExpectation {
m.mu.Lock()
defer m.mu.Unlock()
expectation := &mockExpectation{args: args}
@@ -65,7 +66,7 @@
}
// findMatchingExpectation finds a matching expectation for a method call
-func (m *MockConvo) findMatchingExpectation(method string, args ...interface{}) (*mockExpectation, bool) {
+func (m *MockConvo) findMatchingExpectation(method string, args ...any) (*mockExpectation, bool) {
m.mu.Lock()
defer m.mu.Unlock()
expectations, ok := m.expectations[method]
@@ -87,7 +88,7 @@
}
// matchArgs checks if call arguments match expectation arguments
-func matchArgs(expected, actual []interface{}) bool {
+func matchArgs(expected, actual []any) bool {
if len(expected) != len(actual) {
return false
}
@@ -107,7 +108,7 @@
}
// recordCall records a method call
-func (m *MockConvo) recordCall(method string, args ...interface{}) {
+func (m *MockConvo) recordCall(method string, args ...any) {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.calls[method]; !ok {
@@ -116,7 +117,7 @@
m.calls[method] = append(m.calls[method], &mockCall{args: args})
}
-func (m *MockConvo) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
+func (m *MockConvo) SendMessage(message llm.Message) (*llm.Response, error) {
m.recordCall("SendMessage", message)
exp, ok := m.findMatchingExpectation("SendMessage", message)
if !ok {
@@ -129,10 +130,10 @@
if err, ok := exp.result[1].(error); ok {
retErr = err
}
- return exp.result[0].(*ant.MessageResponse), retErr
+ return exp.result[0].(*llm.Response), retErr
}
-func (m *MockConvo) SendUserTextMessage(message string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
+func (m *MockConvo) SendUserTextMessage(message string, otherContents ...llm.Content) (*llm.Response, error) {
m.recordCall("SendUserTextMessage", message, otherContents)
exp, ok := m.findMatchingExpectation("SendUserTextMessage", message, otherContents)
if !ok {
@@ -145,10 +146,10 @@
if err, ok := exp.result[1].(error); ok {
retErr = err
}
- return exp.result[0].(*ant.MessageResponse), retErr
+ return exp.result[0].(*llm.Response), retErr
}
-func (m *MockConvo) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
+func (m *MockConvo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
m.recordCall("ToolResultContents", resp)
exp, ok := m.findMatchingExpectation("ToolResultContents", resp)
if !ok {
@@ -162,10 +163,10 @@
retErr = err
}
- return exp.result[0].([]ant.Content), retErr
+ return exp.result[0].([]llm.Content), retErr
}
-func (m *MockConvo) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
+func (m *MockConvo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
m.recordCall("ToolResultCancelContents", resp)
exp, ok := m.findMatchingExpectation("ToolResultCancelContents", resp)
if !ok {
@@ -179,12 +180,12 @@
retErr = err
}
- return exp.result[0].([]ant.Content), retErr
+ return exp.result[0].([]llm.Content), retErr
}
-func (m *MockConvo) CumulativeUsage() ant.CumulativeUsage {
+func (m *MockConvo) CumulativeUsage() conversation.CumulativeUsage {
m.recordCall("CumulativeUsage")
- return ant.CumulativeUsage{}
+ return conversation.CumulativeUsage{}
}
func (m *MockConvo) OverBudget() error {
@@ -197,12 +198,12 @@
return "mock-conversation-id"
}
-func (m *MockConvo) SubConvoWithHistory() *ant.Convo {
+func (m *MockConvo) SubConvoWithHistory() *conversation.Convo {
m.recordCall("SubConvoWithHistory")
return nil
}
-func (m *MockConvo) ResetBudget(_ ant.Budget) {
+func (m *MockConvo) ResetBudget(_ conversation.Budget) {
m.recordCall("ResetBudget")
}
diff --git a/loop/server/loophttp.go b/loop/server/loophttp.go
index 4a415c8..f7a3979 100644
--- a/loop/server/loophttp.go
+++ b/loop/server/loophttp.go
@@ -23,7 +23,7 @@
"sketch.dev/loop/server/gzhandler"
"github.com/creack/pty"
- "sketch.dev/ant"
+ "sketch.dev/llm/conversation"
"sketch.dev/loop"
"sketch.dev/webui"
)
@@ -50,29 +50,29 @@
}
type State struct {
- MessageCount int `json:"message_count"`
- TotalUsage *ant.CumulativeUsage `json:"total_usage,omitempty"`
- InitialCommit string `json:"initial_commit"`
- Title string `json:"title"`
- BranchName string `json:"branch_name,omitempty"`
- Hostname string `json:"hostname"` // deprecated
- WorkingDir string `json:"working_dir"` // deprecated
- OS string `json:"os"` // deprecated
- GitOrigin string `json:"git_origin,omitempty"`
- OutstandingLLMCalls int `json:"outstanding_llm_calls"`
- OutstandingToolCalls []string `json:"outstanding_tool_calls"`
- SessionID string `json:"session_id"`
- SSHAvailable bool `json:"ssh_available"`
- SSHError string `json:"ssh_error,omitempty"`
- InContainer bool `json:"in_container"`
- FirstMessageIndex int `json:"first_message_index"`
- AgentState string `json:"agent_state,omitempty"`
- OutsideHostname string `json:"outside_hostname,omitempty"`
- InsideHostname string `json:"inside_hostname,omitempty"`
- OutsideOS string `json:"outside_os,omitempty"`
- InsideOS string `json:"inside_os,omitempty"`
- OutsideWorkingDir string `json:"outside_working_dir,omitempty"`
- InsideWorkingDir string `json:"inside_working_dir,omitempty"`
+ MessageCount int `json:"message_count"`
+ TotalUsage *conversation.CumulativeUsage `json:"total_usage,omitempty"`
+ InitialCommit string `json:"initial_commit"`
+ Title string `json:"title"`
+ BranchName string `json:"branch_name,omitempty"`
+ Hostname string `json:"hostname"` // deprecated
+ WorkingDir string `json:"working_dir"` // deprecated
+ OS string `json:"os"` // deprecated
+ GitOrigin string `json:"git_origin,omitempty"`
+ OutstandingLLMCalls int `json:"outstanding_llm_calls"`
+ OutstandingToolCalls []string `json:"outstanding_tool_calls"`
+ SessionID string `json:"session_id"`
+ SSHAvailable bool `json:"ssh_available"`
+ SSHError string `json:"ssh_error,omitempty"`
+ InContainer bool `json:"in_container"`
+ FirstMessageIndex int `json:"first_message_index"`
+ AgentState string `json:"agent_state,omitempty"`
+ OutsideHostname string `json:"outside_hostname,omitempty"`
+ InsideHostname string `json:"inside_hostname,omitempty"`
+ OutsideOS string `json:"outside_os,omitempty"`
+ InsideOS string `json:"inside_os,omitempty"`
+ OutsideWorkingDir string `json:"outside_working_dir,omitempty"`
+ InsideWorkingDir string `json:"inside_working_dir,omitempty"`
}
type InitRequest struct {
@@ -298,12 +298,12 @@
// Create a combined structure with all information
downloadData := struct {
- Messages []loop.AgentMessage `json:"messages"`
- MessageCount int `json:"message_count"`
- TotalUsage ant.CumulativeUsage `json:"total_usage"`
- Hostname string `json:"hostname"`
- WorkingDir string `json:"working_dir"`
- DownloadTime string `json:"download_time"`
+ Messages []loop.AgentMessage `json:"messages"`
+ MessageCount int `json:"message_count"`
+ TotalUsage conversation.CumulativeUsage `json:"total_usage"`
+ Hostname string `json:"hostname"`
+ WorkingDir string `json:"working_dir"`
+ DownloadTime string `json:"download_time"`
}{
Messages: messages,
MessageCount: messageCount,