Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/api/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ func (c *Client) GetAPIKey() (*APIKeyResponse, error) {
return nil, err
}

resp, err := c.httpClient.Do(req)
resp, err := c.do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, err
}
defer resp.Body.Close()

Expand Down
43 changes: 43 additions & 0 deletions internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@ package api
import (
"encoding/json"
"fmt"
"math/rand/v2"
"net/http"
"time"
)

const (
maxRetries = 2
baseDelay = 500 * time.Millisecond
)

var sleep = time.Sleep

func isRetryable(statusCode int) bool {
return statusCode == http.StatusTooManyRequests || statusCode >= 500
}

type APIError struct {
StatusCode int
Message string
Expand Down Expand Up @@ -40,6 +52,37 @@ func errorFromResponse(resp *http.Response) *APIError {
return &APIError{StatusCode: resp.StatusCode, Message: fmt.Sprintf("unexpected status: %d", resp.StatusCode)}
}

func (c *Client) do(req *http.Request) (*http.Response, error) {
var (
resp *http.Response
err error
)
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
delay := time.Duration(1<<(attempt-1)) * baseDelay
jitter := time.Duration(rand.Int64N(int64(delay / 2)))
sleep(delay + jitter)
}
resp, err = c.httpClient.Do(req)
if err != nil {
if req.Context().Err() != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
continue
}
if !isRetryable(resp.StatusCode) {
return resp, nil
}
if attempt < maxRetries {
resp.Body.Close()
}
}
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
return resp, nil
}

func (c *Client) newRequest(method, path string) (*http.Request, error) {
url := fmt.Sprintf("%s%s", c.baseURL, path)
req, err := http.NewRequest(method, url, nil)
Expand Down
75 changes: 75 additions & 0 deletions internal/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package api

import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)

func TestNewRequest(t *testing.T) {
Expand Down Expand Up @@ -48,3 +51,75 @@ func TestNewRequest_InvalidURL(t *testing.T) {
t.Error("expected error for invalid URL, got nil")
}
}

func TestDo_Retries(t *testing.T) {
origSleep := sleep
sleep = func(time.Duration) {}
defer func() { sleep = origSleep }()

tests := []struct {
name string
responses []int
wantStatus int
wantAttempts int32
}{
{
name: "success on first attempt",
responses: []int{200},
wantStatus: 200,
wantAttempts: 1,
},
{
name: "retries on 429 then succeeds",
responses: []int{429, 200},
wantStatus: 200,
wantAttempts: 2,
},
{
name: "retries on 500 then succeeds",
responses: []int{500, 200},
wantStatus: 200,
wantAttempts: 2,
},
{
name: "exhausts retries on persistent 429",
responses: []int{429, 429, 429},
wantStatus: 429,
wantAttempts: 3,
},
{
name: "no retry on 401",
responses: []int{401},
wantStatus: 401,
wantAttempts: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var attempts atomic.Int32
idx := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts.Add(1)
w.WriteHeader(tt.responses[idx])
idx++
}))
defer server.Close()

client := NewClient(server.URL, "test-key")
req, _ := client.newRequest(http.MethodGet, "/")
resp, err := client.do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != tt.wantStatus {
t.Errorf("status = %d, want %d", resp.StatusCode, tt.wantStatus)
}
if attempts.Load() != tt.wantAttempts {
t.Errorf("attempts = %d, want %d", attempts.Load(), tt.wantAttempts)
}
})
}
}
Loading