Skip to content

Add the getter and setter of skip_fp8_weight_update_tensor#3015

Open
xrennvidia wants to merge 6 commits into
NVIDIA:mainfrom
xrennvidia:xren/fix_skip_fp8_weight_update
Open

Add the getter and setter of skip_fp8_weight_update_tensor#3015
xrennvidia wants to merge 6 commits into
NVIDIA:mainfrom
xrennvidia:xren/fix_skip_fp8_weight_update

Conversation

@xrennvidia
Copy link
Copy Markdown
Collaborator

@xrennvidia xrennvidia commented May 20, 2026

Description

The getter and setter of skip_fp8_weight_update_tensor were deleted in @pggPL 's PR2759, but MCore local Cuda Graph implementation still needs it (like here), so create this PR to recover it back.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Xiaowei Ren <xren@nvidia.com>
@xrennvidia xrennvidia requested a review from ksivaman as a code owner May 20, 2026 09:48
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 20, 2026

Greptile Summary

This PR restores the get_skip_fp8_weight_update_tensor and set_skip_fp8_weight_update_tensor class methods on FP8GlobalStateManager that were removed in PR #2759, which Megatron-Core's CUDA-graph implementation still depends on. As a secondary improvement, the existing inline tensor-initialization logic in graph.py is moved into the new setter, reducing duplication.

  • quantization.py: Two new class methods — set_skip_fp8_weight_update_tensor(skip: bool) (initialises the tensor on first call then fills it) and get_skip_fp8_weight_update_tensor() (returns the Optional[torch.Tensor]) — are added to FP8GlobalStateManager.
  • graph.py: Both sites that previously directly mutated quantization_state.skip_fp8_weight_update_tensor are replaced with calls to the new setter, keeping the runtime behaviour identical.

Confidence Score: 5/5

Safe to merge — the change is a minimal, additive restoration of two accessor methods with no behaviour changes to existing call sites.

Both new methods encapsulate logic that was already present and tested inline in graph.py. The setter correctly handles lazy initialisation of the underlying tensor, and the getter faithfully mirrors the field value. No existing call sites are altered in a way that could change runtime behaviour.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Adds set_skip_fp8_weight_update_tensor and get_skip_fp8_weight_update_tensor class methods to FP8GlobalStateManager; logic is correct and matches the original inline usage.
transformer_engine/pytorch/graph.py Replaces two direct accesses to quantization_state.skip_fp8_weight_update_tensor with calls to the new set_skip_fp8_weight_update_tensor setter; behavior is unchanged.

Sequence Diagram

sequenceDiagram
    participant MCore as MCore / graph.py
    participant Manager as FP8GlobalStateManager
    participant State as FP8GlobalState

    MCore->>Manager: set_skip_fp8_weight_update_tensor(False)
    Manager->>State: skip_fp8_weight_update_tensor is None?
    alt tensor is None
        Manager->>State: create torch.empty(1, float32, cuda)
    end
    Manager->>State: fill_(False)

    MCore->>Manager: get_skip_fp8_weight_update_tensor()
    Manager->>State: read skip_fp8_weight_update_tensor
    State-->>Manager: Optional[Tensor]
    Manager-->>MCore: Optional[Tensor]
Loading

Reviews (5): Last reviewed commit: "Merge branch 'main' into xren/fix_skip_f..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/quantization.py Outdated
return type fix

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com>
@ptrendx ptrendx requested a review from pggPL May 21, 2026 00:39
Copy link
Copy Markdown
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there could be a reason why Pawel removed those functions from this object and we may need to change MCore instead in order to have this be compatible with torch.compile. Setting 'request changes' status for now until @pggPL reviews it.

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
Copy link
Copy Markdown
Collaborator

@pggPL pggPL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

I didn't know that this is used in mcore, I've run the torch compile test with this code and it also passes.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 27, 2026

/te-ci pytorch

@xrennvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants