add dot.rs; rm 4 dupe dot impl; migrate dot_32 call site; lean dispat…#29
add dot.rs; rm 4 dupe dot impl; migrate dot_32 call site; lean dispat…#29DrJesseGlass wants to merge 5 commits into
Conversation
…ch standard cpu_flash
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: eaa72f15a7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| s += q[i] * k[i] as f32 | ||
| + q[i + 1] * k[i + 1] as f32 | ||
| + q[i + 2] * k[i + 2] as f32 | ||
| + q[i + 3] * k[i + 3] as f32; |
There was a problem hiding this comment.
Preserve f64 precision in CPU flash-attn dot path
The new DotF32 for f64 implementation downcasts every k element to f32 before multiply/accumulate, which is a precision regression from the previous vec_dot::<T>(q_row, k_row).to_f32() behavior that accumulated in f64 and only cast once at the end. This materially changes logits for the DType::F64 CPU flash-attention path (used in candle-transformers/src/models/qwen3.rs), especially on larger head dimensions where rounding error compounds, and effectively removes the numerical benefit of the f64 path.
Useful? React with 👍 / 👎.
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2927574de3
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| for t in 0..d { | ||
| q_f32[t] = q_row[t].to_f32().unwrap_or(0.0); | ||
| } |
There was a problem hiding this comment.
Preserve f64 query precision in generic CPU flash path
The generic causal kernel now downcasts each q_row element to f32 before the dot product (q_f32[t] = q_row[t].to_f32()), so flash_attn::<f64> no longer computes q·k with full f64 query precision. In the previous implementation, vec_dot::<T>(q_row, k_row) kept f64 values through multiplication/accumulation and only converted after the dot. This is a numerical regression for real f64 callers (e.g. the CPU DType::F64 path), and can change attention weights for high-dynamic-range inputs.
Useful? React with 👍 / 👎.
…ch standard cpu_flash