| crawshaw | 5c86165 | 2025-07-29 16:34:52 +0000 | [diff] [blame] | 1 | package oai |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "errors" |
| 6 | "net/http" |
| 7 | "strings" |
| 8 | "testing" |
| 9 | "time" |
| 10 | |
| 11 | "sketch.dev/llm" |
| 12 | ) |
| 13 | |
| 14 | // mockRoundTripper is a mock HTTP round tripper that can simulate TLS errors |
| 15 | type mockRoundTripper struct { |
| 16 | callCount int |
| 17 | errorOnAttempt []int // which attempts should return TLS errors |
| 18 | } |
| 19 | |
| 20 | func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
| 21 | m.callCount++ |
| 22 | |
| 23 | // Check if this attempt should return a TLS error |
| 24 | for _, errorAttempt := range m.errorOnAttempt { |
| 25 | if m.callCount == errorAttempt { |
| 26 | return nil, errors.New(`Post "https://api.fireworks.ai/inference/v1/chat/completions": remote error: tls: bad record MAC`) |
| 27 | } |
| 28 | } |
| 29 | |
| 30 | // Simulate timeout for other cases to avoid actual HTTP calls |
| 31 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) |
| 32 | defer cancel() |
| 33 | <-ctx.Done() |
| 34 | return nil, ctx.Err() |
| 35 | } |
| 36 | |
| 37 | func TestTLSBadRecordMACRetry(t *testing.T) { |
| 38 | tests := []struct { |
| 39 | name string |
| 40 | errorOnAttempt []int |
| 41 | expectedCalls int |
| 42 | shouldSucceed bool |
| 43 | }{ |
| 44 | { |
| 45 | name: "first attempt succeeds", |
| 46 | errorOnAttempt: []int{}, // no TLS errors |
| 47 | expectedCalls: 1, |
| 48 | shouldSucceed: false, // will timeout, but that's expected for this test |
| 49 | }, |
| 50 | { |
| 51 | name: "first attempt fails with TLS error, second succeeds", |
| 52 | errorOnAttempt: []int{1}, // TLS error on first attempt |
| 53 | expectedCalls: 2, |
| 54 | shouldSucceed: false, // will timeout on second attempt |
| 55 | }, |
| 56 | { |
| 57 | name: "both attempts fail with TLS error", |
| 58 | errorOnAttempt: []int{1, 2}, // TLS error on both attempts |
| 59 | expectedCalls: 2, |
| 60 | shouldSucceed: false, // should fail after second TLS error |
| 61 | }, |
| 62 | } |
| 63 | |
| 64 | for _, tt := range tests { |
| 65 | t.Run(tt.name, func(t *testing.T) { |
| 66 | mockRT := &mockRoundTripper{ |
| 67 | errorOnAttempt: tt.errorOnAttempt, |
| 68 | } |
| 69 | mockClient := &http.Client{ |
| 70 | Transport: mockRT, |
| 71 | } |
| 72 | |
| 73 | service := &Service{ |
| 74 | HTTPC: mockClient, |
| 75 | Model: Qwen3CoderFireworks, |
| 76 | APIKey: "test-key", |
| 77 | } |
| 78 | |
| 79 | req := &llm.Request{ |
| 80 | Messages: []llm.Message{ |
| 81 | {Role: llm.MessageRoleUser, Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test"}}}, |
| 82 | }, |
| 83 | } |
| 84 | |
| 85 | _, err := service.Do(context.Background(), req) |
| 86 | |
| 87 | // Verify the expected number of calls were made |
| 88 | if mockRT.callCount != tt.expectedCalls { |
| 89 | t.Errorf("expected %d calls, got %d", tt.expectedCalls, mockRT.callCount) |
| 90 | } |
| 91 | |
| 92 | // For TLS error cases, verify the error message contains both attempts |
| 93 | if len(tt.errorOnAttempt) > 1 { |
| 94 | if err == nil { |
| 95 | t.Error("expected error after multiple TLS failures") |
| 96 | } else if !strings.Contains(err.Error(), "tls: bad record MAC") { |
| 97 | t.Errorf("expected error to contain TLS error message, got: %v", err) |
| 98 | } |
| 99 | } |
| 100 | }) |
| 101 | } |
| 102 | } |