blob: 694b977c767a8ab63b245224593f06e0e1c0171c [file] [log] [blame]
package conversation
import (
"cmp"
"context"
"net/http"
"os"
"slices"
"strings"
"testing"
"sketch.dev/httprr"
"sketch.dev/llm"
"sketch.dev/llm/ant"
)
func TestBasicConvo(t *testing.T) {
ctx := context.Background()
rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
if err != nil {
t.Fatal(err)
}
rr.ScrubReq(func(req *http.Request) error {
req.Header.Del("x-api-key")
return nil
})
apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
srv := &ant.Service{
APIKey: apiKey,
HTTPC: rr.Client(),
}
convo := New(ctx, srv, nil)
const name = "Cornelius"
res, err := convo.SendUserTextMessage("Hi, my name is " + name)
if err != nil {
t.Fatal(err)
}
for _, part := range res.Content {
t.Logf("%s", part.Text)
}
res, err = convo.SendUserTextMessage("What is my name?")
if err != nil {
t.Fatal(err)
}
got := ""
for _, part := range res.Content {
got += part.Text
}
if !strings.Contains(got, name) {
t.Errorf("model does not know the given name %s: %q", name, got)
}
}
// TestCancelToolUse tests the CancelToolUse function of the Convo struct
func TestCancelToolUse(t *testing.T) {
tests := []struct {
name string
setupToolUse bool
toolUseID string
cancelErr error
expectError bool
expectCancel bool
}{
{
name: "Cancel existing tool use",
setupToolUse: true,
toolUseID: "tool123",
cancelErr: nil,
expectError: false,
expectCancel: true,
},
{
name: "Cancel existing tool use with error",
setupToolUse: true,
toolUseID: "tool456",
cancelErr: context.Canceled,
expectError: false,
expectCancel: true,
},
{
name: "Cancel non-existent tool use",
setupToolUse: false,
toolUseID: "tool789",
cancelErr: nil,
expectError: true,
expectCancel: false,
},
}
srv := &ant.Service{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
convo := New(context.Background(), srv, nil)
var cancelCalled bool
var cancelledWithErr error
if tt.setupToolUse {
// Setup a mock cancel function to track calls
mockCancel := func(err error) {
cancelCalled = true
cancelledWithErr = err
}
convo.toolUseCancelMu.Lock()
convo.toolUseCancel[tt.toolUseID] = mockCancel
convo.toolUseCancelMu.Unlock()
}
err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
// Check if we got the expected error state
if (err != nil) != tt.expectError {
t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
}
// Check if the cancel function was called as expected
if cancelCalled != tt.expectCancel {
t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
}
// If we expected the cancel to be called, verify it was called with the right error
if tt.expectCancel && cancelledWithErr != tt.cancelErr {
t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
}
// Verify the toolUseID was removed from the map if it was initially added
if tt.setupToolUse {
convo.toolUseCancelMu.Lock()
_, exists := convo.toolUseCancel[tt.toolUseID]
convo.toolUseCancelMu.Unlock()
if exists {
t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
}
}
})
}
}
// TestInsertMissingToolResults tests the insertMissingToolResults function
// to ensure it doesn't create duplicate tool results when multiple tool uses are missing results.
func TestInsertMissingToolResults(t *testing.T) {
tests := []struct {
name string
messages []llm.Message
currentMsg llm.Message
expectedCount int
expectedToolIDs []string
}{
{
name: "Single missing tool result",
messages: []llm.Message{
{
Role: llm.MessageRoleAssistant,
Content: []llm.Content{
{
Type: llm.ContentTypeToolUse,
ID: "tool1",
},
},
},
},
currentMsg: llm.Message{
Role: llm.MessageRoleUser,
Content: []llm.Content{},
},
expectedCount: 1,
expectedToolIDs: []string{"tool1"},
},
{
name: "Multiple missing tool results",
messages: []llm.Message{
{
Role: llm.MessageRoleAssistant,
Content: []llm.Content{
{
Type: llm.ContentTypeToolUse,
ID: "tool1",
},
{
Type: llm.ContentTypeToolUse,
ID: "tool2",
},
{
Type: llm.ContentTypeToolUse,
ID: "tool3",
},
},
},
},
currentMsg: llm.Message{
Role: llm.MessageRoleUser,
Content: []llm.Content{},
},
expectedCount: 3,
expectedToolIDs: []string{"tool1", "tool2", "tool3"},
},
{
name: "No missing tool results when results already present",
messages: []llm.Message{
{
Role: llm.MessageRoleAssistant,
Content: []llm.Content{
{
Type: llm.ContentTypeToolUse,
ID: "tool1",
},
},
},
},
currentMsg: llm.Message{
Role: llm.MessageRoleUser,
Content: []llm.Content{
{
Type: llm.ContentTypeToolResult,
ToolUseID: "tool1",
},
},
},
expectedCount: 1, // Only the existing one
expectedToolIDs: []string{"tool1"},
},
{
name: "No tool uses in previous message",
messages: []llm.Message{
{
Role: llm.MessageRoleAssistant,
Content: []llm.Content{
{
Type: llm.ContentTypeText,
Text: "Just some text",
},
},
},
},
currentMsg: llm.Message{
Role: llm.MessageRoleUser,
Content: []llm.Content{},
},
expectedCount: 0,
expectedToolIDs: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := &ant.Service{}
convo := New(context.Background(), srv, nil)
// Create request with messages
req := &llm.Request{
Messages: append(tt.messages, tt.currentMsg),
}
// Call insertMissingToolResults
msg := tt.currentMsg
convo.insertMissingToolResults(req, &msg)
// Count tool results in the message
toolResultCount := 0
toolIDs := []string{}
for _, content := range msg.Content {
if content.Type == llm.ContentTypeToolResult {
toolResultCount++
toolIDs = append(toolIDs, content.ToolUseID)
}
}
// Verify count
if toolResultCount != tt.expectedCount {
t.Errorf("Expected %d tool results, got %d", tt.expectedCount, toolResultCount)
}
// Verify no duplicates by checking unique tool IDs
seenIDs := make(map[string]int)
for _, id := range toolIDs {
seenIDs[id]++
}
// Check for duplicates
for id, count := range seenIDs {
if count > 1 {
t.Errorf("Duplicate tool result for ID %s: found %d times", id, count)
}
}
// Verify all expected tool IDs are present
for _, expectedID := range tt.expectedToolIDs {
if !slices.Contains(toolIDs, expectedID) {
t.Errorf("Expected tool ID %s not found in results", expectedID)
}
}
})
}
}