From 85dd92fe212cee7b2e5a8fc618397d01d230e348 Mon Sep 17 00:00:00 2001 From: yash Date: Mon, 16 Mar 2026 12:11:55 +0530 Subject: [PATCH] adds support for smooth-{2,3,13} NTT domains --- src/algebra/fields.rs | 2 + src/algebra/linear_form/covector.rs | 5 +- .../linear_form/multilinear_extension.rs | 7 +- .../linear_form/univariate_evaluation.rs | 50 +- src/algebra/ntt/cooley_tukey.rs | 408 +++++++++++++++- src/algebra/ntt/utils.rs | 98 ++-- src/lib.rs | 1 + src/protocols/challenge_indices.rs | 86 +++- src/protocols/matrix_commit.rs | 44 +- src/protocols/merkle_tree.rs | 452 ++++++++++++++---- src/protocols/sumcheck.rs | 6 +- src/protocols/whir/config.rs | 55 ++- src/protocols/whir/mod.rs | 56 ++- src/protocols/whir/prover.rs | 14 +- src/protocols/whir/verifier.rs | 17 +- src/protocols/whir_zk/utils.rs | 10 +- src/protocols/whir_zk/verifier.rs | 7 +- src/smooth_domain.rs | 290 +++++++++++ src/utils.rs | 12 +- 19 files changed, 1431 insertions(+), 189 deletions(-) create mode 100644 src/smooth_domain.rs diff --git a/src/algebra/fields.rs b/src/algebra/fields.rs index 0675d19d..fd15b397 100644 --- a/src/algebra/fields.rs +++ b/src/algebra/fields.rs @@ -65,6 +65,8 @@ impl TypeInfo for F { #[derive(MontConfig)] #[modulus = "21888242871839275222246405745257275088548364400416034343698204186575808495617"] #[generator = "5"] +#[small_subgroup_base = "3"] +#[small_subgroup_power = "2"] pub struct BN254Config; pub type Field256 = Fp256>; diff --git a/src/algebra/linear_form/covector.rs b/src/algebra/linear_form/covector.rs index f6ef866b..c28b4575 100644 --- a/src/algebra/linear_form/covector.rs +++ b/src/algebra/linear_form/covector.rs @@ -15,7 +15,10 @@ impl LinearForm for Covector { } fn mle_evaluate(&self, point: &[F]) -> F { - multilinear_extend(&self.vector, point) + let k = self.vector.len().trailing_zeros() as usize; + let extra = point.len().saturating_sub(k); + let head_factor: F = point[..extra].iter().map(|p| F::ONE - *p).product::(); + head_factor * multilinear_extend(&self.vector, &point[extra..]) } fn accumulate(&self, accumulator: &mut [F], scalar: F) { diff --git a/src/algebra/linear_form/multilinear_extension.rs b/src/algebra/linear_form/multilinear_extension.rs index 643d9922..e943ef30 100644 --- a/src/algebra/linear_form/multilinear_extension.rs +++ b/src/algebra/linear_form/multilinear_extension.rs @@ -30,9 +30,12 @@ impl LinearForm for MultilinearExtension { } fn mle_evaluate(&self, point: &[F]) -> F { - zip_strict(&self.point, point).fold(F::ONE, |acc, (&l, &r)| { + let extra = point.len().saturating_sub(self.point.len()); + let head_factor: F = point[..extra].iter().map(|p| F::ONE - *p).product::(); + let eq_value = zip_strict(&self.point, &point[extra..]).fold(F::ONE, |acc, (&l, &r)| { acc * (l * r + (F::ONE - l) * (F::ONE - r)) - }) + }); + head_factor * eq_value } fn accumulate(&self, accumulator: &mut [F], scalar: F) { diff --git a/src/algebra/linear_form/univariate_evaluation.rs b/src/algebra/linear_form/univariate_evaluation.rs index 4d565198..d084a668 100644 --- a/src/algebra/linear_form/univariate_evaluation.rs +++ b/src/algebra/linear_form/univariate_evaluation.rs @@ -38,21 +38,59 @@ impl UnivariateEvaluation { } } +/// Lagrange basis polynomial L_index(point) on {0,1}^n (MSB-first). +fn lagrange_basis_single(point: &[F], index: usize) -> F { + let n = point.len(); + point.iter().enumerate().fold(F::ONE, |acc, (j, &r)| { + if (index >> (n - 1 - j)) & 1 == 1 { + acc * r + } else { + acc * (F::ONE - r) + } + }) +} + impl LinearForm for UnivariateEvaluation { fn size(&self) -> usize { self.size } fn mle_evaluate(&self, point: &[F]) -> F { - // Multilinear extension of (1, x, x^2, ..) = ⨂_i (1, x^2^i). + let k = self.size.trailing_zeros() as usize; + let extra = point.len().saturating_sub(k); + + if extra == 0 { + // Power-of-2 path: MLE of (1, x, x^2, ..) = ⊗_i (1, x^{2^i}). + let mut x2i = self.point; + let mut result = F::ONE; + for &r in point.iter().rev() { + result *= (F::ONE - r) + r * x2i; + x2i.square_in_place(); + } + return result; + } + + // Smooth path: size = 2^k * odd where odd = 3^b * 13^c. + let odd = self.size >> k; + let leading = &point[..extra]; + let trailing = &point[extra..]; + let mut x2i = self.point; - let mut result = F::ONE; - for &r in point.iter().rev() { - // TODO: Why rev? - result *= (F::ONE - r) + r * x2i; + let mut inner = F::ONE; + for &r in trailing.iter().rev() { + inner *= (F::ONE - r) + r * x2i; x2i.square_in_place(); } - result + + let x_pow_2k = x2i; + let mut outer = F::ZERO; + let mut x_h = F::ONE; + for h in 0..odd { + outer += lagrange_basis_single(leading, h) * x_h; + x_h *= x_pow_2k; + } + + outer * inner } /// See also [`Self::accumulate_many`] for a more efficient batched version. diff --git a/src/algebra/ntt/cooley_tukey.rs b/src/algebra/ntt/cooley_tukey.rs index 74cc5bdb..68530bdf 100644 --- a/src/algebra/ntt/cooley_tukey.rs +++ b/src/algebra/ntt/cooley_tukey.rs @@ -40,6 +40,12 @@ pub struct NttEngine { omega_16_3: F, omega_16_9: F, + // Winograd constants for the 13-point DFT: + // C_m = (ω₁₃^m + ω₁₃^{13-m}) / 2 for m = 1..6 + // S_m = (ω₁₃^m - ω₁₃^{13-m}) / 2 for m = 1..6 + omega_13_c: [F; 6], + omega_13_s: [F; 6], + // Root lookup table (extended on demand) roots: RwLock>, } @@ -86,8 +92,29 @@ impl NttEngine { /// Construct a new engine from the field's `FftField` trait. pub(crate) fn new_from_fftfield() -> Self { - // TODO: Support SMALL_SUBGROUP - if F::TWO_ADICITY <= 63 { + if let Some(large_root) = F::LARGE_SUBGROUP_ROOT_OF_UNITY { + let q = F::SMALL_SUBGROUP_BASE.unwrap() as usize; + let q_adicity = F::SMALL_SUBGROUP_BASE_ADICITY.unwrap(); + let two_adicity = F::TWO_ADICITY.min(63) as u32; + + let mut order = 1usize << two_adicity; + for _ in 0..q_adicity { + order = order + .checked_mul(q) + .expect("Mixed-radix NTT order overflows usize"); + } + + let mut generator = large_root; + if F::TWO_ADICITY > 63 { + for _ in 0..(F::TWO_ADICITY - 63) { + generator = generator.square(); + } + } + + let (order, generator) = try_extend_order_with_primes(order, generator, &[13]); + + Self::new(order, generator) + } else if F::TWO_ADICITY <= 63 { Self::new(1 << F::TWO_ADICITY, F::TWO_ADIC_ROOT_OF_UNITY) } else { let mut generator = F::TWO_ADIC_ROOT_OF_UNITY; @@ -99,11 +126,107 @@ impl NttEngine { } } +/// Extend the NTT order by extra small primes dividing |F*| beyond `order`. +fn try_extend_order_with_primes( + mut order: usize, + mut root: F, + extra_primes: &[usize], +) -> (usize, F) { + let char_limbs = F::characteristic(); + let mut p_minus_1: Vec = char_limbs + .iter() + .flat_map(|limb| limb.to_le_bytes()) + .collect(); + let mut borrow = 1u16; + for byte in &mut p_minus_1 { + let diff = (*byte as u16).wrapping_sub(borrow); + *byte = diff as u8; + borrow = if diff > 255 { 1 } else { 0 }; + } + while p_minus_1.last() == Some(&0) && p_minus_1.len() > 1 { + p_minus_1.pop(); + } + + for &p in extra_primes { + let extended_order = match order.checked_mul(p) { + Some(ext) => ext, + None => continue, + }; + + let mut cofactor = p_minus_1.clone(); + let mut remaining = extended_order; + while remaining % 2 == 0 { + if cofactor[0] & 1 != 0 { + break; + } + let mut carry = 0u8; + for byte in cofactor.iter_mut().rev() { + let new_carry = *byte & 1; + *byte = (*byte >> 1) | (carry << 7); + carry = new_carry; + } + remaining /= 2; + } + if remaining != extended_order >> extended_order.trailing_zeros() { + continue; + } + + let mut divisible = true; + if remaining > 1 { + let divisor = remaining as u128; + let mut carry: u128 = 0; + for byte in cofactor.iter_mut().rev() { + let cur = carry * 256 + *byte as u128; + *byte = (cur / divisor) as u8; + carry = cur % divisor; + } + if carry != 0 { + divisible = false; + } + } + if !divisible { + continue; + } + + while cofactor.last() == Some(&0) && cofactor.len() > 1 { + cofactor.pop(); + } + + let cofactor_limbs: Vec = cofactor + .chunks(8) + .map(|chunk: &[u8]| { + let mut bytes = [0u8; 8]; + bytes[..chunk.len()].copy_from_slice(chunk); + u64::from_le_bytes(bytes) + }) + .collect(); + + for g in [2u64, 3, 5, 7, 11] { + let candidate = F::from(g).pow(&cofactor_limbs); + if candidate.pow([extended_order as u64]) != F::ONE { + continue; + } + if candidate.pow([(extended_order / 2) as u64]) == F::ONE { + continue; + } + if candidate.pow([(extended_order / p) as u64]) == F::ONE { + continue; + } + if extended_order % 3 == 0 && candidate.pow([(extended_order / 3) as u64]) == F::ONE { + continue; + } + order = extended_order; + root = candidate; + break; + } + } + (order, root) +} + /// Creates a new NttEngine. `omega_order` must be a primitive root of unity of even order `omega`. impl NttEngine { pub fn new(order: usize, omega_order: F) -> Self { assert!(order.trailing_zeros() > 0, "Order must be a multiple of 2."); - // TODO: Assert that omega factors into 2s and 3s. assert_eq!(omega_order.pow([order as u64]), F::ONE); assert_ne!(omega_order.pow([order as u64 / 2]), F::ONE); let mut res = Self { @@ -117,6 +240,8 @@ impl NttEngine { omega_16_1: F::ZERO, omega_16_3: F::ZERO, omega_16_9: F::ZERO, + omega_13_c: [F::ZERO; 6], + omega_13_s: [F::ZERO; 6], roots: RwLock::new(Vec::new()), }; if order.is_multiple_of(3) { @@ -138,6 +263,16 @@ impl NttEngine { res.omega_16_3 = res.omega_16_1.pow([3]); res.omega_16_9 = res.omega_16_1.pow([9]); } + if order.is_multiple_of(13) { + let w = res.root(13); + let two_inv = F::from(2u64).inverse().unwrap(); + for m in 1..=6 { + let wm = w.pow([m as u64]); + let wn = w.pow([(13 - m) as u64]); + res.omega_13_c[m - 1] = (wm + wn) * two_inv; + res.omega_13_s[m - 1] = (wm - wn) * two_inv; + } + } res } @@ -353,6 +488,54 @@ impl NttEngine { (v[11], v[14]) = (v[14], v[11]); } } + 13 => { + // Winograd-style 13-point DFT exploiting ω^{13-m} = ω^{-m}. + // Pairs outputs X[k] and X[13-k] via symmetric/antisymmetric + // decomposition, reducing from 169 to 72 multiplications. + let [c1, c2, c3, c4, c5, c6] = self.omega_13_c; + let [t1, t2, t3, t4, t5, t6] = self.omega_13_s; + for v in values.chunks_exact_mut(13) { + let v0 = v[0]; + let (s1, d1) = (v[1] + v[12], v[1] - v[12]); + let (s2, d2) = (v[2] + v[11], v[2] - v[11]); + let (s3, d3) = (v[3] + v[10], v[3] - v[10]); + let (s4, d4) = (v[4] + v[9], v[4] - v[9]); + let (s5, d5) = (v[5] + v[8], v[5] - v[8]); + let (s6, d6) = (v[6] + v[7], v[6] - v[7]); + + v[0] = v0 + s1 + s2 + s3 + s4 + s5 + s6; + + // k=1: C-idx [1,2,3,4,5,6], S-signs [+,+,+,+,+,+] + let sc = c1 * s1 + c2 * s2 + c3 * s3 + c4 * s4 + c5 * s5 + c6 * s6; + let ss = t1 * d1 + t2 * d2 + t3 * d3 + t4 * d4 + t5 * d5 + t6 * d6; + (v[1], v[12]) = (v0 + sc + ss, v0 + sc - ss); + + // k=2: C-idx [2,4,6,5,3,1], S-signs [+,+,+,-,-,-] + let sc = c2 * s1 + c4 * s2 + c6 * s3 + c5 * s4 + c3 * s5 + c1 * s6; + let ss = t2 * d1 + t4 * d2 + t6 * d3 - t5 * d4 - t3 * d5 - t1 * d6; + (v[2], v[11]) = (v0 + sc + ss, v0 + sc - ss); + + // k=3: C-idx [3,6,4,1,2,5], S-signs [+,+,-,-,+,+] + let sc = c3 * s1 + c6 * s2 + c4 * s3 + c1 * s4 + c2 * s5 + c5 * s6; + let ss = t3 * d1 + t6 * d2 - t4 * d3 - t1 * d4 + t2 * d5 + t5 * d6; + (v[3], v[10]) = (v0 + sc + ss, v0 + sc - ss); + + // k=4: C-idx [4,5,1,3,6,2], S-signs [+,-,-,+,-,-] + let sc = c4 * s1 + c5 * s2 + c1 * s3 + c3 * s4 + c6 * s5 + c2 * s6; + let ss = t4 * d1 - t5 * d2 - t1 * d3 + t3 * d4 - t6 * d5 - t2 * d6; + (v[4], v[9]) = (v0 + sc + ss, v0 + sc - ss); + + // k=5: C-idx [5,3,2,6,1,4], S-signs [+,-,+,-,-,+] + let sc = c5 * s1 + c3 * s2 + c2 * s3 + c6 * s4 + c1 * s5 + c4 * s6; + let ss = t5 * d1 - t3 * d2 + t2 * d3 - t6 * d4 - t1 * d5 + t4 * d6; + (v[5], v[8]) = (v0 + sc + ss, v0 + sc - ss); + + // k=6: C-idx [6,1,5,2,4,3], S-signs [+,-,+,-,+,-] + let sc = c6 * s1 + c1 * s2 + c5 * s3 + c2 * s4 + c4 * s5 + c3 * s6; + let ss = t6 * d1 - t1 * d2 + t5 * d3 - t2 * d4 + t4 * d5 - t3 * d6; + (v[6], v[7]) = (v0 + sc + ss, v0 + sc - ss); + } + } size => self.ntt_recurse(values, roots, size), } } @@ -458,7 +641,7 @@ mod tests { use ark_ff::{AdditiveGroup as _, BigInteger, PrimeField}; use super::*; - use crate::algebra::fields::Field64; + use crate::algebra::fields::{Field256, Field64}; #[test] fn test_new_from_fftfield_basic() { @@ -963,4 +1146,221 @@ mod tests { assert_eq!(values_ntt, expected_values); } + + #[test] + fn test_field256_mixed_radix_engine() { + // Field256 (BN254 Fr) should have a mixed-radix engine with order 2^28 * 3^2 * 13. + let engine = NttEngine::::new_from_fftfield(); + let expected_order = (1usize << 28) * 9 * 13; + assert_eq!(engine.order, expected_order); + + // Verify the root of unity has the correct order. + assert_eq!( + engine.omega_order.pow([expected_order as u64]), + Field256::ONE + ); + assert_ne!( + engine.omega_order.pow([(expected_order / 2) as u64]), + Field256::ONE + ); + assert_ne!( + engine.omega_order.pow([(expected_order / 3) as u64]), + Field256::ONE + ); + assert_ne!( + engine.omega_order.pow([(expected_order / 13) as u64]), + Field256::ONE + ); + } + + #[test] + fn test_field256_mixed_radix_roots() { + let engine = NttEngine::::new_from_fftfield(); + + // Verify roots for mixed-radix sizes. + assert!(engine.checked_root(3).is_some()); + assert!(engine.checked_root(9).is_some()); + assert!(engine.checked_root(6).is_some()); // 2 * 3 + assert!(engine.checked_root(12).is_some()); // 4 * 3 + assert!(engine.checked_root(18).is_some()); // 2 * 9 + assert!(engine.checked_root(36).is_some()); // 4 * 9 + + // Verify root properties. + let root_3 = engine.root(3); + assert_eq!(root_3.pow([3]), Field256::ONE); + assert_ne!(root_3, Field256::ONE); + + let root_9 = engine.root(9); + assert_eq!(root_9.pow([9]), Field256::ONE); + assert_ne!(root_9.pow([3]), Field256::ONE); + + let root_36 = engine.root(36); + assert_eq!(root_36.pow([36]), Field256::ONE); + assert_ne!(root_36.pow([18]), Field256::ONE); + } + + #[test] + fn test_field256_ntt_size_3() { + let engine = NttEngine::::new_from_fftfield(); + + // NTT of size 3 over Field256 (BN254 Fr). + let values: Vec<_> = (1..=3).map(Field256::from).collect(); + let mut values_ntt = values.clone(); + + let omega = engine.root(3); + let mut expected = vec![Field256::ZERO; 3]; + for k in 0..3 { + let omega_k = omega.pow([k as u64]); + expected[k] = values + .iter() + .enumerate() + .map(|(j, &v)| v * omega_k.pow([j as u64])) + .sum(); + } + + engine.ntt_batch(&mut values_ntt, 3); + assert_eq!(values_ntt, expected); + } + + #[test] + fn test_field256_ntt_size_9() { + let engine = NttEngine::::new_from_fftfield(); + + // NTT of size 9 over Field256 (BN254 Fr). + let values: Vec<_> = (1..=9).map(Field256::from).collect(); + let mut values_ntt = values.clone(); + + let omega = engine.root(9); + let mut expected = vec![Field256::ZERO; 9]; + for k in 0..9 { + let omega_k = omega.pow([k as u64]); + expected[k] = values + .iter() + .enumerate() + .map(|(j, &v)| v * omega_k.pow([j as u64])) + .sum(); + } + + engine.ntt_batch(&mut values_ntt, 9); + assert_eq!(values_ntt, expected); + } + + #[test] + fn test_field256_ntt_size_12() { + // 12 = 4 * 3, a mixed-radix size. + let engine = NttEngine::::new_from_fftfield(); + + let values: Vec<_> = (1..=12).map(Field256::from).collect(); + let mut values_ntt = values.clone(); + + let omega = engine.root(12); + let mut expected = vec![Field256::ZERO; 12]; + for k in 0..12 { + let omega_k = omega.pow([k as u64]); + expected[k] = values + .iter() + .enumerate() + .map(|(j, &v)| v * omega_k.pow([j as u64])) + .sum(); + } + + engine.ntt_batch(&mut values_ntt, 12); + assert_eq!(values_ntt, expected); + } + + #[test] + fn test_field256_ntt_size_13() { + let engine = NttEngine::::new_from_fftfield(); + + // NTT of size 13 over Field256 (BN254 Fr). + let values: Vec<_> = (1..=13).map(Field256::from).collect(); + let mut values_ntt = values.clone(); + + let omega = engine.root(13); + let mut expected = vec![Field256::ZERO; 13]; + for k in 0..13 { + let omega_k = omega.pow([k as u64]); + expected[k] = values + .iter() + .enumerate() + .map(|(j, &v)| v * omega_k.pow([j as u64])) + .sum(); + } + + engine.ntt_batch(&mut values_ntt, 13); + assert_eq!(values_ntt, expected); + } + + #[test] + fn test_field256_ntt_size_26() { + // 26 = 2 * 13, a mixed-radix size. + let engine = NttEngine::::new_from_fftfield(); + + let values: Vec<_> = (1..=26).map(Field256::from).collect(); + let mut values_ntt = values.clone(); + + let omega = engine.root(26); + let mut expected = vec![Field256::ZERO; 26]; + for k in 0..26 { + let omega_k = omega.pow([k as u64]); + expected[k] = values + .iter() + .enumerate() + .map(|(j, &v)| v * omega_k.pow([j as u64])) + .sum(); + } + + engine.ntt_batch(&mut values_ntt, 26); + assert_eq!(values_ntt, expected); + } + + #[test] + fn test_field256_ntt_size_39() { + // 39 = 3 * 13, a mixed-radix size. + let engine = NttEngine::::new_from_fftfield(); + + let values: Vec<_> = (1..=39).map(Field256::from).collect(); + let mut values_ntt = values.clone(); + + let omega = engine.root(39); + let mut expected = vec![Field256::ZERO; 39]; + for k in 0..39 { + let omega_k = omega.pow([k as u64]); + expected[k] = values + .iter() + .enumerate() + .map(|(j, &v)| v * omega_k.pow([j as u64])) + .sum(); + } + + engine.ntt_batch(&mut values_ntt, 39); + assert_eq!(values_ntt, expected); + } + + #[test] + fn test_field256_ntt_roundtrip_mixed() { + // Test NTT → INTT roundtrip for mixed-radix sizes. + use ark_std::UniformRand; + + let engine = NttEngine::::new_from_fftfield(); + let mut rng = ark_std::test_rng(); + + for size in [ + 3, 6, 9, 12, 13, 18, 26, 36, 39, 52, 72, 78, 104, 117, 144, 156, + ] { + let original: Vec<_> = (0..size).map(|_| Field256::rand(&mut rng)).collect(); + let mut values = original.clone(); + + engine.ntt_batch(&mut values, size); + engine.intt_batch(&mut values, size); + + // INTT omits the 1/n scaling factor, so we need to multiply by 1/n. + let n_inv = Field256::from(size as u64).inverse().unwrap(); + for v in &mut values { + *v *= n_inv; + } + + assert_eq!(values, original, "Roundtrip failed for size {size}"); + } + } } diff --git a/src/algebra/ntt/utils.rs b/src/algebra/ntt/utils.rs index ce3b8c27..c11f81bb 100644 --- a/src/algebra/ntt/utils.rs +++ b/src/algebra/ntt/utils.rs @@ -1,53 +1,35 @@ /// Compute the largest factor of `n` that is ≤ sqrt(n). -/// Assumes `n` is of the form `2^k * {1,3,9}`. +/// Assumes `n` is a smooth-{2,3,13} number, i.e. of the form `2^a * 3^b * 13^c`. pub fn sqrt_factor(n: usize) -> usize { - // Count the number of trailing zeros in `n`, i.e., the power of 2 in `n` - let twos = n.trailing_zeros(); - - // Divide `n` by the highest power of 2 to extract the base component - let base = n >> twos; - - // Determine the largest factor ≤ sqrt(n) based on the extracted `base` - match base { - // Case: `n` is purely a power of 2 (base = 1) - // The largest factor ≤ sqrt(n) is 2^(twos/2) - 1 => 1 << (twos / 2), - - // Case: `n = 2^k * 3` - 3 => { - if twos == 0 { - // sqrt(3) ≈ 1.73, so the largest integer factor ≤ sqrt(3) is 1 - 1 - } else { - // - If `twos` is even: The largest factor is `3 * 2^((twos - 1) / 2)` - // - If `twos` is odd: The largest factor is `2^((twos / 2))` - if twos.is_multiple_of(2) { - 3 << ((twos - 1) / 2) - } else { - 2 << (twos / 2) - } - } - } - - // Case: `n = 2^k * 9` - 9 => { - if twos == 1 { - // sqrt(9 * 2^1) = sqrt(18) ≈ 4.24, largest factor ≤ sqrt(18) is 3 - 3 - } else { - // - If `twos` is even: The largest factor is `3 * 2^(twos / 2)` - // - If `twos` is odd: The largest factor is `4 * 2^(twos / 2)` - if twos.is_multiple_of(2) { - 3 << (twos / 2) - } else { - 4 << (twos / 2) - } - } + let twos = n.trailing_zeros() as usize; + let odd = n >> twos; + + // Enumerate all divisors of the odd part and for each, find the largest + // power-of-2 multiplier that keeps the product ≤ sqrt(n). + let odd_divisors: &[usize] = match odd { + 1 => &[1], + 3 => &[1, 3], + 9 => &[1, 3, 9], + 13 => &[1, 13], + 39 => &[1, 3, 13, 39], + 117 => &[1, 3, 9, 13, 39, 117], + _ => panic!("n is not a smooth-{{2,3,13}} number"), + }; + + let mut best = 1usize; + for &d in odd_divisors { + let d_sq = d * d; + if d_sq > n { + continue; } - - // If `base` is not in {1,3,9}, `n` is not in the expected form - _ => panic!("n is not in the form 2^k * {{1,3,9}}"), + // We need d * 2^a ≤ sqrt(n), i.e. d² * 4^a ≤ n, i.e. 4^a ≤ n/d². + let ratio = n / d_sq; + // max a such that 4^a ≤ ratio: a = floor(log2(ratio)) / 2, capped at twos. + let max_2a = (usize::BITS - 1 - ratio.leading_zeros()) as usize; + let a = (max_2a / 2).min(twos); + best = best.max(d << a); } + best } /// Least common multiple. @@ -136,11 +118,33 @@ mod tests { assert_eq!(sqrt_factor(144), 12); // 144 = 2^4 * 9 assert_eq!(sqrt_factor(576), 24); // 576 = 2^6 * 9 assert_eq!(sqrt_factor(2304), 48); // 2304 = 2^8 * 9 + + // Cases where n = 2^k * 13 + assert_eq!(sqrt_factor(13), 1); // 13 = 2^0 * 13 + assert_eq!(sqrt_factor(26), 2); // 26 = 2^1 * 13 + assert_eq!(sqrt_factor(52), 4); // 52 = 2^2 * 13 + assert_eq!(sqrt_factor(208), 13); // 208 = 2^4 * 13 + assert_eq!(sqrt_factor(832), 26); // 832 = 2^6 * 13 + assert_eq!(sqrt_factor(3328), 52); // 3328 = 2^8 * 13 + + // Cases where n = 2^k * 39 + assert_eq!(sqrt_factor(39), 3); // 39 = 2^0 * 39 + assert_eq!(sqrt_factor(78), 6); // 78 = 2^1 * 39 + assert_eq!(sqrt_factor(156), 12); // 156 = 2^2 * 39 + assert_eq!(sqrt_factor(624), 24); // 624 = 2^4 * 39 + assert_eq!(sqrt_factor(2496), 48); // 2496 = 2^6 * 39 + + // Cases where n = 2^k * 117 + assert_eq!(sqrt_factor(117), 9); // 117 = 2^0 * 117 + assert_eq!(sqrt_factor(234), 13); // 234 = 2^1 * 117 + assert_eq!(sqrt_factor(468), 18); // 468 = 2^2 * 117 + assert_eq!(sqrt_factor(1872), 39); // 1872 = 2^4 * 117 + assert_eq!(sqrt_factor(7488), 78); // 7488 = 2^6 * 117 } proptest! { #[test] - fn proptest_sqrt_factor(k in 0usize..30, base in prop_oneof![Just(1), Just(3), Just(9)]) + fn proptest_sqrt_factor(k in 0usize..30, base in prop_oneof![Just(1), Just(3), Just(9), Just(13), Just(39), Just(117)]) { let n = (1 << k) * base; let expected = get_largest_divisor_up_to_sqrt(n); diff --git a/src/lib.rs b/src/lib.rs index 7d1d8b45..89f808ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub mod engines; pub mod hash; pub mod parameters; pub mod protocols; +pub mod smooth_domain; pub mod transcript; pub mod type_info; pub mod type_map; diff --git a/src/protocols/challenge_indices.rs b/src/protocols/challenge_indices.rs index 87831382..a6baeca7 100644 --- a/src/protocols/challenge_indices.rs +++ b/src/protocols/challenge_indices.rs @@ -3,6 +3,12 @@ use crate::transcript::{Decoding, VerifierMessage}; /// Generate a set of indices for challenges. +/// +/// For power-of-2 `num_leaves`, indices are sampled without bias using exactly +/// `ceil(log2(num_leaves))` bits per index. +/// +/// For other `num_leaves` (e.g., `2^a * 3^b * 13^c`), extra entropy bytes are used +/// to make the modular bias negligible (< 2^{-64}). pub fn challenge_indices( transcript: &mut T, num_leaves: usize, @@ -16,18 +22,23 @@ where if count == 0 { return Vec::new(); } - assert!( - num_leaves.is_power_of_two(), - "Number of leaves must be a power of two for unbiased results." - ); + assert!(num_leaves > 0, "Number of leaves must be positive."); if num_leaves == 1 { // `size_bytes` would be zero, making `chunks_exact` panic. return if deduplicate { vec![0] } else { vec![0; count] }; } - // Calculate the required bytes of entropy + // Calculate the required bytes of entropy per index. + // For power-of-2, use exactly ceil(log2(N)) bits (no bias). + // For non-power-of-2, add 8 extra bytes to make modular bias < 2^{-64}. // TODO: Round total to bytes, instead of per index. - let size_bytes = (num_leaves.ilog2() as usize).div_ceil(8); + let size_bytes = if num_leaves.is_power_of_two() { + (num_leaves.ilog2() as usize).div_ceil(8) + } else { + // ceil(log2(num_leaves)) bits + 64 extra bits for negligible bias + let bits_needed = usize::BITS - (num_leaves - 1).leading_zeros(); + (bits_needed as usize).div_ceil(8) + 8 + }; // Get required entropy bits. let entropy: Vec = (0..count * size_bytes) @@ -200,4 +211,67 @@ mod tests { "Mismatch in computed indices for deduplication test" ); } + + #[test] + fn test_challenge_indices_non_power_of_two() { + // Test with num_leaves = 12 (= 2^2 * 3), a mixed-radix size. + let num_leaves = 12; + let num_queries = 3; + + let ds = DomainSeparator::protocol(&module_path!()) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&Empty); + // For non-power-of-2: ceil(log2(12)) = 4 bits → 1 byte + 8 extra = 9 bytes per index + // So 3 queries × 9 bytes = 27 bytes needed. + let sponge = MockSponge { + absorb: None, + squeeze: &[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, // Query 1: index 5 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0B, // Query 2: index 11 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Query 3: index 0 + ], + }; + let mut prover_state = ProverState::new(&ds, sponge); + + let result = challenge_indices(&mut prover_state, num_leaves, num_queries, true); + + let mut expected = vec![5 % num_leaves, 11 % num_leaves, 0 % num_leaves]; + expected.sort_unstable(); + expected.dedup(); + + assert_eq!(result, expected); + } + + #[test] + fn test_challenge_indices_mixed_radix_large() { + // Test with num_leaves = 2^18 * 3^2 = 2359296. + let num_leaves = (1 << 18) * 9; + assert_eq!(num_leaves, 2_359_296); + + // ceil(log2(2359296)) = 22 bits → 3 bytes + 8 extra = 11 bytes per index + let num_queries = 2; + + let ds = DomainSeparator::protocol(&module_path!()) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&Empty); + // 2 queries × 11 bytes = 22 bytes + let sponge = MockSponge { + absorb: None, + squeeze: &[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x23, 0x45, 0x67, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + ], + }; + let mut prover_state = ProverState::new(&ds, sponge); + + let result = challenge_indices(&mut prover_state, num_leaves, num_queries, true); + + let val_1 = 0x00_00_00_00_00_00_00_01_23_45_67_usize % num_leaves; + let val_2 = 0x00_00_00_00_00_00_00_00_00_00_01_usize % num_leaves; + let mut expected = vec![val_1, val_2]; + expected.sort_unstable(); + expected.dedup(); + + assert_eq!(result, expected); + } } diff --git a/src/protocols/matrix_commit.rs b/src/protocols/matrix_commit.rs index 436331b8..dac0d6f0 100644 --- a/src/protocols/matrix_commit.rs +++ b/src/protocols/matrix_commit.rs @@ -351,8 +351,9 @@ fn hash_rows_serial( if encoder.is_buffered() { // Buffered encoder, find some optimal size. let target = workload_size::() / 8; - let batch_size = (target / message_size).next_multiple_of(engine.preferred_batch_size()); - assert!(batch_size >= 1); + let batch_size = (target / message_size) + .max(1) + .next_multiple_of(engine.preferred_batch_size()); for (matrix, out) in zip_strict(matrix.chunks(batch_size * cols), out.chunks_mut(batch_size)) { @@ -409,7 +410,7 @@ pub(crate) mod tests { Standard: Distribution, { crate::tests::init(); - assert!(layers >= merkle_tree::layers_for_size(num_rows)); + assert!(layers >= merkle_tree::layers_for_size(num_rows).len()); assert!(indices.iter().all(|&index| index < num_rows)); // Config @@ -417,9 +418,27 @@ pub(crate) mod tests { element_type: Type::::new(), num_cols, leaf_hash_id: leaf_hash, - merkle_tree: merkle_tree::Config { - num_leaves: num_rows, - layers: vec![merkle_tree::LayerConfig { hash_id: node_hash }; layers], + merkle_tree: { + // Build a mixed-arity layer config with extra binary layers on top. + let base_arities = merkle_tree::layers_for_size(num_rows); + let extra = layers.saturating_sub(base_arities.len()); + let mut layer_configs = Vec::with_capacity(layers); + for _ in 0..extra { + layer_configs.push(merkle_tree::LayerConfig { + hash_id: node_hash, + arity: 2, + }); + } + for &arity in &base_arities { + layer_configs.push(merkle_tree::LayerConfig { + hash_id: node_hash, + arity, + }); + } + merkle_tree::Config { + num_leaves: num_rows, + layers: layer_configs, + } }, }; let ds = DomainSeparator::protocol(&config) @@ -464,26 +483,33 @@ pub(crate) mod tests { T: Clone + TypeInfo + Encodable + Send + Sync, Standard: Distribution, { + // Smooth-{2,3,13} values (2^a * 3^b * 13^c) up to ~120. + let smooth_rows: Vec = vec![ + 0, 1, 2, 3, 4, 6, 8, 9, 12, 13, 16, 18, 24, 26, 27, 32, 36, 39, 48, 52, 54, 64, 72, 78, + 96, 104, 117, + ]; let hashes = [hash::COPY, hash::SHA2, hash::SHA3, hash::BLAKE3]; proptest!(|( seed: u64, leaf_hash in 0_usize..hashes.len(), node_hash in 1_usize..hashes.len(), layers in 0_usize..10, - num_rows in 0_usize..100, + row_idx in 0_usize..smooth_rows.len(), num_cols in 0_usize..100, num_indices in 0_usize..100, )| { + let num_rows = smooth_rows[row_idx]; + // There are no valid indices without rows. let num_indices = if num_rows == 0 { 0 } else { num_indices }; // We need at least enough layers to cover the number of rows. - let layers = layers + merkle_tree::layers_for_size(num_rows); + let layers = layers + merkle_tree::layers_for_size(num_rows).len(); let leaf_hash = hashes[leaf_hash]; let node_hash = hashes[node_hash]; prop_assume!(hash::ENGINES.retrieve(leaf_hash).unwrap().supports_size(T::encoded_size() * num_cols)); - prop_assume!(hash::ENGINES.retrieve(node_hash).unwrap().supports_size(64)); + prop_assume!(hash::ENGINES.retrieve(node_hash).unwrap().supports_size(merkle_tree::MAX_NODE_HASH_SIZE)); let mut rng = StdRng::seed_from_u64(seed); let indices = (0..num_indices).map(|_| rng.gen_range(0..num_rows)).collect::>(); diff --git a/src/protocols/merkle_tree.rs b/src/protocols/merkle_tree.rs index 8a30f9b8..f8761e57 100644 --- a/src/protocols/merkle_tree.rs +++ b/src/protocols/merkle_tree.rs @@ -1,5 +1,9 @@ //! Protocol for committing to a vector of [`struct@Hash`]es. //! +//! Uses a mixed-arity Merkle tree: each layer has its own arity (2, 3, or 13). +//! For `num_leaves = 2^a * 3^b * 13^c`, the tree uses `c` arity-13 layers +//! (bottom), `b` ternary layers (middle), and `a` binary layers (top). +//! //! See for analysis when used with truncated permutation //! node hashes. @@ -22,6 +26,14 @@ use crate::{ verify, }; +/// Size in bytes of the hash input for one internal node with the given arity. +pub const fn node_hash_size(arity: usize) -> usize { + arity * 32 +} + +/// Maximum supported arity. Used for fixed-size arrays in open/verify loops. +const MAX_ARITY: usize = 13; + #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] pub struct Config { /// Number of leaves in the Merkle tree. @@ -37,6 +49,9 @@ pub struct Config { pub struct LayerConfig { /// The engine used to hash siblings. pub hash_id: EngineId, + + /// Arity of this layer (number of children per node). Must be 2, 3, or 13. + pub arity: usize, } impl fmt::Display for Config { @@ -67,15 +82,40 @@ impl Config { Self::with_hash(hash::BLAKE3, num_leaves) } + /// Create a mixed-arity Merkle tree configuration. + /// + /// The tree pads `num_leaves` to the smallest smooth-{2,3,13} number >= `num_leaves`. pub fn with_hash(hash_id: EngineId, num_leaves: usize) -> Self { - Self { - num_leaves, - layers: vec![LayerConfig { hash_id }; layers_for_size(num_leaves)], + let arities = layers_for_size(num_leaves); + let layers = arities + .iter() + .map(|&arity| LayerConfig { hash_id, arity }) + .collect(); + Self { num_leaves, layers } + } + + /// Total number of nodes in the tree (all layers including leaves). + pub fn num_nodes(&self) -> usize { + if self.layers.is_empty() { + return 1; + } + let mut total = 0; + let mut layer_size = 1; // root + // Layers are root-to-bottom, so multiply by arity going down. + for layer in &self.layers { + layer_size *= layer.arity; + total += layer_size; } + total + 1 // +1 for the root node } - pub const fn num_nodes(&self) -> usize { - (1 << (self.layers.len() + 1)) - 1 + /// Capacity of the leaf layer (product of all arities). + fn leaf_capacity(&self) -> usize { + self.layers + .iter() + .map(|l| l.arity) + .product::() + .max(1) } #[cfg_attr(feature = "tracing", instrument(skip(prover_state, leaves), fields(self = %self)))] @@ -93,14 +133,16 @@ impl Config { leaves.len() ); - // Allocate nodes and fill with leaf layer. This implicitely pads the first layer. + // Allocate nodes and fill with leaf layer. This implicitly pads the first layer. let mut nodes = leaves; nodes.resize(self.num_nodes(), Hash::default()); - let (mut previous, mut remaining) = nodes.split_at_mut(1 << self.layers.len()); + let (mut previous, mut remaining) = nodes.split_at_mut(self.leaf_capacity()); - // Compute merkle tree nodes. + // Compute merkle tree nodes bottom-to-top. for layer in self.layers.iter().rev() { - let (current, next_remaining) = remaining.split_at_mut(previous.len() / 2); + let arity = layer.arity; + let parent_count = previous.len() / arity; + let (current, next_remaining) = remaining.split_at_mut(parent_count); let engine = ENGINES .retrieve(layer.hash_id) .expect("Hash Engine not found"); @@ -109,13 +151,15 @@ impl Config { Level::DEBUG, "layer", engine = engine.name().as_ref(), - count = current.len() + count = current.len(), + arity = arity ) .entered(); // TODO: Parallelize over subtrees, not layerwise. This will // increase locality. - parallel_hash(&*engine, 64, previous.as_bytes(), current); + let hash_size = node_hash_size(arity); + parallel_hash(&*engine, hash_size, previous.as_bytes(), current); previous = current; remaining = next_remaining; } @@ -155,31 +199,49 @@ impl Config { assert_eq!(witness.nodes.len(), self.num_nodes()); assert!(indices.iter().all(|&i| i < self.num_leaves)); - // Abstract execution of verify algorithm wrting required hashes. + // Abstract execution of verify algorithm writing required hashes. let mut indices = indices.to_vec(); indices.sort_unstable(); indices.dedup(); - let (mut layer, mut remaining) = witness.nodes.split_at(1 << self.layers.len()); - while layer.len() > 1 { + let (mut layer, mut remaining) = witness.nodes.split_at(self.leaf_capacity()); + + // Walk bottom-to-top through layers. + for layer_config in self.layers.iter().rev() { + let arity = layer_config.arity; let mut next_indices = Vec::with_capacity(indices.len()); let mut iter = indices.iter().copied().peekable(); loop { - match (iter.next(), iter.peek()) { - (Some(a), Some(&b)) if b == a ^ 1 => { - // Neighboring indices, merging branches. - next_indices.push(a >> 1); - iter.next(); // Skip the next index. - } - (Some(a), _) => { - // Single index, pushing the neighbor hash. - prover_state.prover_hint(&layer[a ^ 1]); - next_indices.push(a >> 1); + match iter.next() { + Some(a) => { + let parent = a / arity; + let group_start = parent * arity; + + // Track which siblings are present in the query set. + let mut present = [false; MAX_ARITY]; + present[a - group_start] = true; + while let Some(&b) = iter.peek() { + if b / arity == parent { + present[b - group_start] = true; + iter.next(); + } else { + break; + } + } + + // Send the missing siblings as hints. + for i in 0..arity { + if !present[i] { + prover_state.prover_hint(&layer[group_start + i]); + } + } + next_indices.push(parent); } - (None, _) => break, + None => break, } } indices = next_indices; - let (next_layer, next_remaining) = remaining.split_at(layer.len() / 2); + let parent_count = layer.len() / arity; + let (next_layer, next_remaining) = remaining.split_at(parent_count); layer = next_layer; remaining = next_remaining; } @@ -216,51 +278,60 @@ impl Config { } layer.dedup_by_key(|(i, _)| *i); - // Validate the layers + // Validate the layers bottom-to-top. let mut indices = layer.iter().map(|(i, _)| *i).collect::>(); let mut hashes = layer.iter().map(|(_, h)| *h).collect::>(); let mut next_indices = Vec::with_capacity(layer.len()); - let mut input_hashes = Vec::with_capacity(layer.len() * 2); + let mut input_hashes = Vec::with_capacity(layer.len() * MAX_ARITY); let mut next_hashes = Vec::with_capacity(layer.len()); - for layer in self.layers.iter().rev() { + for layer_config in self.layers.iter().rev() { + let arity = layer_config.arity; next_indices.clear(); input_hashes.clear(); next_hashes.clear(); - // Pair hashes with either hint or neighbor. + // Group hashes by parent, filling missing siblings from hints. let mut indices_iter = indices.iter().copied().peekable(); let mut hashes_iter = hashes.iter().copied(); loop { - match (indices_iter.next(), indices_iter.peek()) { - (Some(a), Some(&b)) if b == a ^ 1 => { - // Neighboring indices, merging branches. - input_hashes.push(hashes_iter.next().unwrap()); - input_hashes.push(hashes_iter.next().unwrap()); - next_indices.push(a >> 1); - indices_iter.next(); // Skip the next index. - } - (Some(a), _) => { - // Single index, receiving the neighbor hash. - let hash = verifier_state.prover_hint()?; - if a & 1 == 0 { - input_hashes.push(hashes_iter.next().unwrap()); - input_hashes.push(hash); - } else { - input_hashes.push(hash); - input_hashes.push(hashes_iter.next().unwrap()); + match indices_iter.next() { + Some(a) => { + let parent = a / arity; + let group_start = parent * arity; + + // Collect known hashes for this sibling group. + let mut group = [None; MAX_ARITY]; + group[a - group_start] = Some(hashes_iter.next().unwrap()); + while let Some(&b) = indices_iter.peek() { + if b / arity == parent { + group[b - group_start] = Some(hashes_iter.next().unwrap()); + indices_iter.next(); + } else { + break; + } + } + + // Push all `arity` children in order, reading hints for + // missing positions. + for slot in &group[..arity] { + match slot { + Some(h) => input_hashes.push(*h), + None => input_hashes.push(verifier_state.prover_hint()?), + } } - next_indices.push(a >> 1); + next_indices.push(parent); } - (None, _) => break, + None => break, } } // Compute next layer hashes - next_hashes.resize(input_hashes.len() / 2, Hash::default()); + let hash_size = node_hash_size(arity); + next_hashes.resize(input_hashes.len() / arity, Hash::default()); ENGINES - .retrieve(layer.hash_id) + .retrieve(layer_config.hash_id) .ok_or(VerificationError)? - .hash_many(64, input_hashes.as_bytes(), &mut next_hashes); + .hash_many(hash_size, input_hashes.as_bytes(), &mut next_hashes); swap(&mut indices, &mut next_indices); swap(&mut hashes, &mut next_hashes); } @@ -278,10 +349,45 @@ impl Witness { } } -pub const fn layers_for_size(size: usize) -> usize { - size.next_power_of_two().ilog2() as usize +/// Compute the optimal mixed-arity layer sequence for a tree with `size` leaves. +/// +/// Returns a vector of arities (root-to-bottom) matching the NTT domain +/// decomposition of `size`. +/// +/// `size` must be a smooth-{2,3,13} number (`2^a * 3^b * 13^c`). The tree +/// uses `a` binary layers on top, `b` ternary layers in the middle, and `c` +/// arity-13 layers on the bottom. This deterministic layout lets +/// circuit-based verifiers (e.g. Gnark) hardcode the arity sequence at +/// compile time. +/// +/// # Panics +/// +/// Panics if `size` is not a smooth-{2,3,13} number. +pub fn layers_for_size(size: usize) -> Vec { + if size <= 1 { + return Vec::new(); + } + + let (a, b, c) = crate::smooth_domain::decompose(size); + + // Root-to-bottom: binary layers first, then ternary, then arity-13. + let mut arities = Vec::with_capacity(a + b + c); + for _ in 0..a { + arities.push(2); + } + for _ in 0..b { + arities.push(3); + } + for _ in 0..c { + arities.push(13); + } + arities } +/// Maximum node hash size across all supported arities (arity 3 → 96 bytes). +/// Used for hash engine compatibility checks. +pub const MAX_NODE_HASH_SIZE: usize = node_hash_size(MAX_ARITY); + #[cfg(not(feature = "parallel"))] fn parallel_hash(engine: &dyn HashEngine, size: usize, input: &[u8], output: &mut [Hash]) { engine.hash_many(size, input, output); @@ -291,9 +397,11 @@ fn parallel_hash(engine: &dyn HashEngine, size: usize, input: &[u8], output: &mu fn parallel_hash(engine: &dyn HashEngine, size: usize, input: &[u8], output: &mut [Hash]) { use crate::utils::workload_size; assert_eq!(input.len(), size * output.len()); - if input.len() > workload_size::() && input.len() / size >= 2 { - let (input_a, input_b) = input.split_at(input.len() / 2); - let (output_a, output_b) = output.split_at_mut(output.len() / 2); + if input.len() > workload_size::() && output.len() >= 2 { + // Split on output count so input bytes stay aligned to `size`. + let split = output.len() / 2; + let (input_a, input_b) = input.split_at(split * size); + let (output_a, output_b) = output.split_at_mut(split); rayon::join( || parallel_hash(engine, size, input_a, output_a), || parallel_hash(engine, size, input_b, output_b), @@ -314,23 +422,47 @@ pub(crate) mod tests { }; pub fn config(num_leaves: usize) -> impl Strategy { - let min_layers = layers_for_size(num_leaves); - // Add up to three unnecessary layers - let num_layers = min_layers..=min_layers + 3; - // Each layer gets its own choice of hash function - let layer = hash_for_size(64).prop_map(|hash_id| LayerConfig { hash_id }); - vec(layer, num_layers).prop_map(move |layers| Config { num_leaves, layers }) + let base_arities = layers_for_size(num_leaves); + let min_layers = base_arities.len(); + // Add up to three unnecessary extra binary layers at the front. + let num_extra = 0..=3usize; + num_extra.prop_flat_map(move |extra| { + // Arities: extra binary layers at the front, then base arities. + let mut arities = Vec::with_capacity(min_layers + extra); + for _ in 0..extra { + arities.push(2usize); + } + arities.extend(&base_arities); + + // For each layer, pick a hash function that supports its node_hash_size. + // Use MAX_NODE_HASH_SIZE so we can use any hash engine for all layers. + let layer_count = arities.len(); + let arities_for_map = arities.clone(); + vec(hash_for_size(MAX_NODE_HASH_SIZE), layer_count..=layer_count).prop_map( + move |hash_ids| { + let layers = hash_ids + .into_iter() + .zip(arities_for_map.iter()) + .map(|(hash_id, &arity)| LayerConfig { hash_id, arity }) + .collect(); + Config { num_leaves, layers } + }, + ) + }) } - #[test] - fn test_merkle_tree() { + fn run_merkle_tree_test(num_leaves: usize, arities: &[usize], indices: &[usize]) { crate::tests::init(); - let config = Config { - num_leaves: 256, - layers: vec![LayerConfig { hash_id: BLAKE3 }; 8], - }; - - let leaves = (0..config.num_leaves) + let layers: Vec = arities + .iter() + .map(|&arity| LayerConfig { + hash_id: BLAKE3, + arity, + }) + .collect(); + let config = Config { num_leaves, layers }; + + let leaves = (0..num_leaves) .map(|i| Hash([i as u8; 32])) .collect::>(); @@ -341,32 +473,178 @@ pub(crate) mod tests { // Prover let mut prover_state = ProverState::new_std(&ds); let tree = config.commit(&mut prover_state, leaves); - config.open(&mut prover_state, &tree, &[13, 42]); + config.open(&mut prover_state, &tree, indices); let proof = prover_state.proof(); // Verifier let mut verifier_state = VerifierState::new_std(&ds, &proof); let root = config.receive_commitment(&mut verifier_state).unwrap(); + let leaf_hashes: Vec = indices.iter().map(|&i| Hash([i as u8; 32])).collect(); config - .verify( - &mut verifier_state, - &root, - &[13, 42], - &[Hash([13; 32]), Hash([42; 32])], - ) + .verify(&mut verifier_state, &root, indices, &leaf_hashes) .unwrap(); } + #[test] + fn test_merkle_tree_binary() { + // Pure binary tree: 2^3 = 8 leaves. + run_merkle_tree_test(8, &[2, 2, 2], &[1, 5, 7]); + } + + #[test] + fn test_merkle_tree_ternary() { + // Pure ternary tree: 3^3 = 27 leaves. + run_merkle_tree_test(27, &[3, 3, 3], &[5, 13, 26]); + } + + #[test] + fn test_merkle_tree_mixed_2a_3b() { + // Mixed: 2^2 * 3^1 = 12 leaves. Binary top, ternary bottom. + run_merkle_tree_test(12, &[2, 2, 3], &[0, 5, 11]); + } + + #[test] + fn test_merkle_tree_mixed_large() { + // Mixed: 2^3 * 3^2 = 72 leaves. + run_merkle_tree_test(72, &[2, 2, 2, 3, 3], &[0, 13, 42, 71]); + } + + #[test] + fn test_merkle_tree_arity_13() { + // Pure arity-13: 13 leaves. + run_merkle_tree_test(13, &[13], &[0, 6, 12]); + // Mixed: 2 * 13 = 26 leaves. + run_merkle_tree_test(26, &[2, 13], &[0, 12, 25]); + // Mixed: 2 * 3 * 13 = 78 leaves. + run_merkle_tree_test(78, &[2, 3, 13], &[0, 39, 77]); + } + + #[test] + fn test_merkle_tree_smooth_domain_sizes() { + // Verify trees work for all smooth sizes with exact capacity. + for &size in &[ + 2, 3, 4, 6, 8, 9, 12, 13, 16, 18, 24, 26, 36, 39, 48, 72, 78, 104, 117, + ] { + let arities = layers_for_size(size); + assert_eq!(arities.iter().product::(), size); + run_merkle_tree_test(size, &arities, &[0, size / 2, size - 1]); + } + } + + #[test] + fn test_merkle_tree_single_leaf() { + run_merkle_tree_test(1, &[], &[0]); + } + #[test] fn test_layers_for_size() { - assert_eq!(layers_for_size(0), 0); - assert_eq!(layers_for_size(1), 0); - assert_eq!(layers_for_size(2), 1); - assert_eq!(layers_for_size(3), 2); - assert_eq!(layers_for_size(4), 2); - assert_eq!(layers_for_size(5), 3); - assert_eq!(layers_for_size(6), 3); - assert_eq!(layers_for_size(7), 3); - assert_eq!(layers_for_size(8), 3); + // Edge cases. + let empty: Vec = vec![]; + assert_eq!(layers_for_size(0), empty); + assert_eq!(layers_for_size(1), empty); + + // Pure powers of 2. + assert_eq!(layers_for_size(2), vec![2]); + assert_eq!(layers_for_size(4), vec![2, 2]); + assert_eq!(layers_for_size(8), vec![2, 2, 2]); + assert_eq!(layers_for_size(16), vec![2, 2, 2, 2]); + + // Pure powers of 3. + assert_eq!(layers_for_size(3), vec![3]); + assert_eq!(layers_for_size(9), vec![3, 3]); + + // Mixed: binary on top, ternary on bottom. + assert_eq!(layers_for_size(6), vec![2, 3]); // 2 * 3 + assert_eq!(layers_for_size(12), vec![2, 2, 3]); // 4 * 3 + assert_eq!(layers_for_size(18), vec![2, 3, 3]); // 2 * 9 + assert_eq!(layers_for_size(36), vec![2, 2, 3, 3]); // 4 * 9 + assert_eq!(layers_for_size(72), vec![2, 2, 2, 3, 3]); // 8 * 9 + + // Arity-13 sizes. + assert_eq!(layers_for_size(13), vec![13]); // 13 + assert_eq!(layers_for_size(26), vec![2, 13]); // 2 * 13 + assert_eq!(layers_for_size(39), vec![3, 13]); // 3 * 13 + assert_eq!(layers_for_size(78), vec![2, 3, 13]); // 2 * 3 * 13 + assert_eq!(layers_for_size(117), vec![3, 3, 13]); // 9 * 13 + assert_eq!(layers_for_size(104), vec![2, 2, 2, 13]); // 8 * 13 + + // Capacity equals size exactly (no over-provisioning). + for &size in &[ + 2, 3, 4, 6, 8, 9, 12, 13, 16, 18, 24, 26, 36, 39, 48, 72, 78, 96, 104, 117, 144, + ] { + let arities = layers_for_size(size); + let cap: usize = arities.iter().product(); + assert_eq!(cap, size, "size={size}, cap={cap}, arities={arities:?}"); + } + } + + #[test] + #[should_panic(expected = "not a smooth-{2,3,13} number")] + fn test_layers_for_size_rejects_non_smooth() { + layers_for_size(5); + } + + #[test] + fn test_num_nodes() { + // 0 layers: just the root = 1 node + let c0 = Config { + num_leaves: 1, + layers: vec![], + }; + assert_eq!(c0.num_nodes(), 1); + + // Binary: 2 layers → 4 leaves + 2 internal + 1 root = 7 + let c_bin = Config { + num_leaves: 4, + layers: vec![ + LayerConfig { + hash_id: BLAKE3, + arity: 2, + }, + LayerConfig { + hash_id: BLAKE3, + arity: 2, + }, + ], + }; + assert_eq!(c_bin.num_nodes(), 7); + + // Ternary: 1 layer → 3 leaves + 1 root = 4 + let c_ter = Config { + num_leaves: 3, + layers: vec![LayerConfig { + hash_id: BLAKE3, + arity: 3, + }], + }; + assert_eq!(c_ter.num_nodes(), 4); + + // Mixed: [2, 3] → leaf capacity = 6. Nodes: 6 + 2 + 1 = 9 + let c_mix = Config { + num_leaves: 6, + layers: vec![ + LayerConfig { + hash_id: BLAKE3, + arity: 2, + }, + LayerConfig { + hash_id: BLAKE3, + arity: 3, + }, + ], + }; + assert_eq!(c_mix.num_nodes(), 9); + } + + #[test] + fn test_merkle_tree_243_ternary() { + // 243 = 3^5, all ternary like the original test. + run_merkle_tree_test(243, &[3, 3, 3, 3, 3], &[13, 42]); + } + + #[test] + fn test_merkle_tree_256_mixed() { + // 256 = 2^8, all binary. + run_merkle_tree_test(256, &[2, 2, 2, 2, 2, 2, 2, 2], &[13, 42, 255]); } } diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index bc3dd192..1633cb5b 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -39,11 +39,11 @@ where impl Config { pub fn validate(&self) -> Result<(), &'static str> { ensure!( - self.initial_size.is_power_of_two(), - "Initial size must be power of two." + self.initial_size > 0 && self.initial_size.is_multiple_of(2) || self.num_rounds == 0, + "Initial size must be even (or num_rounds must be 0)." ); ensure!( - self.initial_size.ilog2() as usize >= self.num_rounds, + self.initial_size >> self.num_rounds >= 1, "Initial size must be >= 2^{rounds}." ); Ok(()) diff --git a/src/protocols/whir/config.rs b/src/protocols/whir/config.rs index ac427b13..04d2c3a1 100644 --- a/src/protocols/whir/config.rs +++ b/src/protocols/whir/config.rs @@ -8,6 +8,7 @@ use crate::{ bits::Bits, parameters::ProtocolParameters, protocols::{irs_commit, proof_of_work, sumcheck}, + smooth_domain::{extra_rounds, is_smooth}, type_info::Type, }; @@ -22,8 +23,8 @@ where M: Default, { assert!( - size.is_power_of_two(), - "Only powers of two size are supported at the moment." + is_smooth(size), + "Size must be a smooth-{{2,3,13}} number (2^a * 3^b * 13^c), got {size}." ); // Proof of work constructor with the requested hash function. @@ -38,7 +39,10 @@ where .saturating_sub(whir_parameters.pow_bits) as f64; let field_size_bits = M::Target::field_size_bits(); let mut log_inv_rate = whir_parameters.starting_log_inv_rate; + // `num_variables` counts binary variables only (the 2^a part). let mut num_variables = size.trailing_zeros() as usize; + // `current_size` tracks the actual polynomial size including 3^b factor. + let mut current_size = size; #[allow(clippy::cast_possible_wrap)] let initial_committer = irs_commit::Config::new( @@ -71,6 +75,7 @@ where let mut in_domain_samples = initial_committer.in_domain_samples; let mut query_error = initial_committer.rbr_queries(); num_variables -= whir_parameters.initial_folding_factor; + current_size >>= whir_parameters.initial_folding_factor; while num_variables >= whir_parameters.folding_factor { // Queries are set w.r.t. to old rate, while the rest to the new rate let round_folding_factor = if round == 0 { @@ -86,7 +91,7 @@ where whir_parameters.unique_decoding, whir_parameters.hash_id, 1, - 1 << num_variables, + current_size, 1 << whir_parameters.folding_factor, 0.5_f64.powi(next_rate as i32), ); @@ -109,7 +114,7 @@ where irs_committer, sumcheck: sumcheck::Config { field: Type::new(), - initial_size: 1 << num_variables, + initial_size: current_size, round_pow: pow(folding_pow_bits), num_rounds: whir_parameters.folding_factor, }, @@ -118,6 +123,7 @@ where round += 1; num_variables -= whir_parameters.folding_factor; + current_size >>= whir_parameters.folding_factor; log_inv_rate = next_rate; in_domain_samples = config.irs_committer.in_domain_samples; query_error = config.irs_committer.rbr_queries(); @@ -132,6 +138,13 @@ where let final_folding_pow_bits = 0_f64.max(security_level - field_size_bits + 1.0); + // After all binary folds: current_size = 3^b * 2^remaining_binary. + // For the final sumcheck, pad to the next power of two and run + // enough rounds to reduce to a single element. + let extra = extra_rounds(current_size); + let final_padded_size = 1usize << (num_variables + extra); + let final_num_rounds = num_variables + extra; + Self { initial_committer, initial_sumcheck: sumcheck::Config { @@ -144,9 +157,9 @@ where round_configs, final_sumcheck: sumcheck::Config { field: Type::new(), - initial_size: 1 << num_variables, + initial_size: final_padded_size, round_pow: pow(final_folding_pow_bits), - num_rounds: num_variables, + num_rounds: final_num_rounds, }, final_pow: pow(final_pow_bits), } @@ -270,11 +283,24 @@ where self.initial_committer.vector_size } + /// Number of binary sumcheck variables in the initial polynomial. + /// + /// For a smooth-domain size `2^a * 3^b * 13^c`, this returns `a`. The + /// total evaluation-point entries contributed by the initial round is `a` + /// (the `3^b * 13^c` residual is handled in the final sumcheck padding). pub fn initial_num_variables(&self) -> usize { - assert!(self.initial_size().is_power_of_two()); self.initial_size().trailing_zeros() as usize } + /// Total number of evaluation-point entries for the initial polynomial. + /// + /// For a smooth-domain size `2^a * 3^b * 13^c`, this returns `a + ceil(log2(3^b * 13^c))`. + /// The extra rounds come from the final sumcheck padding the `3^b * 13^c` + /// residual to a power of two. + pub fn total_eval_variables(&self) -> usize { + self.initial_num_variables() + extra_rounds(self.initial_size()) + } + pub const fn final_size(&self) -> usize { self.final_sumcheck.final_size() } @@ -447,14 +473,23 @@ impl RoundConfig { self.sumcheck.final_size() } + /// Number of binary sumcheck variables for this round. + /// + /// For a smooth-domain size `2^a * 3^b * 13^c`, this returns `a`. pub fn initial_num_variables(&self) -> usize { - assert!(self.irs_committer.vector_size.is_power_of_two()); - self.irs_committer.vector_size.ilog2() as usize + self.irs_committer.vector_size.trailing_zeros() as usize } pub fn final_num_variables(&self) -> usize { self.initial_num_variables() - self.sumcheck.num_rounds } + + /// Total number of evaluation-point entries for this round's polynomial. + /// + /// For a smooth-domain size `2^a * 3^b * 13^c`, returns `a + ceil(log2(3^b * 13^c))`. + pub fn total_eval_variables(&self) -> usize { + self.initial_num_variables() + extra_rounds(self.initial_size()) + } } impl Display for RoundConfig @@ -462,7 +497,7 @@ where F: FftField, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, " commit {}", self.irs_committer,)?; + writeln!(f, " commit {}", self.irs_committer)?; writeln!(f, " pow {:.2} bits", self.pow.difficulty())?; writeln!(f, " sumcheck {}", self.sumcheck) } diff --git a/src/protocols/whir/mod.rs b/src/protocols/whir/mod.rs index dd808518..1219a7cc 100644 --- a/src/protocols/whir/mod.rs +++ b/src/protocols/whir/mod.rs @@ -66,6 +66,13 @@ pub struct FinalClaim { /// Claimed value of the rlc of the mle of the linears forms in the point. /// Note: not computed on the prover side, set to zero instead. pub linear_form_rlc: F, + /// Number of binary (halving) folds before the final-sumcheck pad. + /// + /// For power-of-2 sizes this equals `evaluation_point.len()` (all folds + /// are binary). For smooth-domain sizes `2^a * 3^b * 13^c` this equals + /// the number of sumcheck rounds that halve even-length vectors, before + /// the final sumcheck pads the odd residual to a power of two. + pub binary_folds: usize, } impl FinalClaim { @@ -74,13 +81,60 @@ impl FinalClaim { linear_forms: impl IntoIterator>, ) -> VerificationResult<()> { let rlc = zip_strict(&self.rlc_coefficients, linear_forms) - .map(|(&c, l)| c * l.mle_evaluate(&self.evaluation_point)) + .map(|(&c, l)| { + c * fold_based_mle_evaluate(l, &self.evaluation_point, self.binary_folds) + }) .sum::(); verify!(rlc == self.linear_form_rlc); Ok(()) } } +/// Evaluate a linear form's MLE using the fold-based semantics that match +/// the sumcheck prover's actual fold sequence. +/// +/// For power-of-2 sizes this delegates to the fast `mle_evaluate`. For +/// smooth-domain sizes (where `N/2 ≠ 2^{n-1}`), the function materialises the +/// weight vector, performs `binary_folds` halvings (matching the sumcheck's +/// `fold` at `len/2`), pads the residual to a power of two, then finishes +/// with a standard `multilinear_extend`. +pub fn fold_based_mle_evaluate( + lf: &dyn LinearForm, + point: &[F], + binary_folds: usize, +) -> F { + use crate::algebra::multilinear_extend; + + let size = lf.size(); + + // Fast path: power-of-2 sizes have fold ≡ standard MLE. + if size.is_power_of_two() { + return lf.mle_evaluate(point); + } + + // Materialise the weight vector. + let mut w = vec![F::ZERO; size]; + lf.accumulate(&mut w, F::ONE); + + // Binary halvings (matching sumcheck's fold at len/2). + for j in 0..binary_folds { + let half = w.len() / 2; + let r = point[j]; + for i in 0..half { + let delta = (w[i + half] - w[i]) * r; + w[i] += delta; + } + w.truncate(half); + } + + // Pad residual to next power of two. + let padded = w.len().next_power_of_two(); + w.resize(padded, F::ZERO); + + // Standard MLE for remaining variables. + multilinear_extend(&w, &point[binary_folds..]) +} + impl Config where M: Embedding, diff --git a/src/protocols/whir/prover.rs b/src/protocols/whir/prover.rs index ab08ac1c..1c6bfda0 100644 --- a/src/protocols/whir/prover.rs +++ b/src/protocols/whir/prover.rs @@ -282,8 +282,18 @@ where vector_rlc_coeffs = vec![M::Target::ONE]; } + // Pad vector and covector to the final sumcheck's initial_size. + // For smooth-domain polynomials (2^a * 3^b * 13^c), the residual 3^b * 13^c elements + // are zero-padded to the next power of two before the final sumcheck. + let final_padded_size = self.final_sumcheck.initial_size; + vector.resize(final_padded_size, M::Target::ZERO); + if has_constraints { + covector.resize(final_padded_size, M::Target::ZERO); + } + debug_assert!(!has_constraints || dot(&vector, &covector) == the_sum); + // Directly send the vector to the verifier. - assert_eq!(vector.len(), self.final_sumcheck.initial_size); + assert_eq!(vector.len(), final_padded_size); for coeff in &vector { prover_state.prover_message(coeff); } @@ -311,10 +321,12 @@ where .prove(prover_state, &mut vector, &mut covector, &mut the_sum); evaluation_point.extend(final_folding_randomness.0.iter().copied()); + let binary_folds = evaluation_point.len() - self.final_sumcheck.num_rounds; FinalClaim { evaluation_point, rlc_coefficients: initial_forms_rlc_coeffs.to_vec(), linear_form_rlc: M::Target::ZERO, + binary_folds, } } } diff --git a/src/protocols/whir/verifier.rs b/src/protocols/whir/verifier.rs index cf6501bd..3707000b 100644 --- a/src/protocols/whir/verifier.rs +++ b/src/protocols/whir/verifier.rs @@ -2,12 +2,12 @@ use ark_ff::{AdditiveGroup, FftField, Field}; #[cfg(feature = "tracing")] use tracing::instrument; -use super::{Commitment, Config}; +use super::{fold_based_mle_evaluate, Commitment, Config}; use crate::{ algebra::{ dot, embedding::{Embedding, Identity}, - linear_form::{Evaluate, LinearForm, MultilinearExtension}, + linear_form::{Evaluate, MultilinearExtension}, tensor_product, MultilinearPoint, }, hash::Hash, @@ -238,23 +238,30 @@ where .evaluate(&Identity::new(), &final_vector); let mut linear_form_rlc = the_sum / poly_eval; + let final_num_rounds = self.final_sumcheck.num_rounds; + // Subtract all internal linear forms. for (round, (weights_rlc_coeffs, weights)) in round_constraints.into_iter().enumerate() { let num_variables = round.checked_sub(1).map_or_else( - || self.initial_num_variables(), - |p| self.round_configs[p].initial_num_variables(), + || self.total_eval_variables(), + |p| self.round_configs[p].total_eval_variables(), ); let start = evaluation_point.len().saturating_sub(num_variables); + let binary_folds = num_variables - final_num_rounds; for (rlc_coeff, weights) in zip_strict(weights_rlc_coeffs, weights) { - linear_form_rlc -= rlc_coeff * weights.mle_evaluate(&evaluation_point[start..]); + linear_form_rlc -= rlc_coeff + * fold_based_mle_evaluate(&weights, &evaluation_point[start..], binary_folds); } } + let binary_folds_initial = evaluation_point.len() - final_num_rounds; + // Return the evaluation point and the claimed values of the deferred weights. Ok(FinalClaim { evaluation_point, rlc_coefficients: initial_form_rlc_coeffs.to_vec(), linear_form_rlc, + binary_folds: binary_folds_initial, }) } } diff --git a/src/protocols/whir_zk/utils.rs b/src/protocols/whir_zk/utils.rs index 779150cf..6fde60b2 100644 --- a/src/protocols/whir_zk/utils.rs +++ b/src/protocols/whir_zk/utils.rs @@ -227,7 +227,15 @@ pub fn fold_weight_to_mask_size( owned = Covector::from(weight); &owned.vector }; - debug_assert_eq!(vector.len(), 1usize << num_witness_variables); + // For smooth-domain sizes (2^a * 3^b * 13^c), the vector length equals the full + // polynomial size which is >= 2^num_witness_variables. + debug_assert!( + vector.len() >= 1usize << num_witness_variables, + "weight vector too small: {} < 2^{} = {}", + vector.len(), + num_witness_variables, + 1usize << num_witness_variables + ); fold_vector_to_mask_size(vector, mask_size) } diff --git a/src/protocols/whir_zk/verifier.rs b/src/protocols/whir_zk/verifier.rs index 13fe76fa..b5310c41 100644 --- a/src/protocols/whir_zk/verifier.rs +++ b/src/protocols/whir_zk/verifier.rs @@ -113,9 +113,10 @@ impl Config { .map(|(&eval, &m_eval)| eval + m_eval) .collect(); - self.blinded_commitment - .verify(verifier_state, &commitments, &modified_evaluations)? - .verify(weights.iter().copied())?; + let final_claim = + self.blinded_commitment + .verify(verifier_state, &commitments, &modified_evaluations)?; + final_claim.verify(weights.iter().copied())?; verify!(batched_h_claims == expected_batched_h_claims); let g_hat_slices: Vec<&[F]> = g_hat_claims_per_poly.iter().map(Vec::as_slice).collect(); diff --git a/src/smooth_domain.rs b/src/smooth_domain.rs new file mode 100644 index 00000000..d8852981 --- /dev/null +++ b/src/smooth_domain.rs @@ -0,0 +1,290 @@ +//! Utilities for smooth-{2,3,13} numbers: values of the form `2^a * 3^b * 13^c`. +//! +//! BN254 Fr has `r-1 = 2^28 * 3^2 * 13 * …`, so NTT domains can be +//! `2^a * 3^b * 13^c` with `a ≤ 28`, `b ≤ 2`, and `c ≤ 1`. +//! Using smooth sizes instead of strict powers of two lets us pad closer +//! to the actual constraint count. + +/// Returns `true` if `n` is a smooth-{2,3,13} number, i.e. `n = 2^a * 3^b * 13^c`. +/// +/// `0` is NOT smooth. +#[inline] +pub const fn is_smooth(n: usize) -> bool { + if n == 0 { + return false; + } + let mut v = n; + while v % 2 == 0 { + v /= 2; + } + while v % 3 == 0 { + v /= 3; + } + while v % 13 == 0 { + v /= 13; + } + v == 1 +} + +/// Returns the odd part of `n`: after removing all factors of 2, +/// the remaining value `3^b * 13^c`. +#[inline] +pub const fn odd_part(n: usize) -> usize { + assert!(n > 0); + let mut v = n; + while v % 2 == 0 { + v /= 2; + } + v +} + +/// Number of extra sumcheck rounds needed for the odd residual: +/// `ceil(log2(odd_part))` where `odd_part = 3^b * 13^c`. +#[inline] +pub const fn extra_rounds(n: usize) -> usize { + let odd = odd_part(n); + if odd == 1 { + 0 + } else { + // ceil(log2(odd)) = 64 - (odd - 1).leading_zeros() + (usize::BITS - (odd - 1).leading_zeros()) as usize + } +} + +/// Next power of two that is ≥ the odd part. +/// Equivalently `1 << extra_rounds(n)`. +#[inline] +pub const fn odd_part_padded(n: usize) -> usize { + 1 << extra_rounds(n) +} + +/// Padded size: the power-of-two size that the final sumcheck operates on +/// for a smooth polynomial of size `n`. +/// +/// Equal to `2^(a + ceil(log2(odd)))` where `n = 2^a * odd`. +#[inline] +pub const fn padded_size(n: usize) -> usize { + let a = n.trailing_zeros() as usize; + let extra = extra_rounds(n); + 1 << (a + extra) +} + +/// Find the smallest smooth-{2,3,13} number ≥ `n`. +/// +/// Enumerates all `2^a * 3^b * 13^c` with `b ≤ 2`, `c ≤ 1`. +pub const fn next_smooth(n: usize) -> usize { + if n <= 1 { + return n; + } + let mut best = n.next_power_of_two(); // worst case: pure power of 2 + + // Try each combination of 3^b * 13^c + let mut c: usize = 0; + while c <= 1 { + let pow13 = if c == 0 { 1 } else { 13 }; + let mut pow3: usize = 1; + let mut b: usize = 0; + while b <= 2 { + let odd = pow3 * pow13; + // Smallest 2^a such that 2^a * odd >= n + let needed = (n + odd - 1) / odd; // ceil(n / odd) + let pow2 = needed.next_power_of_two(); + let candidate = pow2 * odd; + if candidate < best { + best = candidate; + } + pow3 *= 3; + b += 1; + } + c += 1; + } + best +} + +/// Decompose a smooth-{2,3,13} number into `(a, b, c)` where `n = 2^a * 3^b * 13^c`. +/// +/// Panics if `n` is not smooth-{2,3,13}. +#[inline] +pub const fn decompose(n: usize) -> (usize, usize, usize) { + assert!(n > 0); + let a = n.trailing_zeros() as usize; + let mut odd = n >> a; + let mut b = 0; + while odd % 3 == 0 { + odd /= 3; + b += 1; + } + let mut c = 0; + while odd % 13 == 0 { + odd /= 13; + c += 1; + } + assert!(odd == 1, "not a smooth-{{2,3,13}} number"); + (a, b, c) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_smooth() { + assert!(!is_smooth(0)); + assert!(is_smooth(1)); + assert!(is_smooth(2)); + assert!(is_smooth(3)); + assert!(is_smooth(4)); + assert!(!is_smooth(5)); + assert!(is_smooth(6)); + assert!(!is_smooth(7)); + assert!(is_smooth(8)); + assert!(is_smooth(9)); + assert!(!is_smooth(10)); + assert!(is_smooth(12)); + assert!(is_smooth(13)); // 13^1 + assert!(is_smooth(16)); + assert!(is_smooth(18)); + assert!(is_smooth(24)); + assert!(is_smooth(26)); // 2 * 13 + assert!(is_smooth(39)); // 3 * 13 + assert!(is_smooth(78)); // 2 * 3 * 13 + assert!(is_smooth(117)); // 9 * 13 + assert!(!is_smooth(15)); + assert!(!is_smooth(20)); + assert!(is_smooth(1 << 28)); + assert!(is_smooth(9 * (1 << 18))); + assert!(is_smooth(13 * (1 << 10))); + assert!(is_smooth(117 * (1 << 5))); // 9 * 13 * 32 + } + + #[test] + fn test_odd_part() { + assert_eq!(odd_part(1), 1); + assert_eq!(odd_part(2), 1); + assert_eq!(odd_part(4), 1); + assert_eq!(odd_part(8), 1); + assert_eq!(odd_part(3), 3); + assert_eq!(odd_part(6), 3); + assert_eq!(odd_part(12), 3); + assert_eq!(odd_part(9), 9); + assert_eq!(odd_part(18), 9); + assert_eq!(odd_part(36), 9); + assert_eq!(odd_part(9 * (1 << 18)), 9); + assert_eq!(odd_part(13), 13); + assert_eq!(odd_part(26), 13); + assert_eq!(odd_part(39), 39); + assert_eq!(odd_part(78), 39); + assert_eq!(odd_part(117), 117); + assert_eq!(odd_part(117 * 4), 117); + } + + #[test] + fn test_extra_rounds() { + // odd part = 1: no extra rounds + assert_eq!(extra_rounds(1), 0); + assert_eq!(extra_rounds(2), 0); + assert_eq!(extra_rounds(1024), 0); + + // odd part = 3: ceil(log2(3)) = 2 + assert_eq!(extra_rounds(3), 2); + assert_eq!(extra_rounds(6), 2); + + // odd part = 9: ceil(log2(9)) = 4 + assert_eq!(extra_rounds(9), 4); + assert_eq!(extra_rounds(18), 4); + + // odd part = 13: ceil(log2(13)) = 4 + assert_eq!(extra_rounds(13), 4); + assert_eq!(extra_rounds(26), 4); + + // odd part = 39 = 3*13: ceil(log2(39)) = 6 + assert_eq!(extra_rounds(39), 6); + assert_eq!(extra_rounds(78), 6); + + // odd part = 117 = 9*13: ceil(log2(117)) = 7 + assert_eq!(extra_rounds(117), 7); + assert_eq!(extra_rounds(234), 7); + } + + #[test] + fn test_odd_part_padded() { + assert_eq!(odd_part_padded(1), 1); + assert_eq!(odd_part_padded(8), 1); + assert_eq!(odd_part_padded(3), 4); + assert_eq!(odd_part_padded(6), 4); + assert_eq!(odd_part_padded(9), 16); + assert_eq!(odd_part_padded(18), 16); + assert_eq!(odd_part_padded(13), 16); // ceil(log2(13)) = 4, 2^4 = 16 + assert_eq!(odd_part_padded(26), 16); + assert_eq!(odd_part_padded(39), 64); // ceil(log2(39)) = 6, 2^6 = 64 + assert_eq!(odd_part_padded(117), 128); // ceil(log2(117)) = 7, 2^7 = 128 + } + + #[test] + fn test_padded_size() { + assert_eq!(padded_size(1), 1); + assert_eq!(padded_size(8), 8); + assert_eq!(padded_size(1024), 1024); + // 3 * 2^2 → padded = 2^(2+2) = 16 + assert_eq!(padded_size(12), 16); + // 9 * 2^1 → padded = 2^(1+4) = 32 + assert_eq!(padded_size(18), 32); + // 13 * 2^1 → padded = 2^(1+4) = 32 + assert_eq!(padded_size(26), 32); + // 39 * 2^1 → padded = 2^(1+6) = 128 + assert_eq!(padded_size(78), 128); + // 117 * 2^2 → padded = 2^(2+7) = 512 + assert_eq!(padded_size(468), 512); + } + + #[test] + fn test_next_smooth() { + assert_eq!(next_smooth(0), 0); + assert_eq!(next_smooth(1), 1); + assert_eq!(next_smooth(2), 2); + assert_eq!(next_smooth(3), 3); + assert_eq!(next_smooth(4), 4); + assert_eq!(next_smooth(5), 6); + assert_eq!(next_smooth(6), 6); + assert_eq!(next_smooth(7), 8); + assert_eq!(next_smooth(8), 8); + assert_eq!(next_smooth(9), 9); + assert_eq!(next_smooth(10), 12); + assert_eq!(next_smooth(11), 12); + assert_eq!(next_smooth(12), 12); + assert_eq!(next_smooth(13), 13); // now 13 is smooth! + assert_eq!(next_smooth(14), 16); + assert_eq!(next_smooth(17), 18); + assert_eq!(next_smooth(18), 18); + assert_eq!(next_smooth(19), 24); + } + + #[test] + fn test_decompose() { + assert_eq!(decompose(1), (0, 0, 0)); + assert_eq!(decompose(2), (1, 0, 0)); + assert_eq!(decompose(3), (0, 1, 0)); + assert_eq!(decompose(4), (2, 0, 0)); + assert_eq!(decompose(6), (1, 1, 0)); + assert_eq!(decompose(8), (3, 0, 0)); + assert_eq!(decompose(9), (0, 2, 0)); + assert_eq!(decompose(12), (2, 1, 0)); + assert_eq!(decompose(13), (0, 0, 1)); + assert_eq!(decompose(18), (1, 2, 0)); + assert_eq!(decompose(26), (1, 0, 1)); + assert_eq!(decompose(39), (0, 1, 1)); + assert_eq!(decompose(78), (1, 1, 1)); + assert_eq!(decompose(117), (0, 2, 1)); + assert_eq!(decompose(9 * (1 << 18)), (18, 2, 0)); + assert_eq!(decompose(13 * (1 << 10)), (10, 0, 1)); + } + + #[test] + fn test_next_smooth_all_results_are_smooth() { + for n in 1..=10_000 { + let s = next_smooth(n); + assert!(s >= n, "next_smooth({n}) = {s} < {n}"); + assert!(is_smooth(s), "next_smooth({n}) = {s} is not smooth"); + } + } +} diff --git a/src/utils.rs b/src/utils.rs index f37b6145..03beacb7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -22,15 +22,21 @@ pub const fn workload_size() -> usize { #[cfg(all(target_arch = "aarch64", target_os = "macos"))] const CACHE_SIZE: usize = 1 << 17; // 128KB for Apple Silicon - #[cfg(all(target_arch = "aarch64", any(target_os = "ios", target_os = "android")))] - const CACHE_SIZE: usize = 1 << 16; // 64KB for mobile ARM + #[cfg(all( + target_arch = "aarch64", + any(target_os = "ios", target_os = "android", target_os = "linux") + ))] + const CACHE_SIZE: usize = 1 << 16; // 64KB for mobile/server ARM #[cfg(target_arch = "x86_64")] const CACHE_SIZE: usize = 1 << 15; // 32KB for x86-64 #[cfg(not(any( all(target_arch = "aarch64", target_os = "macos"), - all(target_arch = "aarch64", any(target_os = "ios", target_os = "android")), + all( + target_arch = "aarch64", + any(target_os = "ios", target_os = "android", target_os = "linux") + ), target_arch = "x86_64" )))] const CACHE_SIZE: usize = 1 << 15; // 32KB default