[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052
[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052allenphilipj wants to merge 2 commits into
Conversation
Signed-off-by: allenphilipj <allenphilipj@users.noreply.github.com>
937ef34 to
80304fa
Compare
Greptile SummaryThis PR fixes a missing propagation of the CUDA graph FP8 weight-update skip tensor (
Confidence Score: 5/5Safe to merge — the change is a straight port of an established pattern from sibling modules with no behavioral differences. The new block in grouped_linear.py is character-for-character identical to the corresponding block in linear.py, layernorm_linear.py, and layernorm_mlp.py, and the inner _GroupedLinear kernel was already wired to consume skip_fp8_weight_update. The only change to externally-visible behavior is that the tensor is no longer silently dropped, which is the intended fix. No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant GroupedLinear
participant FP8GlobalStateManager
participant _GroupedLinear
Caller->>GroupedLinear: "forward(inp, m_splits, is_first_microbatch=True)"
GroupedLinear->>FP8GlobalStateManager: fp8_graph_capturing()
FP8GlobalStateManager-->>GroupedLinear: True
GroupedLinear->>FP8GlobalStateManager: quantization_state.skip_fp8_weight_update_tensor
FP8GlobalStateManager-->>GroupedLinear: skip_tensor (non-None)
Note over GroupedLinear: is_first_microbatch = False (overridden)
GroupedLinear->>GroupedLinear: "build non_tensor_args [2]=False, [19]=skip_tensor"
GroupedLinear->>_GroupedLinear: forward(ctx, inp, non_tensor_args, weights...)
_GroupedLinear->>_GroupedLinear: "skip_update_flag=skip_tensor used in FP8 weight cast"
_GroupedLinear-->>GroupedLinear: out, workspaces
GroupedLinear-->>Caller: out
Reviews (2): Last reviewed commit: "Update tests/pytorch/test_cuda_graphs.py" | Re-trigger Greptile |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: allenphilipj <allen.philip@intercom.io>
Summary:
Validation:
Fixes #3051