blob: 7e0507051e6658ef16efe68c5a7bd94bda7e7c5a [file] [log] [blame]
Earl Lee2e463fb2025-04-17 11:22:22 -07001package loop
2
3import (
4 "context"
5 "reflect"
6 "sync"
7 "testing"
8
9 "sketch.dev/ant"
10)
11
12// MockConvo is a custom mock for ant.Convo interface
13type MockConvo struct {
14 mu sync.Mutex
15 t *testing.T
16
17 // Maps method name to a list of calls with arguments and return values
18 calls map[string][]*mockCall
19 // Maps method name to expected calls
20 expectations map[string][]*mockExpectation
21}
22
23type mockCall struct {
24 args []interface{}
25 result []interface{}
26}
27
28type mockExpectation struct {
29 until chan any
30 args []interface{}
31 result []interface{}
32}
33
34// Return sets up return values for an expectation
35func (e *mockExpectation) Return(values ...interface{}) {
36 e.result = values
37}
38
39// Return sets up return values for an expectation
40func (e *mockExpectation) BlockAndReturn(until chan any, values ...interface{}) {
41 e.until = until
42 e.result = values
43}
44
45// NewMockConvo creates a new mock Convo
46func NewMockConvo(t *testing.T) *MockConvo {
47 return &MockConvo{
48 t: t,
49 mu: sync.Mutex{},
50 calls: make(map[string][]*mockCall),
51 expectations: make(map[string][]*mockExpectation),
52 }
53}
54
55// ExpectCall sets up an expectation for a method call
56func (m *MockConvo) ExpectCall(method string, args ...interface{}) *mockExpectation {
57 m.mu.Lock()
58 defer m.mu.Unlock()
59 expectation := &mockExpectation{args: args}
60 if _, ok := m.expectations[method]; !ok {
61 m.expectations[method] = []*mockExpectation{}
62 }
63 m.expectations[method] = append(m.expectations[method], expectation)
64 return expectation
65}
66
67// findMatchingExpectation finds a matching expectation for a method call
68func (m *MockConvo) findMatchingExpectation(method string, args ...interface{}) (*mockExpectation, bool) {
69 m.mu.Lock()
70 defer m.mu.Unlock()
71 expectations, ok := m.expectations[method]
72 if !ok {
73 return nil, false
74 }
75
76 for i, exp := range expectations {
77 if matchArgs(exp.args, args) {
78 if exp.until != nil {
79 <-exp.until
80 }
81 // Remove the matched expectation
82 m.expectations[method] = append(expectations[:i], expectations[i+1:]...)
83 return exp, true
84 }
85 }
86 return nil, false
87}
88
89// matchArgs checks if call arguments match expectation arguments
90func matchArgs(expected, actual []interface{}) bool {
91 if len(expected) != len(actual) {
92 return false
93 }
94
95 for i, exp := range expected {
96 // Special case: nil matches anything
97 if exp == nil {
98 continue
99 }
100
101 // Check for equality
102 if !reflect.DeepEqual(exp, actual[i]) {
103 return false
104 }
105 }
106 return true
107}
108
109// recordCall records a method call
110func (m *MockConvo) recordCall(method string, args ...interface{}) {
111 m.mu.Lock()
112 defer m.mu.Unlock()
113 if _, ok := m.calls[method]; !ok {
114 m.calls[method] = []*mockCall{}
115 }
116 m.calls[method] = append(m.calls[method], &mockCall{args: args})
117}
118
119func (m *MockConvo) SendMessage(message ant.Message) (*ant.MessageResponse, error) {
120 m.recordCall("SendMessage", message)
121 exp, ok := m.findMatchingExpectation("SendMessage", message)
122 if !ok {
123 m.t.Errorf("unexpected call to SendMessage: %+v", message)
124 m.t.FailNow()
125 }
126 var retErr error
127 m.mu.Lock()
128 defer m.mu.Unlock()
129 if err, ok := exp.result[1].(error); ok {
130 retErr = err
131 }
132 return exp.result[0].(*ant.MessageResponse), retErr
133}
134
135func (m *MockConvo) SendUserTextMessage(message string, otherContents ...ant.Content) (*ant.MessageResponse, error) {
136 m.recordCall("SendUserTextMessage", message, otherContents)
137 exp, ok := m.findMatchingExpectation("SendUserTextMessage", message, otherContents)
138 if !ok {
139 m.t.Error("unexpected call to SendUserTextMessage")
140 m.t.FailNow()
141 }
142 var retErr error
143 m.mu.Lock()
144 defer m.mu.Unlock()
145 if err, ok := exp.result[1].(error); ok {
146 retErr = err
147 }
148 return exp.result[0].(*ant.MessageResponse), retErr
149}
150
151func (m *MockConvo) ToolResultContents(ctx context.Context, resp *ant.MessageResponse) ([]ant.Content, error) {
152 m.recordCall("ToolResultContents", resp)
153 exp, ok := m.findMatchingExpectation("ToolResultContents", resp)
154 if !ok {
155 m.t.Error("unexpected call to ToolResultContents")
156 m.t.FailNow()
157 }
158 m.mu.Lock()
159 defer m.mu.Unlock()
160 var retErr error
161 if err, ok := exp.result[1].(error); ok {
162 retErr = err
163 }
164
165 return exp.result[0].([]ant.Content), retErr
166}
167
168func (m *MockConvo) ToolResultCancelContents(resp *ant.MessageResponse) ([]ant.Content, error) {
169 m.recordCall("ToolResultCancelContents", resp)
170 exp, ok := m.findMatchingExpectation("ToolResultCancelContents", resp)
171 if !ok {
172 m.t.Error("unexpected call to ToolResultCancelContents")
173 m.t.FailNow()
174 }
175 var retErr error
176 m.mu.Lock()
177 defer m.mu.Unlock()
178 if err, ok := exp.result[1].(error); ok {
179 retErr = err
180 }
181
182 return exp.result[0].([]ant.Content), retErr
183}
184
185func (m *MockConvo) CumulativeUsage() ant.CumulativeUsage {
186 m.recordCall("CumulativeUsage")
187 return ant.CumulativeUsage{}
188}
189
190func (m *MockConvo) OverBudget() error {
191 m.recordCall("OverBudget")
192 return nil
193}
194
Philip Zeyliger2c4db092025-04-28 16:57:50 -0700195func (m *MockConvo) GetID() string {
196 m.recordCall("GetID")
197 return "mock-conversation-id"
198}
199
200func (m *MockConvo) SubConvoWithHistory() *ant.Convo {
201 m.recordCall("SubConvoWithHistory")
202 return nil
203}
204
Earl Lee2e463fb2025-04-17 11:22:22 -0700205func (m *MockConvo) ResetBudget(_ ant.Budget) {
206 m.recordCall("ResetBudget")
207}
208
209// AssertExpectations checks that all expectations were met
210func (m *MockConvo) AssertExpectations(t *testing.T) {
211 m.mu.Lock()
212 defer m.mu.Unlock()
213
214 for method, expectations := range m.expectations {
215 if len(expectations) > 0 {
216 t.Errorf("not all expectations were met for method %s:", method)
217 }
218 }
219}