diff --git a/internal/api/client.go b/internal/api/client.go index 3ad92be..a5afb75 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -66,6 +66,13 @@ func (c *Client) do(req *http.Request) (*http.Response, error) { ) for attempt := 0; attempt <= maxRetries; attempt++ { if attempt > 0 { + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("failed to reset request body: %w", err) + } + req.Body = body + } delay := time.Duration(1<<(attempt-1)) * baseDelay jitter := time.Duration(rand.Int64N(int64(delay / 2))) sleep(delay + jitter) diff --git a/internal/api/client_test.go b/internal/api/client_test.go index 208f0d9..f9c5d91 100644 --- a/internal/api/client_test.go +++ b/internal/api/client_test.go @@ -1,6 +1,8 @@ package api import ( + "bytes" + "io" "net/http" "net/http/httptest" "sync/atomic" @@ -8,6 +10,43 @@ import ( "time" ) +func TestDo_RetryResetsBody(t *testing.T) { + origSleep := sleep + sleep = func(time.Duration) {} + defer func() { sleep = origSleep }() + + var bodies []string + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + bodies = append(bodies, string(b)) + n := attempts.Add(1) + if n == 1 { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.WriteHeader(http.StatusOK) + } + })) + defer server.Close() + + client := NewClient(server.URL, "test-key") + req, _ := client.newRequest(http.MethodPost, "/", bytes.NewReader([]byte(`{"hello":"world"}`))) + resp, err := client.do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if attempts.Load() != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts.Load()) + } + for i, body := range bodies { + if body != `{"hello":"world"}` { + t.Errorf("attempt %d body = %q, want non-empty JSON", i+1, body) + } + } +} + func TestNewRequest(t *testing.T) { client := NewClient("https://example.com/api/v1", "test-key")