blob: 811ab2c25c9b64d055085ecdc5c940c92e8ac608 [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
5 "reflect"
6 "sync"
7 "testing"
8
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -07009 "sketch.dev/llm"
10 "sketch.dev/llm/conversation"
Earl Lee2e463fb2025-04-17 11:22:22 -070011)
12
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070013// MockConvo is a custom mock for conversation.Convo interface
Earl Lee2e463fb2025-04-17 11:22:22 -070014type MockConvo struct {
15 mu sync.Mutex
16 t *testing.T
17
18 // Maps method name to a list of calls with arguments and return values
19 calls map[string][]*mockCall
20 // Maps method name to expected calls
21 expectations map[string][]*mockExpectation
22}
23
24type mockCall struct {
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070025 args []any
26 result []any
Earl Lee2e463fb2025-04-17 11:22:22 -070027}
28
29type mockExpectation struct {
30 until chan any
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070031 args []any
32 result []any
Earl Lee2e463fb2025-04-17 11:22:22 -070033}
34
35// Return sets up return values for an expectation
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070036func (e *mockExpectation) Return(values ...any) {
Earl Lee2e463fb2025-04-17 11:22:22 -070037 e.result = values
38}
39
40// Return sets up return values for an expectation
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070041func (e *mockExpectation) BlockAndReturn(until chan any, values ...any) {
Earl Lee2e463fb2025-04-17 11:22:22 -070042 e.until = until
43 e.result = values
44}
45
46// NewMockConvo creates a new mock Convo
47func NewMockConvo(t *testing.T) *MockConvo {
48 return &MockConvo{
49 t: t,
50 mu: sync.Mutex{},
51 calls: make(map[string][]*mockCall),
52 expectations: make(map[string][]*mockExpectation),
53 }
54}
55
56// ExpectCall sets up an expectation for a method call
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070057func (m *MockConvo) ExpectCall(method string, args ...any) *mockExpectation {
Earl Lee2e463fb2025-04-17 11:22:22 -070058 m.mu.Lock()
59 defer m.mu.Unlock()
60 expectation := &mockExpectation{args: args}
61 if _, ok := m.expectations[method]; !ok {
62 m.expectations[method] = []*mockExpectation{}
63 }
64 m.expectations[method] = append(m.expectations[method], expectation)
65 return expectation
66}
67
68// findMatchingExpectation finds a matching expectation for a method call
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070069func (m *MockConvo) findMatchingExpectation(method string, args ...any) (*mockExpectation, bool) {
Earl Lee2e463fb2025-04-17 11:22:22 -070070 m.mu.Lock()
71 defer m.mu.Unlock()
72 expectations, ok := m.expectations[method]
73 if !ok {
74 return nil, false
75 }
76
77 for i, exp := range expectations {
78 if matchArgs(exp.args, args) {
79 if exp.until != nil {
80 <-exp.until
81 }
82 // Remove the matched expectation
83 m.expectations[method] = append(expectations[:i], expectations[i+1:]...)
84 return exp, true
85 }
86 }
87 return nil, false
88}
89
90// matchArgs checks if call arguments match expectation arguments
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -070091func matchArgs(expected, actual []any) bool {
Earl Lee2e463fb2025-04-17 11:22:22 -070092 if len(expected) != len(actual) {
93 return false
94 }
95
96 for i, exp := range expected {
97 // Special case: nil matches anything
98 if exp == nil {
99 continue
100 }
101
102 // Check for equality
103 if !reflect.DeepEqual(exp, actual[i]) {
104 return false
105 }
106 }
107 return true
108}
109
110// recordCall records a method call
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700111func (m *MockConvo) recordCall(method string, args ...any) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700112 m.mu.Lock()
113 defer m.mu.Unlock()
114 if _, ok := m.calls[method]; !ok {
115 m.calls[method] = []*mockCall{}
116 }
117 m.calls[method] = append(m.calls[method], &mockCall{args: args})
118}
119
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700120func (m *MockConvo) SendMessage(message llm.Message) (*llm.Response, error) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700121 m.recordCall("SendMessage", message)
122 exp, ok := m.findMatchingExpectation("SendMessage", message)
123 if !ok {
124 m.t.Errorf("unexpected call to SendMessage: %+v", message)
125 m.t.FailNow()
126 }
127 var retErr error
128 m.mu.Lock()
129 defer m.mu.Unlock()
130 if err, ok := exp.result[1].(error); ok {
131 retErr = err
132 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700133 return exp.result[0].(*llm.Response), retErr
Earl Lee2e463fb2025-04-17 11:22:22 -0700134}
135
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700136func (m *MockConvo) SendUserTextMessage(message string, otherContents ...llm.Content) (*llm.Response, error) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700137 m.recordCall("SendUserTextMessage", message, otherContents)
138 exp, ok := m.findMatchingExpectation("SendUserTextMessage", message, otherContents)
139 if !ok {
140 m.t.Error("unexpected call to SendUserTextMessage")
141 m.t.FailNow()
142 }
143 var retErr error
144 m.mu.Lock()
145 defer m.mu.Unlock()
146 if err, ok := exp.result[1].(error); ok {
147 retErr = err
148 }
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700149 return exp.result[0].(*llm.Response), retErr
Earl Lee2e463fb2025-04-17 11:22:22 -0700150}
151
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700152func (m *MockConvo) ToolResultContents(ctx context.Context, resp *llm.Response) ([]llm.Content, error) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700153 m.recordCall("ToolResultContents", resp)
154 exp, ok := m.findMatchingExpectation("ToolResultContents", resp)
155 if !ok {
156 m.t.Error("unexpected call to ToolResultContents")
157 m.t.FailNow()
158 }
159 m.mu.Lock()
160 defer m.mu.Unlock()
161 var retErr error
162 if err, ok := exp.result[1].(error); ok {
163 retErr = err
164 }
165
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700166 return exp.result[0].([]llm.Content), retErr
Earl Lee2e463fb2025-04-17 11:22:22 -0700167}
168
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700169func (m *MockConvo) ToolResultCancelContents(resp *llm.Response) ([]llm.Content, error) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700170 m.recordCall("ToolResultCancelContents", resp)
171 exp, ok := m.findMatchingExpectation("ToolResultCancelContents", resp)
172 if !ok {
173 m.t.Error("unexpected call to ToolResultCancelContents")
174 m.t.FailNow()
175 }
176 var retErr error
177 m.mu.Lock()
178 defer m.mu.Unlock()
179 if err, ok := exp.result[1].(error); ok {
180 retErr = err
181 }
182
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700183 return exp.result[0].([]llm.Content), retErr
Earl Lee2e463fb2025-04-17 11:22:22 -0700184}
185
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700186func (m *MockConvo) CumulativeUsage() conversation.CumulativeUsage {
Earl Lee2e463fb2025-04-17 11:22:22 -0700187 m.recordCall("CumulativeUsage")
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700188 return conversation.CumulativeUsage{}
Earl Lee2e463fb2025-04-17 11:22:22 -0700189}
190
191func (m *MockConvo) OverBudget() error {
192 m.recordCall("OverBudget")
193 return nil
194}
195
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700196func (m *MockConvo) GetID() string {
197 m.recordCall("GetID")
198 return "mock-conversation-id"
199}
200
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700201func (m *MockConvo) SubConvoWithHistory() *conversation.Convo {
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700202 m.recordCall("SubConvoWithHistory")
203 return nil
204}
205
Josh Bleecher Snyder4f84ab72025-04-22 16:40:54 -0700206func (m *MockConvo) ResetBudget(_ conversation.Budget) {
Earl Lee2e463fb2025-04-17 11:22:22 -0700207 m.recordCall("ResetBudget")
208}
209
210// AssertExpectations checks that all expectations were met
211func (m *MockConvo) AssertExpectations(t *testing.T) {
212 m.mu.Lock()
213 defer m.mu.Unlock()
214
215 for method, expectations := range m.expectations {
216 if len(expectations) > 0 {
217 t.Errorf("not all expectations were met for method %s:", method)
218 }
219 }
220}