blob: 73ab4fa635ad4ea77be55cf3da3e6b0947ba0d51 [file] [log] [blame]
crawshaw5c861652025-07-29 16:34:52 +00001package oai
2
3import (
4 "context"
5 "errors"
6 "net/http"
7 "strings"
8 "testing"
9 "time"
10
11 "sketch.dev/llm"
12)
13
14// mockRoundTripper is a mock HTTP round tripper that can simulate TLS errors
15type mockRoundTripper struct {
16 callCount int
17 errorOnAttempt []int // which attempts should return TLS errors
18}
19
20func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
21 m.callCount++
22
23 // Check if this attempt should return a TLS error
24 for _, errorAttempt := range m.errorOnAttempt {
25 if m.callCount == errorAttempt {
26 return nil, errors.New(`Post "https://api.fireworks.ai/inference/v1/chat/completions": remote error: tls: bad record MAC`)
27 }
28 }
29
30 // Simulate timeout for other cases to avoid actual HTTP calls
31 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
32 defer cancel()
33 <-ctx.Done()
34 return nil, ctx.Err()
35}
36
37func TestTLSBadRecordMACRetry(t *testing.T) {
38 tests := []struct {
39 name string
40 errorOnAttempt []int
41 expectedCalls int
42 shouldSucceed bool
43 }{
44 {
45 name: "first attempt succeeds",
46 errorOnAttempt: []int{}, // no TLS errors
47 expectedCalls: 1,
48 shouldSucceed: false, // will timeout, but that's expected for this test
49 },
50 {
51 name: "first attempt fails with TLS error, second succeeds",
52 errorOnAttempt: []int{1}, // TLS error on first attempt
53 expectedCalls: 2,
54 shouldSucceed: false, // will timeout on second attempt
55 },
56 {
57 name: "both attempts fail with TLS error",
58 errorOnAttempt: []int{1, 2}, // TLS error on both attempts
59 expectedCalls: 2,
60 shouldSucceed: false, // should fail after second TLS error
61 },
62 }
63
64 for _, tt := range tests {
65 t.Run(tt.name, func(t *testing.T) {
66 mockRT := &mockRoundTripper{
67 errorOnAttempt: tt.errorOnAttempt,
68 }
69 mockClient := &http.Client{
70 Transport: mockRT,
71 }
72
73 service := &Service{
74 HTTPC: mockClient,
75 Model: Qwen3CoderFireworks,
76 APIKey: "test-key",
77 }
78
79 req := &llm.Request{
80 Messages: []llm.Message{
81 {Role: llm.MessageRoleUser, Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test"}}},
82 },
83 }
84
85 _, err := service.Do(context.Background(), req)
86
87 // Verify the expected number of calls were made
88 if mockRT.callCount != tt.expectedCalls {
89 t.Errorf("expected %d calls, got %d", tt.expectedCalls, mockRT.callCount)
90 }
91
92 // For TLS error cases, verify the error message contains both attempts
93 if len(tt.errorOnAttempt) > 1 {
94 if err == nil {
95 t.Error("expected error after multiple TLS failures")
96 } else if !strings.Contains(err.Error(), "tls: bad record MAC") {
97 t.Errorf("expected error to contain TLS error message, got: %v", err)
98 }
99 }
100 })
101 }
102}