Skip to content

[PyTorch] Integrate the cuBLAS MXFP8 NN, NT support for sm120#3050

Draft
KshitijLakhani wants to merge 5 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/test/mxfp8-cublas-gemm-sm120
Draft

[PyTorch] Integrate the cuBLAS MXFP8 NN, NT support for sm120#3050
KshitijLakhani wants to merge 5 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/test/mxfp8-cublas-gemm-sm120

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented May 28, 2026

Description

Enable MXFP8 support for NT, NN single GEMMs via cuBLAS for sm120.

Fixes #2668

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Testing

Tested test_numerics locally on sm120 for support

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

KshitijLakhani and others added 5 commits May 26, 2026 17:44
Adds NVTE_ENABLE_MXFP8_SM120 environment variable to unblock MXFP8
testing on sm120 (compute capability 12.0) devices. Default behavior
unchanged; MXFP8 remains gated off on sm120 without explicit opt-in
since not all GEMM layouts are currently supported.

Also adds tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py: a focused
layout x shape x dtype matrix exercising MXFP8 single GEMM via the
underlying general_gemm call directly. The TN layout is exercised
across small/medium/transformer-sized shapes and BF16/FP32 outputs.
NN and NT layouts on sm120 are marked strict-xfail; the suite will
fail-on-XPASS once full layout support is added so the markers can
be removed.
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…13.6+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Remove tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py. The TN/NN/NT MXFP8
GEMM code paths it was added to localize are already exercised end-to-end
on sm_120 (with cuBLASLt >= 13.6.0.2) by the existing te.Linear /
te.LayerNormLinear / te.GroupedLinear / te.TransformerLayer numerics
tests in tests/pytorch/test_numerics.py via the MXFP8BlockScaling entry
in fp8_recipes (each Linear forward + backward dispatches the three
cuBLAS GEMMs as fwd=TN, dgrad=NN, wgrad=NT).

The runtime _compute_mxfp8_support gate added in the earlier commits
on this branch already module-skips MXFP8 below cuBLASLt 13.6.0.2 on
sm_120, so the per-layout strict-xfail layer in this file is redundant.
Out-of-tree triage material (Testing/repro_mxfp8_layouts.cu and the
Testing/repro_mxfp8_layouts.py driver) remains available if a future
cuBLAS regression needs layout-localized signal again.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Inquiry] Support status and roadmap for MXFP8 on SM120 (Blackwell)

1 participant