Skip to content

FSDP2 fully_shard breaks Linear4bit forward via NaN canonicalization of packed-NF4 bf16 storage #1945

@neil-the-nowledgeable

Description

@neil-the-nowledgeable

Summary

FSDP2.fully_shard produces incorrect forward output for modules containing bitsandbytes.nn.Linear4bit parameters. Two distinct mechanisms are involved; the primary one (Mechanism A) is the focus of this issue,
with the secondary one (Mechanism B) noted for completeness.

  1. Mechanism A (NaN normalization in parameter-swap): post-fully_shard, FSDP2's parameter-swap mechanism reads bf16-stored Params4bit weight bytes through a float-aware code path that normalizes every NaN
    bit pattern to a fixed quiet-NaN representation 0x7FFF
    (sign=0, exp=0xFF, mantissa=0x7F — quiet bit set, all other mantissa bits set). Because bnb stores packed-NF4 nibble indices in bf16-shaped containers,
    ~0.098% of weight elements coincidentally encode bf16 NaN bit patterns; FSDP2 normalizes them all to 0x7FFF. bnb.matmul_4bit then decodes the normalized bytes as different NF4 indices than the original,
    producing wrong matmul output. Wrap-granularity invariant, reduce_dtype invariant, cast_forward_inputs invariant, cross-host deterministic (full empirical matrix below).

  2. Mechanism B (segfault, secondary): weight.redistribute([Replicate]) and weight.full_tensor() SIGSEGV on bf16-packed Params4bit data. DTensor's gather-to-replicate path is broken for packed-NF4 storage.
    Happy to file as a separate issue if preferred.

Both reproducible at WS=1 (single-rank, single-host) and WS=2 (multi-rank). FSDP1 with the canonical bnb FSDP-QLoRA recipe (use_orig_params=False) is not affected — empirically verified here at WS=1 (test 6 in
the diagnostic chain below; FSDP1 degrades to NO_SHARD), and indirectly corroborated at WS≥2 by the Answer.AI / HF Transformers / Axolotl reference recipes continuing to train correctly in production. I did not
run a direct FSDP1 cross-check at WS≥2 in this report.

Byte-level mechanism

Empirical forensic at TinyLlama-1.1B layers[0].self_attn.q_proj.base_layer.weight (1,048,576 bf16 elements = 2,097,152 bytes of packed NF4):

Metric Value
Differing bytes (orig vs swap-target read inside forward) 1095 (0.0522%)
Differing bf16 elements 1030 (0.0982%)
Original class of every differing element NaN (1030/1030 = 100%)
Post-swap bit pattern of every differing element 0x7FFF (a quiet NaN — not IEEE's recommended canonical 0x7FC0) (1030/1030 = 100%)
DTensor._local_tensor view vs original 0 byte diff (byte-true preserved)
Per-Linear max_abs_delta (swap vs byte-true cached-ref) 0.100586
Per-Linear mean_abs_delta 0.000313
Cross-host (two physically separate Jetson Orin Nano Super units, sm_87, JP6.2) Bit-identical on every metric, including exact element positions

The 1095:1030 ratio (~1.06 bytes per differing element) implies ~65 elements have both bytes changed, consistent with negative NaNs (0xFFxx) being flipped to positive 0x7FFF in addition to mantissa
normalization. The remaining ~965 elements differ only in the low byte (mantissa normalization alone).

Sample differing elements (bit-identical across hosts):

Element Original bits Original class Post-swap bits
580 0x7fea qNaN, mant=0x6a 0x7fff
992 0x7f89 sNaN, mant=0x09 0x7fff
1077 0x7fb5 sNaN, mant=0x35 0x7fff
4574 0x7fd8 qNaN, mant=0x58 0x7fff
5188 0x7f88 sNaN, mant=0x08 0x7fff

The pattern is unambiguous: any bf16 NaN encoding gets normalized to a single fixed quiet-NaN pattern 0x7FFF, regardless of original sign or mantissa — consistent with a float-aware read path that quiets sNaNs
and clears non-quiet-bit mantissa state. The packed-NF4 nibble indices happened to encode NaN bit patterns; after normalization they encode different nibble indices.

Full forensic script: https://gist.github.com/neil-the-nowledgeable/d5bbffd9bd83029314771d9f46472cb2

Configuration-invariance matrix

Single-rank, single-host TinyLlama-1.1B + bnb-NF4 + PEFT-LoRA, forward-only loss on fixed-seed input. All non-failing cells produce loss_fsdp2 = 12.725160598754883 to 16 decimal places, vs baseline
12.691308975219727 (Δ=0.0339, 0.27% relative error):

Cell Wrap FSDP units mp_policy other loss_fsdp2
A per-DecoderLayer 23 bf16/fp32 12.725160598754883
C per-DecoderLayer 23 bf16/bf16 12.725160598754883
D per-Linear4bit 155 bf16/fp32 12.725160598754883
E root-only 1 bf16/fp32 12.725160598754883
G per-DecoderLayer 23 bf16/fp32 cast_forward_inputs=False 12.725160598754883
F per-DecoderLayer 23 bf16/fp32 double_quant=False (different baseline 12.6973) 12.742940902709961
B per-DecoderLayer 23 None AttributeError in _init_mp_dtypes (FSDP2 requires MixedPrecisionPolicy instance)

Invariant in: wrap granularity (1, 23, 155 FSDP units → identical bit pattern), reduce_dtype (bf16 or fp32), cast_forward_inputs flag.
Depends on: param_dtype = torch.bfloat16 — the only configuration FSDP2 supports for sharding bnb's bf16-quant-storage Params4bit, and the configuration that triggers the float-normalization read path.
Note on double_quant: incidental; removing it amplifies the delta by ~35% (0.0339 → 0.0456) because the weight distribution shifts and hits slightly more NaN patterns; the bug is not caused by double-quant.

Full matrix script: https://gist.github.com/neil-the-nowledgeable/24412f2da39880e4fb0570198aca442e

Empirical workaround (proof-by-construction)

The following pattern produces bit-identical training results to pre-shard baseline at WS=1 AND WS=2, including with real models (TinyLlama-1.1B + real peft.LoraConfig + 154 Linear4bit modules):

# Step 1: BEFORE fully_shard, walk the model and capture each Linear4bit's
# quant_state by module path (because fully_shard re-routes self.weight
# through a swap-target Parameter that loses quant_state).
import bitsandbytes as bnb
import torch
import torch.distributed as dist

qs_cache = {}
for name, module in model.named_modules():
    if isinstance(module, bnb.nn.Linear4bit):
        qs_cache[name] = module.weight.quant_state

# Step 2: call fully_shard(...) on the model / decoder layers as usual.

# Step 3: AFTER fully_shard, walk again and capture the byte-true
# DTensor._local_tensor reference + the quant_state from step 1.
for name, module in model.named_modules():
    if isinstance(module, bnb.nn.Linear4bit) and hasattr(module.weight, '_local_tensor'):
        module._cached_local = module.weight._local_tensor
        module._cached_qs = qs_cache[name]

# Step 4: replace Linear4bit.forward globally with a cached-ref-aware version.
_orig_forward = bnb.nn.Linear4bit.forward

def _ws_aware_forward(self, x_in):
    qs = getattr(self, '_cached_qs', None)
    cached_local = getattr(self, '_cached_local', None)
    if qs is None or cached_local is None:
        return _orig_forward(self, x_in)
    local = cached_local.contiguous()
    ws = dist.get_world_size() if dist.is_initialized() else 1
    if ws > 1:
        # Manual all_gather — DTensor's full_tensor() / redistribute() SIGSEGVs
        # on bf16-packed Params4bit (Mechanism B above).
        gathered = [torch.empty_like(local) for _ in range(ws)]
        dist.all_gather(gathered, local)
        full_w = torch.cat(gathered, dim=0)
    else:
        full_w = local
    if x_in.dtype != torch.bfloat16:
        x_in = x_in.to(torch.bfloat16)
    return bnb.matmul_4bit(x_in, full_w.t(), bias=self.bias, quant_state=qs)

bnb.nn.Linear4bit.forward = _ws_aware_forward


**Critical detail**: the `_local_tensor` reference must be captured OUTSIDE forward (where `linear.weight` resolves to the DTensor). Inside `Linear4bit.forward`, FSDP2's swap mechanism replaces `self.weight` with a Parameter that LACKS the `_local_tensor` attribute and points at different memory than what `_local_tensor` references externally. Reading `self.weight._local_tensor` from inside forward will silently fail (AttributeError) or fall through to broken behavior.

### Minimal reproducer

30-line reproducer demonstrating Mechanism A (wrong-output) at WS=1. Numbers below are illustrativeyour exact values depend on the seeded weight distribution, but `max_delta` should be O(0.1) (vs O(1e-6) baseline) and `y_pre.sum()` vs `y_post.sum()` should differ by O(1):

```python
import os, torch, torch.distributed as dist, bitsandbytes as bnb
os.environ.update({"MASTER_ADDR":"127.0.0.1","MASTER_PORT":"29501","RANK":"0","WORLD_SIZE":"1"})
dist.init_process_group(backend="gloo", world_size=1, rank=0)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.device_mesh import DeviceMesh

torch.manual_seed(42)
linear = bnb.nn.Linear4bit(512, 512, bias=False, quant_type='nf4',
                         compute_dtype=torch.bfloat16,
                         quant_storage=torch.bfloat16,
                         compress_statistics=True).to('cuda')

torch.manual_seed(123)
x = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda')
y_pre = linear(x).clone()  # baseline

mesh = DeviceMesh.from_group(dist.distributed_c10d._get_default_group(), "cuda")
mp = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
fully_shard(linear, mesh=mesh, mp_policy=mp)

y_post = linear(x)  # post-shard via vanilla forward

print(f"y_pre.sum  = {y_pre.sum().item():.6f}")
print(f"y_post.sum = {y_post.sum().item():.6f}")
print(f"max_delta  = {(y_pre - y_post).abs().max().item():.6e}")
# Expected: max_delta on the order of 1e-1 (not 1e-6).

Mechanism B reproducer (same setup, then):

from torch.distributed._tensor import Replicate  # or: from torch.distributed.tensor import Replicate on newer torch

linear.weight.full_tensor()                              # SIGSEGV
linear.weight.redistribute(placements=[Replicate()])     # SIGSEGV

Diagnostic chain (summary)

Twelve empirical tests on Jetson Orin Nano Super (sm_87, JetPack 6.2, torch 2.5.1 source-built USE_GLOO=1, bnb 0.46.1 source-built -DCOMPUTE_CAPABILITY=87) document the bug and the workaround:

# Test Outcome
1 Vanilla linear(x) post-shard at WS=1 Wrong (max_delta O(0.1–0.3) on synthetic 512×512, vs O(1e-6) baseline)
2 register_buffer("weight_absmax", ...) + fully_shard Buffer doesn't shard; FSDP2 replicates buffers
3 register_parameter("weight_absmax", ...) at wrap-time Wraps successfully (Shard(dim=0) on both weight + absmax)
4 Same as 3 + actually run forward AssertionError: FSDP expects uniform original parameter dtype
5 FSDP1 + same dual-Parameter ValueError: Must flatten tensors with uniform dtype (same constraint, different path)
6 FSDP1 (vanilla) at WS=1 Bit-identical to baseline (NO_SHARD at WS=1)
7 FSDP2 with mp_policy=MixedPrecisionPolicy(None, None), reshard_after_forward=False Same wrong output as vanilla — NOT mp_policy or reshard related
8 DTensor _local_tensor inspection post-shard Plain Tensor (not Params4bit); quant_state attribute lost; data_ptr differs from original
9 Direct byte comparison: original vs post-shard _local_tensor Bit-identical bytes (uint8 view); only IEEE NaN positions appear different
10 bnb.matmul_4bit(x, _local_tensor.t(), qs_cached) outside linear() Bit-identical to baseline — data + dispatch + kernel all correct in isolation
11 Cached-ref + manual all_gather at WS=1 Bit-identical forward + backward + 5-step training
12 Cached-ref at WS=2 with TinyLlama-1.1B + real PEFT + 154 Linear4bits Bit-identical to WS=1 baseline; both ranks identical; 3 training steps

The combination of 8 + 10 + 11 localizes the bug to FSDP2's parameter-swap mechanism: bytes are bit-identical via _local_tensor access; matmul on _local_tensor directly matches baseline; matmul via self.weight inside forward is wrong.

Loss values at WS=2 + TinyLlama matched bit-exactly in fp32 (delta = 0.0e+00) across ranks AND against WS=1 reference: [13.512944221496582, 13.435127258300781, 13.27534008026123].

Why FSDP1 works

At WS=1, FSDP1 emits "FSDP is switching to use NO_SHARD instead of ShardingStrategy.FULL_SHARD since the world size is 1" — the wrap is effectively a passthrough. At WS≥2, FSDP1's FlatParamHandle uses torch.chunk / torch.split for sharding; our reading is that these go through the Params4bit.__torch_function__ override (added by PR #1719 — "Fix Params4bit tensor subclass handling") which intercepts chunk / split and re-wraps the results as Params4bit (preserving quant_state, quant_type, etc.), whereas FSDP2's DTensor-based wrap appears to bypass that path. I have not stepped through DTensor's sharding code directly to confirm — happy to verify if maintainers find that useful.

This is consistent with why the bnb FSDP-QLoRA reference recipe still uses FSDP1.

Cross-references

  • PR #970 (MERGED) — original FSDP-QLoRA enablement (Answer.AI). Added quant_storage selection (uint8 / bf16 / fp16) so FSDP can shard Params4bit at all. The PR body notes that the sync_module_states memory optimization was left for future work — that's a separate axis from the runtime forward-correctness issue reported here.
  • PR #1719 (MERGED) — "Fix Params4bit tensor subclass handling". Added __torch_function__ override on Params4bit so torch.chunk / torch.split (the ops FSDP1's FlatParamHandle uses for sharding) re-wrap results back as Params4bit rather than returning plain Tensor. This is the protection FSDP2's DTensor wrap appears to bypass.
  • PR #1866 (MERGED) — added __getattr__ to Params4bit so FSDP state_dict traversal can resolve weight.absmax / weight.quant_map / weight.quant_state.bitsandbytes__* FQN paths.
  • PR #1916 (MERGED) — replaced 'Fix Params4bit attribute access for FSDP state_dict traversal #1866''s __getattr__ with @property descriptors to eliminate torch.compile graph breaks under activation checkpointing; FSDP state_dict traversal continues to work through the descriptor protocol.
  • State_dict serialization (covered by Fix Params4bit attribute access for FSDP state_dict traversal #1866 + Fix torch.compile graph breaks from Params4bit __getattr__ (#1904, #1917) #1916) works as expected; this issue concerns runtime forward.

What's actionable for bnb maintainers

The fix target is primarily in PyTorch FSDP2 / DTensor, not in bitsandbytes:

  1. Cross-link to a PyTorch core issue. I didn't bisected into FSDP2's internals; the empirical symptom (byte-level NaN normalization on the swap-target buffer) points to the parameter-storage path that allocates and initializes the swap-target — torch.distributed._composable.fsdp._fsdp_param.py (likely init_dtype_attrs or the storage allocation that copies bytes through a bf16-typed view). Possible fix shapes:

    • Read the bf16 Tensor-subclass buffer via .view(torch.uint8) (byte-true) rather than through any code path that triggers IEEE 754 NaN canonicalization.
    • Skip the canonicalization for parameters whose owning Tensor subclass overrides __torch_function__ (i.e., Params4bit explicitly opts out of float-semantic handling).

    I'm happy to open the PyTorch core issue when you confirm this framing — wanted to surface here first since you have the most context on the FSDP-QLoRA history.

  2. Document the cached-ref workaround in bnb's FSDP-QLoRA docs as the recommended path for FSDP2 users until the upstream PyTorch fix lands.

  3. (Optional contribution) Ship a bitsandbytes.fsdp2_utils module containing install_fsdp2_workaround(model) that walks Linear4bit modules and applies the patch — ~50 LoC, working implementation already exists. Happy to submit as a PR if you'd like it.

Test artifacts

All scripts are gist-linked above. Combined runtime is ~3 minutes on a Jetson Orin Nano Super at WS=1, ~5 minutes for the full matrix at WS=2. Happy to provide additional reproducers (FSDP1 cross-check, WS=4 cluster verification) on request.

Thanks for the maintenance work on bnb's FSDP integration so far — quant_storage=bf16 (PR #970) and the @property accessors (PR #1866 + #1916) made this kind of diagnostic possible.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions