oai: retry once for TLS bad record MAC errors
Add automatic retry logic for the specific TLS error "tls: bad record MAC"
that occurs occasionally with Fireworks API. The first occurrence triggers
an automatic retry, only failing to the user after a second consecutive failure.
This addresses intermittent connection issues without requiring manual
intervention for transient TLS handshake problems.
Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: s701f83a97373c925k
diff --git a/llm/oai/oai.go b/llm/oai/oai.go
index 2b8b3a1..8a450c4 100644
--- a/llm/oai/oai.go
+++ b/llm/oai/oai.go
@@ -733,6 +733,13 @@
}
// Handle errors
+ // Check for TLS "bad record MAC" errors and retry once
+ if strings.Contains(err.Error(), "tls: bad record MAC") && attempts == 0 {
+ slog.WarnContext(ctx, "tls bad record MAC error, retrying once", "error", err.Error())
+ errs = errors.Join(errs, fmt.Errorf("TLS error (attempt %d): %w", attempts+1, err))
+ continue
+ }
+
var apiErr *openai.APIError
if ok := errors.As(err, &apiErr); !ok {
// Not an OpenAI API error, return immediately with accumulated errors
diff --git a/llm/oai/oai_test.go b/llm/oai/oai_test.go
new file mode 100644
index 0000000..73ab4fa
--- /dev/null
+++ b/llm/oai/oai_test.go
@@ -0,0 +1,102 @@
+package oai
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+
+ "sketch.dev/llm"
+)
+
+// mockRoundTripper is a mock HTTP round tripper that can simulate TLS errors
+type mockRoundTripper struct {
+ callCount int
+ errorOnAttempt []int // which attempts should return TLS errors
+}
+
+func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ m.callCount++
+
+ // Check if this attempt should return a TLS error
+ for _, errorAttempt := range m.errorOnAttempt {
+ if m.callCount == errorAttempt {
+ return nil, errors.New(`Post "https://api.fireworks.ai/inference/v1/chat/completions": remote error: tls: bad record MAC`)
+ }
+ }
+
+ // Simulate timeout for other cases to avoid actual HTTP calls
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+ <-ctx.Done()
+ return nil, ctx.Err()
+}
+
+func TestTLSBadRecordMACRetry(t *testing.T) {
+ tests := []struct {
+ name string
+ errorOnAttempt []int
+ expectedCalls int
+ shouldSucceed bool
+ }{
+ {
+ name: "first attempt succeeds",
+ errorOnAttempt: []int{}, // no TLS errors
+ expectedCalls: 1,
+ shouldSucceed: false, // will timeout, but that's expected for this test
+ },
+ {
+ name: "first attempt fails with TLS error, second succeeds",
+ errorOnAttempt: []int{1}, // TLS error on first attempt
+ expectedCalls: 2,
+ shouldSucceed: false, // will timeout on second attempt
+ },
+ {
+ name: "both attempts fail with TLS error",
+ errorOnAttempt: []int{1, 2}, // TLS error on both attempts
+ expectedCalls: 2,
+ shouldSucceed: false, // should fail after second TLS error
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRT := &mockRoundTripper{
+ errorOnAttempt: tt.errorOnAttempt,
+ }
+ mockClient := &http.Client{
+ Transport: mockRT,
+ }
+
+ service := &Service{
+ HTTPC: mockClient,
+ Model: Qwen3CoderFireworks,
+ APIKey: "test-key",
+ }
+
+ req := &llm.Request{
+ Messages: []llm.Message{
+ {Role: llm.MessageRoleUser, Content: []llm.Content{{Type: llm.ContentTypeText, Text: "test"}}},
+ },
+ }
+
+ _, err := service.Do(context.Background(), req)
+
+ // Verify the expected number of calls were made
+ if mockRT.callCount != tt.expectedCalls {
+ t.Errorf("expected %d calls, got %d", tt.expectedCalls, mockRT.callCount)
+ }
+
+ // For TLS error cases, verify the error message contains both attempts
+ if len(tt.errorOnAttempt) > 1 {
+ if err == nil {
+ t.Error("expected error after multiple TLS failures")
+ } else if !strings.Contains(err.Error(), "tls: bad record MAC") {
+ t.Errorf("expected error to contain TLS error message, got: %v", err)
+ }
+ }
+ })
+ }
+}