llm/oai: retry more on failure

Co-Authored-By: sketch <hello@sketch.dev>
Change-ID: s6b8e59e9e006f5bak
diff --git a/llm/oai/oai.go b/llm/oai/oai.go
index 37484e0..2b5b7ed 100644
--- a/llm/oai/oai.go
+++ b/llm/oai/oai.go
@@ -674,10 +674,20 @@
 	// fmt.Printf("\n")
 
 	// Retry mechanism
-	backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second}
+	backoff := []time.Duration{1 * time.Second, 2 * time.Second, 5 * time.Second, 10 * time.Second, 15 * time.Second}
 
 	// retry loop
+	var errs error // accumulated errors across all attempts
 	for attempts := 0; ; attempts++ {
+		if attempts > 10 {
+			return nil, fmt.Errorf("openai request failed after %d attempts: %w", attempts, errs)
+		}
+		if attempts > 0 {
+			sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
+			slog.WarnContext(ctx, "openai request sleep before retry", "sleep", sleep, "attempts", attempts)
+			time.Sleep(sleep)
+		}
+
 		resp, err := client.CreateChatCompletion(ctx, req)
 
 		// Handle successful response
@@ -688,28 +698,33 @@
 		// Handle errors
 		var apiErr *openai.APIError
 		if ok := errors.As(err, &apiErr); !ok {
-			// Not an OpenAI API error, return immediately
-			return nil, err
+			// Not an OpenAI API error, return immediately with accumulated errors
+			return nil, errors.Join(errs, err)
 		}
 
 		switch {
 		case apiErr.HTTPStatusCode >= 500:
 			// Server error, try again with backoff
-			sleep := backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
-			slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode, "sleep", sleep)
-			time.Sleep(sleep)
+			slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
+			errs = errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
 			continue
 
 		case apiErr.HTTPStatusCode == 429:
-			// Rate limited, back off longer
-			sleep := 20*time.Second + backoff[min(attempts, len(backoff)-1)] + time.Duration(rand.Int64N(int64(time.Second)))
-			slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error(), "sleep", sleep)
-			time.Sleep(sleep)
+			// Rate limited, accumulate error and retry
+			slog.WarnContext(ctx, "openai_request_rate_limited", "error", apiErr.Error())
+			errs = errors.Join(errs, fmt.Errorf("status %d (rate limited): %s", apiErr.HTTPStatusCode, apiErr.Error()))
 			continue
 
+		case apiErr.HTTPStatusCode >= 400 && apiErr.HTTPStatusCode < 500:
+			// Client error, probably unrecoverable
+			slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
+			return nil, errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
+
 		default:
-			// Other error, return immediately
-			return nil, fmt.Errorf("OpenAI API error: %w", err)
+			// Other error, accumulate and retry
+			slog.WarnContext(ctx, "openai_request_failed", "error", apiErr.Error(), "status_code", apiErr.HTTPStatusCode)
+			errs = errors.Join(errs, fmt.Errorf("status %d: %s", apiErr.HTTPStatusCode, apiErr.Error()))
+			continue
 		}
 	}
 }