Skip to content

FSDP2 fully_shard breaks Linear4bit forward via NaN canonicalization of packed-NF4 bf16 storage #1945

@neil-the-nowledgeable

Description

@neil-the-nowledgeable

Summary

FSDP2.fully_shard produces incorrect forward output for modules containing bitsandbytes.nn.Linear4bit parameters. Two distinct mechanisms are involved; the primary one (Mechanism A) is the focus of this issue,
with the secondary one (Mechanism B) noted for completeness.

  1. Mechanism A (NaN normalization in parameter-swap): post-fully_shard, FSDP2's parameter-swap mechanism reads bf16-stored Params4bit weight bytes through a float-aware code path that normalizes every NaN
    bit pattern to a fixed quiet-NaN representation 0x7FFF
    (sign=0, exp=0xFF, mantissa=0x7F — quiet bit set, all other mantissa bits set). Because bnb stores packed-NF4 nibble indices in bf16-shaped containers,
    ~0.098% of weight elements coincidentally encode bf16 NaN bit patterns; FSDP2 normalizes them all to 0x7FFF. bnb.matmul_4bit then decodes the normalized bytes as different NF4 indices than the original,
    producing wrong matmul output. Wrap-granularity invariant, reduce_dtype invariant, cast_forward_inputs invariant, cross-host deterministic (full empirical matrix below).

  2. Mechanism B (segfault, secondary): weight.redistribute([Replicate]) and weight.full_tensor() SIGSEGV on bf16-packed Params4bit data. DTensor's gather-to-replicate path is broken for packed-NF4 storage.
    Happy to file as a separate issue if preferred.

Both reproducible at WS=1 (single-rank, single-host) and WS=2 (multi-rank). FSDP1 with the canonical bnb FSDP-QLoRA recipe (use_orig_params=False) is not affected — empirically verified here at WS=1 (test 6 in
the diagnostic chain below; FSDP1 degrades to NO_SHARD), and indirectly corroborated at WS≥2 by the Answer.AI / HF Transformers / Axolotl reference recipes continuing to train correctly in production. I did not
run a direct FSDP1 cross-check at WS≥2 in this report.

Byte-level mechanism

Empirical forensic at TinyLlama-1.1B layers[0].self_attn.q_proj.base_layer.weight (1,048,576 bf16 elements = 2,097,152 bytes of packed NF4):

Metric Value
Differing bytes (orig vs swap-target read inside forward) 1095 (0.0522%)
Differing bf16 elements 1030 (0.0982%)
Original class of every differing element NaN (1030/1030 = 100%)
Post-swap bit pattern of every differing element 0x7FFF (a quiet NaN — not IEEE's recommended canonical 0x7FC0) (1030/1030 = 100%)
DTensor._local_tensor view vs original 0 byte diff (byte-true preserved)
Per-Linear max_abs_delta (swap vs byte-true cached-ref) 0.100586
Per-Linear mean_abs_delta 0.000313
Cross-host (two physically separate Jetson Orin Nano Super units, sm_87, JP6.2) Bit-identical on every metric, including exact element positions

The 1095:1030 ratio (~1.06 bytes per differing element) implies ~65 elements have both bytes changed, consistent with negative NaNs (0xFFxx) being flipped to positive 0x7FFF in addition to mantissa
normalization. The remaining ~965 elements differ only in the low byte (mantissa normalization alone).

Sample differing elements (bit-identical across hosts):

Element Original bits Original class Post-swap bits
580 0x7fea qNaN, mant=0x6a 0x7fff
992 0x7f89 sNaN, mant=0x09 0x7fff
1077 0x7fb5 sNaN, mant=0x35 0x7fff
4574 0x7fd8 qNaN, mant=0x58 0x7fff
5188 0x7f88 sNaN, mant=0x08 0x7fff

The pattern is unambiguous: any bf16 NaN encoding gets normalized to a single fixed quiet-NaN pattern 0x7FFF, regardless of original sign or mantissa — consistent with a float-aware read path that quiets sNaNs
and clears non-quiet-bit mantissa state. The packed-NF4 nibble indices happened to encode NaN bit patterns; after normalization they encode different nibble indices.

Full forensic script: https://gist.github.com/neil-the-nowledgeable/d5bbffd9bd83029314771d9f46472cb2

Configuration-invariance matrix

Single-rank, single-host TinyLlama-1.1B + bnb-NF4 + PEFT-LoRA, forward-only loss on fixed-seed input. All non-failing cells produce loss_fsdp2 = 12.725160598754883 to 16 decimal places, vs baseline
12.691308975219727 (Δ=0.0339, 0.27% relative error):

Cell Wrap FSDP units mp_policy other loss_fsdp2
A per-DecoderLayer 23 bf16/fp32 12.725160598754883
C per-DecoderLayer 23 bf16/bf16 12.725160598754883
D per-Linear4bit 155 bf16/fp32 12.725160598754883
E root-only 1 bf16/fp32 12.725160598754883
G per-DecoderLayer 23 bf16/fp32 cast_forward_inputs=False 12.725160598754883
F per-DecoderLayer 23 bf16/fp32 double_quant=False (different baseline 12.6973) 12.742940902709961
B per-DecoderLayer 23 None AttributeError in _init_mp_dtypes (FSDP2 requires MixedPrecisionPolicy instance)

Invariant in: wrap granularity (1, 23, 155 FSDP units → identical bit pattern), reduce_dtype (bf16 or fp32), cast_forward_inputs flag.
Depends on: param_dtype = torch.bfloat16 — the only configuration FSDP2 supports for sharding bnb's bf16-quant-storage Params4bit, and the configuration that triggers the float-normalization read path.
Note on double_quant: incidental; removing it amplifies the delta by ~35% (0.0339 → 0.0456) because the weight distribution shifts and hits slightly more NaN patterns; the bug is not caused by double-quant.

Full matrix script: https://gist.github.com/neil-the-nowledgeable/24412f2da39880e4fb0570198aca442e

Empirical workaround (proof-by-construction)

The following pattern produces bit-identical training results to pre-shard baseline at WS=1 AND WS=2, including with real models (TinyLlama-1.1B + real peft.LoraConfig + 154 Linear4bit modules):

# Step 1: BEFORE fully_shard, walk the model and capture each Linear4bit's
# quant_state by module path (because fully_shard re-routes self.weight
# through a swap-target Parameter that loses quant_state).
import bitsandbytes as bnb
import torch
import torch.distributed as dist

qs_cache = {}
for name, module in model.named_modules():
    if isinstance(module, bnb.nn.Linear4bit):
        qs_cache[name] = module.weight.quant_state

# Step 2: call fully_shard(...) on the model / decoder layers as usual.

# Step 3: AFTER fully_shard, walk again and capture the byte-true
# DTensor._local_tensor reference + the quant_state from step 1.
for name, module in model.named_modules():
    if isinstance(module, bnb.nn.Linear4bit) and hasattr(module.weight, '_local_tensor'):
        module._cached_local = module.weight._local_tensor
        module._cached_qs = qs_cache[name]

# Step 4: replace Linear4bit.forward globally with a cached-ref-aware version.
_orig_forward = bnb.nn.Linear4bit.forward

def _ws_aware_forward(self, x_in):
    qs = getattr(self, '_cached_qs', None)
    cached_local = getattr(self, '_cached_local', None)
    if qs is None or cached_local is None:
        return _orig_forward(self, x_in)
    local = cached_local.contiguous()
    ws = dist.get_world_size() if dist.is_initialized() else 1
    if ws > 1:
        # Manual all_gather — DTensor's full_tensor() / redistribute() SIGSEGVs
        # on bf16-packed Params4bit (Mechanism B above).
        gathered = [torch.empty_like(local) for _ in range(ws)]
        dist.all_gather(gathered, local)
        full_w = torch.cat(gathered, dim=0)
    else:
        full_w = local
    if x_in.dtype != torch.bfloat16:
        x_in = x_in.to(torch.bfloat16)
    return bnb.matmul_4bit(x_in, full_w.t(), bias=self.bias, quant_state=qs)

bnb.nn.Linear4bit.forward = _ws_aware_forward


**Critical detail**: the `_local_tensor` reference must be captured OUTSIDE forward (where `linear.weight` resolves to the DTensor). Inside `Linear4bit.forward`, FSDP2's swap mechanism replaces `self.weight` with a Parameter that LACKS the `_local_tensor` attribute and points at different memory than what `_local_tensor` references externally. Reading `self.weight._local_tensor` from inside forward will silently fail (AttributeError) or fall through to broken behavior.

### Minimal reproducer

30-line reproducer demonstrating Mechanism A (wrong-output) at WS=1. Numbers below are illustrativeyour exact values depend on the seeded weight distribution, but `max_delta` should be O(0.1) (vs O(1e-6) baseline) and `y_pre.sum()` vs `y_post.sum()` should differ by O(1):

```python
import os, torch, torch.distributed as dist, bitsandbytes as bnb
os.environ.update({"MASTER_ADDR":"127.0.0.1","MASTER_PORT":"29501","RANK":"0","WORLD_SIZE":"1"})
dist.init_process_group(backend="gloo", world_size=1, rank=0)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.device_mesh import DeviceMesh

torch.manual_seed(42)
linear = bnb.nn.Linear4bit(512, 512, bias=False, quant_type='nf4',
                         compute_dtype=torch.bfloat16,
                         quant_storage=torch.bfloat16,
                         compress_statistics=True).to('cuda')

torch.manual_seed(123)
x = torch.randn(2, 512, dtype=torch.bfloat16, device='cuda')
y_pre = linear(x).clone()  # baseline

mesh = DeviceMesh.from_group(dist.distributed_c10d._get_default_group(), "cuda")
mp = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
fully_shard(linear, mesh=mesh, mp_policy=mp)

y_post = linear(x)  # post-shard via vanilla forward

print(f"y_pre.sum  = {y_pre.sum().item():.6f}")
print(f"y_post.sum = {y_post.sum().item():.6f}")
print(f"max_delta  = {(y_pre - y_post).abs().max().item():.6e}")
# Expected: max_delta on the order of 1e-1 (not 1e-6).

Mechanism B reproducer (same setup, then):

from torch.distributed._tensor import Replicate  # or: from torch.distributed.tensor import Replicate on newer torch

linear.weight.full_tensor()                              # SIGSEGV
linear.weight.redistribute(placements=[Replicate()])     # SIGSEGV

Diagnostic chain (summary)

Twelve empirical tests on Jetson Orin Nano Super (sm_87, JetPack 6.2, torch 2.5.1 source-built USE_GLOO=1, bnb 0.46.1 source-built -DCOMPUTE_CAPABILITY=87) document the bug and the workaround:

# Test Outcome
1 Vanilla linear(x) post-shard at WS=1 Wrong (max_delta O(0.1–0.3) on synthetic 512×512, vs O(1e-6) baseline)
2 register_buffer("weight_absmax", ...) + fully_shard Buffer doesn't shard; FSDP2 replicates buffers
3 register_parameter("weight_absmax", ...) at wrap-time Wraps successfully (Shard(dim=0) on both weight + absmax)
4 Same as 3 + actually run forward AssertionError: FSDP expects uniform original parameter dtype
5 FSDP1 + same dual-Parameter ValueError: Must flatten tensors with uniform dtype (same constraint, different path)
6 FSDP1 (vanilla) at WS=1 Bit-identical to baseline (NO_SHARD at WS=1)
7 FSDP2 with mp_policy=MixedPrecisionPolicy(None, None), reshard_after_forward=False Same wrong output as vanilla — NOT mp_policy or reshard related
8 DTensor _local_tensor inspection post-shard Plain Tensor (not Params4bit); quant_state attribute lost; data_ptr differs from original
9 Direct byte comparison: original vs post-shard _local_tensor Bit-identical bytes (uint8 view); only IEEE NaN positions appear different
10 bnb.matmul_4bit(x, _local_tensor.t(), qs_cached) outside linear() Bit-identical to baseline — data + dispatch + kernel all correct in isolation
11 Cached-ref + manual all_gather at WS=1 Bit-identical forward + backward + 5-step training
12 Cached-ref at WS=2 with TinyLlama-1.1B + real PEFT + 154 Linear4bits Bit-identical to WS=1 baseline; both ranks identical; 3 training steps

The combination of 8 + 10 + 11 localizes the bug to FSDP2's parameter-swap mechanism: bytes are bit-identical via _local_tensor access; matmul on _local_tensor directly matches baseline; matmul via self.weight inside forward is wrong.

Loss values at WS=2 + TinyLlama matched bit-exactly in fp32 (delta = 0.0e+00) across ranks AND against WS=1 reference: [13.512944221496582, 13.435127258300781, 13.27534008026123].

Why FSDP1 works

At WS=1, FSDP1 emits "FSDP is switching to use NO_SHARD instead of ShardingStrategy.FULL_SHARD since the world size is 1" — the wrap is effectively a passthrough. At WS≥2, FSDP1's FlatParamHandle uses torch.chunk / torch.split for sharding; our reading is that these go through the Params4bit.__torch_function__ override (added by PR #1719 — "Fix Params4bit tensor subclass handling") which intercepts chunk / split and re-wraps the results as Params4bit (preserving quant_state, quant_type, etc.), whereas FSDP2's DTensor-based wrap appears to bypass that path. I have not stepped through DTensor's sharding code directly to confirm — happy to verify if maintainers find that useful.

This is consistent with why the bnb FSDP-QLoRA reference recipe still uses FSDP1.

Cross-references

  • PR #970 (MERGED) — original FSDP-QLoRA enablement (Answer.AI). Added quant_storage selection (uint8 / bf16 / fp16) so FSDP can shard Params4bit at all. The PR body notes that the sync_module_states memory optimization was left for future work — that's a separate axis from the runtime forward-correctness issue reported here.
  • PR #1719 (MERGED) — "Fix Params4bit tensor subclass handling". Added __torch_function__ override on Params4bit so torch.chunk / torch.split (the ops FSDP1's FlatParamHandle uses for sharding) re-wrap results back as Params4bit rather than returning plain Tensor. This is the protection FSDP2's DTensor wrap appears to bypass.
  • PR #1866 (MERGED) — added __getattr__ to Params4bit so FSDP state_dict traversal can resolve weight.absmax / weight.quant_map / weight.quant_state.bitsandbytes__* FQN paths.
  • PR #1916 (MERGED) — replaced 'Fix Params4bit attribute access for FSDP state_dict traversal #1866''s __getattr__ with @property descriptors to eliminate torch.compile graph breaks under activation checkpointing; FSDP state_dict traversal continues to work through the descriptor protocol.
  • State_dict serialization (covered by Fix Params4bit attribute access for FSDP state_dict traversal #1866 + Fix torch.compile graph breaks from Params4bit __getattr__ (#1904, #1917) #1916) works as expected; this issue concerns runtime forward.

What's actionable for bnb maintainers

The fix target is primarily in PyTorch FSDP2 / DTensor, not in bitsandbytes:

  1. Cross-link to a PyTorch core issue. I didn't bisected into FSDP2's internals; the empirical symptom (byte-level NaN normalization on the swap-target buffer) points to the parameter-storage path that allocates and initializes the swap-target — torch.distributed._composable.fsdp._fsdp_param.py (likely init_dtype_attrs or the storage allocation that copies bytes through a bf16-typed view). Possible fix shapes:

    • Read the bf16 Tensor-subclass buffer via .view(torch.uint8) (byte-true) rather than through any code path that triggers IEEE 754 NaN canonicalization.
    • Skip the canonicalization for parameters whose owning Tensor subclass overrides __torch_function__ (i.e., Params4bit explicitly opts out of float-semantic handling).

    I'm happy to open the PyTorch core issue when you confirm this framing — wanted to surface here first since you have the most context on the FSDP-QLoRA history.

  2. Document the cached-ref workaround in bnb's FSDP-QLoRA docs as the recommended path for FSDP2 users until the upstream PyTorch fix lands.

  3. (Optional contribution) Ship a bitsandbytes.fsdp2_utils module containing install_fsdp2_workaround(model) that walks Linear4bit modules and applies the patch — ~50 LoC, working implementation already exists. Happy to submit as a PR if you'd like it.

Test artifacts

All scripts are gist-linked above. Combined runtime is ~3 minutes on a Jetson Orin Nano Super at WS=1, ~5 minutes for the full matrix at WS=2. Happy to provide additional reproducers (FSDP1 cross-check, WS=4 cluster verification) on request.

Thanks for the maintenance work on bnb's FSDP integration so far — quant_storage=bf16 (PR #970) and the @property accessors (PR #1866 + #1916) made this kind of diagnostic possible.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions