Skip to content

[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052

Open
allenphilipj wants to merge 2 commits into
NVIDIA:mainfrom
allenphilipj:codex-grouped-linear-fp8-cudagraph-skip
Open

[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052
allenphilipj wants to merge 2 commits into
NVIDIA:mainfrom
allenphilipj:codex-grouped-linear-fp8-cudagraph-skip

Conversation

@allenphilipj
Copy link
Copy Markdown

Summary:

  • Propagate the FP8 graph-capture skip_fp8_weight_update tensor through GroupedLinear.
  • Align GroupedLinear graph-capture handling with Linear, LayerNormLinear, and LayerNormMLP.
  • Add a focused regression test for the forwarded skip tensor and graph-compatible is_first_microbatch behavior.

Validation:

  • git diff --check
  • python3 -m py_compile transformer_engine/pytorch/module/grouped_linear.py tests/pytorch/test_cuda_graphs.py
  • Not run: focused pytest, because pytest is not installed in this local environment.

Fixes #3051

@allenphilipj allenphilipj requested a review from ksivaman as a code owner May 28, 2026 12:36
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
Signed-off-by: allenphilipj <allenphilipj@users.noreply.github.com>
@allenphilipj allenphilipj force-pushed the codex-grouped-linear-fp8-cudagraph-skip branch from 937ef34 to 80304fa Compare May 28, 2026 12:40
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR fixes a missing propagation of the CUDA graph FP8 weight-update skip tensor (skip_fp8_weight_update_tensor) in GroupedLinear, aligning it with Linear, LayerNormLinear, and LayerNormMLP.

  • grouped_linear.py: Before building non_tensor_args, the forward method now reads skip_fp8_weight_update_tensor from FP8GlobalStateManager.quantization_state when fp8_graph_capturing() is true, forces is_first_microbatch = False when that tensor is set, and passes the tensor into the tuple instead of the previously hardcoded None.
  • test_cuda_graphs.py: Adds a focused unit test that monkeypatches the graph-capture state, intercepts _GroupedLinear.forward, and asserts both the overridden is_first_microbatch value and the propagated skip tensor are present in non_tensor_args at the correct positions using named index constants.

Confidence Score: 5/5

Safe to merge — the change is a straight port of an established pattern from sibling modules with no behavioral differences.

The new block in grouped_linear.py is character-for-character identical to the corresponding block in linear.py, layernorm_linear.py, and layernorm_mlp.py, and the inner _GroupedLinear kernel was already wired to consume skip_fp8_weight_update. The only change to externally-visible behavior is that the tensor is no longer silently dropped, which is the intended fix.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds the standard FP8-graph-capture skip-tensor propagation block (9 lines) that was already present in Linear, LayerNormLinear, and LayerNormMLP; replaces the hardcoded None with the live tensor in non_tensor_args.
tests/pytorch/test_cuda_graphs.py Adds a focused regression test using pytest monkeypatching to verify skip-tensor propagation and is_first_microbatch override without requiring a CUDA device or full FP8 stack; uses named index constants instead of bare magic numbers.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant GroupedLinear
    participant FP8GlobalStateManager
    participant _GroupedLinear

    Caller->>GroupedLinear: "forward(inp, m_splits, is_first_microbatch=True)"
    GroupedLinear->>FP8GlobalStateManager: fp8_graph_capturing()
    FP8GlobalStateManager-->>GroupedLinear: True
    GroupedLinear->>FP8GlobalStateManager: quantization_state.skip_fp8_weight_update_tensor
    FP8GlobalStateManager-->>GroupedLinear: skip_tensor (non-None)
    Note over GroupedLinear: is_first_microbatch = False (overridden)
    GroupedLinear->>GroupedLinear: "build non_tensor_args [2]=False, [19]=skip_tensor"
    GroupedLinear->>_GroupedLinear: forward(ctx, inp, non_tensor_args, weights...)
    _GroupedLinear->>_GroupedLinear: "skip_update_flag=skip_tensor used in FP8 weight cast"
    _GroupedLinear-->>GroupedLinear: out, workspaces
    GroupedLinear-->>Caller: out
Loading

Reviews (2): Last reviewed commit: "Update tests/pytorch/test_cuda_graphs.py" | Re-trigger Greptile

Comment thread tests/pytorch/test_cuda_graphs.py Outdated
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: allenphilipj <allen.philip@intercom.io>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch] GroupedLinear does not propagate skip_fp8_weight_update during FP8 CUDA graph capture

1 participant