Skip to content

[PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM #3020

Open
vedaanta wants to merge 3 commits into
NVIDIA:mainfrom
vedaanta:vedaanta/te-fp8-vs-f16-shrink-b1
Open

[PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM #3020
vedaanta wants to merge 3 commits into
NVIDIA:mainfrom
vedaanta:vedaanta/te-fp8-vs-f16-shrink-b1

Conversation

@vedaanta
Copy link
Copy Markdown

@vedaanta vedaanta commented May 21, 2026

The 9 fp8_9..fp8_17 configs in model_configs_fp8_vs_f16 use shapes (B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference comparison. The reference path in test_dpa_fp8_vs_f16 materializes the full (B, H, S, S) attention matrix in bf16, and keeps a handful of them live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64 the per-test peak is ~70 GiB, which pushes the suite into OOM territory on Blackwell (~91 GB measured with the cuDNN caching allocator residue).

Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured on B200 (SM_100, cuDNN 9.23, TE main):

per-test peak torch.cuda.max_memory_allocated:
before: 70.0 GiB (fp8_14)
after : 36.1 GiB (fp8_14) -48%
per-test peak nvidia-smi memory.used:
before: 96.8 GiB
after : 51.3 GiB -47%
test outcome (B200, develop FE, this TE):
identical 618F / 2196P / 891S, wall time within ~3%

The shrunk configs still exercise every distinct shape/mask/SWA/GQA combination that the originals did -- only B is smaller. The suite now fits comfortably on 80 GB cards.

fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small (~few GiB) and the larger batch is useful coverage for padding_causal.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 21, 2026

Greptile Summary

This PR reduces peak GPU memory usage in the test_dpa_fp8_vs_f16 test suite by shrinking the nine large configs (fp8_9fp8_17) in model_configs_fp8_vs_f16, cutting per-test peak from ~70 GiB to ~36 GiB so the suite fits on 80 GB cards.

  • fp8_9, fp8_11, fp8_14, fp8_17: sequence length halved (S: 4096→2048 or 8192→4096); batch size unchanged.
  • fp8_12, fp8_15, fp8_16: batch size halved (B: 2→1); sequence length unchanged.
  • fp8_10: both B increased (1→2) and S halved (4096→2048); net memory is lower.
  • fp8_13: B and S unchanged (B=2, S=8192), but num_gqa_groups=4 is added, converting a non-GQA + sliding-window config into a GQA + sliding-window config — a subtle coverage trade-off noted below.

Confidence Score: 5/5

Safe to merge — the change only reduces test config sizes to avoid OOM; no production code is touched.

The entire change is confined to test parameter values. Every modified config produces a strictly smaller attention matrix, the test logic itself is untouched, and fp8_19/fp8_20 are intentionally left at B=2 for padding_causal coverage. The one noteworthy detail is fp8_13 gaining num_gqa_groups=4, which shifts its coverage from non-GQA+SWA to GQA+SWA for the H=32/D=128 shape, but this does not break anything.

tests/pytorch/attention/test_attention.py — specifically the fp8_13 config change which alters GQA grouping rather than just reducing batch or sequence size.

Important Files Changed

Filename Overview
tests/pytorch/attention/test_attention.py Reduces test config sizes (B or S halved) in model_configs_fp8_vs_f16 (fp8_9–fp8_17) to avoid OOM on 80 GB cards; fp8_13 also gains num_gqa_groups=4 changing its coverage profile

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fp8_9–fp8_17 configs\n(original: up to B=2, S=8192)"] --> B{Memory reduction\nstrategy}
    B -->|"fp8_9, fp8_11\nfp8_14, fp8_17"| C["Halve S\n(S: 4096→2048 or 8192→4096)\nB unchanged"]
    B -->|"fp8_12, fp8_15\nfp8_16"| D["Halve B\n(B: 2→1)\nS unchanged"]
    B -->|"fp8_10"| E["Increase B (1→2)\nHalve S (4096→2048)\nNet memory lower"]
    B -->|"fp8_13"| F["B=2, S=8192 unchanged\nAdd num_gqa_groups=4\nCoverage profile changed"]
    C --> G["Peak ~36 GiB\nFits on 80 GB cards"]
    D --> G
    E --> G
    F --> G
Loading

Reviews (5): Last reviewed commit: "tests/attention: black format fp8_13 Mod..." | Re-trigger Greptile

@vedaanta vedaanta force-pushed the vedaanta/te-fp8-vs-f16-shrink-b1 branch from 1a59d59 to c3f1e50 Compare May 21, 2026 22:36
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !
Combinations discussed and adjusted offline to not choose the hard hammer approach of making B=1 for all but instead diversify across B, S, H and D

Good to merge after CI passes

@KshitijLakhani KshitijLakhani changed the title tests/attention: shrink fp8_vs_f16 configs from B=2 to B=1 [PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM May 21, 2026
@KshitijLakhani
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L0

"fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)),
"fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"),
"fp8_13": ModelConfig(
2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the "num_gqa_groups=4" intended here? Should we focus on just reducing the size of the config via B or S, instead of adding GQA?

vedaanta and others added 3 commits May 26, 2026 21:48
The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes
(B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference
comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the
full (B, H, S, S) attention matrix in bf16, and keeps a handful of them
live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64
the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB
cards (H100) and pushes the suite into OOM territory on Blackwell (~91
GB measured with the cuDNN caching allocator residue).

Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured
on B200 (SM_100, cuDNN 9.23, TE main):

  per-test peak `torch.cuda.max_memory_allocated`:
     before: 70.0 GiB (fp8_14)
     after : 36.1 GiB (fp8_14)         -48%
  per-test peak `nvidia-smi memory.used`:
     before: 96.8 GiB
     after : 51.3 GiB                  -47%
  test outcome (B200, develop FE, this TE):
     identical 618F / 2196P / 891S, wall time within ~3%

The shrunk configs still exercise every distinct shape/mask/SWA/GQA
combination that the originals did -- only B is smaller. The suite now
fits comfortably on 80 GB cards.

fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small
(~few GiB) and the larger batch is useful coverage for padding_causal.

Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>
Line was 105 chars; black requires <=100 with the project's preview+
string_processing settings.

Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
@vedaanta vedaanta force-pushed the vedaanta/te-fp8-vs-f16-shrink-b1 branch from cd29763 to 56b1837 Compare May 27, 2026 04:48
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

attention community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants