Skip to content

Fix bf16 rounding to IEEE 754 ties-to-even#5648

Open
cyyever wants to merge 3 commits intopytorch:mainfrom
cyyever:fix/bf16-ties-to-even
Open

Fix bf16 rounding to IEEE 754 ties-to-even#5648
cyyever wants to merge 3 commits intopytorch:mainfrom
cyyever:fix/bf16-ties-to-even

Conversation

@cyyever
Copy link
Copy Markdown
Contributor

@cyyever cyyever commented Apr 16, 2026

No description provided.

@meta-cla meta-cla Bot added the cla signed label Apr 16, 2026
@cyyever cyyever force-pushed the fix/bf16-ties-to-even branch 5 times, most recently from 778ddfc to 5bb9e0e Compare April 16, 2026 08:25
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Apr 16, 2026

@q10 has imported this pull request. If you are a Meta employee, you can view this in D101141846.

@cyyever cyyever force-pushed the fix/bf16-ties-to-even branch 7 times, most recently from f83d6b6 to 24174a4 Compare April 16, 2026 13:03
@q10
Copy link
Copy Markdown
Contributor

q10 commented Apr 16, 2026

Could you provide some summary of the motivations for this change, and how it would impact perf and correctness on (on both ARM and x86)? It's highly likely that a decision was made a long time ago internally with regards to the rounding technique, and the decision was made with tradeoffs relevant to internal use cases in mind. Also, could you also provide microbenchmark reproducers with perf numbers?

@cyyever
Copy link
Copy Markdown
Contributor Author

cyyever commented Apr 17, 2026

@q10 We change the rounding mode to ties-to-even because it is IEEE 754 default behaviour, also note that CUDA implementation, PyTorch, NumPy use ties-to-even, FBGemm should be consistent with these libraries to avoid silent numerical precision issues. For example, the following snippet confirms the rounding mode.

import torch
import numpy as np


def tie_fp32(hi: int) -> float:
    """fp32 with upper 16 bits = hi, low 16 bits = 0x8000 (a rounding tie)."""
    return np.uint32((hi << 16) | 0x8000).view(np.float32).item()


def torch_bf16_bits(f):
    return torch.tensor([f]).to(torch.bfloat16).view(torch.uint16).item()


# Pairs of adjacent ties: first has even keep-bit (should round down),
# second has odd keep-bit (should round up). Both go to the even neighbor.
print(f"{'fp32 hi':>8}  {'keep LSB':>8}  {'expected':>8}   {'torch':>6}")
for hi in [0x3F80, 0x3F81, 0x4000, 0x4001, 0xBF80, 0xBF81]:
    f = tie_fp32(hi)
    expected = hi if (hi & 1) == 0 else hi + 1  # ties-to-even
    parity = "even" if (hi & 1) == 0 else "odd"
    print(f"  0x{hi:04x}  {parity:>8}    0x{expected:04x}   0x{torch_bf16_bits(f):04x}")
    assert torch_bf16_bits(f) == expected

A simple bench generated by LLM for AVX2 reported that

  ┌────────────────────────┬───────┬───────┬────────┐                                                          
  │      Workload (N)      │ main  │  PR   │    Δ   │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                          
  │ 256   scalar           │ 0.057 │ 0.096 │  +69 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                          
  │ 256   AVX2 bulk        │ 0.049 │ 0.084 │  +72 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                          
  │ 256   AVX2 asmjit-tail │ 0.088 │ 0.109 │  +24 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                          
  │ 4 K   scalar           │ 0.060 │ 0.099 │  +64 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                          
  │ 4 K   AVX2 bulk        │ 0.055 │ 0.087 │  +59 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                          
  │ 4 K   AVX2 asmjit-tail │ 0.087 │ 0.113 │  +30 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                  
  │ 64 K  scalar           │ 0.074 │ 0.093 │  +26 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                        
  │ 64 K  AVX2 bulk        │ 0.066 │ 0.089 │  +34 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                        
  │ 64 K  AVX2 asmjit-tail │ 0.080 │ 0.104 │  +31 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                        
  │ 1 M   scalar           │ 0.094 │ 0.095 │   +1 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                        
  │ 1 M   AVX2 bulk        │ 0.093 │ 0.094 │   +0 % │                                                          
  ├────────────────────────┼───────┼───────┼────────┤                                                        
  │ 1 M   AVX2 asmjit-tail │ 0.093 │ 0.104 │  +12 % │                                                          
  └────────────────────────┴───────┴───────┴────────┘                                                                
                                                            

because the new rounding mode requires 2 times more instructions than the old mode.

@cyyever cyyever force-pushed the fix/bf16-ties-to-even branch from b57f191 to 9d95dd5 Compare May 5, 2026 09:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants