llm/oai: retry more on failure
Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: s6b8e59e9e006f5bak
diff --git a/llm/oai/oai.go b/llm/oai/oai.go
index 37484e0..2b5b7ed 100644
--- a/llm/oai/oai.go
+++ b/llm/oai/oai.go
@@ -674,10 +674,20 @@
// fmt.Printf("\n")
// Retry mechanism
- backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second}
+ backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second, 10 * time.Second, 15 * time.Second}
// retry loop
+ var errs error // accumulated errors across all attempts
for attempts := 0; ; attempts++ {
+ if attempts > 10 {
+ return nil, fmt.Errorf("openai request failed after %d attempts: %w", attempts, errs)
+ }
+ if attempts > 0 {
+ sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+ slog.WarnContext(ctx, "openai request sleep before retry", "sleep", sleep, "attempts", attempts)
+ time.Sleep(sleep)
+ }
+
resp, err := client.CreateChatCompletion(ctx, req)
// Handle successful response
@@ -688,28 +698,33 @@
// Handle errors
var apiErr *openai.APIError
if ok := errors.As(err, &apiErr); !ok {
- // Not an OpenAI API error, return immediately
- return nil, err
+ // Not an OpenAI API error, return immediately with accumulated errors
+ return nil, errors.Join(errs, err)
}
switch {
case apiErr.HTTPStatusCode >= 500:
// Server error, try again with backoff
- sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
- slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode, "sleep", sleep)
- time.Sleep(sleep)
+ slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
+ errs = errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
continue
case apiErr.HTTPStatusCode == 429:
- // Rate limited, back off longer
- sleep := 20*time.Second + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
- slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error(), "sleep", sleep)
- time.Sleep(sleep)
+ // Rate limited, accumulate error and retry
+ slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error())
+ errs = errors.Join(errs, fmt.Errorf("status %d (rate limited): %s", apiErr.HTTPStatusCode, apiErr.Error()))
continue
+ case apiErr.HTTPStatusCode >= 400 && apiErr.HTTPStatusCode < 500:
+ // Client error, probably unrecoverable
+ slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
+ return nil, errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
+
default:
- // Other error, return immediately
- return nil, fmt.Errorf("OpenAI API error: %w", err)
+ // Other error, accumulate and retry
+ slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
+ errs = errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
+ continue
}
}
}