You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
FSDP2.fully_shard produces incorrect forward output for modules containing bitsandbytes.nn.Linear4bit parameters. Two distinct mechanisms are involved; the primary one (Mechanism A) is the focus of this issue,
with the secondary one (Mechanism B) noted for completeness.
Mechanism A (NaN normalization in parameter-swap): post-fully_shard, FSDP2's parameter-swap mechanism reads bf16-stored Params4bit weight bytes through a float-aware code path that normalizes every NaN
bit pattern to a fixed quiet-NaN representation 0x7FFF (sign=0, exp=0xFF, mantissa=0x7F — quiet bit set, all other mantissa bits set). Because bnb stores packed-NF4 nibble indices in bf16-shaped containers,
~0.098% of weight elements coincidentally encode bf16 NaN bit patterns; FSDP2 normalizes them all to 0x7FFF. bnb.matmul_4bit then decodes the normalized bytes as different NF4 indices than the original,
producing wrong matmul output. Wrap-granularity invariant, reduce_dtype invariant, cast_forward_inputs invariant, cross-host deterministic (full empirical matrix below).
Mechanism B (segfault, secondary): weight.redistribute([Replicate]) and weight.full_tensor() SIGSEGV on bf16-packed Params4bit data. DTensor's gather-to-replicate path is broken for packed-NF4 storage.
Happy to file as a separate issue if preferred.
Both reproducible at WS=1 (single-rank, single-host) and WS=2 (multi-rank). FSDP1 with the canonical bnb FSDP-QLoRA recipe (use_orig_params=False) is not affected — empirically verified here at WS=1 (test 6 in
the diagnostic chain below; FSDP1 degrades to NO_SHARD), and indirectly corroborated at WS≥2 by the Answer.AI / HF Transformers / Axolotl reference recipes continuing to train correctly in production. I did not
run a direct FSDP1 cross-check at WS≥2 in this report.
Byte-level mechanism
Empirical forensic at TinyLlama-1.1B layers[0].self_attn.q_proj.base_layer.weight (1,048,576 bf16 elements = 2,097,152 bytes of packed NF4):
Metric
Value
Differing bytes (orig vs swap-target read inside forward)
1095 (0.0522%)
Differing bf16 elements
1030 (0.0982%)
Original class of every differing element
NaN (1030/1030 = 100%)
Post-swap bit pattern of every differing element
0x7FFF (a quiet NaN — not IEEE's recommended canonical 0x7FC0) (1030/1030 = 100%)
DTensor._local_tensor view vs original
0 byte diff (byte-true preserved)
Per-Linear max_abs_delta (swap vs byte-true cached-ref)
0.100586
Per-Linear mean_abs_delta
0.000313
Cross-host (two physically separate Jetson Orin Nano Super units, sm_87, JP6.2)
Bit-identical on every metric, including exact element positions
The 1095:1030 ratio (~1.06 bytes per differing element) implies ~65 elements have both bytes changed, consistent with negative NaNs (0xFFxx) being flipped to positive 0x7FFF in addition to mantissa
normalization. The remaining ~965 elements differ only in the low byte (mantissa normalization alone).
Sample differing elements (bit-identical across hosts):
Element
Original bits
Original class
Post-swap bits
580
0x7fea
qNaN, mant=0x6a
0x7fff
992
0x7f89
sNaN, mant=0x09
0x7fff
1077
0x7fb5
sNaN, mant=0x35
0x7fff
4574
0x7fd8
qNaN, mant=0x58
0x7fff
5188
0x7f88
sNaN, mant=0x08
0x7fff
The pattern is unambiguous: any bf16 NaN encoding gets normalized to a single fixed quiet-NaN pattern 0x7FFF, regardless of original sign or mantissa — consistent with a float-aware read path that quiets sNaNs
and clears non-quiet-bit mantissa state. The packed-NF4 nibble indices happened to encode NaN bit patterns; after normalization they encode different nibble indices.
Single-rank, single-host TinyLlama-1.1B + bnb-NF4 + PEFT-LoRA, forward-only loss on fixed-seed input. All non-failing cells produce loss_fsdp2 = 12.725160598754883 to 16 decimal places, vs baseline 12.691308975219727 (Δ=0.0339, 0.27% relative error):
Cell
Wrap
FSDP units
mp_policy
other
loss_fsdp2
A
per-DecoderLayer
23
bf16/fp32
—
12.725160598754883
C
per-DecoderLayer
23
bf16/bf16
—
12.725160598754883
D
per-Linear4bit
155
bf16/fp32
—
12.725160598754883
E
root-only
1
bf16/fp32
—
12.725160598754883
G
per-DecoderLayer
23
bf16/fp32
cast_forward_inputs=False
12.725160598754883
F
per-DecoderLayer
23
bf16/fp32
double_quant=False (different baseline 12.6973)
12.742940902709961
B
per-DecoderLayer
23
None
—
AttributeError in _init_mp_dtypes (FSDP2 requires MixedPrecisionPolicy instance)
Invariant in: wrap granularity (1, 23, 155 FSDP units → identical bit pattern), reduce_dtype (bf16 or fp32), cast_forward_inputs flag. Depends on:param_dtype = torch.bfloat16 — the only configuration FSDP2 supports for sharding bnb's bf16-quant-storage Params4bit, and the configuration that triggers the float-normalization read path. Note on double_quant: incidental; removing it amplifies the delta by ~35% (0.0339 → 0.0456) because the weight distribution shifts and hits slightly more NaN patterns; the bug is not caused by double-quant.
The following pattern produces bit-identical training results to pre-shard baseline at WS=1 AND WS=2, including with real models (TinyLlama-1.1B + real peft.LoraConfig + 154 Linear4bit modules):
# Step 1: BEFORE fully_shard, walk the model and capture each Linear4bit's# quant_state by module path (because fully_shard re-routes self.weight# through a swap-target Parameter that loses quant_state).importbitsandbytesasbnbimporttorchimporttorch.distributedasdistqs_cache= {}
forname, moduleinmodel.named_modules():
ifisinstance(module, bnb.nn.Linear4bit):
qs_cache[name] =module.weight.quant_state# Step 2: call fully_shard(...) on the model / decoder layers as usual.# Step 3: AFTER fully_shard, walk again and capture the byte-true# DTensor._local_tensor reference + the quant_state from step 1.forname, moduleinmodel.named_modules():
ifisinstance(module, bnb.nn.Linear4bit) andhasattr(module.weight, '_local_tensor'):
module._cached_local=module.weight._local_tensormodule._cached_qs=qs_cache[name]
# Step 4: replace Linear4bit.forward globally with a cached-ref-aware version._orig_forward=bnb.nn.Linear4bit.forwarddef_ws_aware_forward(self, x_in):
qs=getattr(self, '_cached_qs', None)
cached_local=getattr(self, '_cached_local', None)
ifqsisNoneorcached_localisNone:
return_orig_forward(self, x_in)
local=cached_local.contiguous()
ws=dist.get_world_size() ifdist.is_initialized() else1ifws>1:
# Manual all_gather — DTensor's full_tensor() / redistribute() SIGSEGVs# on bf16-packed Params4bit (Mechanism B above).gathered= [torch.empty_like(local) for_inrange(ws)]
dist.all_gather(gathered, local)
full_w=torch.cat(gathered, dim=0)
else:
full_w=localifx_in.dtype!=torch.bfloat16:
x_in=x_in.to(torch.bfloat16)
returnbnb.matmul_4bit(x_in, full_w.t(), bias=self.bias, quant_state=qs)
bnb.nn.Linear4bit.forward=_ws_aware_forward**Criticaldetail**: the`_local_tensor`referencemustbecapturedOUTSIDEforward (where`linear.weight`resolvestotheDTensor). Inside`Linear4bit.forward`, FSDP2'sswapmechanismreplaces`self.weight`withaParameterthatLACKSthe`_local_tensor`attributeandpointsatdifferentmemorythanwhat`_local_tensor`referencesexternally. Reading`self.weight._local_tensor`frominsideforwardwillsilentlyfail (AttributeError) orfallthroughtobrokenbehavior.
### Minimal reproducer30-linereproducerdemonstratingMechanismA (wrong-output) atWS=1.Numbersbelowareillustrative — yourexactvaluesdependontheseededweightdistribution, but`max_delta`shouldbeO(0.1) (vsO(1e-6) baseline) and`y_pre.sum()`vs`y_post.sum()`shoulddifferbyO(1):
```pythonimportos, torch, torch.distributedasdist, bitsandbytesasbnbos.environ.update({"MASTER_ADDR":"127.0.0.1","MASTER_PORT":"29501","RANK":"0","WORLD_SIZE":"1"})
dist.init_process_group(backend="gloo", world_size=1, rank=0)
fromtorch.distributed._composable.fsdpimportfully_shard, MixedPrecisionPolicyfromtorch.distributed.device_meshimportDeviceMeshtorch.manual_seed(42)
linear=bnb.nn.Linear4bit(512, 512, bias=False, quant_type='nf4',
compute_dtype=torch.bfloat16,
quant_storage=torch.bfloat16,
compress_statistics=True).to('cuda')
torch.manual_seed(123)
x=torch.randn(2, 512, dtype=torch.bfloat16, device='cuda')
y_pre=linear(x).clone() # baselinemesh=DeviceMesh.from_group(dist.distributed_c10d._get_default_group(), "cuda")
mp=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
fully_shard(linear, mesh=mesh, mp_policy=mp)
y_post=linear(x) # post-shard via vanilla forwardprint(f"y_pre.sum = {y_pre.sum().item():.6f}")
print(f"y_post.sum = {y_post.sum().item():.6f}")
print(f"max_delta = {(y_pre-y_post).abs().max().item():.6e}")
# Expected: max_delta on the order of 1e-1 (not 1e-6).
Mechanism B reproducer (same setup, then):
fromtorch.distributed._tensorimportReplicate# or: from torch.distributed.tensor import Replicate on newer torchlinear.weight.full_tensor() # SIGSEGVlinear.weight.redistribute(placements=[Replicate()]) # SIGSEGV
Diagnostic chain (summary)
Twelve empirical tests on Jetson Orin Nano Super (sm_87, JetPack 6.2, torch 2.5.1 source-built USE_GLOO=1, bnb 0.46.1 source-built -DCOMPUTE_CAPABILITY=87) document the bug and the workaround:
#
Test
Outcome
1
Vanilla linear(x) post-shard at WS=1
Wrong (max_delta O(0.1–0.3) on synthetic 512×512, vs O(1e-6) baseline)
Bit-identical to baseline — data + dispatch + kernel all correct in isolation
11
Cached-ref + manual all_gather at WS=1
Bit-identical forward + backward + 5-step training
12
Cached-ref at WS=2 with TinyLlama-1.1B + real PEFT + 154 Linear4bits
Bit-identical to WS=1 baseline; both ranks identical; 3 training steps
The combination of 8 + 10 + 11 localizes the bug to FSDP2's parameter-swap mechanism: bytes are bit-identical via _local_tensor access; matmul on _local_tensor directly matches baseline; matmul via self.weight inside forward is wrong.
Loss values at WS=2 + TinyLlama matched bit-exactly in fp32 (delta = 0.0e+00) across ranks AND against WS=1 reference: [13.512944221496582, 13.435127258300781, 13.27534008026123].
Why FSDP1 works
At WS=1, FSDP1 emits "FSDP is switching to use NO_SHARD instead of ShardingStrategy.FULL_SHARD since the world size is 1" — the wrap is effectively a passthrough. At WS≥2, FSDP1's FlatParamHandle uses torch.chunk / torch.split for sharding; our reading is that these go through the Params4bit.__torch_function__ override (added by PR #1719 — "Fix Params4bit tensor subclass handling") which intercepts chunk / split and re-wraps the results as Params4bit (preserving quant_state, quant_type, etc.), whereas FSDP2's DTensor-based wrap appears to bypass that path. I have not stepped through DTensor's sharding code directly to confirm — happy to verify if maintainers find that useful.
This is consistent with why the bnb FSDP-QLoRA reference recipe still uses FSDP1.
Cross-references
PR #970 (MERGED) — original FSDP-QLoRA enablement (Answer.AI). Added quant_storage selection (uint8 / bf16 / fp16) so FSDP can shard Params4bit at all. The PR body notes that the sync_module_states memory optimization was left for future work — that's a separate axis from the runtime forward-correctness issue reported here.
PR #1719 (MERGED) — "Fix Params4bit tensor subclass handling". Added __torch_function__ override on Params4bit so torch.chunk / torch.split (the ops FSDP1's FlatParamHandle uses for sharding) re-wrap results back as Params4bit rather than returning plain Tensor. This is the protection FSDP2's DTensor wrap appears to bypass.
PR #1866 (MERGED) — added __getattr__ to Params4bit so FSDP state_dict traversal can resolve weight.absmax / weight.quant_map / weight.quant_state.bitsandbytes__* FQN paths.
The fix target is primarily in PyTorch FSDP2 / DTensor, not in bitsandbytes:
Cross-link to a PyTorch core issue. I didn't bisected into FSDP2's internals; the empirical symptom (byte-level NaN normalization on the swap-target buffer) points to the parameter-storage path that allocates and initializes the swap-target — torch.distributed._composable.fsdp._fsdp_param.py (likely init_dtype_attrs or the storage allocation that copies bytes through a bf16-typed view). Possible fix shapes:
Read the bf16 Tensor-subclass buffer via .view(torch.uint8) (byte-true) rather than through any code path that triggers IEEE 754 NaN canonicalization.
Skip the canonicalization for parameters whose owning Tensor subclass overrides __torch_function__ (i.e., Params4bit explicitly opts out of float-semantic handling).
I'm happy to open the PyTorch core issue when you confirm this framing — wanted to surface here first since you have the most context on the FSDP-QLoRA history.
Document the cached-ref workaround in bnb's FSDP-QLoRA docs as the recommended path for FSDP2 users until the upstream PyTorch fix lands.
(Optional contribution) Ship a bitsandbytes.fsdp2_utils module containing install_fsdp2_workaround(model) that walks Linear4bit modules and applies the patch — ~50 LoC, working implementation already exists. Happy to submit as a PR if you'd like it.
Test artifacts
All scripts are gist-linked above. Combined runtime is ~3 minutes on a Jetson Orin Nano Super at WS=1, ~5 minutes for the full matrix at WS=2. Happy to provide additional reproducers (FSDP1 cross-check, WS=4 cluster verification) on request.
Thanks for the maintenance work on bnb's FSDP integration so far — quant_storage=bf16 (PR #970) and the @property accessors (PR #1866 + #1916) made this kind of diagnostic possible.
Summary
FSDP2.fully_shardproduces incorrect forward output for modules containingbitsandbytes.nn.Linear4bitparameters. Two distinct mechanisms are involved; the primary one (Mechanism A) is the focus of this issue,with the secondary one (Mechanism B) noted for completeness.
Mechanism A (NaN normalization in parameter-swap): post-
fully_shard, FSDP2's parameter-swap mechanism reads bf16-storedParams4bitweight bytes through a float-aware code path that normalizes every NaNbit pattern to a fixed quiet-NaN representation
0x7FFF(sign=0, exp=0xFF, mantissa=0x7F — quiet bit set, all other mantissa bits set). Because bnb stores packed-NF4 nibble indices in bf16-shaped containers,~0.098% of weight elements coincidentally encode bf16 NaN bit patterns; FSDP2 normalizes them all to
0x7FFF.bnb.matmul_4bitthen decodes the normalized bytes as different NF4 indices than the original,producing wrong matmul output. Wrap-granularity invariant,
reduce_dtypeinvariant,cast_forward_inputsinvariant, cross-host deterministic (full empirical matrix below).Mechanism B (segfault, secondary):
weight.redistribute([Replicate])andweight.full_tensor()SIGSEGV on bf16-packedParams4bitdata. DTensor's gather-to-replicate path is broken for packed-NF4 storage.Happy to file as a separate issue if preferred.
Both reproducible at WS=1 (single-rank, single-host) and WS=2 (multi-rank). FSDP1 with the canonical bnb FSDP-QLoRA recipe (
use_orig_params=False) is not affected — empirically verified here at WS=1 (test 6 inthe diagnostic chain below; FSDP1 degrades to NO_SHARD), and indirectly corroborated at WS≥2 by the Answer.AI / HF Transformers / Axolotl reference recipes continuing to train correctly in production. I did not
run a direct FSDP1 cross-check at WS≥2 in this report.
Byte-level mechanism
Empirical forensic at TinyLlama-1.1B
layers[0].self_attn.q_proj.base_layer.weight(1,048,576 bf16 elements = 2,097,152 bytes of packed NF4):0x7FFF(a quiet NaN — not IEEE's recommended canonical0x7FC0) (1030/1030 = 100%)DTensor._local_tensorview vs originalThe 1095:1030 ratio (~1.06 bytes per differing element) implies ~65 elements have both bytes changed, consistent with negative NaNs (
0xFFxx) being flipped to positive0x7FFFin addition to mantissanormalization. The remaining ~965 elements differ only in the low byte (mantissa normalization alone).
Sample differing elements (bit-identical across hosts):
The pattern is unambiguous: any bf16 NaN encoding gets normalized to a single fixed quiet-NaN pattern
0x7FFF, regardless of original sign or mantissa — consistent with a float-aware read path that quiets sNaNsand clears non-quiet-bit mantissa state. The packed-NF4 nibble indices happened to encode NaN bit patterns; after normalization they encode different nibble indices.
Full forensic script: https://gist.github.com/neil-the-nowledgeable/d5bbffd9bd83029314771d9f46472cb2
Configuration-invariance matrix
Single-rank, single-host TinyLlama-1.1B + bnb-NF4 + PEFT-LoRA, forward-only loss on fixed-seed input. All non-failing cells produce
loss_fsdp2 = 12.725160598754883to 16 decimal places, vs baseline12.691308975219727(Δ=0.0339, 0.27% relative error):loss_fsdp2cast_forward_inputs=Falsedouble_quant=False(different baseline 12.6973)AttributeErrorin_init_mp_dtypes(FSDP2 requiresMixedPrecisionPolicyinstance)Invariant in: wrap granularity (1, 23, 155 FSDP units → identical bit pattern),
reduce_dtype(bf16 or fp32),cast_forward_inputsflag.Depends on:
param_dtype = torch.bfloat16— the only configuration FSDP2 supports for sharding bnb's bf16-quant-storage Params4bit, and the configuration that triggers the float-normalization read path.Note on
double_quant: incidental; removing it amplifies the delta by ~35% (0.0339 → 0.0456) because the weight distribution shifts and hits slightly more NaN patterns; the bug is not caused by double-quant.Full matrix script: https://gist.github.com/neil-the-nowledgeable/24412f2da39880e4fb0570198aca442e
Empirical workaround (proof-by-construction)
The following pattern produces bit-identical training results to pre-shard baseline at WS=1 AND WS=2, including with real models (TinyLlama-1.1B + real
peft.LoraConfig+ 154 Linear4bit modules):Mechanism B reproducer (same setup, then):
Diagnostic chain (summary)
Twelve empirical tests on Jetson Orin Nano Super (sm_87, JetPack 6.2, torch 2.5.1 source-built
USE_GLOO=1, bnb 0.46.1 source-built-DCOMPUTE_CAPABILITY=87) document the bug and the workaround:linear(x)post-shard at WS=1register_buffer("weight_absmax", ...)+ fully_shardregister_parameter("weight_absmax", ...)at wrap-timeAssertionError: FSDP expects uniform original parameter dtypeValueError: Must flatten tensors with uniform dtype(same constraint, different path)mp_policy=MixedPrecisionPolicy(None, None),reshard_after_forward=False_local_tensorinspection post-shardquant_stateattribute lost;data_ptrdiffers from original_local_tensorbnb.matmul_4bit(x, _local_tensor.t(), qs_cached)outsidelinear()The combination of 8 + 10 + 11 localizes the bug to FSDP2's parameter-swap mechanism: bytes are bit-identical via
_local_tensoraccess; matmul on_local_tensordirectly matches baseline; matmul viaself.weightinside forward is wrong.Loss values at WS=2 + TinyLlama matched bit-exactly in fp32 (
delta = 0.0e+00) across ranks AND against WS=1 reference:[13.512944221496582, 13.435127258300781, 13.27534008026123].Why FSDP1 works
At WS=1, FSDP1 emits
"FSDP is switching to use NO_SHARD instead of ShardingStrategy.FULL_SHARD since the world size is 1"— the wrap is effectively a passthrough. At WS≥2, FSDP1'sFlatParamHandleusestorch.chunk/torch.splitfor sharding; our reading is that these go through theParams4bit.__torch_function__override (added by PR #1719 — "Fix Params4bit tensor subclass handling") which interceptschunk/splitand re-wraps the results asParams4bit(preservingquant_state,quant_type, etc.), whereas FSDP2's DTensor-based wrap appears to bypass that path. I have not stepped through DTensor's sharding code directly to confirm — happy to verify if maintainers find that useful.This is consistent with why the bnb FSDP-QLoRA reference recipe still uses FSDP1.
Cross-references
quant_storageselection (uint8/bf16/fp16) so FSDP can shardParams4bitat all. The PR body notes that thesync_module_statesmemory optimization was left for future work — that's a separate axis from the runtime forward-correctness issue reported here.__torch_function__override onParams4bitsotorch.chunk/torch.split(the ops FSDP1'sFlatParamHandleuses for sharding) re-wrap results back asParams4bitrather than returning plainTensor. This is the protection FSDP2's DTensor wrap appears to bypass.__getattr__toParams4bitso FSDPstate_dicttraversal can resolveweight.absmax/weight.quant_map/weight.quant_state.bitsandbytes__*FQN paths.__getattr__with@propertydescriptors to eliminatetorch.compilegraph breaks under activation checkpointing; FSDPstate_dicttraversal continues to work through the descriptor protocol.What's actionable for bnb maintainers
The fix target is primarily in PyTorch FSDP2 / DTensor, not in bitsandbytes:
Cross-link to a PyTorch core issue. I didn't bisected into FSDP2's internals; the empirical symptom (byte-level NaN normalization on the swap-target buffer) points to the parameter-storage path that allocates and initializes the swap-target —
torch.distributed._composable.fsdp._fsdp_param.py(likelyinit_dtype_attrsor the storage allocation that copies bytes through a bf16-typed view). Possible fix shapes:.view(torch.uint8)(byte-true) rather than through any code path that triggers IEEE 754 NaN canonicalization.__torch_function__(i.e., Params4bit explicitly opts out of float-semantic handling).I'm happy to open the PyTorch core issue when you confirm this framing — wanted to surface here first since you have the most context on the FSDP-QLoRA history.
Document the cached-ref workaround in bnb's FSDP-QLoRA docs as the recommended path for FSDP2 users until the upstream PyTorch fix lands.
(Optional contribution) Ship a
bitsandbytes.fsdp2_utilsmodule containinginstall_fsdp2_workaround(model)that walks Linear4bit modules and applies the patch — ~50 LoC, working implementation already exists. Happy to submit as a PR if you'd like it.Test artifacts
All scripts are gist-linked above. Combined runtime is ~3 minutes on a Jetson Orin Nano Super at WS=1, ~5 minutes for the full matrix at WS=2. Happy to provide additional reproducers (FSDP1 cross-check, WS=4 cluster verification) on request.
Thanks for the maintenance work on bnb's FSDP integration so far —
quant_storage=bf16(PR #970) and the@propertyaccessors (PR #1866 + #1916) made this kind of diagnostic possible.