diff --git a/Cargo.toml b/Cargo.toml index 25e759db..656ae418 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,11 +80,18 @@ parallel = [ rayon = ["dep:rayon"] asm = ["ark-ff/asm"] tracing = ["dep:tracing"] +# Enable per-step allocation tracking in ZK WHIR. +# Run: cargo run --bin alloc_report --features alloc-track +alloc-track = [] [[bench]] name = "expand_from_coeff" harness = false +[[bench]] +name = "whir_zk" +harness = false + [[bench]] name = "sumcheck" harness = false diff --git a/benches/whir_zk.rs b/benches/whir_zk.rs new file mode 100644 index 00000000..51c42088 --- /dev/null +++ b/benches/whir_zk.rs @@ -0,0 +1,465 @@ +//! Benchmark: ZK v1 vs ZK v2 WHIR proving (2 polynomials). +//! +//! Run with: +//! cargo bench --bench whir_zk +//! +//! Or filter to a specific group: +//! cargo bench --bench whir_zk -- zk_v1 +//! cargo bench --bench whir_zk -- zk_v2 + +use ark_std::rand::{rngs::StdRng, SeedableRng}; +use divan::{black_box, AllocProfiler, Bencher}; +use whir::{ + algebra::{ + fields::{Field64, Field64_2}, + polynomials::{CoefficientList, MultilinearPoint}, + Weights, + }, + hash, + parameters::{FoldingFactor, MultivariateParameters, ProtocolParameters, SoundnessType}, + protocols::{ + whir::Config, + whir_zk::{utils::ZkPreprocessingPolynomials, ZkParams}, + }, + transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, +}; + +#[global_allocator] +static ALLOC: AllocProfiler = AllocProfiler::system(); + +type F = Field64; +type EF = Field64_2; + +/// Polynomial sizes to benchmark (log₂ of number of coefficients). +const SIZES: &[usize] = &[20]; + +/// Number of polynomials for batched benchmarks. +const NUM_POLYS: usize = 2; + +// ──────────────────────────────────────────────────────────────────────────── +// Shared setup helpers +// ──────────────────────────────────────────────────────────────────────────── + +/// Build `n` deterministic polynomials with distinct coefficients. +fn make_polynomials(num_variables: usize, n: usize) -> Vec> { + let num_coeffs = 1usize << num_variables; + (0..n) + .map(|i| { + CoefficientList::new( + (0..num_coeffs) + .map(|j| F::from((i * num_coeffs + j + 1) as u64)) + .collect(), + ) + }) + .collect() +} + +/// Build weights + evaluations for multiple polynomials. +/// Layout: row-major [w₀_p₀, w₀_p₁, …] (one eval per polynomial per weight). +fn make_weights_and_evaluations_multi( + polynomials: &[CoefficientList], + config: &Config, + num_variables: usize, +) -> (Vec>, Vec) { + let mut rng = StdRng::seed_from_u64(0xBEEF); + let point = MultilinearPoint::rand(&mut rng, num_variables); + let mut evaluations = Vec::with_capacity(polynomials.len()); + for poly in polynomials { + evaluations.push(poly.mixed_evaluate(config.embedding(), &point)); + } + let weights = vec![Weights::evaluation(point)]; + (weights, evaluations) +} + +/// Build N independent ZK preprocessing polynomials. +fn make_preprocessings(n: usize, zk_params: &ZkParams) -> Vec> { + let mut rng = StdRng::seed_from_u64(42); + (0..n) + .map(|_| ZkPreprocessingPolynomials::::sample(&mut rng, zk_params.clone())) + .collect() +} + +// ──────────────────────────────────────────────────────────────────────────── +// ZK v1 helpers +// ──────────────────────────────────────────────────────────────────────────── + +/// ZK v1: WHIR config for committing, μ+1 variables. +/// `batch_size` = number of polynomials × 2 (each poly contributes f̂ and g). +fn zk_v1_commit_config(num_variables: usize, num_polynomials: usize) -> Config { + let extended = num_variables + 1; + let mv = MultivariateParameters::new(extended); + let params = ProtocolParameters { + initial_statement: true, + security_level: 32, + pow_bits: 0, + folding_factor: FoldingFactor::ConstantFromSecondRound(4, 4), + soundness_type: SoundnessType::ConjectureList, + starting_log_inv_rate: 1, + batch_size: 2 * num_polynomials, + hash_id: hash::SHA2, + }; + Config::new(mv, ¶ms) +} + +/// ZK v1: WHIR config for proving P₁..Pₙ, μ+1 variables. +/// `batch_size` = number of P polynomials to prove. +fn zk_v1_prove_config(num_variables: usize, num_polynomials: usize) -> Config { + let extended = num_variables + 1; + let mv = MultivariateParameters::new(extended); + let params = ProtocolParameters { + initial_statement: true, + security_level: 32, + pow_bits: 0, + folding_factor: FoldingFactor::ConstantFromSecondRound(4, 4), + soundness_type: SoundnessType::ConjectureList, + starting_log_inv_rate: 1, + batch_size: num_polynomials, + hash_id: hash::SHA2, + }; + Config::new(mv, ¶ms) +} + +/// ZK v1 polynomial bundle: f̂(x,y) = f(x) + y·msk(x), random g(x,y), P = ρ·f̂ + g. +struct ZkV1Polys { + f_hat: CoefficientList, + g_poly: CoefficientList, + p_poly: CoefficientList, +} + +/// Build N v1 polynomial bundles: for each i, f̂_i, g_i, P_i = ρ·f̂_i + g_i. +fn make_zk_v1_polys(num_variables: usize, n: usize) -> Vec { + use ark_std::UniformRand; + + let mut rng = StdRng::seed_from_u64(0xCAFE); + let num_coeffs = 1usize << num_variables; + let extended_num_coeffs = 1usize << (num_variables + 1); + let rho = F::rand(&mut rng); + + (0..n) + .map(|i| { + // Deterministic base polynomial (distinct per i). + let base_coeffs: Vec = (0..num_coeffs) + .map(|j| F::from((i * num_coeffs + j + 1) as u64)) + .collect(); + + // f̂_i(x,y) = base_i(x) + y·msk_i(x) + let mut f_hat_coeffs = vec![F::from(0u64); extended_num_coeffs]; + for (j, &c) in base_coeffs.iter().enumerate() { + f_hat_coeffs[j] = c; + } + for j in 0..num_coeffs { + f_hat_coeffs[num_coeffs + j] = F::rand(&mut rng); + } + let f_hat = CoefficientList::new(f_hat_coeffs); + + // Random g_i(x,y) + let g_coeffs: Vec = (0..extended_num_coeffs) + .map(|_| F::rand(&mut rng)) + .collect(); + let g_poly = CoefficientList::new(g_coeffs); + + // P_i = ρ·f̂_i + g_i + let p_coeffs: Vec = f_hat + .coeffs() + .iter() + .zip(g_poly.coeffs().iter()) + .map(|(&fh, &gv)| rho * fh + gv) + .collect(); + let p_poly = CoefficientList::new(p_coeffs); + + ZkV1Polys { + f_hat, + g_poly, + p_poly, + } + }) + .collect() +} + +/// Build weights + evaluations for multiple (μ+1)-variable P polynomials at (ā, 0). +/// Evaluations layout: row-major [w₀_P₀, w₀_P₁, …]. +fn make_zk_v1_weights_and_evaluations( + p_polys: &[CoefficientList], + config: &Config, + num_variables: usize, +) -> (Vec>, Vec) { + let mut rng = StdRng::seed_from_u64(0xBEEF); + let base_point = MultilinearPoint::rand(&mut rng, num_variables); + let mut coords = base_point.0; + coords.push(EF::from(0u64)); // y = 0 + let extended_point = MultilinearPoint(coords); + let mut evaluations = Vec::with_capacity(p_polys.len()); + for p in p_polys { + evaluations.push(p.mixed_evaluate(config.embedding(), &extended_point)); + } + (vec![Weights::evaluation(extended_point)], evaluations) +} + +// ──────────────────────────────────────────────────────────────────────────── +// ZK v2 helpers +// ──────────────────────────────────────────────────────────────────────────── + +/// ZK v2 main WHIR configuration (round-0 fold = 2 for small k). +fn zk_main_config(num_variables: usize) -> Config { + let mv = MultivariateParameters::new(num_variables); + let params = ProtocolParameters { + initial_statement: true, + security_level: 32, + pow_bits: 0, + folding_factor: FoldingFactor::ConstantFromSecondRound(2, 4), + soundness_type: SoundnessType::ConjectureList, + starting_log_inv_rate: 1, + batch_size: 1, + hash_id: hash::SHA2, + }; + Config::new(mv, ¶ms) +} + +/// ZK v2 helper WHIR configuration, tuned for the given ZK params and number of polynomials. +fn zk_helper_config(zk_params: &ZkParams, num_polynomials: usize) -> Config { + let helper_vars = zk_params.ell + 1; + let mv = MultivariateParameters::new(helper_vars); + let params = ProtocolParameters { + initial_statement: true, + security_level: 32, + pow_bits: 0, + folding_factor: FoldingFactor::Constant(1), + soundness_type: SoundnessType::ConjectureList, + starting_log_inv_rate: 1, + batch_size: zk_params.helper_batch_size(num_polynomials), + hash_id: hash::SHA2, + }; + Config::new(mv, ¶ms) +} + +// ──────────────────────────────────────────────────────────────────────────── +// ZK v1 benchmarks – 2 polynomials (batched) +// ──────────────────────────────────────────────────────────────────────────── + +/// Commit [f̂₁, g₁, f̂₂, g₂] with batch_size=4. +#[divan::bench(args = SIZES)] +fn zk_v1_commit(bencher: Bencher, num_variables: usize) { + let bundles = make_zk_v1_polys(num_variables, NUM_POLYS); + let commit_config = zk_v1_commit_config(num_variables, NUM_POLYS); + let ds = DomainSeparator::protocol(&commit_config) + .session(&format!("bench-zk-v1-commit-{num_variables}")) + .instance(&Empty); + + // Flatten: [f̂₁, g₁, f̂₂, g₂] + let commit_polys: Vec<&CoefficientList> = + bundles.iter().flat_map(|b| [&b.f_hat, &b.g_poly]).collect(); + + bencher + .with_inputs(|| ProverState::new_std(&ds)) + .bench_values(|mut prover_state| { + let _ = black_box(commit_config.commit(&mut prover_state, &commit_polys)); + }); +} + +/// Prove [P₁, P₂] with batch_size=2, μ+1 variables. +#[divan::bench(args = SIZES)] +fn zk_v1_prove(bencher: Bencher, num_variables: usize) { + let bundles = make_zk_v1_polys(num_variables, NUM_POLYS); + let prove_config = zk_v1_prove_config(num_variables, NUM_POLYS); + let p_polys: Vec> = bundles.iter().map(|b| b.p_poly.clone()).collect(); + let (weights, evaluations) = + make_zk_v1_weights_and_evaluations(&p_polys, &prove_config, num_variables); + let weight_refs: Vec<&Weights> = weights.iter().collect(); + + let ds = DomainSeparator::protocol(&prove_config) + .session(&format!("bench-zk-v1-prove-{num_variables}")) + .instance(&Empty); + + let p_refs: Vec<&CoefficientList> = p_polys.iter().collect(); + + bencher + .with_inputs(|| { + let mut prover_state = ProverState::new_std(&ds); + let witness = prove_config.commit(&mut prover_state, &p_refs); + (prover_state, witness) + }) + .bench_values(|(mut prover_state, witness)| { + black_box(prove_config.prove( + &mut prover_state, + &p_refs, + &[&witness], + &weight_refs, + &evaluations, + )); + }); +} + +/// Verify [P₁, P₂] via standard WHIR. +#[divan::bench(args = SIZES)] +fn zk_v1_verify(bencher: Bencher, num_variables: usize) { + let bundles = make_zk_v1_polys(num_variables, NUM_POLYS); + let prove_config = zk_v1_prove_config(num_variables, NUM_POLYS); + let p_polys: Vec> = bundles.iter().map(|b| b.p_poly.clone()).collect(); + let (weights, evaluations) = + make_zk_v1_weights_and_evaluations(&p_polys, &prove_config, num_variables); + let weight_refs: Vec<&Weights> = weights.iter().collect(); + + let ds = DomainSeparator::protocol(&prove_config) + .session(&format!("bench-zk-v1-verify-{num_variables}")) + .instance(&Empty); + + let p_refs: Vec<&CoefficientList> = p_polys.iter().collect(); + + // Generate a proof once. + let proof = { + let mut ps = ProverState::new_std(&ds); + let w = prove_config.commit(&mut ps, &p_refs); + prove_config.prove(&mut ps, &p_refs, &[&w], &weight_refs, &evaluations); + ps.proof() + }; + + bencher + .with_inputs(|| { + let mut vs = VerifierState::new_std(&ds, &proof); + let commitment = prove_config.receive_commitment(&mut vs).unwrap(); + (vs, commitment) + }) + .bench_values(|(mut vs, commitment)| { + black_box( + prove_config + .verify(&mut vs, &[&commitment], &weight_refs, &evaluations) + .unwrap(), + ); + }); +} + +// ──────────────────────────────────────────────────────────────────────────── +// ZK v2 benchmarks – 2 polynomials (batched) +// ──────────────────────────────────────────────────────────────────────────── + +#[divan::bench(args = SIZES)] +fn zk_v2_commit(bencher: Bencher, num_variables: usize) { + let polynomials = make_polynomials(num_variables, NUM_POLYS); + let config = zk_main_config(num_variables); + let zk_params = ZkParams::from_whir_params(&config); + let helper_config = zk_helper_config(&zk_params, NUM_POLYS); + let preprocessings = make_preprocessings(NUM_POLYS, &zk_params); + + let ds = DomainSeparator::protocol(&config) + .session(&format!("bench-zk-v2-commit-{num_variables}")) + .instance(&Empty); + + let preproc_refs: Vec<&ZkPreprocessingPolynomials> = preprocessings.iter().collect(); + + bencher + .with_inputs(|| ProverState::new_std(&ds)) + .bench_values(|mut prover_state| { + let poly_refs: Vec<&CoefficientList> = polynomials.iter().collect(); + black_box(config.commit_zk( + &mut prover_state, + &poly_refs, + &helper_config, + &preproc_refs, + )); + }); +} + +#[divan::bench(args = SIZES)] +fn zk_v2_prove(bencher: Bencher, num_variables: usize) { + let polynomials = make_polynomials(num_variables, NUM_POLYS); + let config = zk_main_config(num_variables); + let zk_params = ZkParams::from_whir_params(&config); + let helper_config = zk_helper_config(&zk_params, NUM_POLYS); + let preprocessings = make_preprocessings(NUM_POLYS, &zk_params); + + let (weights, evaluations) = + make_weights_and_evaluations_multi(&polynomials, &config, num_variables); + let weight_refs: Vec<&Weights> = weights.iter().collect(); + + let ds = DomainSeparator::protocol(&config) + .session(&format!("bench-zk-v2-prove-{num_variables}")) + .instance(&Empty); + + let preproc_refs: Vec<&ZkPreprocessingPolynomials> = preprocessings.iter().collect(); + + bencher + .with_inputs(|| { + let mut prover_state = ProverState::new_std(&ds); + let poly_refs: Vec<&CoefficientList> = polynomials.iter().collect(); + let zk_witness = + config.commit_zk(&mut prover_state, &poly_refs, &helper_config, &preproc_refs); + (prover_state, zk_witness) + }) + .bench_values(|(mut prover_state, zk_witness)| { + let poly_refs: Vec<&CoefficientList> = polynomials.iter().collect(); + black_box(config.prove_zk( + &mut prover_state, + &poly_refs, + &zk_witness, + &helper_config, + &weight_refs, + &evaluations, + )); + }); +} + +#[divan::bench(args = SIZES)] +fn zk_v2_verify(bencher: Bencher, num_variables: usize) { + let polynomials = make_polynomials(num_variables, NUM_POLYS); + let config = zk_main_config(num_variables); + let zk_params = ZkParams::from_whir_params(&config); + let helper_config = zk_helper_config(&zk_params, NUM_POLYS); + let preprocessings = make_preprocessings(NUM_POLYS, &zk_params); + + let (weights, evaluations) = + make_weights_and_evaluations_multi(&polynomials, &config, num_variables); + let weight_refs: Vec<&Weights> = weights.iter().collect(); + + let ds = DomainSeparator::protocol(&config) + .session(&format!("bench-zk-v2-verify-{num_variables}")) + .instance(&Empty); + + let preproc_refs: Vec<&ZkPreprocessingPolynomials> = preprocessings.iter().collect(); + + // Generate a proof once (outside the benchmark loop). + let proof = { + let mut ps = ProverState::new_std(&ds); + let poly_refs: Vec<&CoefficientList> = polynomials.iter().collect(); + let zk_witness = config.commit_zk(&mut ps, &poly_refs, &helper_config, &preproc_refs); + config.prove_zk( + &mut ps, + &poly_refs, + &zk_witness, + &helper_config, + &weight_refs, + &evaluations, + ); + ps.proof() + }; + + bencher + .with_inputs(|| { + let mut vs = VerifierState::new_std(&ds, &proof); + let mut f_hat_commitments = Vec::with_capacity(NUM_POLYS); + for _ in 0..NUM_POLYS { + f_hat_commitments.push(config.receive_commitment(&mut vs).unwrap()); + } + let helper_commitment = helper_config.receive_commitment(&mut vs).unwrap(); + (vs, f_hat_commitments, helper_commitment) + }) + .bench_values(|(mut vs, f_hat_commitments, helper_commitment)| { + let f_hat_refs: Vec<_> = f_hat_commitments.iter().collect(); + black_box( + config + .verify_zk( + &mut vs, + &f_hat_refs, + &helper_commitment, + &helper_config, + &zk_params, + &weight_refs, + &evaluations, + ) + .unwrap(), + ); + }); +} + +fn main() { + divan::main(); +} diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs index 898826f6..c8686317 100644 --- a/src/algebra/mod.rs +++ b/src/algebra/mod.rs @@ -6,7 +6,7 @@ pub mod polynomials; pub mod sumcheck; mod weights; -use ark_ff::{AdditiveGroup, Field}; +use ark_ff::{AdditiveGroup, FftField, Field}; #[cfg(feature = "parallel")] use rayon::prelude::*; @@ -54,6 +54,13 @@ pub fn lift(embedding: &M, source: &[M::Source]) -> Vec result } +/// Scalar-mul add (same-field AXPY) +/// +/// accumulator[i] += weight * vector[i] +pub fn scalar_mul_add(accumulator: &mut [F], weight: F, vector: &[F]) { + mixed_scalar_mul_add(&embedding::Identity::new(), accumulator, weight, vector); +} + /// Mixed scalar-mul add /// /// accumulator[i] += weight * vector[i] @@ -119,3 +126,73 @@ pub fn mixed_dot( result } + +/// Project an extension field element to its base prime field component. +/// +/// Panics if the element does not lie in the base prime subfield. +#[inline] +pub fn project_to_base(val: F) -> F::BasePrimeField { + val.to_base_prime_field_elements() + .next() + .expect("element should lie in base prime subfield") +} + +/// Project every element of an extension-field slice to the base prime field. +/// +/// Panics if any element does not lie in the base prime subfield. +pub fn project_all_to_base(coeffs: &[F]) -> Vec { + #[cfg(feature = "parallel")] + { + coeffs.par_iter().map(|c| project_to_base(*c)).collect() + } + #[cfg(not(feature = "parallel"))] + { + coeffs.iter().map(|&c| project_to_base(c)).collect() + } +} + +/// Element-wise add a base-field slice with a (possibly shorter) extension-field +/// slice projected to base field. +/// +/// Computes `result[i] = base[i] + project_to_base(ext[i])` for `i < ext.len()`, +/// and `result[i] = base[i]` for `i >= ext.len()`. +/// +/// Each element of `ext` must lie in the base prime subfield. +pub fn add_base_with_projection( + base: &[F::BasePrimeField], + ext_addend: &[F], +) -> Vec { + debug_assert!( + ext_addend.len() <= base.len(), + "ext_addend ({}) must not exceed base ({})", + ext_addend.len(), + base.len(), + ); + let ext_len = ext_addend.len(); + + #[cfg(feature = "parallel")] + { + (0..base.len()) + .into_par_iter() + .map(|i| { + if i < ext_len { + base[i] + project_to_base(ext_addend[i]) + } else { + base[i] + } + }) + .collect() + } + #[cfg(not(feature = "parallel"))] + { + (0..base.len()) + .map(|i| { + if i < ext_len { + base[i] + project_to_base(ext_addend[i]) + } else { + base[i] + } + }) + .collect() + } +} diff --git a/src/algebra/polynomials/coeffs.rs b/src/algebra/polynomials/coeffs.rs index a5987466..7f39d2f4 100644 --- a/src/algebra/polynomials/coeffs.rs +++ b/src/algebra/polynomials/coeffs.rs @@ -80,6 +80,26 @@ impl CoefficientList { } } + /// Embed an ℓ-variate polynomial into an n-variate polynomial (n ≥ ℓ) + /// by treating the extra variables as having zero contribution. + /// + /// Coefficient at index `i` in the ℓ-variate maps to index `i * 2^(n-ℓ)` + /// in the n-variate, with all other coefficients set to zero. + pub fn embed_into_variables(&self, n: usize) -> Self { + let ell = self.num_variables; + assert!(n >= ell); + + let factor = 1 << (n - ell); + let new_size = 1 << n; + let mut coeffs = vec![F::ZERO; new_size]; + + for (i, &c) in self.coeffs.iter().enumerate() { + coeffs[i * factor] = c; + } + + Self::new(coeffs) + } + /// Evaluates the polynomial at an arbitrary point in `F^n`. /// /// This generalizes evaluation beyond `(0,1)^n`, allowing fractional or arbitrary field @@ -137,6 +157,33 @@ impl CoefficientList { num_variables: self.num_variables() - folding_factor, } } + + /// Folds the polynomial in-place along high-indexed variables. + /// + /// Like [`fold`](Self::fold), but modifies the polynomial in-place instead of + /// allocating a new coefficient vector. The excess capacity is freed via truncation. + /// + /// # Safety of in-place overwrite + /// + /// For each output index `i`, `eval_multivariate` reads from + /// `coeffs[i*chunk .. (i+1)*chunk]` and the result is written to `coeffs[i]`. + /// Since `chunk >= 2`, the write target `i` is always strictly less than the + /// start of the next read range `(i+1)*chunk`, so writes never corrupt data + /// needed by subsequent iterations. + pub fn fold_in_place(&mut self, folding_randomness: &MultilinearPoint) { + let folding_factor = folding_randomness.num_variables(); + let chunk_size = 1 << folding_factor; + let new_len = self.coeffs.len() / chunk_size; + for i in 0..new_len { + let val = eval_multivariate( + &self.coeffs[i * chunk_size..(i + 1) * chunk_size], + &folding_randomness.0, + ); + self.coeffs[i] = val; + } + self.coeffs.truncate(new_len); + self.num_variables -= folding_factor; + } } /// Multivariate evaluation in coefficient form. diff --git a/src/algebra/polynomials/evals.rs b/src/algebra/polynomials/evals.rs index 1d20958b..e925411a 100644 --- a/src/algebra/polynomials/evals.rs +++ b/src/algebra/polynomials/evals.rs @@ -108,6 +108,26 @@ where self.num_variables } + /// Folds evaluations in-place by linear interpolation at the given weight. + /// + /// For each pair `(evals[2i], evals[2i+1])`, computes the interpolated value + /// `(evals[2i+1] - evals[2i]) * weight + evals[2i]` and stores it at `evals[i]`. + /// The vector is then truncated to half its original size. + /// + /// This is equivalent to creating a new `EvaluationsList` via + /// `algebra::sumcheck::fold`, but avoids allocating a new vector. + pub fn fold_in_place(&mut self, weight: F) { + assert!(self.evals.len().is_multiple_of(2)); + let half = self.evals.len() / 2; + for i in 0..half { + let v0 = self.evals[2 * i]; + let v1 = self.evals[2 * i + 1]; + self.evals[i] = (v1 - v0) * weight + v0; + } + self.evals.truncate(half); + self.num_variables -= 1; + } + pub fn to_coeffs(&self) -> crate::algebra::polynomials::coeffs::CoefficientList { let mut coeffs = self.evals.clone(); crate::algebra::ntt::inverse_wavelet_transform(&mut coeffs); diff --git a/src/algebra/polynomials/multilinear.rs b/src/algebra/polynomials/multilinear.rs index 65208eb8..df174162 100644 --- a/src/algebra/polynomials/multilinear.rs +++ b/src/algebra/polynomials/multilinear.rs @@ -118,12 +118,31 @@ where acc } - /// Computes eq(c, p) on the hypercube for all p. + /// Computes eq(self, z) for every z ∈ {0,1}ⁿ using a butterfly expansion. + /// + /// Returns a `Vec` of length `2^n` where entry `z` (in lexicographic + /// order) is `eq(self, z)`. + /// + /// Runs in O(2ⁿ) time and O(2ⁿ) space. pub fn eq_weights(&self) -> Vec { - (0..1 << self.0.len()) - .map(BinaryHypercubePoint) - .map(|point| self.eq_poly(point)) - .collect() + let n = self.num_variables(); + let size = 1 << n; + let mut evals = Vec::with_capacity(size); + evals.push(F::ONE); + // Process coordinates in storage order (big-endian: x_{n-1}, …, x_0). + // Each step doubles the vector via the identity: + // eq(c, z||0) = eq(c', z) · (1 − cᵢ) + // eq(c, z||1) = eq(c', z) · cᵢ + for &ci in &self.0 { + let len = evals.len(); + let one_minus_ci = F::ONE - ci; + evals.resize(2 * len, F::ZERO); + for j in (0..len).rev() { + evals[2 * j + 1] = evals[j] * ci; + evals[2 * j] = evals[j] * one_minus_ci; + } + } + evals } pub fn coeff_weights(&self, reversed: bool) -> Vec { diff --git a/src/algebra/weights.rs b/src/algebra/weights.rs index cc79c446..a05056c1 100644 --- a/src/algebra/weights.rs +++ b/src/algebra/weights.rs @@ -384,4 +384,144 @@ mod tests { let expected = weight_list.eval_extension(&folding_randomness); assert_eq!(weight.compute(&folding_randomness), expected); } + + #[test] + fn test_protocol() { + // ── Step 1: Create a CoefficientList (4 variables, 16 coefficients) ── + let coeffs = CoefficientList::new(vec![ + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + Field64::ONE, + ]); + println!("coeffs: {:?}\n", coeffs); + + // ── Step 2: Evaluate at several MultilinearPoints ── + let evaluation_points = vec![ + MultilinearPoint(vec![ + Field64::ONE, + Field64::ZERO, + Field64::ZERO, + Field64::ZERO, + ]), + MultilinearPoint(vec![ + Field64::ZERO, + Field64::ONE, + Field64::ZERO, + Field64::ZERO, + ]), + MultilinearPoint(vec![ + Field64::ZERO, + Field64::ZERO, + Field64::ONE, + Field64::ZERO, + ]), + MultilinearPoint(vec![ + Field64::ZERO, + Field64::ZERO, + Field64::ZERO, + Field64::ONE, + ]), + ]; + println!("evaluation_points: {:?}\n", evaluation_points); + let weights = evaluation_points + .iter() + .map(|point| Weights::evaluation(point.clone())) + .collect::>(); + println!("weights: {:?}\n", weights); + let evaluations = evaluation_points + .iter() + .map(|point| coeffs.mixed_evaluate(&Identity::new(), point)) + .collect::>(); + println!("evaluations: {:?}\n", evaluations); + + // ── Step 3: Convert CoefficientList → EvaluationsList → CoefficientList ── + // CoefficientList → EvaluationsList (via wavelet transform) + let evals = EvaluationsList::from(coeffs.clone()); + println!("evals (hypercube evaluations): {:?}\n", evals); + + // EvaluationsList → CoefficientList (via inverse wavelet transform) + let coeffs_roundtrip = evals.to_coeffs(); + println!("coeffs_roundtrip: {:?}\n", coeffs_roundtrip); + + // Verify round-trip: coeffs → evals → coeffs gives back the same polynomial + assert_eq!( + coeffs.coeffs(), + coeffs_roundtrip.coeffs(), + "Round-trip CoefficientList → EvaluationsList → CoefficientList must be identity" + ); + + // Both representations should evaluate to the same values at any point + for point in &evaluation_points { + let from_coeffs = coeffs.evaluate(point); + let from_evals = evals.evaluate(point); + assert_eq!( + from_coeffs, from_evals, + "CoefficientList and EvaluationsList must agree at {:?}", + point + ); + } + println!("✓ Round-trip and evaluation consistency verified\n"); + + // ── Step 4: Verify fold_in_place matches fold ── + // fold() creates a new polynomial; fold_in_place() mutates in-place. + // After folding f(X₀, X₁, X₂, X₃) at (r₀, r₁), we get g(X₂, X₃) = f(X₂, X₃, r₀, r₁). + let folding_randomness = MultilinearPoint(vec![Field64::from(3u64), Field64::from(7u64)]); + + // fold() — allocating version + let folded = coeffs.fold(&folding_randomness); + println!("folded (via fold): {:?}", folded); + + // fold_in_place() — in-place version + let mut coeffs_mut = coeffs.clone(); + coeffs_mut.fold_in_place(&folding_randomness); + println!("folded (via fold_in_place): {:?}", coeffs_mut); + + // They must produce identical results + assert_eq!( + folded.coeffs(), + coeffs_mut.coeffs(), + "fold() and fold_in_place() must produce the same polynomial" + ); + println!("✓ fold and fold_in_place match\n"); + + // ── Step 5: Verify folded polynomial is consistent with full evaluation ── + // g(a, b) should equal f(a, b, r₀, r₁) for any (a, b) + let eval_point = MultilinearPoint(vec![Field64::from(5u64), Field64::from(11u64)]); + println!("eval_point: {:?}\n", eval_point); + let full_point = MultilinearPoint(vec![ + eval_point.0[0], + eval_point.0[1], + folding_randomness.0[0], + folding_randomness.0[1], + ]); + println!("full_point: {:?}\n", full_point); + let folded_eval = folded.evaluate(&eval_point); + println!("folded poly: {:?}\n", folded); + let full_eval = coeffs.evaluate(&full_point); + println!("full poly: {:?}\n", coeffs); + + println!("folded_eval: {:?}\n", folded_eval); + println!("full_eval: {:?}\n", full_eval); + assert_eq!( + folded_eval, full_eval, + "f.fold(r).evaluate(a) must equal f.evaluate(a || r)" + ); + println!( + "✓ folded.evaluate({:?}) == coeffs.evaluate({:?}) == {:?}\n", + eval_point.0, full_point.0, folded_eval + ); + } } diff --git a/src/alloc_track.rs b/src/alloc_track.rs new file mode 100644 index 00000000..5682565d --- /dev/null +++ b/src/alloc_track.rs @@ -0,0 +1,152 @@ +//! Allocation tracking for profiling memory usage at each protocol step. +//! +//! # Usage +//! +//! 1. Enable the `alloc-track` feature. +//! 2. In your binary, set the global allocator: +//! ```rust,ignore +//! #[global_allocator] +//! static ALLOC: whir::alloc_track::TrackingAllocator = whir::alloc_track::TrackingAllocator; +//! ``` +//! 3. Run your binary — each instrumented protocol step will print allocation +//! counts and bytes to stderr. +//! +//! See `src/bin/alloc_report.rs` for a ready-to-run example. + +use std::alloc::{GlobalAlloc, Layout, System}; +use std::sync::atomic::{AtomicU64, Ordering}; + +// ── Global counters ────────────────────────────────────────────────────── + +static ALLOC_COUNT: AtomicU64 = AtomicU64::new(0); +static ALLOC_BYTES: AtomicU64 = AtomicU64::new(0); + +/// A global allocator that counts every allocation. +/// +/// Wraps `std::alloc::System` and atomically increments counters on each +/// `alloc` / `realloc`. De-allocations are forwarded without counting +/// (tracking freed memory is possible but adds overhead and isn't needed +/// for allocation-count profiling). +pub struct TrackingAllocator; + +unsafe impl GlobalAlloc for TrackingAllocator { + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + ALLOC_COUNT.fetch_add(1, Ordering::Relaxed); + ALLOC_BYTES.fetch_add(layout.size() as u64, Ordering::Relaxed); + unsafe { System.alloc(layout) } + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + unsafe { System.dealloc(ptr, layout) } + } + + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + ALLOC_COUNT.fetch_add(1, Ordering::Relaxed); + ALLOC_BYTES.fetch_add(layout.size() as u64, Ordering::Relaxed); + unsafe { System.alloc_zeroed(layout) } + } + + unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { + ALLOC_COUNT.fetch_add(1, Ordering::Relaxed); + ALLOC_BYTES.fetch_add(new_size as u64, Ordering::Relaxed); + unsafe { System.realloc(ptr, layout, new_size) } + } +} + +// ── Snapshot API ───────────────────────────────────────────────────────── + +/// A point-in-time snapshot of allocation counters. +/// +/// Take a snapshot before a code region, then call [`Snapshot::elapsed`] +/// to see how many allocations occurred in between. +#[derive(Debug, Clone, Copy)] +pub struct Snapshot { + pub allocs: u64, + pub bytes: u64, +} + +/// The delta between two snapshots. +#[derive(Debug, Clone, Copy)] +pub struct AllocDelta { + pub allocs: u64, + pub bytes: u64, +} + +impl Snapshot { + /// Capture the current allocation counters. + #[inline] + pub fn now() -> Self { + Self { + allocs: ALLOC_COUNT.load(Ordering::Relaxed), + bytes: ALLOC_BYTES.load(Ordering::Relaxed), + } + } + + /// Compute how many allocations have happened since this snapshot. + #[inline] + pub fn elapsed(&self) -> AllocDelta { + let now = Self::now(); + AllocDelta { + allocs: now.allocs.wrapping_sub(self.allocs), + bytes: now.bytes.wrapping_sub(self.bytes), + } + } +} + +impl std::fmt::Display for AllocDelta { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{:>8} allocs, {:>10}", + self.allocs, + format_bytes(self.bytes) + ) + } +} + +/// Reset all counters to zero (useful before a profiling run). +pub fn reset() { + ALLOC_COUNT.store(0, Ordering::Relaxed); + ALLOC_BYTES.store(0, Ordering::Relaxed); +} + +/// Print a labelled allocation-delta line to stderr. +pub fn report(label: &str, snap: &Snapshot) { + let delta = snap.elapsed(); + eprintln!(" {label:50} {delta}"); +} + +fn format_bytes(bytes: u64) -> String { + if bytes < 1024 { + format!("{bytes} B") + } else if bytes < 1024 * 1024 { + format!("{:.1} KB", bytes as f64 / 1024.0) + } else if bytes < 1024 * 1024 * 1024 { + format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0)) + } else { + format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0)) + } +} + +// ── Macros for zero-cost instrumentation ───────────────────────────────── + +/// Take an allocation snapshot. Expands to `()` when `alloc-track` is off. +#[macro_export] +macro_rules! alloc_snap { + () => {{ + $crate::alloc_track::Snapshot::now() + }}; +} + +/// Print an allocation report line and reset the snapshot variable. +/// +/// Usage: `alloc_report!("label", snap);` +/// +/// After this macro, `snap` holds a fresh snapshot for the next region. +#[macro_export] +macro_rules! alloc_report { + ($label:expr, $snap:ident) => {{ + $crate::alloc_track::report($label, &$snap); + $snap = $crate::alloc_track::Snapshot::now(); + }}; +} diff --git a/src/bin/alloc_report.rs b/src/bin/alloc_report.rs new file mode 100644 index 00000000..ec01e9be --- /dev/null +++ b/src/bin/alloc_report.rs @@ -0,0 +1,221 @@ +//! Allocation profiling binary for ZK WHIR. +//! +//! Run with: +//! ```bash +//! cargo run --bin alloc_report --features alloc-track --release +//! ``` +//! +//! You can change `NUM_VARIABLES`, `NUM_POLYS`, and `NUM_POINTS` at the +//! top of `run()` to match the configuration you care about. + +fn main() { + #[cfg(feature = "alloc-track")] + run(); + + #[cfg(not(feature = "alloc-track"))] + eprintln!( + "This binary requires the `alloc-track` feature.\n\ + Run: cargo run --bin alloc_report --features alloc-track --release" + ); +} + +#[cfg(feature = "alloc-track")] +#[global_allocator] +static ALLOC: whir::alloc_track::TrackingAllocator = whir::alloc_track::TrackingAllocator; + +#[cfg(feature = "alloc-track")] +fn run() { + use ark_std::rand::{rngs::StdRng, SeedableRng}; + use whir::{ + algebra::{ + fields::{Field64, Field64_2}, + polynomials::{CoefficientList, MultilinearPoint}, + Weights, + }, + alloc_track, hash, + parameters::{FoldingFactor, MultivariateParameters, ProtocolParameters, SoundnessType}, + protocols::{ + whir::Config, + whir_zk::{ZkParams, ZkPreprocessingPolynomials}, + }, + transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, + }; + + type F = Field64; + type EF = Field64_2; + + /// ── Tunables ──────────────────────────────────────────────────── + const NUM_VARIABLES: usize = 20; + const NUM_POLYS: usize = 2; + const NUM_POINTS: usize = 10; + + let mut rng = StdRng::seed_from_u64(42); + let num_coeffs = 1usize << NUM_VARIABLES; + + // ── Build config ───────────────────────────────────────────────── + let mv = MultivariateParameters::new(NUM_VARIABLES); + let params = ProtocolParameters { + initial_statement: true, + security_level: 32, + pow_bits: 0, + folding_factor: FoldingFactor::ConstantFromSecondRound(2, 4), + soundness_type: SoundnessType::ConjectureList, + starting_log_inv_rate: 1, + batch_size: 1, + hash_id: hash::SHA2, + }; + let config = Config::::new(mv, ¶ms); + let zk_params = ZkParams::from_whir_params(&config); + + let helper_mv = MultivariateParameters::new(zk_params.ell + 1); + let helper_params = ProtocolParameters { + initial_statement: true, + security_level: 32, + pow_bits: 0, + folding_factor: FoldingFactor::Constant(4), + soundness_type: SoundnessType::ConjectureList, + starting_log_inv_rate: 1, + batch_size: zk_params.helper_batch_size(NUM_POLYS), + hash_id: hash::SHA2, + }; + let helper_config = Config::::new(helper_mv, &helper_params); + + eprintln!("╔══════════════════════════════════════════════════════════════════════╗"); + eprintln!("║ ZK WHIR Allocation Report ║"); + eprintln!("╠══════════════════════════════════════════════════════════════════════╣"); + eprintln!( + "║ num_variables = {:>4} ║", + NUM_VARIABLES + ); + eprintln!( + "║ num_polys = {:>4} ║", + NUM_POLYS + ); + eprintln!( + "║ num_points = {:>4} ║", + NUM_POINTS + ); + eprintln!( + "║ ell (ZK) = {:>4} ║", + zk_params.ell + ); + eprintln!( + "║ mu (ZK) = {:>4} ║", + zk_params.mu + ); + eprintln!( + "║ WHIR rounds = {:>4} ║", + config.n_rounds() + ); + eprintln!("╚══════════════════════════════════════════════════════════════════════╝"); + eprintln!(); + + // ── Build polynomials and preprocessings ───────────────────────── + eprintln!("── setup ──────────────────────────────────────────────────────────"); + let mut snap = alloc_track::Snapshot::now(); + + let polynomials: Vec> = (0..NUM_POLYS) + .map(|i| { + CoefficientList::new( + (0..num_coeffs) + .map(|j| F::from((i * num_coeffs + j + 1) as u64)) + .collect(), + ) + }) + .collect(); + alloc_track::report("setup::build_polynomials", &snap); + snap = alloc_track::Snapshot::now(); + + let preprocessings: Vec> = (0..NUM_POLYS) + .map(|_| ZkPreprocessingPolynomials::::sample(&mut rng, zk_params.clone())) + .collect(); + alloc_track::report("setup::sample_preprocessing", &snap); + snap = alloc_track::Snapshot::now(); + + // ── Build weights and evaluations ──────────────────────────────── + let mut weights = Vec::new(); + let mut evaluations = Vec::new(); + for _ in 0..NUM_POINTS { + let point = MultilinearPoint::rand(&mut rng, NUM_VARIABLES); + weights.push(Weights::evaluation(point.clone())); + for poly in &polynomials { + evaluations.push(poly.mixed_evaluate(config.embedding(), &point)); + } + } + let weight_refs: Vec<&Weights> = weights.iter().collect(); + let poly_refs: Vec<&CoefficientList> = polynomials.iter().collect(); + alloc_track::report("setup::weights_and_evaluations", &snap); + + // ── Transcript setup ───────────────────────────────────────────── + let ds = DomainSeparator::protocol(&config) + .session(&String::from("alloc-report")) + .instance(&Empty); + + // ══════════════════════════════════════════════════════════════════ + // COMMIT + // ══════════════════════════════════════════════════════════════════ + eprintln!(); + eprintln!("── commit_zk ──────────────────────────────────────────────────────"); + let mut prover_state = ProverState::new_std(&ds); + snap = alloc_track::Snapshot::now(); + let zk_witness = config.commit_zk( + &mut prover_state, + &poly_refs, + &helper_config, + &preprocessings.iter().collect::>(), + ); + eprintln!(" ──────────────────────────────────────────────────────────────"); + alloc_track::report("commit_zk TOTAL", &snap); + + // ══════════════════════════════════════════════════════════════════ + // PROVE + // ══════════════════════════════════════════════════════════════════ + eprintln!(); + eprintln!("── prove_zk ───────────────────────────────────────────────────────"); + snap = alloc_track::Snapshot::now(); + let (_point, _evals) = config.prove_zk( + &mut prover_state, + &poly_refs, + &zk_witness, + &helper_config, + &weight_refs, + &evaluations, + ); + eprintln!(" ──────────────────────────────────────────────────────────────"); + alloc_track::report("prove_zk TOTAL", &snap); + + // ══════════════════════════════════════════════════════════════════ + // VERIFY + // ══════════════════════════════════════════════════════════════════ + let proof = prover_state.proof(); + eprintln!(); + eprintln!("── verify_zk ──────────────────────────────────────────────────────"); + snap = alloc_track::Snapshot::now(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let f_hat_commitments: Vec<_> = (0..NUM_POLYS) + .map(|_| config.receive_commitment(&mut verifier_state).unwrap()) + .collect(); + let f_hat_refs: Vec<_> = f_hat_commitments.iter().collect(); + let helper_commitment = helper_config + .receive_commitment(&mut verifier_state) + .unwrap(); + alloc_track::report("verify_zk::receive_commitments", &snap); + snap = alloc_track::Snapshot::now(); + + let result = config.verify_zk( + &mut verifier_state, + &f_hat_refs, + &helper_commitment, + &helper_config, + &zk_params, + &weight_refs, + &evaluations, + ); + assert!(result.is_ok(), "Verification failed: {result:?}"); + eprintln!(" ──────────────────────────────────────────────────────────────"); + alloc_track::report("verify_zk TOTAL", &snap); + + eprintln!(); + eprintln!("✓ Verification passed. All allocation counts above are per full run."); +} diff --git a/src/lib.rs b/src/lib.rs index 7d1d8b45..3153d20d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,6 @@ pub mod algebra; +#[cfg(feature = "alloc-track")] +pub mod alloc_track; pub mod ark_serde; pub mod bits; pub mod cmdline_utils; diff --git a/src/protocols/mod.rs b/src/protocols/mod.rs index 60b7fd88..4f782391 100644 --- a/src/protocols/mod.rs +++ b/src/protocols/mod.rs @@ -16,3 +16,4 @@ pub mod merkle_tree; pub mod proof_of_work; pub mod sumcheck; pub mod whir; +pub mod whir_zk; diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index 74599a65..4c7170da 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -12,7 +12,7 @@ use crate::{ algebra::{ dot, polynomials::{EvaluationsList, MultilinearPoint}, - sumcheck::{compute_sumcheck_polynomial, fold}, + sumcheck::compute_sumcheck_polynomial, }, ensure, protocols::proof_of_work, @@ -98,9 +98,9 @@ impl Config { let folding_randomness = prover_state.verifier_message::(); res.push(folding_randomness); - // Fold the inputs - *a = EvaluationsList::new(fold(folding_randomness, a.evals())); - *b = EvaluationsList::new(fold(folding_randomness, b.evals())); + // Fold the inputs in-place (avoids allocating new Vecs each round) + a.fold_in_place(folding_randomness); + b.fold_in_place(folding_randomness); *sum = (c2 * folding_randomness + c1) * folding_randomness + c0; } diff --git a/src/protocols/whir/prover.rs b/src/protocols/whir/prover.rs index 6ba30187..244c6e11 100644 --- a/src/protocols/whir/prover.rs +++ b/src/protocols/whir/prover.rs @@ -171,7 +171,7 @@ impl Config { &mut the_sum, ) }; - coefficients = coefficients.fold(&folding_randomness); + coefficients.fold_in_place(&folding_randomness); if constraint_rlc_coeffs.is_empty() { // We didn't fold evaluations, so compute it here. evaluations = EvaluationsList::from(coefficients.clone()); @@ -234,7 +234,7 @@ impl Config { &mut constraints, &mut the_sum, ); - coefficients = coefficients.fold(&folding_randomness); + coefficients.fold_in_place(&folding_randomness); randomness_vec.extend(folding_randomness.0.iter().rev()); debug_assert_eq!(evaluations, EvaluationsList::from(coefficients.clone())); debug_assert_eq!(dot(evaluations.evals(), constraints.evals()), the_sum); @@ -244,10 +244,7 @@ impl Config { } // Directly send coefficients of the polynomial to the verifier. - assert_eq!(coefficients.num_coeffs(), self.final_sumcheck.initial_size); - for coeff in coefficients.coeffs() { - prover_state.prover_message(coeff); - } + self.send_final_coefficients(prover_state, &coefficients); // PoW self.final_pow.prove(prover_state); @@ -274,14 +271,46 @@ impl Config { randomness_vec.extend(final_folding_randomness.0.iter().rev()); // Hints for deferred constraints + self.compute_deferred_hints(prover_state, weights, &randomness_vec) + } + + /// Send final polynomial coefficients to verifier. + pub(crate) fn send_final_coefficients( + &self, + prover_state: &mut ProverState, + coefficients: &CoefficientList, + ) where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + { + assert_eq!(coefficients.num_coeffs(), self.final_sumcheck.initial_size); + for coeff in coefficients.coeffs() { + prover_state.prover_message(coeff); + } + } + + /// Compute deferred constraint hints and write them to the transcript. + /// + /// Returns `(constraint_evaluation_point, deferred_values)`. + pub(crate) fn compute_deferred_hints( + &self, + prover_state: &mut ProverState, + weights: &[&Weights], + randomness_vec: &[F], + ) -> (MultilinearPoint, Vec) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + { let constraint_eval = MultilinearPoint(randomness_vec.iter().copied().rev().collect()); - let deferred = weights + let deferred: Vec = weights .iter() .filter(|w| w.deferred()) .map(|w| w.compute(&constraint_eval)) .collect(); prover_state.prover_hint_ark(&deferred); - (constraint_eval, deferred) } } diff --git a/src/protocols/whir/verifier.rs b/src/protocols/whir/verifier.rs index e5992d5e..fa441c2b 100644 --- a/src/protocols/whir/verifier.rs +++ b/src/protocols/whir/verifier.rs @@ -219,43 +219,69 @@ impl Config { let final_sumcheck_randomness = self.final_sumcheck.verify(verifier_state, &mut the_sum)?; round_folding_randomness.push(final_sumcheck_randomness.clone()); - // Compute folding randomness across all rounds + // Final consistency check (shared with verify_zk) + self.verify_final_consistency( + verifier_state, + &round_constraints, + &round_folding_randomness, + &final_coefficients, + &final_sumcheck_randomness, + the_sum, + ) + } + + /// Verify the final consistency check: compute folding randomness, read deferred + /// hints, evaluate weight functions, and check the final sumcheck equation. + /// + /// This is shared between `verify` and `verify_zk`. + pub(crate) fn verify_final_consistency( + &self, + verifier_state: &mut VerifierState<'_, H>, + round_constraints: &[(Vec, Vec>)], + round_folding_randomness: &[MultilinearPoint], + final_coefficients: &CoefficientList, + final_sumcheck_randomness: &MultilinearPoint, + the_sum: F, + ) -> VerificationResult<(MultilinearPoint, Vec)> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + { let folding_randomness = MultilinearPoint( round_folding_randomness - .into_iter() + .iter() .rev() - .flat_map(|poly| poly.0.into_iter()) + .flat_map(|poly| poly.0.iter().copied()) .collect(), ); - // Compute evaluation of weights in folding randomness point let deferred: Vec = verifier_state.prover_hint_ark()?; let mut deferred_iter = deferred.iter().copied(); let mut weight_eval = F::ZERO; - for (round, (weights_rlc_coeffs, weights)) in round_constraints.iter().enumerate() { + + for (round, (weights_rlc_coeffs, round_weights)) in round_constraints.iter().enumerate() { let num_variables = round.checked_sub(1).map_or_else( || self.initial_num_variables(), |p| self.round_configs[p].initial_num_variables(), ); let point = MultilinearPoint(folding_randomness.0[..num_variables].to_vec()); - for (rlc_coeff, weights) in zip_strict(weights_rlc_coeffs, weights) { - let eval = if weights.deferred() { - let deferred = deferred_iter.next(); - verify!(deferred.is_some()); - deferred.unwrap() + for (rlc_coeff, w) in zip_strict(weights_rlc_coeffs, round_weights) { + let eval = if w.deferred() { + let d = deferred_iter.next(); + verify!(d.is_some()); + d.unwrap() } else { - weights.compute(&point) + w.compute(&point) }; weight_eval += *rlc_coeff * eval; } } verify!(deferred_iter.next().is_none()); - // Check the final sumcheck equation - let poly_eval = final_coefficients.evaluate(&final_sumcheck_randomness); + let poly_eval = final_coefficients.evaluate(final_sumcheck_randomness); verify!(poly_eval * weight_eval == the_sum); - // Return the evaluation point and the claimed values of the deferred weights. Ok((folding_randomness, deferred)) } } diff --git a/src/protocols/whir_zk/api.rs b/src/protocols/whir_zk/api.rs new file mode 100644 index 00000000..9bed8a61 --- /dev/null +++ b/src/protocols/whir_zk/api.rs @@ -0,0 +1,441 @@ +//! Unified API for ZK-WHIR batch proving with mixed-arity polynomial groups. +//! +//! The caller provides polynomials, weights, and evaluations grouped by arity. +//! The library handles all internal bookkeeping: config creation, preprocessing +//! sampling, commitment, and proof generation. +//! +//! # Example +//! +//! ```ignore +//! // Prover side +//! let groups = vec![ +//! ProverInput::new(vec![&poly_10var], weights_10, evals_10), +//! ProverInput::new(vec![&poly_12var], weights_12, evals_12), +//! ]; +//! let (point, evals) = main_config.batch_prove_zk( +//! &mut prover_state, &whir_params, &groups, &mut rng, +//! ); +//! +//! // Verifier side +//! let claims: Vec> = groups.iter().map(|g| g.to_verifier_input()).collect(); +//! let result = main_config.batch_verify_zk( +//! &mut verifier_state, &whir_params, &claims, +//! ); +//! ``` + +use ark_ff::FftField; +use ark_std::rand::{CryptoRng, RngCore}; + +use crate::{ + algebra::{ + fields::FieldWithSize, + polynomials::{CoefficientList, MultilinearPoint}, + Weights, + }, + hash::Hash, + parameters::ProtocolParameters, + protocols::whir::Config, + transcript::{ + codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, + VerificationResult, VerifierState, + }, +}; + +use super::{ + prefold::{ + commit_zk_at_level, receive_prefold_commitments, PrefoldGroupCommitments, + PrefoldGroupInput, PrefoldLevelConfig, + }, + utils::{ZkParams, ZkPreprocessingPolynomials, ZkWitness}, +}; + +/// Prover input: a group of polynomials at the same arity with shared constraints. +/// +/// All polynomials in a group must have the same number of variables (arity). +/// The library automatically determines which groups are "native" (minimum arity) +/// and which need prefolding. +pub struct ProverInput<'a, EF: FftField> { + /// Base-field polynomials (all at the same arity). + pub polynomials: Vec<&'a CoefficientList>, + /// Shared evaluation constraint weights. + pub weights: Vec>, + /// Evaluations: row-major `[w₀_p₀, w₀_p₁, ..., w₁_p₀, ...]`. + pub evaluations: Vec, +} + +/// Verifier input: a claim about a group of polynomials at the same arity. +/// +/// Same layout as [`ProverInput`] but without the actual polynomial coefficients — +/// the verifier only needs the arity, group size, and the claimed evaluations. +pub struct VerifierInput { + /// Number of variables (arity) for this group. + pub arity: usize, + /// Number of polynomials in this group. + pub num_polynomials: usize, + /// Shared evaluation constraint weights. + pub weights: Vec>, + /// Evaluations: row-major `[w₀_p₀, w₀_p₁, ..., w₁_p₀, ...]`. + pub evaluations: Vec, +} + +impl<'a, EF: FftField> ProverInput<'a, EF> { + /// Create a new prover input group. + pub fn new( + polynomials: Vec<&'a CoefficientList>, + weights: Vec>, + evaluations: Vec, + ) -> Self { + assert!( + !polynomials.is_empty(), + "ProverInput must have at least one polynomial" + ); + let arity = polynomials[0].num_variables(); + debug_assert!( + polynomials.iter().all(|p| p.num_variables() == arity), + "All polynomials in a ProverInput must have the same arity" + ); + debug_assert_eq!( + evaluations.len(), + weights.len() * polynomials.len(), + "evaluations.len() must equal weights.len() × polynomials.len()" + ); + Self { + polynomials, + weights, + evaluations, + } + } + + /// Infer the arity (number of variables) from the polynomials. + pub fn arity(&self) -> usize { + self.polynomials[0].num_variables() + } + + /// Number of polynomials in this group. + pub fn num_polynomials(&self) -> usize { + self.polynomials.len() + } + + /// Create a corresponding [`VerifierInput`] from this prover input. + pub fn to_verifier_input(&self) -> VerifierInput { + VerifierInput { + arity: self.arity(), + num_polynomials: self.polynomials.len(), + weights: self.weights.clone(), + evaluations: self.evaluations.clone(), + } + } +} + +impl VerifierInput { + /// Create a new verifier input claim. + pub fn new( + arity: usize, + num_polynomials: usize, + weights: Vec>, + evaluations: Vec, + ) -> Self { + debug_assert_eq!( + evaluations.len(), + weights.len() * num_polynomials, + "evaluations.len() must equal weights.len() × num_polynomials" + ); + Self { + arity, + num_polynomials, + weights, + evaluations, + } + } +} + +/// Separate groups into native (arity == n_min) and prefold (arity > n_min). +/// +/// Returns `(native_index, prefold_indices)` where prefold indices are sorted +/// by decreasing arity (highest first), matching the prefold proof order. +fn separate_by_arity(n_min: usize, arities: &[usize]) -> (usize, Vec) { + let mut native_idx: Option = None; + let mut prefold_indices: Vec = Vec::new(); + + for (i, &arity) in arities.iter().enumerate() { + if arity == n_min { + assert!( + native_idx.is_none(), + "only one native group (arity == n_min = {n_min}) allowed" + ); + native_idx = Some(i); + } else { + assert!( + arity > n_min, + "group arity ({arity}) must be >= n_min ({n_min})" + ); + prefold_indices.push(i); + } + } + + let native_idx = native_idx.expect("must have exactly one group at the native arity (n_min)"); + prefold_indices.sort_by(|&a, &b| arities[b].cmp(&arities[a])); + + (native_idx, prefold_indices) +} + +impl Config { + /// Unified ZK-WHIR batch proof for mixed-arity polynomial groups. + /// + /// The caller provides polynomial groups (each at its own arity, with its own + /// constraints and evaluations). The library automatically: + /// + /// 1. Identifies the minimum arity (native) group. + /// 2. Creates prefold configs for any higher-arity groups. + /// 3. Samples ZK preprocessing (blinding polynomials) for every polynomial. + /// 4. Commits all polynomials (native + prefold). + /// 5. Runs the prefold + ZK-WHIR proof pipeline. + /// + /// If all groups are at the same arity (no prefold needed), falls back to + /// the standard `prove_zk` path. + /// + /// # Arguments + /// + /// * `self` — Main WHIR config (must be configured at the minimum arity). + /// * `prover_state` — Fiat-Shamir prover transcript. + /// * `whir_params` — Protocol parameters (used to build prefold sub-configs). + /// * `groups` — Polynomial groups at arbitrary arities. + /// * `rng` — RNG for sampling ZK preprocessing polynomials. + #[allow(clippy::too_many_lines)] + pub fn batch_prove_zk( + &self, + prover_state: &mut ProverState, + whir_params: &ProtocolParameters, + groups: &[ProverInput<'_, F>], + rng: &mut R2, + ) -> (MultilinearPoint, Vec) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + R2: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + u8: Decoding<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + assert!(!groups.is_empty(), "must provide at least one group"); + + let n_min = self.initial_num_variables(); + + // ── Separate groups into native and prefold ── + let arities: Vec = groups.iter().map(|g| g.arity()).collect(); + let (native_idx, prefold_indices) = separate_by_arity(n_min, &arities); + + let native_group = &groups[native_idx]; + let num_native = native_group.num_polynomials(); + + // Native: ZK params, helper config, preprocessings + let native_zk_params = ZkParams::from_whir_params(self); + let native_helper_config = + native_zk_params.build_helper_config::(num_native, whir_params); + let native_preprocessings: Vec> = (0..num_native) + .map(|_| ZkPreprocessingPolynomials::sample(rng, native_zk_params.clone())) + .collect(); + let native_preproc_refs: Vec<&ZkPreprocessingPolynomials> = + native_preprocessings.iter().collect(); + + // Commit native + let native_witness = self.commit_zk( + prover_state, + &native_group.polynomials, + &native_helper_config, + &native_preproc_refs, + ); + + // Fast path: no prefold groups → standard prove_zk + if prefold_indices.is_empty() { + let native_weight_refs: Vec<&Weights> = native_group.weights.iter().collect(); + return self.prove_zk( + prover_state, + &native_group.polynomials, + &native_witness, + &native_helper_config, + &native_weight_refs, + &native_group.evaluations, + ); + } + + // Prefold groups: configs, preprocessings, commitments + let mut prefold_configs: Vec> = Vec::new(); + let mut prefold_witnesses: Vec> = Vec::new(); + + for &gi in &prefold_indices { + let group = &groups[gi]; + let arity = group.arity(); + let num_polys = group.num_polynomials(); + + let level_config = PrefoldLevelConfig::new(self, arity, whir_params); + let preprocs: Vec> = (0..num_polys) + .map(|_| ZkPreprocessingPolynomials::sample(rng, level_config.zk_params.clone())) + .collect(); + let preproc_refs: Vec<&ZkPreprocessingPolynomials> = preprocs.iter().collect(); + + let witness = commit_zk_at_level( + &level_config, + prover_state, + &group.polynomials, + &preproc_refs, + ); + + prefold_configs.push(level_config); + prefold_witnesses.push(witness); + } + + // Build PrefoldGroupInputs + let prefold_weight_refs: Vec>> = prefold_indices + .iter() + .map(|&gi| groups[gi].weights.iter().collect()) + .collect(); + + let prefold_group_inputs: Vec> = prefold_indices + .iter() + .enumerate() + .map(|(ci, &gi)| PrefoldGroupInput { + polynomials: &groups[gi].polynomials, + witness: &prefold_witnesses[ci], + weights: &prefold_weight_refs[ci], + evaluations: &groups[gi].evaluations, + level_config: &prefold_configs[ci], + }) + .collect(); + + // Prove + let native_weight_refs: Vec<&Weights> = native_group.weights.iter().collect(); + + self.prove_zk_prefold( + prover_state, + &native_group.polynomials, + &native_witness, + &native_helper_config, + &native_weight_refs, + &native_group.evaluations, + &prefold_group_inputs, + ) + } + + /// Unified ZK-WHIR batch verification for mixed-arity polynomial groups. + /// + /// Mirrors [`batch_prove_zk`](Self::batch_prove_zk): the verifier provides + /// the same group structure (arity, number of polynomials, weights, evaluations) + /// and the library handles all config re-creation, commitment reception, and + /// verification routing. + /// + /// If all groups are at the same arity (no prefold), falls back to `verify_zk`. + #[allow(clippy::too_many_lines)] + pub fn batch_verify_zk( + &self, + verifier_state: &mut VerifierState<'_, H>, + whir_params: &ProtocolParameters, + claims: &[VerifierInput], + ) -> VerificationResult<(MultilinearPoint, Vec)> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + assert!(!claims.is_empty(), "must provide at least one claim"); + + let n_min = self.initial_num_variables(); + + // Separate into native and prefold + let arities: Vec = claims.iter().map(|c| c.arity).collect(); + let (native_idx, prefold_indices) = separate_by_arity(n_min, &arities); + + let native_claim = &claims[native_idx]; + let num_native = native_claim.num_polynomials; + + // Native: ZK params, helper config + let native_zk_params = ZkParams::from_whir_params(self); + let native_helper_config = + native_zk_params.build_helper_config::(num_native, whir_params); + + // Receive native commitments + let native_f_hat_comms: Vec<_> = (0..num_native) + .map(|_| self.receive_commitment(verifier_state)) + .collect::>()?; + let native_f_hat_comm_refs: Vec<_> = native_f_hat_comms.iter().collect(); + let native_helper_comm = native_helper_config.receive_commitment(verifier_state)?; + + let native_weight_refs: Vec<&Weights> = native_claim.weights.iter().collect(); + + // Fast path: no prefold → standard verify_zk + if prefold_indices.is_empty() { + return self.verify_zk( + verifier_state, + &native_f_hat_comm_refs, + &native_helper_comm, + &native_helper_config, + &native_zk_params, + &native_weight_refs, + &native_claim.evaluations, + ); + } + + // Prefold configs + let prefold_configs: Vec> = prefold_indices + .iter() + .map(|&i| PrefoldLevelConfig::new(self, claims[i].arity, whir_params)) + .collect(); + + // Receive prefold commitments + let mut prefold_group_commitments: Vec> = Vec::new(); + for (ci, &gi) in prefold_indices.iter().enumerate() { + let (f_hat_comms, helper_comm) = receive_prefold_commitments( + &prefold_configs[ci], + verifier_state, + claims[gi].num_polynomials, + )?; + prefold_group_commitments.push(PrefoldGroupCommitments { + f_hat_commitments: f_hat_comms, + helper_commitment: helper_comm, + }); + } + + // Build verify arguments + let prefold_groups_verify: Vec<(&PrefoldGroupCommitments, &PrefoldLevelConfig)> = + prefold_group_commitments + .iter() + .zip(prefold_configs.iter()) + .collect(); + + let prefold_weight_refs: Vec>> = prefold_indices + .iter() + .map(|&gi| claims[gi].weights.iter().collect()) + .collect(); + let prefold_weight_slices: Vec<&[&Weights]> = + prefold_weight_refs.iter().map(|v| v.as_slice()).collect(); + + let prefold_eval_slices: Vec<&[F]> = prefold_indices + .iter() + .map(|&gi| claims[gi].evaluations.as_slice()) + .collect(); + + let prefold_num_polys: Vec = prefold_indices + .iter() + .map(|&gi| claims[gi].num_polynomials) + .collect(); + + self.verify_zk_prefold( + verifier_state, + &native_f_hat_comm_refs, + &native_helper_comm, + &native_helper_config, + &native_zk_params, + &native_weight_refs, + &native_claim.evaluations, + &prefold_groups_verify, + &prefold_weight_slices, + &prefold_eval_slices, + &prefold_num_polys, + ) + } +} diff --git a/src/protocols/whir_zk/committer.rs b/src/protocols/whir_zk/committer.rs new file mode 100644 index 00000000..e9d94be8 --- /dev/null +++ b/src/protocols/whir_zk/committer.rs @@ -0,0 +1,80 @@ +#![allow(type_alias_bounds)] // We need the bound to reference F::BasePrimeField. + +use ark_ff::FftField; +use ark_std::rand::{CryptoRng, RngCore}; +#[cfg(feature = "tracing")] +use tracing::instrument; + +use super::utils::{ + interleave_helper_poly_refs, prepare_helper_polynomials, ZkPreprocessingPolynomials, ZkWitness, +}; +use crate::{ + algebra::{add_base_with_projection, polynomials::CoefficientList}, + hash::Hash, + protocols::whir::Config, + transcript::{Codec, DuplexSpongeInterface, ProverMessage, ProverState}, + utils::zip_strict, +}; + +impl Config { + #[allow(clippy::too_many_lines)] + #[cfg_attr(feature = "tracing", instrument(skip_all, fields(num_polynomials = polynomials.len())))] + pub fn commit_zk( + &self, + prover_state: &mut ProverState, + polynomials: &[&CoefficientList], + helper_config: &Config, + preprocessings: &[&ZkPreprocessingPolynomials], + ) -> ZkWitness + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + #[cfg(feature = "alloc-track")] + let mut __snap = crate::alloc_snap!(); + + // Commit to the polynomials + // 1. Compute f̂ = f + msk directly in base field. + // Both f_coeffs and msk are base field elements (msk is sampled from BasePrimeField + // then lifted to F), so the addition can be done in BasePrimeField directly, + // avoiding a needless round-trip through the extension field. + let mut f_hat_witnesses = Vec::new(); + for (polynomial, preprocessing) in zip_strict(polynomials, preprocessings) { + // f̂ = f + msk in base field (msk projected from extension, zero-padded). + let f_hat_coeffs = + add_base_with_projection::(polynomial.coeffs(), preprocessing.msk.coeffs()); + let f_hat = CoefficientList::new(f_hat_coeffs); + let f_hat_witness = self.commit(prover_state, &[&f_hat]); + f_hat_witnesses.push(f_hat_witness); + } + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("commit_zk::f_hat_commit", __snap); + + // 3. Prepare all helper polynomials in base field for batch commitment + let (m_polys_base, g_hats_embedded_bases) = prepare_helper_polynomials(preprocessings); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("commit_zk::prepare_helper_polys", __snap); + + // 4. Batch-commit all μ+1 helper polynomials in ONE IRS commit + // (helper_config has batch_size = μ+1, so one Merkle tree for all) + // Layout: [M₁, ĝ₁₁, ..., ĝ₁μ, M₂, ĝ₂₁, ..., ĝ₂μ, ..., Mₙ, ĝₙ₁, ..., ĝₙμ] + let helper_poly_refs = + interleave_helper_poly_refs::(&m_polys_base, &g_hats_embedded_bases); + let helper_witness = helper_config.commit(prover_state, &helper_poly_refs); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("commit_zk::helper_batch_commit", __snap); + + ZkWitness { + f_hat_witnesses, + helper_witness, + preprocessings: preprocessings.iter().copied().cloned().collect(), + m_polys_base, + g_hats_embedded_bases, + } + } +} diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs new file mode 100644 index 00000000..880107ca --- /dev/null +++ b/src/protocols/whir_zk/mod.rs @@ -0,0 +1,315 @@ +pub mod api; +mod committer; +pub mod prefold; +mod prover; +pub mod utils; +mod verifier; + +pub use api::{ProverInput, VerifierInput}; +pub use prefold::{PrefoldGroupCommitments, PrefoldGroupInput, PrefoldLevelConfig}; +pub use utils::{HelperEvaluations, ZkParams, ZkPreprocessingPolynomials, ZkWitness}; + +#[cfg(test)] +mod tests { + + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + use super::*; + use crate::{ + algebra::{ + fields::{Field64, Field64_2}, + polynomials::{CoefficientList, MultilinearPoint}, + Weights, + }, + hash, + parameters::{FoldingFactor, MultivariateParameters, ProtocolParameters, SoundnessType}, + protocols::whir::Config, + transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, + }; + + /// Field type used in the tests. + type F = Field64; + + /// Extension field type used in the tests. + type EF = Field64_2; + + /// What to tamper with in soundness tests. + #[derive(Clone, Copy, Debug, PartialEq)] + enum Tamper { + /// Honest run — no tampering. + None, + /// Corrupt a native evaluation passed to the verifier. + NativeEval, + /// Corrupt a byte in the middle of the serialised proof. + ProofBytes, + /// Corrupt a prefold group's evaluation claim. + PrefoldEval, + } + + /// Run a full `batch_prove_zk` → `batch_verify_zk` round-trip and return + /// whether verification succeeds. + /// + /// Each element of `group_specs` is `(arity, num_polys, num_points)`. + /// Exactly one group must have `arity == n_min` (the native group). + fn batch_api_roundtrip( + n_min: usize, + folding_factor: FoldingFactor, + soundness_type: SoundnessType, + pow_bits: usize, + group_specs: &[(usize, usize, usize)], + tamper: Tamper, + ) -> bool { + let mut rng = StdRng::seed_from_u64(12345); + + let mv_params = MultivariateParameters::::new(n_min); + let whir_params = ProtocolParameters { + initial_statement: true, + security_level: 32, + pow_bits, + folding_factor, + soundness_type, + starting_log_inv_rate: 1, + batch_size: 1, + hash_id: hash::SHA2, + }; + let main_config = Config::new(mv_params, &whir_params); + + // Build polynomial groups + let mut all_polys: Vec>> = Vec::new(); + let mut all_weights: Vec>> = Vec::new(); + let mut all_evals: Vec> = Vec::new(); + + for (gi, &(arity, num_polys, num_points)) in group_specs.iter().enumerate() { + let num_coeffs = 1usize << arity; + let polys: Vec> = (0..num_polys) + .map(|i| { + CoefficientList::new( + (0..num_coeffs) + .map(|j| F::from(((gi + 1) * 1000 + i * num_coeffs + j + 1) as u64)) + .collect(), + ) + }) + .collect(); + + let mut weights = Vec::new(); + let mut evals = Vec::new(); + for _ in 0..num_points { + let point = MultilinearPoint::rand(&mut rng, arity); + weights.push(Weights::evaluation(point.clone())); + for poly in &polys { + evals.push(poly.mixed_evaluate(main_config.embedding(), &point)); + } + } + + all_polys.push(polys); + all_weights.push(weights); + all_evals.push(evals); + } + + // Build ProverInputs + let poly_ref_vecs: Vec>> = + all_polys.iter().map(|ps| ps.iter().collect()).collect(); + + let prover_inputs: Vec> = group_specs + .iter() + .enumerate() + .map(|(gi, _)| { + ProverInput::new( + poly_ref_vecs[gi].clone(), + all_weights[gi].clone(), + all_evals[gi].clone(), + ) + }) + .collect(); + + // Prove + let ds = DomainSeparator::protocol(&main_config) + .session(&format!("Batch API Test at {}:{}", file!(), line!())) + .instance(&Empty); + let mut prover_state = ProverState::new_std(&ds); + + let _ = + main_config.batch_prove_zk(&mut prover_state, &whir_params, &prover_inputs, &mut rng); + let mut proof = prover_state.proof(); + + // Apply tampering + if tamper == Tamper::ProofBytes && proof.narg_string.len() > 10 { + let mid = proof.narg_string.len() / 2; + proof.narg_string[mid] ^= 0xFF; + } + + let mut claims: Vec> = prover_inputs + .iter() + .map(|g| g.to_verifier_input()) + .collect(); + + match tamper { + Tamper::NativeEval => { + if let Some(c) = claims + .iter_mut() + .find(|c| c.arity == n_min && !c.evaluations.is_empty()) + { + c.evaluations[0] += EF::from(42u64); + } + } + Tamper::PrefoldEval => { + if let Some(c) = claims + .iter_mut() + .find(|c| c.arity > n_min && !c.evaluations.is_empty()) + { + c.evaluations[0] += EF::from(42u64); + } + } + _ => {} + } + + // Verify + let mut verifier_state = VerifierState::new_std(&ds, &proof); + main_config + .batch_verify_zk(&mut verifier_state, &whir_params, &claims) + .is_ok() + } + + /// Various (num_variables, folding_factor) combos with 0 / 1 / 2 evaluation + /// constraints, single polynomial per group. + #[test] + fn test_batch_api_basic_configs() { + let configs: &[(usize, usize, usize)] = &[ + // (num_variables, folding_factor, num_points) + (10, 2, 0), + (10, 2, 1), + (10, 2, 2), + (12, 2, 1), + (12, 3, 1), + (12, 4, 1), + ]; + + for &(n, k, num_pts) in configs { + eprintln!(); + dbg!(n, k, num_pts); + + let ok = batch_api_roundtrip( + n, + FoldingFactor::Constant(k), + SoundnessType::ConjectureList, + 0, + &[(n, 1, num_pts)], + Tamper::None, + ); + assert!(ok, "failed for n={n}, k={k}, num_pts={num_pts}"); + } + } + + /// Multiple polynomials at native arity (batched proving / verification). + #[test] + fn test_batch_api_multi_polynomial() { + let configs: &[(usize, usize, usize, usize)] = &[ + // (num_variables, folding_factor, num_points, num_polynomials) + (10, 2, 1, 2), + (10, 2, 2, 2), + (12, 2, 1, 3), + (12, 3, 2, 2), + ]; + + for &(n, k, num_pts, num_polys) in configs { + eprintln!(); + dbg!(n, k, num_pts, num_polys); + + let ok = batch_api_roundtrip( + n, + FoldingFactor::Constant(k), + SoundnessType::ConjectureList, + 0, + &[(n, num_polys, num_pts)], + Tamper::None, + ); + assert!( + ok, + "failed for n={n}, k={k}, pts={num_pts}, polys={num_polys}" + ); + } + } + + /// Multi-arity prefold: polynomials across several arities, including a + /// group with zero constraints and fold depths 1–3. + #[test] + fn test_batch_api_multi_arity() { + let ok = batch_api_roundtrip( + 10, + FoldingFactor::Constant(2), + SoundnessType::ConjectureList, + 0, + &[ + (10, 2, 1), // native: 2 polynomials, 1 constraint + (11, 1, 1), // fold_depth = 1 + (12, 1, 1), // fold_depth = 2 + (13, 1, 0), // fold_depth = 3, no constraints + ], + Tamper::None, + ); + assert!(ok, "multi-arity prefold must verify"); + } + + /// Proof-of-work, alternative soundness types, and mixed folding factors. + #[test] + fn test_batch_api_advanced_configs() { + // PoW + let ok = batch_api_roundtrip( + 12, + FoldingFactor::Constant(2), + SoundnessType::ConjectureList, + 5, + &[(12, 1, 2)], + Tamper::None, + ); + assert!(ok, "PoW test failed"); + + // Soundness types + for st in [SoundnessType::ProvableList, SoundnessType::UniqueDecoding] { + let ok = batch_api_roundtrip( + 12, + FoldingFactor::Constant(2), + st, + 0, + &[(12, 1, 1)], + Tamper::None, + ); + assert!(ok, "soundness type {st:?} failed"); + } + + // Mixed folding + let ok = batch_api_roundtrip( + 12, + FoldingFactor::ConstantFromSecondRound(3, 3), + SoundnessType::ConjectureList, + 0, + &[(12, 1, 1)], + Tamper::None, + ); + assert!(ok, "mixed folding failed"); + } + + /// Soundness: tampered native eval, corrupted proof bytes, tampered prefold eval. + #[test] + fn test_batch_api_soundness() { + let groups = &[ + (10, 1, 1), // native + (11, 1, 1), // prefold + ]; + + for tamper in [Tamper::NativeEval, Tamper::ProofBytes, Tamper::PrefoldEval] { + eprintln!(); + dbg!(tamper); + + let ok = batch_api_roundtrip( + 10, + FoldingFactor::Constant(2), + SoundnessType::ConjectureList, + 0, + groups, + tamper, + ); + assert!(!ok, "verification must FAIL with tamper {tamper:?}"); + } + } +} diff --git a/src/protocols/whir_zk/prefold.rs b/src/protocols/whir_zk/prefold.rs new file mode 100644 index 00000000..c2543919 --- /dev/null +++ b/src/protocols/whir_zk/prefold.rs @@ -0,0 +1,688 @@ +//! Staged ZK Folding (Prefold) for mixed-arity polynomial batching. +//! +//! This module implements the prefold approach: given N polynomials at varying +//! arities, fold each down to the minimum arity using sumcheck-derived randomness, +//! then batch-prove all polynomials at the common arity. +//! +//! ## Architecture +//! +//! For each polynomial at arity L > n_min: +//! 1. Commit f̂ = f + msk at arity L (base field, ZK) +//! 2. Build P = ρ·f + g (blinded polynomial at arity L) +//! 3. Run prefold sumcheck on P's constraints → fold randomness +//! 4. Fold: P' = fold(P, fold_randomness) → arity n_min (extension field) +//! 5. Send P' coefficients in the clear (ZK-blinded by g) +//! 6. **Binding equation**: verify v' = Σ_i rlc_i · eq(a_i_high, r) · P'(a_i_low) +//! 7. Open f̂ → virtual oracle → fold → STIR consistency against P' +//! +//! The binding equation (step 6) closes the soundness gap by ensuring the +//! sumcheck's reduced claim matches the actual P' polynomial. The STIR +//! consistency check (step 7) ensures P' is the correct fold of the committed f̂. +//! +//! Then the main WHIR at arity n_min handles all native polynomials. + +use ark_ff::{FftField, PrimeField}; +use ark_std::rand::{CryptoRng, RngCore}; + +use super::utils::{ + interleave_helper_poly_refs, prepare_helper_polynomials, IrsDomainParams, ZkParams, + ZkPreprocessingPolynomials, ZkWitness, +}; +use crate::{ + algebra::{ + add_base_with_projection, dot, + polynomials::{CoefficientList, EvaluationsList, MultilinearPoint}, + Weights, + }, + bits::Bits, + hash::Hash, + parameters::ProtocolParameters, + protocols::{ + geometric_challenge::geometric_challenge, irs_commit, matrix_commit, proof_of_work, + sumcheck, whir::Config, + }, + transcript::{ + codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, + VerificationResult, VerifierMessage, VerifierState, + }, + type_info::Type, + utils::zip_strict, + verify, +}; + +/// Configuration for a single prefold arity level (polynomials at arity > n_min). +/// +/// Each distinct arity above n_min needs its own config: an IRS committer for the +/// original f̂ commitment, a helper WHIR config for the virtual oracle sub-proof, +/// and a sumcheck config. After folding, P' coefficients are sent in the clear +/// (they are ZK-blinded by construction) so no separate P' committer is needed. +#[derive(Clone)] +pub struct PrefoldLevelConfig { + /// IRS committer for f̂ at this arity (base field → extension field). + pub f_hat_committer: irs_commit::BasefieldConfig, + + /// Helper WHIR config for virtual oracle proof at this arity level. + pub helper_config: Config, + + /// ZK parameters (ℓ, μ) for this arity level. + pub zk_params: ZkParams, + + /// Sumcheck config for the prefold (folds `fold_depth` variables). + pub prefold_sumcheck: sumcheck::Config, + + /// Proof-of-work config for the prefold STIR queries. + pub prefold_pow: proof_of_work::Config, + + /// The arity at this level (number of variables). + pub arity: usize, + + /// Number of extra variables to fold away (arity − n_min). + pub fold_depth: usize, +} + +/// Input for a group of polynomials at the same prefold arity level. +pub struct PrefoldGroupInput<'a, F: FftField> { + /// Base-field polynomials at this arity level. + pub polynomials: &'a [&'a CoefficientList], + /// ZK witness (f̂ + helper commitments at this arity). + pub witness: &'a ZkWitness, + /// Constraint weights at this arity level. + pub weights: &'a [&'a Weights], + /// Evaluations: row-major \[w₀\_p₀, w₀\_p₁, ..., w₁\_p₀, ...\]. + pub evaluations: &'a [F], + /// Level config for this arity. + pub level_config: &'a PrefoldLevelConfig, +} + +impl PrefoldLevelConfig +where + F: crate::algebra::fields::FieldWithSize, +{ + /// Build a prefold level config for polynomials at `arity` > `n_min`. + /// + /// `main_config` is the WHIR config at n_min. `whir_params` provides the + /// security/folding parameters. + pub fn new(main_config: &Config, arity: usize, whir_params: &ProtocolParameters) -> Self { + let n_min = main_config.initial_num_variables(); + let fold_depth = arity.checked_sub(n_min).expect("arity must be > n_min"); + assert!(fold_depth > 0, "fold_depth must be positive"); + + // f̂ IRS committer at this arity (base field) + // Interleaving depth = 2^fold_depth so that folding gives 1 value per query. + let interleaving_depth = 1usize << fold_depth; + let polynomial_size = 1usize << arity; + let expansion = main_config.initial_committer.expansion; + let num_rows = polynomial_size * expansion / interleaving_depth; + + let f_hat_committer = irs_commit::Config { + embedding: Default::default(), + num_polynomials: whir_params.batch_size, + polynomial_size, + expansion, + interleaving_depth, + matrix_commit: matrix_commit::Config::with_hash( + whir_params.hash_id, + num_rows, + whir_params.batch_size * interleaving_depth, + ), + in_domain_samples: main_config.initial_committer.in_domain_samples, + out_domain_samples: main_config.initial_committer.out_domain_samples, + deduplicate_in_domain: true, + }; + + // ZK params for this arity + let zk_params = ZkParams::from_whir_params_with_arity(main_config, arity); + + // Helper WHIR config (shared with api.rs) + let helper_config = zk_params.build_helper_config(whir_params.batch_size, whir_params); + + // Prefold sumcheck + // Folds `fold_depth` variables of the blinded polynomial at arity L. + let prefold_sumcheck = sumcheck::Config { + field: Type::::new(), + initial_size: polynomial_size, + round_pow: proof_of_work::Config::from_difficulty(Bits::new(0.0)), + num_rounds: fold_depth, + }; + + // Prefold PoW (minimal for first implementation) + let prefold_pow = proof_of_work::Config::from_difficulty(Bits::new(0.0)); + + Self { + f_hat_committer, + helper_config, + zk_params, + prefold_sumcheck, + prefold_pow, + arity, + fold_depth, + } + } +} + +impl ZkParams { + /// Compute ZK parameters for a given arity, using the main config's query + /// parameters as reference. + pub fn from_whir_params_with_arity(main_config: &Config, arity: usize) -> Self { + let mu = arity; + let k = main_config.initial_committer.interleaving_depth; + let q1 = main_config + .round_configs + .first() + .map_or(main_config.initial_committer.in_domain_samples, |r| { + r.irs_committer.in_domain_samples + }); + + let q_ub = 2 * k * q1 + 4 * mu + 10; + let ell = (q_ub as f64).log2().ceil() as usize; + assert!( + ell < mu, + "ZK requires ℓ < μ (ℓ={ell}, μ={mu}). \ + Increase arity or lower security_level/queries." + ); + Self { ell, mu } + } +} + +/// Commit polynomials at a prefold arity level. +/// +/// This is analogous to `Config::commit_zk` but uses the prefold level's IRS +/// committer (at the higher arity) instead of the main config's. +pub fn commit_zk_at_level( + level_config: &PrefoldLevelConfig, + prover_state: &mut ProverState, + polynomials: &[&CoefficientList], + preprocessings: &[&ZkPreprocessingPolynomials], +) -> ZkWitness +where + F: FftField, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, +{ + // 1. Commit f̂ᵢ = fᵢ + mskᵢ at the prefold arity + let mut f_hat_witnesses = Vec::new(); + for (polynomial, preprocessing) in zip_strict(polynomials, preprocessings) { + let f_hat_coeffs = + add_base_with_projection::(polynomial.coeffs(), preprocessing.msk.coeffs()); + let f_hat = CoefficientList::new(f_hat_coeffs); + let poly_refs: Vec<&[F::BasePrimeField]> = vec![f_hat.coeffs()]; + let f_hat_witness = level_config + .f_hat_committer + .commit(prover_state, &poly_refs); + f_hat_witnesses.push(f_hat_witness); + } + + // 2. Prepare helper polynomials (shared with commit_zk) + let (m_polys_base, g_hats_embedded_bases) = prepare_helper_polynomials(preprocessings); + + // 3. Batch-commit helpers via the level's helper config + let helper_poly_refs = interleave_helper_poly_refs::(&m_polys_base, &g_hats_embedded_bases); + let helper_witness = level_config + .helper_config + .commit(prover_state, &helper_poly_refs); + + ZkWitness { + f_hat_witnesses, + helper_witness, + preprocessings: preprocessings.iter().copied().cloned().collect(), + m_polys_base, + g_hats_embedded_bases, + } +} + +/// Receive commitments for a prefold level on the verifier side. +pub fn receive_prefold_commitments( + level_config: &PrefoldLevelConfig, + verifier_state: &mut VerifierState<'_, H>, + num_polynomials: usize, +) -> VerificationResult<(Vec>, irs_commit::Commitment)> +where + F: FftField, + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, +{ + // Read f̂ commitments (one per polynomial) + let f_hat_commitments: Vec<_> = (0..num_polynomials) + .map(|_| { + level_config + .f_hat_committer + .receive_commitment(verifier_state) + }) + .collect::>()?; + + // Read helper batch commitment + let helper_commitment = level_config + .helper_config + .receive_commitment(verifier_state)?; + + Ok((f_hat_commitments, helper_commitment)) +} + +impl Config { + /// Full ZK prefold + prove pipeline for mixed-arity polynomials. + /// + /// # Architecture + /// + /// 1. Phase 1: ZK blinding (shared β, per-group g evaluations, shared ρ) + /// 2. Phase 2: For each prefold group (highest arity first): + /// - RLC polynomials at this level + /// - Prefold sumcheck (if constraints) → fold randomness + /// - Fold P → P' at n_min + /// - Commit P' via extension-field IRS + /// - Open f̂ → virtual oracle → fold → STIR consistency values + /// 3. Phase 3: Standard prove_zk on native polynomials + /// + /// # Arguments + /// + /// * `self` — Main WHIR config at arity n_min. + /// * `native_polys`, `native_witness`, etc. — Native-arity group. + /// * `prefold_groups` — Higher-arity groups sorted by decreasing arity. + #[allow(clippy::too_many_arguments, clippy::too_many_lines)] + pub fn prove_zk_prefold( + &self, + prover_state: &mut ProverState, + // Native group (arity = n_min) + native_polys: &[&CoefficientList], + native_witness: &ZkWitness, + native_helper_config: &Config, + native_weights: &[&Weights], + native_evals: &[F], + // Prefold groups (highest arity first) + prefold_groups: &[PrefoldGroupInput<'_, F>], + ) -> (MultilinearPoint, Vec) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + u8: Decoding<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let n_min = self.initial_num_variables(); + let num_native = native_polys.len(); + + // ================================================================ + // Phase 1: ZK blinding — shared β, per-group g evals, shared ρ + // ================================================================ + let beta: F = prover_state.verifier_message(); + + // Build g and send evaluations for native group + let mut native_g_polys = Vec::with_capacity(num_native); + let mu_native = native_witness.preprocessings[0].params.mu; + for preprocessing in &native_witness.preprocessings { + native_g_polys.push(self.build_blinding_polynomial(preprocessing, mu_native, beta)); + } + let mut native_g_eval_matrix = vec![F::ZERO; native_weights.len() * num_native]; + for (i, weight) in native_weights.iter().enumerate() { + for (j, g_poly) in native_g_polys.iter().enumerate() { + let eval = weight.evaluate(g_poly); + prover_state.prover_message(&eval); + native_g_eval_matrix[i * num_native + j] = eval; + } + } + + // Build g and send evaluations for each prefold group + let mut prefold_g_polys_per_group: Vec>> = Vec::new(); + let mut prefold_g_eval_matrices: Vec> = Vec::new(); + for group in prefold_groups { + let mu_level = group.level_config.arity; + let num_polys = group.polynomials.len(); + let mut g_polys = Vec::with_capacity(num_polys); + for preprocessing in &group.witness.preprocessings { + g_polys.push(self.build_blinding_polynomial(preprocessing, mu_level, beta)); + } + let mut g_eval_matrix = vec![F::ZERO; group.weights.len() * num_polys]; + for (i, weight) in group.weights.iter().enumerate() { + for (j, g_poly) in g_polys.iter().enumerate() { + let eval = weight.evaluate(g_poly); + prover_state.prover_message(&eval); + g_eval_matrix[i * num_polys + j] = eval; + } + } + prefold_g_polys_per_group.push(g_polys); + prefold_g_eval_matrices.push(g_eval_matrix); + } + + let rho: F = prover_state.verifier_message(); + + // Build P = ρ·f + g for each polynomial + let mut native_p_polys: Vec> = Vec::with_capacity(num_native); + for (polynomial, g_poly) in zip_strict(native_polys.iter(), native_g_polys.into_iter()) { + native_p_polys.push(self.build_blinded_polynomial_p(g_poly, polynomial, rho)); + } + + let mut prefold_p_polys: Vec>> = Vec::new(); + for (group_idx, group) in prefold_groups.iter().enumerate() { + let g_polys = std::mem::take(&mut prefold_g_polys_per_group[group_idx]); + let mut p_polys = Vec::new(); + for (polynomial, g_poly) in zip_strict(group.polynomials.iter(), g_polys.into_iter()) { + p_polys.push(self.build_blinded_polynomial_p(g_poly, polynomial, rho)); + } + prefold_p_polys.push(p_polys); + } + + // ================================================================ + // Phase 2: Prefold stages — fold each group to arity n_min + // ================================================================ + for (group_idx, group) in prefold_groups.iter().enumerate() { + let level_config = group.level_config; + let num_polys = group.polynomials.len(); + let fold_depth = level_config.fold_depth; + + // RLC polynomials at this level + let level_poly_rlc: Vec = geometric_challenge(prover_state, num_polys); + let mut p_combined = { + let p_polys = &prefold_p_polys[group_idx]; + let mut acc = CoefficientList::new(vec![F::ZERO; p_polys[0].num_coeffs()]); + for (rlc, poly) in zip_strict(&level_poly_rlc, p_polys) { + crate::algebra::scalar_mul_add(acc.coeffs_mut(), *rlc, poly.coeffs()); + } + acc + }; + + // Modified evaluations + let modified_evals: Vec = group + .evaluations + .iter() + .zip(prefold_g_eval_matrices[group_idx].iter()) + .map(|(&eval, &g_eval)| rho * eval + g_eval) + .collect(); + + // Prefold sumcheck or random fold + let fold_randomness; + let has_constraints = !group.weights.is_empty(); + if has_constraints { + let constraint_rlc: Vec = geometric_challenge(prover_state, group.weights.len()); + let mut constraints = EvaluationsList::new(vec![F::ZERO; 1 << level_config.arity]); + for (rlc, weight) in zip_strict(&constraint_rlc, group.weights) { + weight.accumulate(&mut constraints, *rlc); + } + let mut the_sum: F = + zip_strict(&constraint_rlc, modified_evals.chunks_exact(num_polys)) + .map(|(w, row)| *w * dot(&level_poly_rlc, row)) + .sum(); + + let mut eval_list = EvaluationsList::from(p_combined.clone()); + fold_randomness = level_config.prefold_sumcheck.prove( + prover_state, + &mut eval_list, + &mut constraints, + &mut the_sum, + ); + + // Fold P → P' at arity n_min + p_combined.fold_in_place(&fold_randomness); + let p_prime_ref = &p_combined; + debug_assert_eq!(p_prime_ref.num_variables(), n_min); + + // Prover-side binding equation sanity check + #[cfg(debug_assertions)] + { + let mut expected = F::ZERO; + for (i, weight) in group.weights.iter().enumerate() { + if let Weights::Evaluation { ref point } = weight { + let a_low = MultilinearPoint(point.0[..n_min].to_vec()); + let a_high = MultilinearPoint(point.0[n_min..].to_vec()); + let eq_factor = a_high.eq_poly_outside(&fold_randomness); + let p_prime_eval = p_prime_ref.evaluate(&a_low); + expected += constraint_rlc[i] * eq_factor * p_prime_eval; + } + } + assert_eq!(the_sum, expected, "[PROVER] Binding equation mismatch"); + } + } else { + // No constraints — sample fold randomness directly + let r: Vec = (0..fold_depth) + .map(|_| prover_state.verifier_message()) + .collect(); + level_config.prefold_pow.prove(prover_state); + fold_randomness = MultilinearPoint(r); + + // Fold P → P' at arity n_min + p_combined.fold_in_place(&fold_randomness); + debug_assert_eq!(p_combined.num_variables(), n_min); + }; + let p_prime = p_combined; + + // Send P' coefficients in the clear + // P' = fold(ρ·f + g, r) is ZK-blinded by g, so revealing + // coefficients does not leak information about f. + // The coefficients are absorbed into the Fiat-Shamir transcript, + // binding the prover to this specific P'. + { + let p_prime_coeffs = p_prime.coeffs(); + let num_coeffs = p_prime_coeffs.len(); + let base_field_size = (F::BasePrimeField::MODULUS_BIT_SIZE.div_ceil(8)) as usize; + let elem_bytes = base_field_size * F::extension_degree() as usize; + let mut encoded = Vec::with_capacity(num_coeffs * elem_bytes); + for c in p_prime_coeffs { + crate::transcript::encode_field_element_into(c, &mut encoded); + } + prover_state.prover_messages_bytes::(num_coeffs, &encoded); + } + + // PoW + level_config.prefold_pow.prove(prover_state); + + // Open f̂ at native arity → virtual oracle → fold → STIR consistency + let f_hat_refs: Vec<_> = group.witness.f_hat_witnesses.iter().collect(); + let in_domain_base = level_config.f_hat_committer.open(prover_state, &f_hat_refs); + + // Prove helper evaluations for the virtual oracle at this level + // (shared implementation with the native ZK prover) + let domain = IrsDomainParams::from_irs_committer(&level_config.f_hat_committer); + super::prover::prove_helper_evaluations( + prover_state, + &domain, + &in_domain_base, + group.witness, + &level_config.helper_config, + rho, + self.embedding(), + ); + } + + // ================================================================ + // Phase 3: Standard prove_zk on native polynomials + // ================================================================ + // The prefold groups are fully proven by their sumcheck + STIR consistency. + // The native group is proven by the standard ZK-WHIR protocol. + self.prove_zk( + prover_state, + native_polys, + native_witness, + native_helper_config, + native_weights, + native_evals, + ) + } +} + +/// Commitments for a prefold group as seen by the verifier. +pub struct PrefoldGroupCommitments { + /// f̂ commitments at this level's arity. + pub f_hat_commitments: Vec>, + /// Helper commitment for this level. + pub helper_commitment: irs_commit::Commitment, +} + +impl Config { + /// Verify a ZK prefold + prove proof for mixed-arity polynomials. + /// + /// Mirrors `prove_zk_prefold` on the verifier side. + #[allow(clippy::too_many_arguments, clippy::too_many_lines)] + pub fn verify_zk_prefold( + &self, + verifier_state: &mut VerifierState<'_, H>, + // Native group + native_f_hat_commitments: &[&irs_commit::Commitment], + native_helper_commitment: &irs_commit::Commitment, + native_helper_config: &Config, + native_zk_params: &ZkParams, + native_weights: &[&Weights], + native_evals: &[F], + // Prefold groups (same order as prover: highest arity first) + prefold_groups: &[(&PrefoldGroupCommitments, &PrefoldLevelConfig)], + prefold_group_weights: &[&[&Weights]], + prefold_group_evals: &[&[F]], + prefold_num_polys: &[usize], + ) -> VerificationResult<(MultilinearPoint, Vec)> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + // ================================================================ + // Phase 1: Read β, g evals, ρ + // ================================================================ + let beta: F = verifier_state.verifier_message(); + + // Read native g evaluations + let num_native = native_f_hat_commitments.len(); + let _native_g_evals: Vec = + verifier_state.prover_messages_vec(native_weights.len() * num_native)?; + + // Read prefold g evaluations + let mut prefold_g_evals_per_group: Vec> = Vec::new(); + for (group_idx, &num_polys) in prefold_num_polys.iter().enumerate() { + let num_weights = prefold_group_weights[group_idx].len(); + let g_evals: Vec = verifier_state.prover_messages_vec(num_weights * num_polys)?; + prefold_g_evals_per_group.push(g_evals); + } + + let rho: F = verifier_state.verifier_message(); + + // ================================================================ + // Phase 2: Verify prefold stages + // ================================================================ + let n_min = self.initial_num_variables(); + + for (group_idx, (commitments, level_config)) in prefold_groups.iter().enumerate() { + let weights = prefold_group_weights[group_idx]; + let evals = prefold_group_evals[group_idx]; + let num_polys = prefold_num_polys[group_idx]; + let fold_depth = level_config.fold_depth; + + // Modified evaluations for this group + let modified_evals: Vec = + zip_strict(evals.iter(), prefold_g_evals_per_group[group_idx].iter()) + .map(|(&eval, &g_eval)| rho * eval + g_eval) + .collect(); + + // RLC + let level_poly_rlc: Vec = geometric_challenge(verifier_state, num_polys); + + // Prefold sumcheck or random fold + let (fold_randomness, reduced_sum, constraint_rlc) = if !weights.is_empty() { + let constraint_rlc: Vec = geometric_challenge(verifier_state, weights.len()); + let mut the_sum: F = + zip_strict(&constraint_rlc, modified_evals.chunks_exact(num_polys)) + .map(|(w, row)| *w * dot(&level_poly_rlc, row)) + .sum(); + + let fold_rand = level_config + .prefold_sumcheck + .verify(verifier_state, &mut the_sum)?; + (fold_rand, Some(the_sum), Some(constraint_rlc)) + } else { + let r: Vec = (0..fold_depth) + .map(|_| verifier_state.verifier_message()) + .collect(); + level_config.prefold_pow.verify(verifier_state)?; + (MultilinearPoint(r), None, None) + }; + + // Read P' coefficients (sent in the clear by the prover) + let p_prime_coeffs: Vec = verifier_state.read_prover_messages_bytes(1 << n_min)?; + let p_prime = CoefficientList::new(p_prime_coeffs); + + // Binding equation check + // After the prefold sumcheck, the reduced sum v' must satisfy: + // v' = Σ_i constraint_rlc[i] · eq(a_i_high, r) · P'(a_i_low) + // where a_i is the evaluation point for weight i, split into + // a_i_high (first fold_depth components) and a_i_low (rest). + if let (Some(v_prime), Some(ref c_rlc)) = (reduced_sum, &constraint_rlc) { + let mut expected = F::ZERO; + for (i, weight) in weights.iter().enumerate() { + let point = match weight { + Weights::Evaluation { point } => point, + _ => panic!( + "prefold binding equation requires Weights::Evaluation; \ + got a non-evaluation weight at index {i}" + ), + }; + // In MultilinearPoint, point[0] is the MSB (x_{L-1}). + // The fold eliminates the LSB variables x_0,...,x_{d-1} + // which are the LAST fold_depth elements: point[n_min..]. + // The remaining n_min MSB variables are point[..n_min]. + let a_low = MultilinearPoint(point.0[..n_min].to_vec()); + let a_high = MultilinearPoint(point.0[n_min..].to_vec()); + let eq_factor = a_high.eq_poly_outside(&fold_randomness); + let p_prime_eval = p_prime.evaluate(&a_low); + expected += c_rlc[i] * eq_factor * p_prime_eval; + } + verify!(v_prime == expected); + } + + // PoW + level_config.prefold_pow.verify(verifier_state)?; + + // Verify f̂ opening + let f_hat_refs: Vec<&irs_commit::Commitment> = + commitments.f_hat_commitments.iter().collect(); + let in_domain_base = level_config + .f_hat_committer + .verify(verifier_state, &f_hat_refs)?; + + // Verify helper evaluations and reconstruct virtual oracle + // (shared implementation with the native ZK verifier) + let domain = IrsDomainParams::from_irs_committer(&level_config.f_hat_committer); + let virtual_values = super::verifier::verify_helper_evaluations( + verifier_state, + &domain, + &in_domain_base, + &commitments.helper_commitment, + &level_config.helper_config, + &level_config.zk_params, + rho, + beta, + &fold_randomness, + num_polys, + &level_poly_rlc, + self.embedding(), + )?; + + // ── STIR consistency check ── + // The folded virtual oracle at each query point α must match + // P' evaluated at the multilinear expansion of α. + // Since we have P' coefficients, we evaluate directly (no hints). + for (qi, &alpha_base) in in_domain_base.points.iter().enumerate() { + let alpha_ext: F = + crate::algebra::embedding::Embedding::map(self.embedding(), alpha_base); + let point = MultilinearPoint::expand_from_univariate(alpha_ext, n_min); + let p_prime_at_point = p_prime.evaluate(&point); + verify!(virtual_values[qi] == p_prime_at_point); + } + } + + // ================================================================ + // Phase 3: Verify main WHIR on native polynomials + // ================================================================ + self.verify_zk( + verifier_state, + native_f_hat_commitments, + native_helper_commitment, + native_helper_config, + native_zk_params, + native_weights, + native_evals, + ) + } +} diff --git a/src/protocols/whir_zk/prover.rs b/src/protocols/whir_zk/prover.rs new file mode 100644 index 00000000..3d46e865 --- /dev/null +++ b/src/protocols/whir_zk/prover.rs @@ -0,0 +1,533 @@ +use ark_ff::{FftField, PrimeField}; +use ark_std::rand::{CryptoRng, RngCore}; +#[cfg(feature = "tracing")] +use tracing::instrument; + +use super::utils::{ + compute_per_polynomial_claims, construct_batched_eq_weights, interleave_helper_poly_refs, + IrsDomainParams, ZkWitness, +}; +use crate::{ + algebra::{ + dot, mixed_scalar_mul_add, + polynomials::{CoefficientList, EvaluationsList, MultilinearPoint}, + Weights, + }, + hash::Hash, + protocols::{geometric_challenge::geometric_challenge, irs_commit, whir::Config}, + transcript::{ + self, codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, + VerifierMessage, + }, + utils::zip_strict, +}; + +impl Config { + /// Prove a ZK WHIR opening. + /// + /// This proves knowledge of a polynomial `f` by: + /// 1. Blinding with g to form P = ρ·f + g + /// 2. Running WHIR rounds on P with a virtual oracle L = ρ·f̂ + h + /// 3. Proving helper polynomial evaluations so verifier can reconstruct L + #[cfg_attr(feature = "tracing", instrument(skip_all, fields(num_polynomials = polynomials.len())))] + pub fn prove_zk( + &self, + prover_state: &mut ProverState, + polynomials: &[&CoefficientList], + witness: &ZkWitness, + helper_config: &Config, + weights: &[&Weights], + evaluations: &[F], + ) -> (MultilinearPoint, Vec) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + u8: Decoding<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + #[cfg(feature = "alloc-track")] + let mut __snap = crate::alloc_snap!(); + + let mu = witness.preprocessings[0].params.mu; + let num_polys = polynomials.len(); + + // Phase 1: ZK blinding setup — build g, evaluate at constraints, form P = ρ·f + g + let beta: F = prover_state.verifier_message(); + let mut g_polys = Vec::with_capacity(num_polys); + for (_polynomial, preprocessing) in zip_strict(polynomials, &witness.preprocessings) { + let g_poly = self.build_blinding_polynomial(preprocessing, mu, beta); + g_polys.push(g_poly); + } + + // Evaluate each gⱼ at each weight and send to verifier. + // Layout: row-major [weight₀_poly₀, weight₀_poly₁, ..., weight₁_poly₀, ...] + // This matches the evaluations matrix layout. + let mut g_eval_matrix = vec![F::ZERO; weights.len() * num_polys]; + for (i, weight) in weights.iter().enumerate() { + for (j, g_poly) in g_polys.iter().enumerate() { + let eval = weight.evaluate(g_poly); + prover_state.prover_message(&eval); + g_eval_matrix[i * num_polys + j] = eval; + } + } + + let rho: F = prover_state.verifier_message(); + + // Build Pᵢ = ρ·fᵢ + gᵢ for each polynomial + let mut p_polys = Vec::with_capacity(num_polys); + for (polynomial, g_poly) in zip_strict(polynomials, g_polys) { + let p_poly = self.build_blinded_polynomial_p(g_poly, polynomial, rho); + p_polys.push(p_poly); + } + + // RLC the polynomials: P₀ = Σ αᵢ · Pᵢ + let polynomial_rlc_coeffs: Vec = geometric_challenge(prover_state, num_polys); + let mut p_poly = { + let mut iter = p_polys.into_iter(); + let mut acc = iter.next().unwrap(); + for (rlc_coeff, src_poly) in zip_strict(&polynomial_rlc_coeffs[1..], iter) { + crate::algebra::scalar_mul_add(acc.coeffs_mut(), *rlc_coeff, src_poly.coeffs()); + } + acc + }; + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("prove_zk::phase1_blinding_setup", __snap); + + // Phase 2: Build modified evaluations and run initial sumcheck + // modified_evaluations[w * N + p] = ρ · evaluations[w * N + p] + g_eval_matrix[w * N + p] + let modified_evaluations: Vec = evaluations + .iter() + .zip(g_eval_matrix.iter()) + .map(|(&eval, &g_eval)| rho * eval + g_eval) + .collect(); + let constraint_rlc_coeffs: Vec = geometric_challenge(prover_state, weights.len()); + let mut constraints = EvaluationsList::new(vec![F::ZERO; self.initial_size()]); + for (rlc_coeff, weight) in zip_strict(&constraint_rlc_coeffs, weights) { + weight.accumulate(&mut constraints, *rlc_coeff); + } + + // Compute "The Sum": Σ_w rlc_w * dot(poly_rlc, modified_evaluations[w*N..(w+1)*N]) + let mut the_sum: F = zip_strict( + &constraint_rlc_coeffs, + modified_evaluations.chunks_exact(num_polys), + ) + .map(|(weight_coeff, row)| *weight_coeff * dot(&polynomial_rlc_coeffs, row)) + .sum(); + + let mut eval_list = EvaluationsList::from(p_poly.clone()); + let mut folding_randomness = if constraint_rlc_coeffs.is_empty() { + let fr = (0..self.initial_sumcheck.num_rounds) + .map(|_| prover_state.verifier_message()) + .collect(); + self.initial_sumcheck.round_pow.prove(prover_state); + constraints = EvaluationsList::new(vec![F::ZERO; self.initial_sumcheck.final_size()]); + MultilinearPoint(fr) + } else { + self.initial_sumcheck.prove( + prover_state, + &mut eval_list, + &mut constraints, + &mut the_sum, + ) + }; + + p_poly.fold_in_place(&folding_randomness); + let mut coefficients = p_poly; + if constraint_rlc_coeffs.is_empty() { + eval_list = EvaluationsList::from(coefficients.clone()); + } + let mut randomness_vec = Vec::with_capacity(mu); + randomness_vec.extend(folding_randomness.0.iter().rev().copied()); + debug_assert_eq!(eval_list, EvaluationsList::from(coefficients.clone())); + debug_assert_eq!(dot(eval_list.evals(), constraints.evals()), the_sum); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("prove_zk::phase2_initial_sumcheck", __snap); + + // Phase 3: WHIR round loop + let mut prev_is_initial = true; + let mut prev_round_witness: Option> = None; + + for (round_index, round_config) in self.round_configs.iter().enumerate() { + let round_witness = round_config + .irs_committer + .commit(prover_state, &[coefficients.coeffs()]); + round_config.pow.prove(prover_state); + + let num_variables = round_config.initial_num_variables(); + + let (in_domain, stir_evaluations) = if prev_is_initial { + self.open_initial_zk_round( + prover_state, + witness, + helper_config, + rho, + &coefficients, + &round_witness, + num_variables, + ) + } else { + self.open_subsequent_round( + prover_state, + round_index, + prev_round_witness.as_ref().unwrap(), + &round_witness, + &folding_randomness, + ) + }; + + let stir_challenges: Vec<_> = round_witness + .out_of_domain() + .weights(num_variables) + .chain(in_domain.weights(num_variables)) + .collect(); + + let stir_rlc_coeffs = geometric_challenge(prover_state, stir_challenges.len()); + for (coeff, w) in zip_strict(&stir_rlc_coeffs, &stir_challenges) { + w.accumulate(&mut constraints, *coeff); + } + the_sum += dot(&stir_rlc_coeffs, &stir_evaluations); + debug_assert_eq!(eval_list, EvaluationsList::from(coefficients.clone())); + debug_assert_eq!(dot(eval_list.evals(), constraints.evals()), the_sum); + + folding_randomness = round_config.sumcheck.prove( + prover_state, + &mut eval_list, + &mut constraints, + &mut the_sum, + ); + coefficients.fold_in_place(&folding_randomness); + randomness_vec.extend(folding_randomness.0.iter().rev()); + debug_assert_eq!(eval_list, EvaluationsList::from(coefficients.clone())); + debug_assert_eq!(dot(eval_list.evals(), constraints.evals()), the_sum); + + prev_is_initial = false; + prev_round_witness = Some(round_witness); + } + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("prove_zk::phase3_whir_rounds", __snap); + + // Phase 4: Final round — send coefficients, PoW, open last commitment + self.send_final_coefficients(prover_state, &coefficients); + self.final_pow.prove(prover_state); + + if prev_is_initial { + let f_hat_refs: Vec<_> = witness.f_hat_witnesses.iter().collect(); + let in_domain_base = self.initial_committer.open(prover_state, &f_hat_refs); + self.prove_zk_helper_evaluations( + prover_state, + &in_domain_base, + witness, + helper_config, + rho, + ); + } else { + let prev_config = self.round_configs.last().unwrap(); + let _in_domain = prev_config + .irs_committer + .open(prover_state, &[prev_round_witness.as_ref().unwrap()]); + } + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("prove_zk::phase4_final_opening", __snap); + + // Phase 5: Final sumcheck and deferred constraint hints + let final_folding_randomness = + self.final_sumcheck + .prove(prover_state, &mut eval_list, &mut constraints, &mut the_sum); + randomness_vec.extend(final_folding_randomness.0.iter().rev()); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("prove_zk::phase5_final_sumcheck", __snap); + + self.compute_deferred_hints(prover_state, weights, &randomness_vec) + } + + /// Build the blinding polynomial g(X) = g₀(X) + Σᵢ₌₁^μ βⁱ · X^(2^(i-1)) · ĝᵢ(X) + /// + /// Returns the blinding polynomial g as a `CoefficientList`. + pub(crate) fn build_blinding_polynomial( + &self, + preprocessing: &super::utils::ZkPreprocessingPolynomials, + mu: usize, + beta: F, + ) -> CoefficientList { + let poly_size = 1 << mu; + let mut coeffs = vec![F::ZERO; poly_size]; + let g0_coeffs = preprocessing.g0_hat.coeffs(); + coeffs[..g0_coeffs.len()].copy_from_slice(g0_coeffs); + + let mut beta_power = beta; + for i in 1..=mu { + let shift = 1 << (i - 1); + let g_hat_coeffs = preprocessing.g_hats[i - 1].coeffs(); + let target = &mut coeffs[shift..shift + g_hat_coeffs.len()]; + crate::algebra::scalar_mul_add(target, beta_power, g_hat_coeffs); + beta_power *= beta; + } + + CoefficientList::new(coeffs) + } + + /// Transform g → P = ρ·f + g in-place: P(X) = ρ·embed(f(X)) + g(X). + pub(crate) fn build_blinded_polynomial_p( + &self, + g_poly: CoefficientList, + polynomial: &CoefficientList, + rho: F, + ) -> CoefficientList { + let mut coeffs = g_poly.into_coeffs(); + let f_coeffs = polynomial.coeffs(); + mixed_scalar_mul_add( + self.embedding(), + &mut coeffs[..f_coeffs.len()], + rho, + f_coeffs, + ); + CoefficientList::new(coeffs) + } + + /// Open the initial f̂ commitment in ZK mode: open f̂, prove helper evaluations, + /// and compute virtual oracle folded values. + /// + /// Returns `(in_domain_evaluations, stir_evaluation_values)`. + fn open_initial_zk_round( + &self, + prover_state: &mut ProverState, + witness: &ZkWitness, + helper_config: &Config, + rho: F, + coefficients: &CoefficientList, + round_witness: &irs_commit::Witness, + num_variables: usize, + ) -> (irs_commit::Evaluations, Vec) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + u8: Decoding<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let f_hat_refs: Vec<_> = witness.f_hat_witnesses.iter().collect(); + let in_domain_base = self.initial_committer.open(prover_state, &f_hat_refs); + + self.prove_zk_helper_evaluations( + prover_state, + &in_domain_base, + witness, + helper_config, + rho, + ); + + let in_domain = in_domain_base.lift(self.embedding()); + + // Virtual oracle values: evaluate folded P at each query point. + // L and P agree on the evaluation domain, so fold_k(L, r̄)(α) = P_folded(α). + let virtual_values: Vec = in_domain + .points + .iter() + .map(|&alpha| { + let point = MultilinearPoint::expand_from_univariate(alpha, num_variables); + coefficients.evaluate(&point) + }) + .collect(); + + let evals: Vec = round_witness + .out_of_domain() + .values(&[F::ONE]) + .chain(virtual_values) + .collect(); + + (in_domain, evals) + } + + /// Open a subsequent (non-initial) round's commitment. + /// + /// Returns `(in_domain_evaluations, stir_evaluation_values)`. + fn open_subsequent_round( + &self, + prover_state: &mut ProverState, + round_index: usize, + prev_witness: &irs_commit::Witness, + round_witness: &irs_commit::Witness, + folding_randomness: &MultilinearPoint, + ) -> (irs_commit::Evaluations, Vec) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + u8: Decoding<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let prev_round_config = &self.round_configs[round_index - 1]; + let in_domain = prev_round_config + .irs_committer + .open(prover_state, &[prev_witness]); + + let evals: Vec = round_witness + .out_of_domain() + .values(&[F::ONE]) + .chain(in_domain.values(&folding_randomness.coeff_weights(true))) + .collect(); + + (in_domain, evals) + } + + /// Prove helper polynomial evaluations for the ZK virtual oracle. + /// + /// Thin wrapper around [`prove_helper_evaluations`] that derives the IRS + /// domain from this config's initial committer. + fn prove_zk_helper_evaluations( + &self, + prover_state: &mut ProverState, + in_domain_base: &irs_commit::Evaluations, + witness: &ZkWitness, + helper_config: &Config, + rho: F, + ) where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + u8: Decoding<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let domain = IrsDomainParams::::from_config(self); + prove_helper_evaluations( + prover_state, + &domain, + in_domain_base, + witness, + helper_config, + rho, + self.embedding(), + ); + } +} + +/// Prove helper polynomial evaluations for the ZK virtual oracle. +/// +/// This is the shared implementation used by both the main ZK prover +/// (via `Config::prove_zk_helper_evaluations`) and the prefold prover. +/// +/// Given the IRS opening of f̂, this: +/// 1. Computes gamma points (coset elements) for each query +/// 2. For each polynomial, batch-evaluates M, ĝ₁, ..., ĝμ at all gamma points +/// 3. Sends evaluations to the verifier (gamma-major, polynomial-minor order) +/// 4. Runs a helper WHIR proof to bind evaluations to committed polynomials +/// +/// For N polynomials, the helper WHIR covers N×(μ+1) polynomials in one batch. +pub(crate) fn prove_helper_evaluations( + prover_state: &mut ProverState, + domain: &IrsDomainParams, + in_domain_base: &irs_commit::Evaluations, + witness: &ZkWitness, + helper_config: &Config, + rho: F, + embedding: &M, +) where + F: FftField, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + u8: Decoding<[H::U]>, + Hash: ProverMessage<[H::U]>, + M: crate::algebra::embedding::Embedding, +{ + #[cfg(feature = "alloc-track")] + let mut __snap = crate::alloc_snap!(); + + let num_polys = witness.preprocessings.len(); + let mu = witness.preprocessings[0].params.mu; + let ell = witness.preprocessings[0].params.ell; + + // Compute gammas: for each query point, produce k coset elements + // These are the SAME for all polynomials (derived from IRS domain structure). + let gammas = domain.all_gammas(&in_domain_base.points, embedding); + + // For each polynomial, batch-evaluate all helper polynomials at all gamma points + let helper_evals_per_poly: Vec>> = witness + .preprocessings + .iter() + .map(|preprocessing| preprocessing.batch_evaluate_helpers(&gammas, rho)) + .collect(); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!(" helper_evals::batch_evaluate", __snap); + + // Send helper evaluations as a single batch message. + // Order: for each gamma, for each polynomial: m_eval, then mu g_hat_evals. + // This groups all polynomial data per-gamma for natural virtual oracle reconstruction. + let evals_per_point = 1 + mu; // m_eval + mu g_hat_evals + let total_evals = gammas.len() * num_polys * evals_per_point; + let base_field_size = (F::BasePrimeField::MODULUS_BIT_SIZE.div_ceil(8)) as usize; + let elem_bytes = base_field_size * F::extension_degree() as usize; + let mut encoded = Vec::with_capacity(total_evals * elem_bytes); + for gamma_idx in 0..gammas.len() { + for poly_idx in 0..num_polys { + let helper_eval = &helper_evals_per_poly[poly_idx][gamma_idx]; + transcript::encode_field_element_into(&helper_eval.m_eval, &mut encoded); + for g_hat_eval in &helper_eval.g_hat_evals { + transcript::encode_field_element_into(g_hat_eval, &mut encoded); + } + } + } + prover_state.prover_messages_bytes::(total_evals, &encoded); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!(" helper_evals::send_to_verifier", __snap); + + // Sample τ₂ for combining query points + let tau2: F = prover_state.verifier_message(); + + // Construct batched eq weights (uses gammas which are same for all polynomials) + let beq_weights = construct_batched_eq_weights(&helper_evals_per_poly[0], rho, tau2, ell); + + // Compute per-polynomial claims and collect evaluations + // Layout: [m₁_claim, ĝ₁₁_claim, ..., ĝ₁μ_claim, m₂_claim, ĝ₂₁_claim, ..., ĝ₂μ_claim, ...] + let mut all_evaluations: Vec = Vec::with_capacity(num_polys * (1 + mu)); + for poly_idx in 0..num_polys { + let (m_claim, g_hat_claims) = + compute_per_polynomial_claims(&helper_evals_per_poly[poly_idx], tau2); + all_evaluations.push(m_claim); + all_evaluations.extend_from_slice(&g_hat_claims); + } + + // Collect all helper polynomials (base-field): + // [M₁, ĝ₁₁, ..., ĝ₁μ, M₂, ĝ₂₁, ..., ĝ₂μ, ...] + let all_polynomials = + interleave_helper_poly_refs::(&witness.m_polys_base, &witness.g_hats_embedded_bases); + + // Single batch witness (helper_config.batch_size = N×(μ+1)) + let all_witnesses: Vec<&irs_commit::Witness> = + vec![&witness.helper_witness]; + + let weight_refs: Vec<&Weights> = vec![&beq_weights]; + + #[cfg(feature = "alloc-track")] + crate::alloc_report!(" helper_evals::build_weights_claims", __snap); + + // Run helper WHIR proof with existing batch commitment + helper_config.prove( + prover_state, + &all_polynomials, + &all_witnesses, + &weight_refs, + &all_evaluations, + ); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!(" helper_evals::helper_whir_prove", __snap); +} diff --git a/src/protocols/whir_zk/utils.rs b/src/protocols/whir_zk/utils.rs new file mode 100644 index 00000000..b87013b1 --- /dev/null +++ b/src/protocols/whir_zk/utils.rs @@ -0,0 +1,497 @@ +use ark_ff::{FftField, Field}; +use ark_std::{ + rand::{CryptoRng, RngCore}, + UniformRand, +}; + +use crate::{ + algebra::{ + embedding::Embedding, + fields::FieldWithSize, + ntt, + polynomials::{CoefficientList, EvaluationsList, MultilinearPoint}, + project_all_to_base, Weights, + }, + parameters::{FoldingFactor, MultivariateParameters, ProtocolParameters, SoundnessType}, + protocols::{irs_commit, whir::Config}, +}; + +// ── IRS Domain Parameters (shared by prover & verifier) ────────────────── + +/// Precomputed IRS domain structure for the initial commitment. +/// +/// Both prover and verifier need the same domain generators, coset roots, and +/// sub-domain powers to compute gamma query points. This struct deduplicates +/// that computation. +pub(crate) struct IrsDomainParams { + /// Interleaving depth k (= 2^folding_factor for first round) + pub k: usize, + /// Generator ω of the full NTT domain of size num_rows × k + pub omega_full: F::BasePrimeField, + /// k-th root of unity ζ = ω^num_rows + pub zeta: F::BasePrimeField, + /// Precomputed sub-domain powers [1, ω_sub, ω_sub², ..., ω_sub^(num_rows-1)] + pub omega_powers: Vec, +} + +impl IrsDomainParams { + /// Compute domain parameters from a WHIR config's initial committer. + pub fn from_config(config: &Config) -> Self { + Self::from_irs_committer(&config.initial_committer) + } + + /// Compute domain parameters from an arbitrary IRS basefield committer. + pub fn from_irs_committer( + committer: &crate::protocols::irs_commit::BasefieldConfig, + ) -> Self { + let k = committer.interleaving_depth; + let num_rows = committer.num_rows(); + let full_domain_size = num_rows * k; + + let omega_full: F::BasePrimeField = + ntt::generator(full_domain_size).expect("Full IRS domain should have primitive root"); + let omega_sub: F::BasePrimeField = committer.generator(); + let zeta: F::BasePrimeField = omega_full.pow([num_rows as u64]); + + let omega_powers = crate::algebra::geometric_sequence(omega_sub, num_rows); + + Self { + k, + omega_full, + zeta, + omega_powers, + } + } + + /// Find the index of `alpha_base` in the sub-domain. + #[inline] + pub fn query_index(&self, alpha_base: F::BasePrimeField) -> usize { + self.omega_powers + .iter() + .position(|&p| p == alpha_base) + .expect("Query point must be in IRS domain") + } + + /// Compute the k coset gamma points for a query at `alpha_base`, + /// lifted to the extension field via `embedding`. + pub fn coset_gammas>( + &self, + alpha_base: F::BasePrimeField, + embedding: &M, + ) -> Vec { + let idx = self.query_index(alpha_base); + let coset_offset = self.omega_full.pow([idx as u64]); + (0..self.k) + .map(|j| { + let gamma_base = coset_offset * self.zeta.pow([j as u64]); + embedding.map(gamma_base) + }) + .collect() + } + + /// Compute all gamma points for a set of query points (flat list). + pub fn all_gammas>( + &self, + query_points: &[F::BasePrimeField], + embedding: &M, + ) -> Vec { + query_points + .iter() + .flat_map(|&alpha| self.coset_gammas(alpha, embedding)) + .collect() + } +} + +#[derive(Clone)] +pub struct ZkParams { + /// ℓ: Number of variables for helper polynomials + /// Chosen such that 2^ℓ > conservative query upper bound + pub ell: usize, + + /// μ: Number of variables in the witness polynomial + pub mu: usize, +} + +impl ZkParams { + /// Compute ell and mu from WHIR parameters. + pub fn from_whir_params(whir_params: &Config) -> Self { + // mu = number of variables (log2 of polynomial size) + let mu = whir_params.initial_sumcheck.initial_size.ilog2() as usize; + // k = folding factor size (2^folding_factor) + let k = 1 << whir_params.initial_sumcheck.num_rounds; + // q1 = number of in-domain query samples in the first round + // (or initial commitment queries if there are no rounds) + let q1 = whir_params + .round_configs + .first() + .map_or(whir_params.initial_committer.in_domain_samples, |r| { + r.irs_committer.in_domain_samples + }); + + let q_ub = 2 * k * q1 + 4 * mu + 10; + let ell = (q_ub as f64).log2().ceil() as usize; + assert!( + ell < mu, + "ZK requires ℓ < μ (ℓ={ell}, μ={mu}). \ + Increase num_variables or lower security_level/queries. \ + (q_ub={q_ub}, k={k}, q1={q1})" + ); + Self { ell, mu } + } + + pub fn helper_batch_size(&self, number_of_polynomials: usize) -> usize { + number_of_polynomials * (self.mu + 1) + } + + /// Build the helper WHIR config for proving helper polynomial evaluations. + /// + /// This config covers `num_polynomials × (μ+1)` helper polynomials + /// (one M and μ embedded ĝ per polynomial) at `ℓ+1` variables. + pub fn build_helper_config( + &self, + num_polynomials: usize, + whir_params: &ProtocolParameters, + ) -> Config { + let helper_mv = MultivariateParameters::new(self.ell + 1); + let helper_whir_params = ProtocolParameters { + initial_statement: true, + security_level: whir_params.security_level, + pow_bits: 0, + folding_factor: FoldingFactor::Constant(1), + soundness_type: SoundnessType::ConjectureList, + starting_log_inv_rate: whir_params.starting_log_inv_rate, + batch_size: self.helper_batch_size(num_polynomials), + hash_id: whir_params.hash_id, + }; + Config::new(helper_mv, &helper_whir_params) + } +} + +/// Sampling random polynomials before the witness polynomial +#[derive(Clone)] +pub struct ZkPreprocessingPolynomials { + pub msk: CoefficientList, + pub g0_hat: CoefficientList, + pub m_poly: CoefficientList, + pub g_hats: Vec>, + pub params: ZkParams, +} + +impl ZkPreprocessingPolynomials { + pub fn sample(rng: &mut R, params: ZkParams) -> Self { + let poly_size = 1 << params.ell; + let m_poly_size = 1 << (params.ell + 1); + + // Sample all preprocessing polynomials from the BASE FIELD, then lift to extension. + // This is required because these polynomials are committed via base-field IRS commitment, + // and the conversion back to base field (to_base_prime_field_elements().next()) must be + // lossless. + let msk_coeffs: Vec = (0..poly_size) + .map(|_| F::from_base_prime_field(F::BasePrimeField::rand(rng))) + .collect(); + let msk = CoefficientList::new(msk_coeffs.clone()); + + let g0_coeffs: Vec = (0..poly_size) + .map(|_| F::from_base_prime_field(F::BasePrimeField::rand(rng))) + .collect(); + let g0 = CoefficientList::new(g0_coeffs.clone()); + + let mut m_coeffs = vec![F::ZERO; m_poly_size]; + for (i, (g0_c, &msk_c)) in g0_coeffs.iter().zip(msk_coeffs.iter()).enumerate() { + m_coeffs[2 * i] = *g0_c; + m_coeffs[2 * i + 1] = msk_c; + } + let m_poly = CoefficientList::new(m_coeffs); + + let g_hats = (0..params.mu) + .map(|_| { + let coeffs = (0..poly_size) + .map(|_| F::from_base_prime_field(F::BasePrimeField::rand(rng))) + .collect(); + CoefficientList::new(coeffs) + }) + .collect(); + + Self { + msk, + g0_hat: g0, + m_poly, + g_hats, + params, + } + } + + /// Extend msk to μ variables by padding with zeros + pub fn extend_msk(&self) -> CoefficientList { + let target_size = 1 << self.params.mu; + let mut coeffs = self.msk.coeffs().to_vec(); + coeffs.resize(target_size, F::ZERO); + CoefficientList::new(coeffs) + } + + /// Batch-evaluate all helper polynomials at multiple gamma points using + /// fused univariate Horner evaluation. + /// + /// For each gamma, evaluates msk, g₀, and all ĝⱼ in a single pass per gamma + /// point, avoiding intermediate per-polynomial allocation vectors. + /// + /// Returns a Vec of `HelperEvaluations` (one per gamma point), in the same + /// order as the input gammas. + pub fn batch_evaluate_helpers(&self, gammas: &[F], rho: F) -> Vec> { + use crate::algebra::univariate_evaluate; + + // Evaluate all helper polynomials at a single gamma point. + // This fuses msk, g₀, and ĝⱼ evaluations per-gamma, avoiding + // μ+2 intermediate Vec allocations of size |gammas|. + let eval_at = |&gamma: &F| -> HelperEvaluations { + let msk_val = univariate_evaluate(self.msk.coeffs(), gamma); + let g0_val = univariate_evaluate(self.g0_hat.coeffs(), gamma); + let m_eval = g0_val - rho * msk_val; + let g_hat_evals = self + .g_hats + .iter() + .map(|g_hat| univariate_evaluate(g_hat.coeffs(), gamma)) + .collect(); + HelperEvaluations { + gamma, + m_eval, + g_hat_evals, + } + }; + + // Parallelize across gamma points (typically q×k, often hundreds). + #[cfg(feature = "parallel")] + { + use rayon::prelude::*; + gammas.par_iter().map(eval_at).collect() + } + #[cfg(not(feature = "parallel"))] + { + gammas.iter().map(eval_at).collect() + } + } +} + +/// ZK Witness: contains commitment witnesses for all ZK components +#[derive(Clone)] +pub struct ZkWitness { + /// Witnesses for [[f̂₁]] = [[f₁ + msk₁]], ..., [[fₙ]] = [[fₙ + mskₙ]] in main WHIR + pub f_hat_witnesses: Vec>, + + /// Single batch witness for all helper polynomials [[M, ĝ₁, ..., ĝμ]] + /// committed via helper_config with batch_size = μ+1 + pub helper_witness: irs_commit::Witness, + + /// Reference to preprocessing data for each polynomial + pub preprocessings: Vec>, + + /// Base-field representations of M polynomials (for helper WHIR prove) + pub m_polys_base: Vec>, + + /// Base-field representations of embedded ĝⱼ polynomials (for helper WHIR prove) + /// Each ĝⱼ is embedded from ℓ-variate to (ℓ+1)-variate for each polynomial + pub g_hats_embedded_bases: Vec>>, +} + +/// Collect interleaved references to all helper polynomials in batch order: +/// `[M₁, ĝ₁₁, …, ĝ₁μ, M₂, ĝ₂₁, …, ĝ₂μ, …]` +/// +/// Used by both committer (batch-commit) and prover (batch-prove). +pub(crate) fn interleave_helper_poly_refs<'a, F: FftField>( + m_polys: &'a [CoefficientList], + g_hats: &'a [Vec>], +) -> Vec<&'a CoefficientList> { + let num_polys = m_polys.len(); + let mu = g_hats.first().map_or(0, |g| g.len()); + let mut refs = Vec::with_capacity(num_polys * (1 + mu)); + for (m_poly, g_hat_list) in m_polys.iter().zip(g_hats) { + refs.push(m_poly); + for g_hat in g_hat_list { + refs.push(g_hat); + } + } + refs +} + +/// Compute per-polynomial claims from helper evaluations. +/// +/// m_claim = Σᵢ τ₂ⁱ · m(γᵢ, ρ) +/// g_hat_j_claim = Σᵢ τ₂ⁱ · ĝⱼ(pow(γᵢ)) +pub(crate) fn compute_per_polynomial_claims( + helper_evals: &[HelperEvaluations], + tau2: F, +) -> (F, Vec) { + let num_g_hats = helper_evals.first().map_or(0, |h| h.g_hat_evals.len()); + + let mut m_claim = F::ZERO; + let mut g_hat_claims = vec![F::ZERO; num_g_hats]; + let mut tau2_power = F::ONE; + + for helper in helper_evals { + m_claim += tau2_power * helper.m_eval; + for (j, &g_eval) in helper.g_hat_evals.iter().enumerate() { + g_hat_claims[j] += tau2_power * g_eval; + } + tau2_power *= tau2; + } + + (m_claim, g_hat_claims) +} + +/// Construct the weight function for the helper WHIR sumcheck: +/// +/// w(z, t) = eq(-ρ, t) · [Σᵢ τ₂ⁱ · eq(pow(γᵢ), z)] +/// +/// Returns a `Weights::Linear` on (ℓ+1) variables. +pub(crate) fn construct_batched_eq_weights( + helper_evals: &[HelperEvaluations], + rho: F, + tau2: F, + ell: usize, +) -> Weights { + let neg_rho = -rho; + let z_size = 1 << ell; + let weight_size = 1 << (ell + 1); + + // Precompute τ₂ powers + let tau2_powers = crate::algebra::geometric_sequence(tau2, helper_evals.len()); + + // For each γᵢ, compute eq(pow(γᵢ), z) for all z ∈ {0,1}^ℓ using the + // O(2^ℓ) butterfly expansion in MultilinearPoint::eq_weights(), then + // accumulate τ₂ⁱ · eq(pow(γᵢ), ·) into a single batched_eq vector. + #[cfg(feature = "parallel")] + let batched_eq: Vec = { + use rayon::prelude::*; + helper_evals + .par_iter() + .zip(tau2_powers.par_iter()) + .fold( + || vec![F::ZERO; z_size], + |mut acc, (helper, &tau2_pow)| { + let eq_vals = + MultilinearPoint::expand_from_univariate(helper.gamma, ell).eq_weights(); + for (a, v) in acc.iter_mut().zip(eq_vals) { + *a += tau2_pow * v; + } + acc + }, + ) + .reduce( + || vec![F::ZERO; z_size], + |mut a, b| { + for (ai, bi) in a.iter_mut().zip(b) { + *ai += bi; + } + a + }, + ) + }; + #[cfg(not(feature = "parallel"))] + let batched_eq: Vec = { + let mut batched = vec![F::ZERO; z_size]; + for (helper, &tau2_pow) in helper_evals.iter().zip(tau2_powers.iter()) { + let eq_vals = MultilinearPoint::expand_from_univariate(helper.gamma, ell).eq_weights(); + for (a, &v) in batched.iter_mut().zip(eq_vals.iter()) { + *a += tau2_pow * v; + } + } + batched + }; + + // Build weight evaluations on {0,1}^(ℓ+1) + // w(z, t) = eq(-ρ, t) × batched_eq[z] + // eq(-ρ, 0) = 1 + ρ, eq(-ρ, 1) = -ρ + let eq_neg_rho_at_0 = F::ONE - neg_rho; // = 1 + ρ + let eq_neg_rho_at_1 = neg_rho; // = -ρ + + let mut weight_evals = vec![F::ZERO; weight_size]; + for (z_idx, &beq_z) in batched_eq.iter().enumerate() { + weight_evals[z_idx * 2] = eq_neg_rho_at_0 * beq_z; // t = 0 + weight_evals[z_idx * 2 + 1] = eq_neg_rho_at_1 * beq_z; // t = 1 + } + + Weights::linear(EvaluationsList::new(weight_evals)) +} + +/// Prepare helper polynomials for batch commitment. +/// +/// For each preprocessing, projects M to base field and embeds each ĝⱼ +/// from ℓ-variate to (ℓ+1)-variate in base field representation. +/// +/// Returns `(m_polys_base, g_hats_embedded_bases)`. +pub(crate) fn prepare_helper_polynomials( + preprocessings: &[&ZkPreprocessingPolynomials], +) -> ( + Vec>, + Vec>>, +) { + let mut m_polys_base = Vec::new(); + let mut g_hats_embedded_bases = Vec::new(); + + for preprocessing in preprocessings { + let m_base = CoefficientList::new(project_all_to_base(preprocessing.m_poly.coeffs())); + let embed_g_hat = |g_hat: &CoefficientList| -> CoefficientList { + let embedded = g_hat.embed_into_variables(preprocessing.params.ell + 1); + CoefficientList::new(project_all_to_base(embedded.coeffs())) + }; + + #[cfg(feature = "parallel")] + let g_hats_base: Vec> = { + use rayon::prelude::*; + preprocessing.g_hats.par_iter().map(embed_g_hat).collect() + }; + #[cfg(not(feature = "parallel"))] + let g_hats_base: Vec> = + preprocessing.g_hats.iter().map(embed_g_hat).collect(); + + m_polys_base.push(m_base); + g_hats_embedded_bases.push(g_hats_base); + } + + (m_polys_base, g_hats_embedded_bases) +} + +/// Helper evaluations at a single query point γ +#[derive(Clone, Debug)] +pub struct HelperEvaluations { + /// The query point γ + pub gamma: F, + + /// m(γ,ρ) = M(pow(γ), -ρ) + pub m_eval: F, + + /// [ĝ₁(pow(γ)), ..., ĝμ(pow(γ))] + pub g_hat_evals: Vec, +} + +impl HelperEvaluations { + /// Compute the helper polynomial value h(γ) (without the ρ·f̂ term). + /// + /// h(γ) = m(γ,ρ) + Σᵢ βⁱ·γ^(2^(i-1))·ĝᵢ(pow(γ)) + pub fn compute_h_value(&self, beta: F) -> F { + let mut value = self.m_eval; + + let mut beta_power = beta; + let mut gamma_power = self.gamma; + + for (i, &g_hat_eval) in self.g_hat_evals.iter().enumerate() { + value += beta_power * gamma_power * g_hat_eval; + + beta_power *= beta; + if i < self.g_hat_evals.len() - 1 { + gamma_power = gamma_power.square(); + } + } + + value + } + + /// Compute the full virtual oracle value L(γ) = ρ·f̂(γ) + h(γ) + /// + /// L(γ) = ρ·f̂(γ) + m(γ,ρ) + Σᵢ βⁱ·γ^(2^(i-1))·ĝᵢ(pow(γ)) + /// = ρ·(f + msk)(γ) + (ĝ₀ - ρ·msk)(pow(γ)) + blinding_terms + /// = ρ·f(γ) + g(γ) + pub fn compute_virtual_value(&self, f_hat_val: F, rho: F, beta: F) -> F { + rho * f_hat_val + self.compute_h_value(beta) + } +} diff --git a/src/protocols/whir_zk/verifier.rs b/src/protocols/whir_zk/verifier.rs new file mode 100644 index 00000000..23a3aa6e --- /dev/null +++ b/src/protocols/whir_zk/verifier.rs @@ -0,0 +1,627 @@ +use ark_ff::{FftField, Field}; + +use super::utils::{ + compute_per_polynomial_claims, construct_batched_eq_weights, IrsDomainParams, ZkParams, +}; +use crate::{ + algebra::{ + dot, + embedding::Embedding, + polynomials::{CoefficientList, MultilinearPoint}, + Weights, + }, + hash::Hash, + protocols::{ + geometric_challenge::geometric_challenge, + irs_commit, + whir::{Commitment, Config}, + }, + transcript::{ + codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, VerificationResult, + VerifierMessage, VerifierState, + }, + utils::zip_strict, + verify, +}; + +impl Config { + /// Verify a ZK WHIR opening proof. + /// + /// This verifies a proof generated by `prove_zk`. The verifier: + /// 1. Reads blinding evaluations g(āᵢ) and reconstructs modified sums + /// 2. Runs WHIR rounds, reconstructing virtual oracle values at each initial opening + /// 3. Verifies helper polynomial evaluations via a nested WHIR proof + /// 4. Checks the final sumcheck equation + pub fn verify_zk( + &self, + verifier_state: &mut VerifierState<'_, H>, + f_hat_commitments: &[&Commitment], + helper_commitment: &Commitment, + helper_config: &Config, + zk_params: &ZkParams, + weights: &[&Weights], + evaluations: &[F], + ) -> VerificationResult<(MultilinearPoint, Vec)> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + #[cfg(feature = "alloc-track")] + let mut __snap = crate::alloc_snap!(); + + // ==================================================================== + // Phase 1: ZK transcript header — read blinding challenges and build + // modified evaluations for P = ρ·f + g + // ==================================================================== + let num_polys = f_hat_commitments.len(); + let (beta, rho, modified_evaluations) = + self.read_zk_header(verifier_state, weights, evaluations, num_polys)?; + + let polynomial_rlc_coeffs: Vec = + geometric_challenge(verifier_state, f_hat_commitments.len()); + let constraint_rlc_coeffs: Vec = geometric_challenge(verifier_state, weights.len()); + let mut round_constraints: Vec<(Vec, Vec>)> = vec![( + constraint_rlc_coeffs.clone(), + weights.iter().map(|&w| w.clone()).collect(), + )]; + + let mut the_sum: F = zip_strict( + &constraint_rlc_coeffs, + modified_evaluations.chunks_exact(num_polys), + ) + .map(|(weight_coeff, row)| *weight_coeff * dot(&polynomial_rlc_coeffs, row)) + .sum(); + + let mut round_folding_randomness = Vec::new(); + + let folding_randomness = if constraint_rlc_coeffs.is_empty() { + assert_eq!(the_sum, F::ZERO); + let fr = verifier_state.verifier_message_vec(self.initial_sumcheck.num_rounds); + self.initial_sumcheck.round_pow.verify(verifier_state)?; + MultilinearPoint(fr) + } else { + self.initial_sumcheck.verify(verifier_state, &mut the_sum)? + }; + round_folding_randomness.push(folding_randomness); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("verify_zk::phase1_header_initial_sumcheck", __snap); + + // ==================================================================== + // Phase 2: WHIR round loop + // ==================================================================== + let mut prev_is_initial = true; + let mut prev_round_commitment: Option> = None; + + for (round_index, round_config) in self.round_configs.iter().enumerate() { + let commitment = round_config + .irs_committer + .receive_commitment(verifier_state)?; + round_config.pow.verify(verifier_state)?; + + let num_variables = round_config.initial_num_variables(); + + let (stir_weights, stir_values) = if prev_is_initial { + self.verify_initial_zk_round( + verifier_state, + f_hat_commitments, + helper_commitment, + helper_config, + zk_params, + rho, + beta, + &commitment, + &round_folding_randomness, + num_variables, + &polynomial_rlc_coeffs, + )? + } else { + self.verify_subsequent_round( + verifier_state, + round_index, + prev_round_commitment.as_ref().unwrap(), + &commitment, + &round_folding_randomness, + num_variables, + )? + }; + + let stir_rlc_coeffs = geometric_challenge(verifier_state, stir_weights.len()); + the_sum += dot(&stir_rlc_coeffs, &stir_values); + round_constraints.push((stir_rlc_coeffs, stir_weights)); + + let folding_randomness = round_config.sumcheck.verify(verifier_state, &mut the_sum)?; + round_folding_randomness.push(folding_randomness); + + prev_is_initial = false; + prev_round_commitment = Some(commitment); + } + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("verify_zk::phase2_whir_rounds", __snap); + + // ==================================================================== + // Phase 3: Final round — read polynomial, verify last commitment + // ==================================================================== + let final_coefficients = CoefficientList::new( + verifier_state.prover_messages_vec(self.final_sumcheck.initial_size)?, + ); + self.final_pow.verify(verifier_state)?; + + self.verify_zk_final_opening( + verifier_state, + prev_is_initial, + f_hat_commitments, + helper_commitment, + helper_config, + zk_params, + rho, + beta, + prev_round_commitment.as_ref(), + &round_folding_randomness, + &final_coefficients, + &polynomial_rlc_coeffs, + )?; + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("verify_zk::phase3_final_opening", __snap); + + // ==================================================================== + // Phase 4: Final sumcheck + consistency check + // ==================================================================== + let final_sumcheck_randomness = self.final_sumcheck.verify(verifier_state, &mut the_sum)?; + round_folding_randomness.push(final_sumcheck_randomness.clone()); + + let result = self.verify_final_consistency( + verifier_state, + &round_constraints, + &round_folding_randomness, + &final_coefficients, + &final_sumcheck_randomness, + the_sum, + ); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!("verify_zk::phase4_final_sumcheck_consistency", __snap); + + result + } + + /// Read the ZK transcript header: β, g(āᵢ) evaluations, ρ. + /// + /// Returns `(beta, rho, modified_evaluations)` where + /// modified_evaluations[i] = ρ·evaluations[i] + g_evals[i]. + fn read_zk_header( + &self, + verifier_state: &mut VerifierState<'_, H>, + weights: &[&Weights], + evaluations: &[F], + num_commitments: usize, + ) -> VerificationResult<(F, F, Vec)> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + { + let beta: F = verifier_state.verifier_message(); + let g_evals: Vec = + verifier_state.prover_messages_vec(weights.len() * num_commitments)?; + let rho: F = verifier_state.verifier_message(); + + let modified_evaluations: Vec = zip_strict(evaluations, &g_evals) + .map(|(&eval, &g_eval)| rho * eval + g_eval) + .collect(); + + Ok((beta, rho, modified_evaluations)) + } + + /// Verify the initial ZK round opening: verify f̂ commitment(s), helper evaluations, + /// and reconstruct virtual oracle values. + /// + /// For N committed polynomials, opens all N f̂ commitments, verifies helper + /// evaluations for each, and reconstructs combined virtual oracle values using + /// the polynomial RLC coefficients. + /// + /// Returns `(stir_weights, stir_values)`. + #[allow(clippy::too_many_arguments)] + fn verify_initial_zk_round( + &self, + verifier_state: &mut VerifierState<'_, H>, + f_hat_commitments: &[&Commitment], + helper_commitment: &Commitment, + helper_config: &Config, + zk_params: &ZkParams, + rho: F, + beta: F, + commitment: &irs_commit::Commitment, + round_folding_randomness: &[MultilinearPoint], + num_variables: usize, + polynomial_rlc_coeffs: &[F], + ) -> VerificationResult<(Vec>, Vec)> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let num_polys = f_hat_commitments.len(); + let in_domain_base = self + .initial_committer + .verify(verifier_state, f_hat_commitments)?; + + let virtual_values = self.verify_zk_helper_evaluations( + verifier_state, + &in_domain_base, + helper_commitment, + helper_config, + zk_params, + rho, + beta, + round_folding_randomness.last().unwrap(), + num_polys, + polynomial_rlc_coeffs, + )?; + let in_domain = in_domain_base.lift(self.embedding()); + + let stir_weights: Vec> = commitment + .out_of_domain() + .weights(num_variables) + .chain(in_domain.weights(num_variables)) + .collect(); + let stir_values: Vec = commitment + .out_of_domain() + .values(&[F::ONE]) + .chain(virtual_values) + .collect(); + + Ok((stir_weights, stir_values)) + } + + /// Verify a subsequent (non-initial) round opening. + /// + /// Returns `(stir_weights, stir_values)`. + fn verify_subsequent_round( + &self, + verifier_state: &mut VerifierState<'_, H>, + round_index: usize, + prev_commitment: &irs_commit::Commitment, + commitment: &irs_commit::Commitment, + round_folding_randomness: &[MultilinearPoint], + num_variables: usize, + ) -> VerificationResult<(Vec>, Vec)> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let prev_round_config = &self.round_configs[round_index - 1]; + let in_domain = prev_round_config + .irs_committer + .verify(verifier_state, &[prev_commitment])?; + + let stir_weights: Vec> = commitment + .out_of_domain() + .weights(num_variables) + .chain(in_domain.weights(num_variables)) + .collect(); + let stir_values: Vec = commitment + .out_of_domain() + .values(&[F::ONE]) + .chain(in_domain.values(&round_folding_randomness.last().unwrap().coeff_weights(true))) + .collect(); + + Ok((stir_weights, stir_values)) + } + + /// Verify the final commitment opening in ZK mode. + /// + /// Handles both the case where this is the initial commitment (no intermediate + /// rounds) and the case where it's the last WHIR round commitment. + /// For the initial case with N polynomials, opens all N f̂ commitments and + /// verifies the batched virtual oracle. + #[allow(clippy::too_many_arguments)] + fn verify_zk_final_opening( + &self, + verifier_state: &mut VerifierState<'_, H>, + prev_is_initial: bool, + f_hat_commitments: &[&Commitment], + helper_commitment: &Commitment, + helper_config: &Config, + zk_params: &ZkParams, + rho: F, + beta: F, + prev_round_commitment: Option<&irs_commit::Commitment>, + round_folding_randomness: &[MultilinearPoint], + final_coefficients: &CoefficientList, + polynomial_rlc_coeffs: &[F], + ) -> VerificationResult<()> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + if prev_is_initial { + let num_polys = f_hat_commitments.len(); + let in_domain_base = self + .initial_committer + .verify(verifier_state, f_hat_commitments)?; + + let virtual_values = self.verify_zk_helper_evaluations( + verifier_state, + &in_domain_base, + helper_commitment, + helper_config, + zk_params, + rho, + beta, + round_folding_randomness.last().unwrap(), + num_polys, + polynomial_rlc_coeffs, + )?; + + let in_domain = in_domain_base.lift(self.embedding()); + for (weights, value) in zip_strict( + in_domain.weights(final_coefficients.num_variables()), + virtual_values, + ) { + verify!(weights.evaluate(final_coefficients) == value); + } + } else { + let prev_round_config = self.round_configs.last().unwrap(); + let in_domain = prev_round_config + .irs_committer + .verify(verifier_state, &[prev_round_commitment.unwrap()])?; + + for (weights, evals) in zip_strict( + in_domain.weights(final_coefficients.num_variables()), + in_domain.values(&round_folding_randomness.last().unwrap().coeff_weights(true)), + ) { + verify!(weights.evaluate(final_coefficients) == evals); + } + } + Ok(()) + } + + /// Verify the ZK initial commitment opening. + /// + /// Thin wrapper around [`verify_helper_evaluations`] that derives the IRS + /// domain from this config's initial committer. + #[allow(clippy::too_many_arguments)] + fn verify_zk_helper_evaluations( + &self, + verifier_state: &mut VerifierState<'_, H>, + in_domain_base: &irs_commit::Evaluations, + helper_commitment: &Commitment, + helper_config: &Config, + zk_params: &ZkParams, + rho: F, + beta: F, + folding_randomness: &MultilinearPoint, + num_polys: usize, + polynomial_rlc_coeffs: &[F], + ) -> VerificationResult> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let domain = IrsDomainParams::::from_config(self); + verify_helper_evaluations( + verifier_state, + &domain, + in_domain_base, + helper_commitment, + helper_config, + zk_params, + rho, + beta, + folding_randomness, + num_polys, + polynomial_rlc_coeffs, + self.embedding(), + ) + } +} + +/// Verify helper polynomial evaluations for the ZK virtual oracle. +/// +/// This is the shared implementation used by both the main ZK verifier +/// (via `Config::verify_zk_helper_evaluations`) and the prefold verifier. +/// +/// For N committed polynomials, reads N helper evaluation sets per gamma point, +/// verifies them via a nested WHIR proof, then reconstructs per-polynomial virtual +/// oracle values L_p(γ) and combines them via polynomial RLC. +/// +/// Returns the virtual oracle folded values for each query point. +#[allow(clippy::too_many_arguments)] +pub(crate) fn verify_helper_evaluations( + verifier_state: &mut VerifierState<'_, H>, + domain: &IrsDomainParams, + in_domain_base: &irs_commit::Evaluations, + helper_commitment: &Commitment, + helper_config: &Config, + zk_params: &ZkParams, + rho: F, + beta: F, + folding_randomness: &MultilinearPoint, + num_polys: usize, + polynomial_rlc_coeffs: &[F], + embedding: &M, +) -> VerificationResult> +where + F: FftField, + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + M: Embedding, +{ + use super::utils::HelperEvaluations; + use crate::algebra::polynomials::fold::compute_fold; + + #[cfg(feature = "alloc-track")] + let mut __snap = crate::alloc_snap!(); + + let mu = zk_params.mu; + let k = domain.k; + let fold_factor = k.trailing_zeros() as usize; // log2(k) + + // Precompute inverses for compute_fold (lifted to extension field) + let zeta_ext_inv: F = embedding + .map(domain.zeta) + .inverse() + .expect("coset generator invertible"); + let two_inv = F::from(2u64).inverse().expect("char ≠ 2"); + + // Read ALL helper evaluations from transcript in one batch. + // Prover sends: for each gamma, for each polynomial p: + // m_eval_p, ĝ₁(pow(γ))_p, ..., ĝμ(pow(γ))_p + let q = in_domain_base.points.len(); + let evals_per_point = 1 + mu; // m_eval + mu g_hat_evals + let num_gammas = q * k; + let total_evals = num_gammas * num_polys * evals_per_point; + let all_evals: Vec = verifier_state.read_prover_messages_bytes(total_evals)?; + + // Parse the flat eval vector into per-polynomial HelperEvaluations. + let mut helper_evals_per_poly: Vec>> = (0..num_polys) + .map(|_| Vec::with_capacity(num_gammas)) + .collect(); + let mut query_indices: Vec = Vec::with_capacity(q); + let mut eval_cursor = 0; + + for &alpha_base in &in_domain_base.points { + let idx = domain.query_index(alpha_base); + query_indices.push(idx); + let coset_gammas = domain.coset_gammas(alpha_base, embedding); + + for gamma in coset_gammas { + for p in 0..num_polys { + let m_eval = all_evals[eval_cursor]; + eval_cursor += 1; + let g_hat_evals = all_evals[eval_cursor..eval_cursor + mu].to_vec(); + eval_cursor += mu; + helper_evals_per_poly[p].push(HelperEvaluations { + gamma, + m_eval, + g_hat_evals, + }); + } + } + } + debug_assert_eq!(eval_cursor, total_evals); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!(" verify_helper::read_transcript", __snap); + + // Sample τ₂ (query-batching challenge) + let tau2: F = verifier_state.verifier_message(); + + // Construct batched eq weights (gammas are shared across polynomials) + let beq_weights = + construct_batched_eq_weights(&helper_evals_per_poly[0], rho, tau2, zk_params.ell); + + // Compute per-polynomial claims and collect evaluations + // Layout: [m₁_claim, ĝ₁₁_claim, ..., ĝ₁μ_claim, m₂_claim, ĝ₂₁_claim, ..., ĝ₂μ_claim, ...] + let mut all_evaluations: Vec = Vec::with_capacity(num_polys * (1 + mu)); + for p in 0..num_polys { + let (m_claim, g_hat_claims) = + compute_per_polynomial_claims(&helper_evals_per_poly[p], tau2); + all_evaluations.push(m_claim); + all_evaluations.extend_from_slice(&g_hat_claims); + } + + let weight_refs: Vec<&Weights> = vec![&beq_weights]; + + #[cfg(feature = "alloc-track")] + crate::alloc_report!(" verify_helper::build_weights_claims", __snap); + + // Verify helper WHIR proof (single batch commitment for all N×(μ+1) helper polys) + helper_config.verify( + verifier_state, + &[helper_commitment], + &weight_refs, + &all_evaluations, + )?; + + #[cfg(feature = "alloc-track")] + crate::alloc_report!(" verify_helper::helper_whir_verify", __snap); + + // Reconstruct virtual oracle values from IRS opening + verified helper evaluations. + // For N polynomials: compute per-polynomial L_p(γ_j), RLC across polynomials, + // then fold the combined coset values. + let num_cols = in_domain_base.num_columns(); + let k_per_poly = k; // columns per committed polynomial (batch_size=1 for ZK) + + let virtual_values: Vec = in_domain_base + .points + .iter() + .enumerate() + .map(|(qi, &alpha_base)| { + let idx = query_indices[qi]; + let coset_offset = domain.omega_full.pow([idx as u64]); + let coset_offset_ext_inv: F = embedding.map(coset_offset).inverse().unwrap_or(F::ZERO); + + let row = &in_domain_base.matrix[qi * num_cols..(qi + 1) * num_cols]; + + // Compute L_combined(γ_j) = Σ_p α_p · L_p(γ_j) for each coset element j + let coset_gammas = domain.coset_gammas(alpha_base, embedding); + let l_coset_values: Vec = coset_gammas + .iter() + .enumerate() + .map(|(j, &gamma_ext)| { + let mut l_combined = F::ZERO; + for p in 0..num_polys { + // f̂_p(γ_j) = Σ_{l=0}^{k-1} embed(row[p*k + l]) · γ^l + let f_hat_slice = &row[p * k_per_poly..(p + 1) * k_per_poly]; + let f_hat_p_at_gamma = crate::algebra::mixed_univariate_evaluate( + embedding, + f_hat_slice, + gamma_ext, + ); + + // h_p(γ_j) from verified per-polynomial helper evaluations + let h_p_at_gamma = + helper_evals_per_poly[p][qi * k + j].compute_h_value(beta); + + // L_p(γ_j) = ρ·f̂_p(γ_j) + h_p(γ_j) + let l_p = rho * f_hat_p_at_gamma + h_p_at_gamma; + + l_combined += polynomial_rlc_coeffs[p] * l_p; + } + l_combined + }) + .collect(); + + // Fold the k combined coset values + compute_fold( + &l_coset_values, + &folding_randomness.0, + coset_offset_ext_inv, + zeta_ext_inv, + two_inv, + fold_factor, + ) + }) + .collect(); + + #[cfg(feature = "alloc-track")] + crate::alloc_report!(" verify_helper::reconstruct_virtual_oracle", __snap); + + Ok(virtual_values) +} diff --git a/src/transcript/mod.rs b/src/transcript/mod.rs index 571d9633..d5eafb10 100644 --- a/src/transcript/mod.rs +++ b/src/transcript/mod.rs @@ -10,6 +10,7 @@ mod mock_sponge; use std::any::type_name; use std::fmt::Debug; +use ark_ff::{Field, PrimeField}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::rand::{rngs::StdRng, CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; @@ -20,6 +21,67 @@ pub use spongefish::{ VerificationError, VerificationResult, }; +/// Zero-allocation wrapper for pre-encoded bytes. +/// +/// Used by [`ProverState::prover_messages_bytes`] to send a pre-serialized +/// byte buffer as a single transcript message, avoiding the per-element +/// allocation overhead of [`Encoding::encode`] on individual field elements. +pub struct RawBytes<'a>(pub &'a [u8]); + +impl Encoding<[u8]> for RawBytes<'_> { + fn encode(&self) -> impl AsRef<[u8]> { + self.0 + } +} +// NargSerialize is provided by the blanket impl: +// impl> NargSerialize for T + +/// Encode a single field element into `dst` without heap allocations. +/// +/// Produces the same byte representation as [`Encoding<[u8]>::encode`] on +/// ark-ff field elements: each base-prime-field coefficient is written in +/// little-endian limb order, truncated/padded to `base_field_size` bytes. +#[inline] +pub fn encode_field_element_into(f: &F, dst: &mut Vec) { + let base_field_size = (F::BasePrimeField::MODULUS_BIT_SIZE.div_ceil(8)) as usize; + for base_element in f.to_base_prime_field_elements() { + let bigint = base_element.into_bigint(); + let limbs: &[u64] = bigint.as_ref(); + let start = dst.len(); + for limb in limbs { + dst.extend_from_slice(&limb.to_le_bytes()); + } + // Match spongefish's encode: resize to exactly base_field_size bytes + // (truncate high zero bytes if N*8 > base_field_size, pad if less). + dst.resize(start + base_field_size, 0); + } +} + +/// Decode field elements from a byte buffer produced by [`encode_field_element_into`]. +/// +/// Returns `count` field elements, advancing `src` past the consumed bytes. +pub fn decode_field_elements_from_bytes(src: &mut &[u8], count: usize) -> Option> { + let base_field_size = (F::BasePrimeField::MODULUS_BIT_SIZE.div_ceil(8)) as usize; + let ext_degree = F::extension_degree() as usize; + let elem_bytes = base_field_size * ext_degree; + let total = count * elem_bytes; + if src.len() < total { + return None; + } + + let mut result = Vec::with_capacity(count); + for _ in 0..count { + let mut base_elems = Vec::with_capacity(ext_degree); + for _ in 0..ext_degree { + let (chunk, rest) = src.split_at(base_field_size); + *src = rest; + base_elems.push(F::BasePrimeField::from_le_bytes_mod_order(chunk)); + } + result.push(F::from_base_prime_field_elems(base_elems)?); + } + Some(result) +} + #[cfg(test)] pub use self::mock_sponge::MockSponge; @@ -217,6 +279,27 @@ where .expect("Failed to serialize hint"); } + /// Send `count` pre-encoded field elements as a **single** transcript message. + /// + /// `encoded` must contain the exact same bytes that `count` individual + /// `prover_message::()` calls would produce (use [`encode_field_element_into`]). + /// + /// This reduces allocations from O(count) to O(1) because the sponge + /// absorbs the whole buffer at once via [`RawBytes`] (zero-alloc Encoding). + /// + /// Requires a byte-oriented sponge (`H::U = u8`). + #[cfg_attr(test, track_caller)] + pub fn prover_messages_bytes(&mut self, _count: usize, encoded: &[u8]) + where + H: DuplexSpongeInterface, + { + #[cfg(debug_assertions)] + for _ in 0.._count { + self.push(Interaction::ProverMessage(type_name::().to_owned())); + } + self.inner.prover_message(&RawBytes(encoded)); + } + pub fn proof(self) -> Proof { Proof { narg_string: self.inner.narg_string().to_owned(), @@ -302,6 +385,35 @@ where (0..len).map(|_| self.prover_message()).collect() } + /// Read `count` field elements that were sent via + /// [`ProverState::prover_messages_bytes`]. + /// + /// Internally this still calls `prover_message::()` per element + /// (spongefish doesn't expose a raw-byte batch read), but it collects + /// into a single pre-allocated `Vec` rather than creating many small + /// intermediate vectors. + /// + /// The Fiat-Shamir transcript is byte-identical regardless of whether + /// the prover used individual `prover_message` calls or a single + /// `prover_messages_bytes` batch — the sponge absorbs the same bytes + /// in both cases. + #[cfg_attr(test, track_caller)] + pub fn read_prover_messages_bytes(&mut self, count: usize) -> VerificationResult> + where + T: Encoding<[H::U]> + NargDeserialize, + { + #[cfg(debug_assertions)] + for _ in 0..count { + self.pop_pattern(&Interaction::ProverMessage(type_name::().to_owned())); + } + + let mut result = Vec::with_capacity(count); + for _ in 0..count { + result.push(self.inner.prover_message()?); + } + Ok(result) + } + #[cfg_attr(test, track_caller)] pub fn prover_hint(&mut self) -> VerificationResult where