diff --git a/internal/api/api_key.go b/internal/api/api_key.go index 72116dd..6e04984 100644 --- a/internal/api/api_key.go +++ b/internal/api/api_key.go @@ -16,9 +16,9 @@ func (c *Client) GetAPIKey() (*APIKeyResponse, error) { return nil, err } - resp, err := c.httpClient.Do(req) + resp, err := c.do(req) if err != nil { - return nil, fmt.Errorf("request failed: %w", err) + return nil, err } defer resp.Body.Close() diff --git a/internal/api/client.go b/internal/api/client.go index 74023ba..053129e 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -3,10 +3,22 @@ package api import ( "encoding/json" "fmt" + "math/rand/v2" "net/http" "time" ) +const ( + maxRetries = 2 + baseDelay = 500 * time.Millisecond +) + +var sleep = time.Sleep + +func isRetryable(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || statusCode >= 500 +} + type APIError struct { StatusCode int Message string @@ -40,6 +52,37 @@ func errorFromResponse(resp *http.Response) *APIError { return &APIError{StatusCode: resp.StatusCode, Message: fmt.Sprintf("unexpected status: %d", resp.StatusCode)} } +func (c *Client) do(req *http.Request) (*http.Response, error) { + var ( + resp *http.Response + err error + ) + for attempt := 0; attempt <= maxRetries; attempt++ { + if attempt > 0 { + delay := time.Duration(1<<(attempt-1)) * baseDelay + jitter := time.Duration(rand.Int64N(int64(delay / 2))) + sleep(delay + jitter) + } + resp, err = c.httpClient.Do(req) + if err != nil { + if req.Context().Err() != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + continue + } + if !isRetryable(resp.StatusCode) { + return resp, nil + } + if attempt < maxRetries { + resp.Body.Close() + } + } + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + return resp, nil +} + func (c *Client) newRequest(method, path string) (*http.Request, error) { url := fmt.Sprintf("%s%s", c.baseURL, path) req, err := http.NewRequest(method, url, nil) diff --git a/internal/api/client_test.go b/internal/api/client_test.go index b4ab380..d76d0d7 100644 --- a/internal/api/client_test.go +++ b/internal/api/client_test.go @@ -2,7 +2,10 @@ package api import ( "net/http" + "net/http/httptest" + "sync/atomic" "testing" + "time" ) func TestNewRequest(t *testing.T) { @@ -48,3 +51,75 @@ func TestNewRequest_InvalidURL(t *testing.T) { t.Error("expected error for invalid URL, got nil") } } + +func TestDo_Retries(t *testing.T) { + origSleep := sleep + sleep = func(time.Duration) {} + defer func() { sleep = origSleep }() + + tests := []struct { + name string + responses []int + wantStatus int + wantAttempts int32 + }{ + { + name: "success on first attempt", + responses: []int{200}, + wantStatus: 200, + wantAttempts: 1, + }, + { + name: "retries on 429 then succeeds", + responses: []int{429, 200}, + wantStatus: 200, + wantAttempts: 2, + }, + { + name: "retries on 500 then succeeds", + responses: []int{500, 200}, + wantStatus: 200, + wantAttempts: 2, + }, + { + name: "exhausts retries on persistent 429", + responses: []int{429, 429, 429}, + wantStatus: 429, + wantAttempts: 3, + }, + { + name: "no retry on 401", + responses: []int{401}, + wantStatus: 401, + wantAttempts: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var attempts atomic.Int32 + idx := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.WriteHeader(tt.responses[idx]) + idx++ + })) + defer server.Close() + + client := NewClient(server.URL, "test-key") + req, _ := client.newRequest(http.MethodGet, "/") + resp, err := client.do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tt.wantStatus { + t.Errorf("status = %d, want %d", resp.StatusCode, tt.wantStatus) + } + if attempts.Load() != tt.wantAttempts { + t.Errorf("attempts = %d, want %d", attempts.Load(), tt.wantAttempts) + } + }) + } +}