[PyTorch] Integrate the cuBLAS MXFP8 NN, NT support for sm120#3050
Draft
KshitijLakhani wants to merge 5 commits into
Draft
[PyTorch] Integrate the cuBLAS MXFP8 NN, NT support for sm120#3050KshitijLakhani wants to merge 5 commits into
KshitijLakhani wants to merge 5 commits into
Conversation
Adds NVTE_ENABLE_MXFP8_SM120 environment variable to unblock MXFP8 testing on sm120 (compute capability 12.0) devices. Default behavior unchanged; MXFP8 remains gated off on sm120 without explicit opt-in since not all GEMM layouts are currently supported. Also adds tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py: a focused layout x shape x dtype matrix exercising MXFP8 single GEMM via the underlying general_gemm call directly. The TN layout is exercised across small/medium/transformer-sized shapes and BF16/FP32 outputs. NN and NT layouts on sm120 are marked strict-xfail; the suite will fail-on-XPASS once full layout support is added so the markers can be removed.
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…13.6+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Remove tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py. The TN/NN/NT MXFP8 GEMM code paths it was added to localize are already exercised end-to-end on sm_120 (with cuBLASLt >= 13.6.0.2) by the existing te.Linear / te.LayerNormLinear / te.GroupedLinear / te.TransformerLayer numerics tests in tests/pytorch/test_numerics.py via the MXFP8BlockScaling entry in fp8_recipes (each Linear forward + backward dispatches the three cuBLAS GEMMs as fwd=TN, dgrad=NN, wgrad=NT). The runtime _compute_mxfp8_support gate added in the earlier commits on this branch already module-skips MXFP8 below cuBLASLt 13.6.0.2 on sm_120, so the per-layout strict-xfail layer in this file is redundant. Out-of-tree triage material (Testing/repro_mxfp8_layouts.cu and the Testing/repro_mxfp8_layouts.py driver) remains available if a future cuBLAS regression needs layout-localized signal again.
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Enable MXFP8 support for NT, NN single GEMMs via cuBLAS for sm120.
Fixes #2668
Type of change
Testing
Tested test_numerics locally on sm120 for support
Checklist: