From 5e541f192deb7ac13980a4525bdbdea6bf20c684 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 30 May 2026 09:12:41 +0530 Subject: [PATCH 1/6] squashed commit from 3-3 branch --- Cargo.lock | 1 + Cargo.toml | 10 +- benches/expand_from_coeff.rs | 1 - benches/zk_whir.rs.disabled | 443 ------------- benches/zook_vs_whir.rs | 413 ++++++++++++ .../protocols/params/basecase.txt | 1 + .../protocols/params/code_switch.txt | 2 + .../protocols/params/derive.txt | 2 + .../protocols/params/irs_commit.txt | 2 + .../protocols/params/sumcheck.txt | 3 + src/algebra/ntt/cooley_tukey.rs | 53 +- src/algebra/ntt/mod.rs | 81 +-- src/protocols/basecase.rs | 1 - src/protocols/challenge_indices.rs | 114 +++- src/protocols/code_switch.rs | 137 ++-- src/protocols/irs_commit.rs | 74 ++- src/protocols/mask_proximity.rs | 271 +++++++- src/protocols/mod.rs | 1 + src/protocols/params/adaptive.rs | 413 ++++++++++++ src/protocols/params/basecase.rs | 11 + src/protocols/params/build_round.rs | 79 ++- src/protocols/params/derive.rs | 481 ++++++++++---- src/protocols/params/error.rs | 10 + src/protocols/params/irs_commit.rs | 51 +- src/protocols/params/layout.rs | 109 +++- src/protocols/params/mod.rs | 5 +- src/protocols/params/protocol_config.rs | 279 ++++----- src/protocols/params/spec.rs | 143 ++++- src/protocols/params/sumcheck.rs | 12 +- src/protocols/sumcheck.rs | 3 +- src/protocols/whir_zk/mod.rs | 80 +++ src/protocols/zook/commit.rs | 239 +++++++ src/protocols/zook/mod.rs | 120 ++++ src/protocols/zook/prover.rs | 586 ++++++++++++++++++ src/protocols/zook/verifier.rs | 317 ++++++++++ 35 files changed, 3577 insertions(+), 971 deletions(-) delete mode 100644 benches/zk_whir.rs.disabled create mode 100644 benches/zook_vs_whir.rs create mode 100644 src/protocols/params/adaptive.rs create mode 100644 src/protocols/zook/commit.rs create mode 100644 src/protocols/zook/mod.rs create mode 100644 src/protocols/zook/prover.rs create mode 100644 src/protocols/zook/verifier.rs diff --git a/Cargo.lock b/Cargo.lock index 62e5a0d4..e46912da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1580,6 +1580,7 @@ dependencies = [ "tracing", "tracing-subscriber", "zerocopy", + "zeroize", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 52fc40e5..128f15e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ const-oid = "0.9.6" arrayvec = "0.7.6" derive-where = { version = "1.6.0", features = ["safe"] } ordered-float = { version = "5.1.0", features = ["serde"] } +zeroize = { version = "1.8", features = ["zeroize_derive"] } thiserror = "2.0" [dev-dependencies] @@ -87,10 +88,11 @@ harness = false name = "sumcheck" harness = false -# Disable untill fixed. -# [[bench]] -# name = "zk_whir" -# harness = false +# Side-by-side wall-time comparison: zook (Standard) vs non-ZK whir. +# zook in Standard mode should match or beat non-ZK whir on the same workload. +[[bench]] +name = "zook_vs_whir" +harness = false [profile.dev] debug = 1 diff --git a/benches/expand_from_coeff.rs b/benches/expand_from_coeff.rs index c244514e..4e5905d2 100644 --- a/benches/expand_from_coeff.rs +++ b/benches/expand_from_coeff.rs @@ -37,7 +37,6 @@ fn interleaved_rs_encode(bencher: Bencher, case: &(usize, usize, usize)) { let coeffs_refs = coeffs.iter().map(|v| v.as_slice()).collect::>(); black_box(ntt::interleaved_rs_encode( &coeffs_refs, - &[], coeffs[0].len() * expansion, )) }); diff --git a/benches/zk_whir.rs.disabled b/benches/zk_whir.rs.disabled deleted file mode 100644 index 529947ee..00000000 --- a/benches/zk_whir.rs.disabled +++ /dev/null @@ -1,443 +0,0 @@ -//! Benchmark: ZK v1 vs ZK v2 WHIR proving (2 polynomials). -//! -//! Run with: -//! cargo bench --bench zk_whir -//! -//! Or filter to a specific group: -//! cargo bench --bench zk_whir -- zk_v1 -//! cargo bench --bench zk_whir -- zk_v2 - -use std::borrow::Cow; - -use ark_ff::FftField; -use ark_std::rand::{distributions::Standard, prelude::Distribution, rngs::StdRng, SeedableRng}; -use divan::{black_box, AllocProfiler, Bencher}; -use whir::{ - algebra::{ - embedding::{Embedding, Identity}, - fields::Field256, - linear_form::{Evaluate, LinearForm, MultilinearExtension}, - MultilinearPoint, - }, - hash, - parameters::ProtocolParameters, - protocols::{whir::Config, whir_zk}, - transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, -}; - -#[global_allocator] -static ALLOC: AllocProfiler = AllocProfiler::system(); - -type F = Field256; -type M = Identity; - -/// Polynomial sizes to benchmark (log₂ of number of coefficients). -const SIZES: &[usize] = &[20]; - -/// Number of polynomials for batched benchmarks. -const NUM_POLYS: usize = 2; - -// ──────────────────────────────────────────────────────────────────────────── -// Shared setup helpers -// ──────────────────────────────────────────────────────────────────────────── - -/// Build `num_polynomials` deterministic polynomials with distinct coefficients. -fn make_polynomials(num_variables: usize, num_polynomials: usize) -> Vec> { - let num_coeffs = 1usize << num_variables; - (0..num_polynomials) - .map(|poly_idx| { - (0..num_coeffs) - .map(|coeff_idx| F::from((poly_idx * num_coeffs + coeff_idx + 1) as u64)) - .collect() - }) - .collect() -} - -/// Build weights + evaluations for multiple polynomials. -/// Layout: row-major [w₀_p₀, w₀_p₁, …] (one eval per polynomial per weight). -fn make_weights_and_evaluations_multi( - polynomials: &[Vec], - config: &Config, - num_variables: usize, -) -> (Vec>, Vec) -where - M::Source: FftField, - M::Target: FftField, - Standard: Distribution, -{ - let mut rng = StdRng::seed_from_u64(0xBEEF); - let point = MultilinearPoint::::rand(&mut rng, num_variables); - let linear_form = MultilinearExtension::new(point.0); - let mut evaluations = Vec::with_capacity(polynomials.len()); - for poly in polynomials { - evaluations.push(linear_form.evaluate(config.embedding(), poly)); - } - let weights = vec![linear_form]; - (weights, evaluations) -} - -// ──────────────────────────────────────────────────────────────────────────── -// ZK v1 helpers -// ──────────────────────────────────────────────────────────────────────────── - -/// ZK v1: WHIR config for committing, μ+1 variables. -/// `batch_size` = number of polynomials × 2 (each poly contributes f̂ and g). -fn zk_v1_commit_config(num_variables: usize, num_polynomials: usize) -> Config { - let extended = num_variables + 1; - let params = ProtocolParameters { - unique_decoding: false, - security_level: 32, - pow_bits: 0, - initial_folding_factor: 4, - folding_factor: 4, - starting_log_inv_rate: 1, - batch_size: 2 * num_polynomials, - hash_id: hash::SHA2, - }; - Config::new(1 << extended, ¶ms) -} - -/// ZK v1: WHIR config for proving P₁..Pₙ, μ+1 variables. -/// `batch_size` = number of P polynomials to prove. -fn zk_v1_prove_config(num_variables: usize, num_polynomials: usize) -> Config { - let extended = num_variables + 1; - let params = ProtocolParameters { - unique_decoding: false, - security_level: 32, - pow_bits: 0, - initial_folding_factor: 4, - folding_factor: 4, - starting_log_inv_rate: 1, - batch_size: num_polynomials, - hash_id: hash::SHA2, - }; - Config::new(1 << extended, ¶ms) -} - -/// ZK v1 polynomial bundle: f̂(x,y) = f(x) + y·msk(x), random g(x,y), P = ρ·f̂ + g. -struct ZkV1Polys { - f_hat: Vec, - g_poly: Vec, - p_poly: Vec, -} - -/// Build N v1 polynomial bundles: for each polynomial, f̂, g, P = masking·f̂ + g. -fn make_zk_v1_polys(num_variables: usize, num_polynomials: usize) -> Vec { - use ark_std::UniformRand; - - let mut rng = StdRng::seed_from_u64(0xCAFE); - let num_coeffs = 1usize << num_variables; - let extended_num_coeffs = 1usize << (num_variables + 1); - let masking_challenge = F::rand(&mut rng); - - (0..num_polynomials) - .map(|poly_idx| { - // Deterministic base polynomial (distinct per polynomial). - let base_coeffs: Vec = (0..num_coeffs) - .map(|coeff_idx| F::from((poly_idx * num_coeffs + coeff_idx + 1) as u64)) - .collect(); - - // f̂(x,y) = base(x) + y·msk(x) - let mut f_hat_coeffs = vec![F::from(0u64); extended_num_coeffs]; - for (coeff_idx, &coeff) in base_coeffs.iter().enumerate() { - f_hat_coeffs[coeff_idx] = coeff; - } - for coeff_idx in 0..num_coeffs { - f_hat_coeffs[num_coeffs + coeff_idx] = F::rand(&mut rng); - } - let f_hat = f_hat_coeffs; - - // Random g(x,y) - let g_coeffs: Vec = (0..extended_num_coeffs) - .map(|_| F::rand(&mut rng)) - .collect(); - let g_poly = g_coeffs; - - // P = masking·f̂ + g - let p_coeffs: Vec = f_hat - .iter() - .zip(g_poly.iter()) - .map(|(&f_hat_coeff, &g_coeff)| masking_challenge * f_hat_coeff + g_coeff) - .collect(); - let p_poly = p_coeffs; - - ZkV1Polys { - f_hat, - g_poly, - p_poly, - } - }) - .collect() -} - -/// Build weights + evaluations for multiple (μ+1)-variable P polynomials at (ā, 0). -/// Evaluations layout: row-major [w₀_P₀, w₀_P₁, …]. -fn make_zk_v1_weights_and_evaluations( - p_polys: &[Vec], - config: &Config, - num_variables: usize, -) -> (Vec>, Vec) { - let mut rng = StdRng::seed_from_u64(0xBEEF); - let base_point = MultilinearPoint::rand(&mut rng, num_variables); - let mut coords = base_point.0; - coords.push(F::from(0u64)); // y = 0 - let extended_point = MultilinearPoint(coords); - let linear_form = MultilinearExtension::new(extended_point.0); - let mut evaluations = Vec::with_capacity(p_polys.len()); - for p_poly in p_polys { - evaluations.push(linear_form.evaluate(config.embedding(), p_poly)); - } - (vec![linear_form], evaluations) -} - -// ──────────────────────────────────────────────────────────────────────────── -// ZK v2 helpers -// ──────────────────────────────────────────────────────────────────────────── - -/// Build a complete ZK v2 config for the given variable count and polynomial count. -fn make_zk_v2_config(num_variables: usize, num_polynomials: usize) -> whir_zk::Config { - whir_zk::Config::new( - 1 << num_variables, - &ProtocolParameters { - unique_decoding: false, - security_level: 32, - pow_bits: 0, - initial_folding_factor: 2, - folding_factor: 4, - starting_log_inv_rate: 1, - batch_size: 1, - hash_id: hash::SHA2, - }, - num_polynomials, - ) -} - -// ──────────────────────────────────────────────────────────────────────────── -// ZK v1 benchmarks – 2 polynomials (batched) -// ──────────────────────────────────────────────────────────────────────────── - -/// Commit [f̂₁, g₁, f̂₂, g₂] with batch_size=4. -#[divan::bench(args = SIZES)] -fn zk_v1_commit(bencher: Bencher, num_variables: usize) { - let bundles = make_zk_v1_polys(num_variables, NUM_POLYS); - let commit_config = zk_v1_commit_config(num_variables, NUM_POLYS); - let ds = DomainSeparator::protocol(&commit_config) - .session(&format!("bench-zk-v1-commit-{num_variables}")) - .instance(&Empty); - - // Flatten: [f̂₁, g₁, f̂₂, g₂] - let commit_polys: Vec<&[F]> = bundles - .iter() - .flat_map(|bundle| [bundle.f_hat.as_slice(), bundle.g_poly.as_slice()]) - .collect(); - - bencher - .with_inputs(|| ProverState::new_std(&ds)) - .bench_values(|mut prover_state| { - let _ = black_box(commit_config.commit(&mut prover_state, &commit_polys)); - }); -} - -/// Prove [P₁, P₂] with batch_size=2, μ+1 variables. -#[divan::bench(args = SIZES)] -fn zk_v1_prove(bencher: Bencher, num_variables: usize) { - let bundles = make_zk_v1_polys(num_variables, NUM_POLYS); - let prove_config = zk_v1_prove_config(num_variables, NUM_POLYS); - let p_polys: Vec> = bundles.iter().map(|bundle| bundle.p_poly.clone()).collect(); - let (weights, evaluations) = - make_zk_v1_weights_and_evaluations(&p_polys, &prove_config, num_variables); - - let ds = DomainSeparator::protocol(&prove_config) - .session(&format!("bench-zk-v1-prove-{num_variables}")) - .instance(&Empty); - - let p_refs: Vec<&[F]> = p_polys.iter().map(Vec::as_slice).collect(); - - bencher - .with_inputs(|| { - let mut prover_state = ProverState::new_std(&ds); - let witness = prove_config.commit(&mut prover_state, &p_refs); - (prover_state, witness) - }) - .bench_values(|(mut prover_state, witness)| { - let prove_forms: Vec>> = vec![Box::new( - MultilinearExtension::new(weights[0].point.clone()), - )]; - let _ = black_box(prove_config.prove( - &mut prover_state, - p_refs.iter().map(|v| Cow::Borrowed(*v)).collect(), - vec![Cow::Borrowed(&witness)], - prove_forms, - Cow::Borrowed(evaluations.as_slice()), - )); - }); -} - -/// Verify [P₁, P₂] via standard WHIR. -#[divan::bench(args = SIZES)] -fn zk_v1_verify(bencher: Bencher, num_variables: usize) { - let bundles = make_zk_v1_polys(num_variables, NUM_POLYS); - let prove_config = zk_v1_prove_config(num_variables, NUM_POLYS); - let p_polys: Vec> = bundles.iter().map(|bundle| bundle.p_poly.clone()).collect(); - let (weights, evaluations) = - make_zk_v1_weights_and_evaluations(&p_polys, &prove_config, num_variables); - - let ds = DomainSeparator::protocol(&prove_config) - .session(&format!("bench-zk-v1-verify-{num_variables}")) - .instance(&Empty); - - let p_refs: Vec<&[F]> = p_polys.iter().map(Vec::as_slice).collect(); - - // Generate a proof once. - let proof = { - let mut prover_state = ProverState::new_std(&ds); - let witness = prove_config.commit(&mut prover_state, &p_refs); - let prove_forms: Vec>> = vec![Box::new(MultilinearExtension::new( - weights[0].point.clone(), - ))]; - let _ = prove_config.prove( - &mut prover_state, - p_refs.iter().map(|v| Cow::Borrowed(*v)).collect(), - vec![Cow::Borrowed(&witness)], - prove_forms, - Cow::Borrowed(evaluations.as_slice()), - ); - prover_state.proof() - }; - - bencher - .with_inputs(|| { - let mut verifier_state = VerifierState::new_std(&ds, &proof); - let commitment = prove_config - .receive_commitment(&mut verifier_state) - .unwrap(); - (verifier_state, commitment) - }) - .bench_values(|(mut verifier_state, commitment)| { - prove_config - .verify(&mut verifier_state, &[&commitment], &evaluations) - .unwrap() - .verify([&weights[0] as &dyn LinearForm]) - .unwrap(); - }); -} - -// ──────────────────────────────────────────────────────────────────────────── -// ZK v2 benchmarks – 2 polynomials (batched) -// ──────────────────────────────────────────────────────────────────────────── - -#[divan::bench(args = SIZES)] -fn zk_v2_commit(bencher: Bencher, num_variables: usize) { - let polynomials = make_polynomials(num_variables, NUM_POLYS); - let zk_config = make_zk_v2_config(num_variables, NUM_POLYS); - - let ds = DomainSeparator::protocol(&zk_config) - .session(&format!("bench-zk-v2-commit-{num_variables}")) - .instance(&Empty); - - bencher - .with_inputs(|| ProverState::new_std(&ds)) - .bench_values(|mut prover_state| { - let poly_refs: Vec<&[F]> = polynomials.iter().map(Vec::as_slice).collect(); - black_box(zk_config.commit(&mut prover_state, &poly_refs)); - }); -} - -#[divan::bench(args = SIZES)] -fn zk_v2_prove(bencher: Bencher, num_variables: usize) { - let polynomials = make_polynomials(num_variables, NUM_POLYS); - let zk_config = make_zk_v2_config(num_variables, NUM_POLYS); - - let (weights, evaluations) = make_weights_and_evaluations_multi( - &polynomials, - &zk_config.blinded_commitment, - num_variables, - ); - - let ds = DomainSeparator::protocol(&zk_config) - .session(&format!("bench-zk-v2-prove-{num_variables}")) - .instance(&Empty); - - bencher - .with_inputs(|| { - let mut prover_state = ProverState::new_std(&ds); - let poly_refs: Vec<&[F]> = polynomials.iter().map(Vec::as_slice).collect(); - let zk_witness = zk_config.commit(&mut prover_state, &poly_refs); - (prover_state, zk_witness) - }) - .bench_values(|(mut prover_state, zk_witness)| { - let poly_refs: Vec<&[F]> = polynomials.iter().map(Vec::as_slice).collect(); - let prove_forms: Vec>> = vec![Box::new( - MultilinearExtension::new(weights[0].point.clone()), - )]; - let _ = black_box( - zk_config.prove( - &mut prover_state, - poly_refs - .iter() - .map(|v| Cow::Borrowed(*v)) - .collect::>(), - zk_witness, - prove_forms, - Cow::Borrowed(evaluations.as_slice()), - ), - ); - }); -} - -#[divan::bench(args = SIZES)] -fn zk_v2_verify(bencher: Bencher, num_variables: usize) { - let polynomials = make_polynomials(num_variables, NUM_POLYS); - let zk_config = make_zk_v2_config(num_variables, NUM_POLYS); - - let (weights, evaluations) = make_weights_and_evaluations_multi( - &polynomials, - &zk_config.blinded_commitment, - num_variables, - ); - - let ds = DomainSeparator::protocol(&zk_config) - .session(&format!("bench-zk-v2-verify-{num_variables}")) - .instance(&Empty); - - // Generate a proof once (outside the benchmark loop). - let proof = { - let mut prover_state = ProverState::new_std(&ds); - let poly_refs: Vec<&[F]> = polynomials.iter().map(Vec::as_slice).collect(); - let zk_witness = zk_config.commit(&mut prover_state, &poly_refs); - let prove_forms: Vec>> = vec![Box::new(MultilinearExtension::new( - weights[0].point.clone(), - ))]; - let _ = zk_config.prove( - &mut prover_state, - poly_refs - .iter() - .map(|v| Cow::Borrowed(*v)) - .collect::>(), - zk_witness, - prove_forms, - Cow::Borrowed(evaluations.as_slice()), - ); - prover_state.proof() - }; - - bencher - .with_inputs(|| { - let mut verifier_state = VerifierState::new_std(&ds, &proof); - let commitment = zk_config - .receive_commitments(&mut verifier_state, NUM_POLYS) - .unwrap(); - (verifier_state, commitment) - }) - .bench_values(|(mut verifier_state, commitment)| { - let weight_refs = [&weights[0] as &dyn LinearForm]; - zk_config - .verify(&mut verifier_state, &weight_refs, &evaluations, &commitment) - .unwrap(); - black_box(()); - }); -} - -fn main() { - divan::main(); -} diff --git a/benches/zook_vs_whir.rs b/benches/zook_vs_whir.rs new file mode 100644 index 00000000..a44e307e --- /dev/null +++ b/benches/zook_vs_whir.rs @@ -0,0 +1,413 @@ +//! Sweep benchmark: base WHIR (non-ZK), Zook (Standard mode), and Zook (ZK +//! mode) across polynomial sizes 2^8 .. 2^24. For each (protocol, size) we +//! run [`ITERATIONS`] passes and record commit / prove / verify wall-clock +//! separately. +//! +//! Outputs: +//! - Tabular summary printed to stdout (min / mean / max in ms per cell) +//! - `bench_results.csv` in the working directory — tidy long-form, one row +//! per (size, protocol, phase) with min/mean/max columns +//! +//! Run: `cargo bench --bench zook_vs_whir` +//! +//! NOTE: With `ITERATIONS = 3` the first iteration at each size carries +//! cold-cache / allocator / rayon-warmup overhead. We do not discard a +//! warmup pass — at 2^24 each prove can take ~1 min and the user opted for +//! tight iters. `max` therefore tends to be inflated by the first sample. + +use std::{ + borrow::Cow, + fs::File, + io::Write, + time::{Duration, Instant}, +}; + +use whir::{ + algebra::{ + embedding::Identity, + fields::Field64_3, + linear_form::{Evaluate, LinearForm, MultilinearExtension}, + random_vector, + }, + hash, + parameters::ProtocolParameters, + protocols::{ + params::{ + DecodingRegime, FoldingFactor, KneeWeight, Mode, PowBudget, ProtocolConfig, + RateSchedule, SecuritySpec, TuningSpec, + }, + whir as whir_non_zk, + }, + transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, +}; + +type F = Field64_3; +type Embed = Identity; + +/// Polynomial sizes swept by `main`. Each entry `k` means a witness of +/// length 2^k. 2^24 is ~16M F-elements ≈ 384 MB raw; the prover allocates +/// a few × that for the codeword + Merkle tree, so peak RSS at the top end +/// will be several GB. Trim the slice if you run into memory pressure. +const LOG_SIZES: &[u32] = &[ + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, +]; + +/// Number of measured iterations per (size, protocol). User-chosen tight +/// budget — see module note about first-iteration cold-cache noise. +const ITERATIONS: usize = 3; + +/// Shared protocol params — held constant across sizes so we measure scaling +/// in vector size only. +const FOLDING_FACTOR: usize = 4; +const STARTING_LOG_INV_RATE: u32 = 2; +const SECURITY_LEVEL: u32 = 128; +const POW_BITS: u32 = 10; + +/// Fixed RNG seed so iteration-to-iteration variance comes only from runtime +/// noise, not from sampling a different witness. +const SEED: u64 = 0; + +/// Per-iteration timings for one (size, protocol) cell. `proof_bytes` is +/// deterministic given the spec + seeded witness, so it's recorded once +/// (overwritten on every iter — all values are equal). +#[derive(Default)] +struct PhaseSamples { + commit: Vec, + prove: Vec, + verify: Vec, + proof_bytes: usize, +} + +/// (min, mean, max). Caller guarantees `samples` is non-empty. +fn summarize(samples: &[Duration]) -> (Duration, Duration, Duration) { + let min = *samples.iter().min().expect("non-empty"); + let max = *samples.iter().max().expect("non-empty"); + let sum_ns: u128 = samples.iter().map(Duration::as_nanos).sum(); + let mean = Duration::from_nanos((sum_ns / samples.len() as u128) as u64); + (min, mean, max) +} + +fn ms(d: Duration) -> f64 { + d.as_secs_f64() * 1000.0 +} + +/// Deterministic per-size witness + a single linear form to evaluate. We +/// reuse the same `(witness, form, eval)` across iterations to keep CPU +/// caches as warm as they would be in a tight prove loop. +fn make_workload(log_size: u32) -> (Vec, MultilinearExtension, F) { + use ark_std::rand::{rngs::StdRng, SeedableRng}; + let size = 1usize << log_size; + let mut rng = StdRng::seed_from_u64(SEED); + let witness = random_vector::(&mut rng, size); + let form = MultilinearExtension:: { + point: random_vector::(&mut rng, log_size as usize), + }; + let embedding = Embed::default(); + let evaluation = form.evaluate(&embedding, &witness); + (witness, form, evaluation) +} + +fn bench_whir(log_size: u32) -> Result { + let size = 1usize << log_size; + let whir_params = ProtocolParameters { + decoding_regime: DecodingRegime::Johnson, + starting_log_inv_rate: STARTING_LOG_INV_RATE as usize, + initial_folding_factor: FOLDING_FACTOR, + folding_factor: FOLDING_FACTOR, + security_level: SECURITY_LEVEL as usize, + pow_bits: POW_BITS as usize, + batch_size: 1, + hash_id: hash::SHA2, + }; + let config = whir_non_zk::Config::::new(size, &whir_params); + let (witness, form, evaluation) = make_workload(log_size); + let mut samples = PhaseSamples::default(); + + for _ in 0..ITERATIONS { + // Fresh DS per iteration so the transcript is independent each run. + let ds = DomainSeparator::protocol(&config) + .session(&format!("bench-whir-non-zk-2^{log_size}")) + .instance(&Empty); + let mut ps = ProverState::new_std(&ds); + + // ---- commit ---- + let t = Instant::now(); + let committed = config.commit(&mut ps, &[&witness]); + samples.commit.push(t.elapsed()); + + // ---- prove ---- + // Clone witness/form into the prove call (it consumes them). Clones + // happen before the timer starts so they don't contaminate prove time. + let witness_owned = witness.clone(); + let form_box: Box> = Box::new(form.clone()); + let t = Instant::now(); + let _ = config.prove( + &mut ps, + vec![Cow::Owned(witness_owned)], + vec![Cow::Owned(committed)], + vec![form_box], + Cow::Owned(vec![evaluation]), + ); + samples.prove.push(t.elapsed()); + + // ---- verify ---- + let proof = ps.proof(); + samples.proof_bytes = proof.narg_string.len() + proof.hints.len(); + let mut vs = VerifierState::new_std(&ds, &proof); + let t = Instant::now(); + let commitment = config + .receive_commitment(&mut vs) + .map_err(|e| format!("whir receive_commitment: {e:?}"))?; + let final_claim = config + .verify(&mut vs, &[&commitment], &[evaluation]) + .map_err(|e| format!("whir verify: {e:?}"))?; + final_claim + .verify(std::iter::once(&form as &dyn LinearForm)) + .map_err(|e| format!("whir final_claim verify: {e:?}"))?; + samples.verify.push(t.elapsed()); + } + Ok(samples) +} + +fn bench_zook( + log_size: u32, + mode: Mode, + rate_schedule: RateSchedule, +) -> Result { + let size = 1usize << log_size; + let schedule_tag = match rate_schedule { + RateSchedule::Adaptive { .. } => "adaptive", + RateSchedule::Stepping => "stepping", + RateSchedule::Capped { .. } => "capped", + }; + let mode_tag = match mode { + Mode::Standard => "standard", + Mode::ZeroKnowledge => "zk", + }; + let label = format!("zook-{mode_tag}-{schedule_tag}"); + let spec = SecuritySpec { + mode, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: SECURITY_LEVEL, + pow_budget: PowBudget::per_slot(POW_BITS), + hash_id: hash::SHA2, + }; + let tuning = TuningSpec { + vector_size: size, + starting_log_inv_rate: STARTING_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FOLDING_FACTOR), + rate_schedule, + }; + let config = ProtocolConfig::::derive(spec, tuning) + .map_err(|e| format!("{label} derive: {e:?}"))?; + let (witness, form, evaluation) = make_workload(log_size); + let mut samples = PhaseSamples::default(); + + for _ in 0..ITERATIONS { + let ds = DomainSeparator::protocol(&format!("bench-{label}")) + .session(&format!("bench-{label}-2^{log_size}")) + .instance(&Empty); + let mut ps = ProverState::new_std(&ds); + + // ---- commit ---- + let t = Instant::now(); + let committed = config.commit(&mut ps, &witness); + samples.commit.push(t.elapsed()); + + // ---- prove ---- + let form_ref: &dyn LinearForm = &form; + let t = Instant::now(); + config.prove(&mut ps, committed, &[form_ref], &[evaluation]); + samples.prove.push(t.elapsed()); + + // ---- verify ---- + let proof = ps.proof(); + samples.proof_bytes = proof.narg_string.len() + proof.hints.len(); + let mut vs = VerifierState::new_std(&ds, &proof); + let t = Instant::now(); + let commitment = config + .receive_commitment(&mut vs) + .map_err(|e| format!("{label} receive_commitment: {e:?}"))?; + config + .verify(&mut vs, commitment, &[form_ref], &[evaluation]) + .map_err(|e| format!("{label} verify: {e:?}"))?; + samples.verify.push(t.elapsed()); + } + Ok(samples) +} + +fn print_header() { + println!( + "{:<9} {:<14} {:<8} {:>12} {:>12} {:>12} {:>12}", + "log_size", "protocol", "phase", "min (ms)", "mean (ms)", "max (ms)", "proof (B)" + ); + println!("{}", "-".repeat(85)); +} + +fn record_cell( + csv: &mut File, + log_size: u32, + size: usize, + protocol: &str, + phase: &str, + samples: &[Duration], + proof_bytes: usize, +) { + let (min, mean, max) = summarize(samples); + println!( + "{:<9} {:<14} {:<8} {:>12.3} {:>12.3} {:>12.3} {:>12}", + log_size, + protocol, + phase, + ms(min), + ms(mean), + ms(max), + proof_bytes, + ); + writeln!( + csv, + "{},{},{},{},{},{},{},{}", + log_size, + size, + protocol, + phase, + ms(min), + ms(mean), + ms(max), + proof_bytes, + ) + .expect("write CSV row"); +} + +#[allow(clippy::too_many_lines, reason = "bench harness — sweep + table format is naturally long")] +fn main() { + let mut csv = File::create("bench_results.csv").expect("create bench_results.csv"); + writeln!( + csv, + "log_size,size,protocol,phase,min_ms,mean_ms,max_ms,proof_bytes" + ) + .unwrap(); + print_header(); + + for &log_size in LOG_SIZES { + let size = 1usize << log_size; + + eprintln!("\n[2^{log_size} = {size} elements] whir_non_zk ..."); + match bench_whir(log_size) { + Ok(s) => { + let pb = s.proof_bytes; + record_cell( + &mut csv, + log_size, + size, + "whir_non_zk", + "commit", + &s.commit, + pb, + ); + record_cell( + &mut csv, + log_size, + size, + "whir_non_zk", + "prove", + &s.prove, + pb, + ); + record_cell( + &mut csv, + log_size, + size, + "whir_non_zk", + "verify", + &s.verify, + pb, + ); + } + Err(e) => eprintln!(" whir_non_zk failed: {e}"), + } + + // Zook Standard on `RateSchedule::Stepping` — same per-round step + // WHIR's legacy `Config::new` uses internally, so proof-size and + // prove-time comparisons isolate the *protocol* delta rather than + // the rate-schedule delta. + let stepping = RateSchedule::Stepping; + eprintln!("[2^{log_size}] zook_standard_stepping ..."); + match bench_zook(log_size, Mode::Standard, stepping) { + Ok(s) => { + let pb = s.proof_bytes; + record_cell( + &mut csv, + log_size, + size, + "zook_standard_stepping", + "commit", + &s.commit, + pb, + ); + record_cell( + &mut csv, + log_size, + size, + "zook_standard_stepping", + "prove", + &s.prove, + pb, + ); + record_cell( + &mut csv, + log_size, + size, + "zook_standard_stepping", + "verify", + &s.verify, + pb, + ); + } + Err(e) => eprintln!(" zook_standard_stepping failed: {e}"), + } + + eprintln!("[2^{log_size}] zook_zk_adaptive ..."); + match bench_zook( + log_size, + Mode::ZeroKnowledge, + RateSchedule::Adaptive { + knee_weight: KneeWeight::DEFAULT, + }, + ) { + Ok(s) => { + let pb = s.proof_bytes; + record_cell( + &mut csv, + log_size, + size, + "zook_zk_adaptive", + "commit", + &s.commit, + pb, + ); + record_cell( + &mut csv, + log_size, + size, + "zook_zk_adaptive", + "prove", + &s.prove, + pb, + ); + record_cell( + &mut csv, + log_size, + size, + "zook_zk_adaptive", + "verify", + &s.verify, + pb, + ); + } + Err(e) => eprintln!(" zook_zk failed: {e}"), + } + } + + println!("\nCSV written to bench_results.csv"); +} diff --git a/proptest-regressions/protocols/params/basecase.txt b/proptest-regressions/protocols/params/basecase.txt index c3b26074..5491a278 100644 --- a/proptest-regressions/protocols/params/basecase.txt +++ b/proptest-regressions/protocols/params/basecase.txt @@ -4,6 +4,7 @@ # # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. +cc 6c25e8c285bba26405888aae0efccc8d511f8c49a42c47a3d4c58234fc101f0e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 30, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (4, 3) cc b6e8ae0b3e6a9769901e0e0e489da34965bf0a8df7dd049aef66e0541bf10baf # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (1, 1) cc a2f771fc5031440200810b95ea2d347da895f8eb2e1a87f53fd69ad224287e84 # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (2, 2) cc f66c89bc700c79bca5f4b7234f1345129962c78e5a2036a6430564f615f19b30 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (1, 1) diff --git a/proptest-regressions/protocols/params/code_switch.txt b/proptest-regressions/protocols/params/code_switch.txt index de6c40cf..e5425ca0 100644 --- a/proptest-regressions/protocols/params/code_switch.txt +++ b/proptest-regressions/protocols/params/code_switch.txt @@ -9,5 +9,7 @@ cc b42c982074a04c7110df07cf00f45156607be547e176b1ddd5f9d994ad491ddb # shrinks to cc eaf09a2b6bdffa86026264679f008326498ca800260dd2f17d4370df9fb3f801 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 cc 3887a5fa698c99109e8262e843dbd24ea94b9c9d420791e4520b5c9211a3eca0 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 100, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 7) cc b3e128084f721e6f43e263e05acf2e2de6fcd05dccf3811f063eeb0b63d78f8e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 47, max_pow_bits: Some(15), hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 4) +cc 9f48d6342753ea1d21d0eec7f6d75d5d57910afd84eeea5ba543c463d1574068 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 41, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (1, 1, 6) +cc 751b1f803194c82d56c6a3c890890d5fa09b581b4a682fdefb81300ea267dd1e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 39, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (2, 1, 3) cc b71da9002ceac9e4a74af097a7b087557a5b916fe8da47e39c4682375d749f88 # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 50, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (1, 2, 4) cc 1981509d857e56772dd4a79f8692619e968891aa3d84576ea1857f6d9a484a2d # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 50, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (1, 2, 4) diff --git a/proptest-regressions/protocols/params/derive.txt b/proptest-regressions/protocols/params/derive.txt index ca83c1b3..a588b9f1 100644 --- a/proptest-regressions/protocols/params/derive.txt +++ b/proptest-regressions/protocols/params/derive.txt @@ -5,3 +5,5 @@ # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. cc 104921a4117ed8255308c1ea5d3e12c72356ef72ef0d93fc0f24ed29f93fdd3a # shrinks to tuning = TuningSpec { vector_size: 32, starting_log_inv_rate: 3, folding_factor: Constant(1) } +cc 7c7a5557796810a266e30fe992f4549f05be927806576e8e3d7e181def7bac6e # shrinks to tuning = TuningSpec { vector_size: 16, starting_log_inv_rate: 3, folding_factor: Constant(1) } +cc 471d5138a47f9794ade3cae88d1bca53a071dfe582c3daa6bbbdfe878ec5496a # shrinks to tuning = TuningSpec { vector_size: 32, starting_log_inv_rate: 3, folding_factor: ConstantFromSecondRound { initial: 3, rest: 2 } } diff --git a/proptest-regressions/protocols/params/irs_commit.txt b/proptest-regressions/protocols/params/irs_commit.txt index 4d35f3df..5629c687 100644 --- a/proptest-regressions/protocols/params/irs_commit.txt +++ b/proptest-regressions/protocols/params/irs_commit.txt @@ -5,4 +5,6 @@ # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. cc 0b6dd03179c9a4e38b29b34b241b88fba69348a2c8938af7253314b7035bea82 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 }, out_domain = 0, seed = 0 +cc 47b7400a2e58354ed9f725193e2ccfdfa05ff7801af994ded27f91e661f93a1a # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 100, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, ctx = RoundContext { vector_size: 128, log_inv_rate: 4, folding_factor: 2 }, out_domain = 2 cc 7e49f7a2d53f55cfa2f09114d17ab4123678b45ddf69e0cfbc646b246de2f042 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, ctx = RoundContext { vector_size: 128, log_inv_rate: 2, folding_factor: 2 }, out_domain = 11 +cc 0269209a3273142bd36769a9bc4e73b2430411838d69a7242681348ad680745f # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, decoding_regime: Capacity, target_security_bits: 91, pow_budget: PerSlot { bits: 60 }, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, ctx = RoundContext { vector_size: 256, log_inv_rate: 2, folding_factor: 1 }, out_domain = 0 diff --git a/proptest-regressions/protocols/params/sumcheck.txt b/proptest-regressions/protocols/params/sumcheck.txt index d6f6e6ed..0f4a9a3e 100644 --- a/proptest-regressions/protocols/params/sumcheck.txt +++ b/proptest-regressions/protocols/params/sumcheck.txt @@ -10,3 +10,6 @@ cc 8c4300cc375640956f81e9da5aef9ea11ef476ddc4dd253dc560afa07609262d # shrinks to cc 8ea40f13c63b4c0021386369ce698a5d9289381a39dc85db43d2d69b9b4877bb # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 88, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 3, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } cc f1dca600886474c74d857c547baea0c2b4faf45b2946036f21a008106396eb1c # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 256, log_inv_rate: 3, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } cc 36d0f5929e8099fa8644b0511229cf11634e5a7a66d99c06099c304f5f7a8c6e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 47, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 128, log_inv_rate: 4, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc 806ab6a7eb30f5ac257fc3d8927cb6b98dd7772dddb5214adcbdeccd28b61bf7 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 38, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, ctx = RoundContext { vector_size: 32, log_inv_rate: 4, folding_factor: 1 } +cc eba638a395b9baa0f71f4a7712c4c480bc4ebc2102fdd04d0d5139026f1cb9fc # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 30, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, ctx = RoundContext { vector_size: 32, log_inv_rate: 3, folding_factor: 1 } +cc 943cf34cf3f7a849cc4a1d7ef26e7ff3fde2a79b4a15287810bef709f0867131 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 31, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, ctx = RoundContext { vector_size: 128, log_inv_rate: 3, folding_factor: 2 } diff --git a/src/algebra/ntt/cooley_tukey.rs b/src/algebra/ntt/cooley_tukey.rs index fde9dc5d..7f8907da 100644 --- a/src/algebra/ntt/cooley_tukey.rs +++ b/src/algebra/ntt/cooley_tukey.rs @@ -18,10 +18,7 @@ use super::{ }; #[cfg(not(feature = "rs_in_order"))] use crate::algebra::ntt::transpose::transpose_permute; -use crate::{ - algebra::ntt::utils::divisors, - utils::{chunks_exact_or_empty, zip_strict}, -}; +use crate::algebra::ntt::utils::divisors; // Supported primes const PRIMES: [usize; 2] = [2, 3]; @@ -374,17 +371,17 @@ impl ReedSolomon for NttEngine { fn evaluation_points( &self, - masked_message_length: usize, + poly_length: usize, codeword_length: usize, indices: &[usize], ) -> Vec { - assert!(masked_message_length <= codeword_length); + assert!(poly_length <= codeword_length); assert!(self.order.is_multiple_of(codeword_length)); let mut result = Vec::new(); let generator = self.generator(codeword_length); // Coset transformation - let mut coset_size = self.next_order(masked_message_length).unwrap(); + let mut coset_size = self.next_order(poly_length).unwrap(); while !codeword_length.is_multiple_of(coset_size) { coset_size = self.next_order(coset_size + 1).unwrap(); } @@ -401,26 +398,20 @@ impl ReedSolomon for NttEngine { result } - #[cfg_attr(feature = "tracing", instrument(skip(self, messages, masks), fields( - num_messages = messages.len(), - message_len = messages.first().map(|c| c.len()), + #[cfg_attr(feature = "tracing", instrument(skip(self, polys), fields( + num_polys = polys.len(), + poly_length = polys.first().map(|p| p.len()), codeword_length = codeword_length, - mask_len = masks.len().checked_div(messages.len()) - )))] - fn interleaved_encode(&self, messages: &[&[F]], masks: &[F], codeword_length: usize) -> Vec { + fn interleaved_encode(&self, polys: &[&[F]], codeword_length: usize) -> Vec { assert!(self.order.is_multiple_of(codeword_length)); - if messages.is_empty() { - assert!(masks.is_empty()); + if polys.is_empty() { return Vec::new(); } - let num_messages = messages.len(); - let message_len = messages[0].len(); - assert!(messages.iter().all(|m| m.len() == message_len)); - assert!(masks.len().is_multiple_of(num_messages)); - let mask_length = masks.len() / num_messages; - let masked_message_length = message_len + mask_length; - assert!(masked_message_length <= codeword_length); + let num_polys = polys.len(); + let poly_length = polys[0].len(); + assert!(polys.iter().all(|p| p.len() == poly_length)); + assert!(poly_length <= codeword_length); // Coset-NTT: instead of doing one codeword-length NTT on mostly zeros, // do `num_cosets` many `coset_size`-point NTTs on twisted coefficient @@ -433,28 +424,24 @@ impl ReedSolomon for NttEngine { // You can also see this as applying a first round of Cooley-Tukey with // N = coset_size × num_cosets, and solving it directly by observing that // only the first coset is non-zero. - let mut coset_size = self.next_order(masked_message_length).unwrap(); + let mut coset_size = self.next_order(poly_length).unwrap(); while !codeword_length.is_multiple_of(coset_size) { coset_size = self.next_order(coset_size + 1).unwrap(); } let num_cosets = codeword_length / coset_size; - let coset_padding = coset_size - masked_message_length; + let coset_padding = coset_size - poly_length; // Lay out twisted coefficients in contiguous coset blocks of length // `coset_size`, zero-padding each block as needed. - let mut result = Vec::with_capacity(num_messages * codeword_length); - for (message, mask) in zip_strict( - messages, - chunks_exact_or_empty(masks, mask_length, num_messages), - ) { + let mut result = Vec::with_capacity(num_polys * codeword_length); + for poly in polys { // FFT[a 0 0 0] = [a a a a], so just replicate input in coset dimension. for _ in 0..num_cosets { - result.extend_from_slice(message); - result.extend_from_slice(mask); + result.extend_from_slice(poly); result.resize(result.len() + coset_padding, F::ZERO); } } - assert_eq!(result.len(), num_messages * codeword_length); + assert_eq!(result.len(), num_polys * codeword_length); // NTT each coset block, then transpose each codeword block from // coset-major `(num_cosets × coset_size)` layout into standard codeword @@ -472,7 +459,7 @@ impl ReedSolomon for NttEngine { transpose(&mut result, num_cosets, coset_size); // Transpose to row-major order with vectors stacked horizontally. - transpose(&mut result, num_messages, codeword_length); + transpose(&mut result, num_polys, codeword_length); result } } diff --git a/src/algebra/ntt/mod.rs b/src/algebra/ntt/mod.rs index 9e3fc74d..5d7cad2b 100644 --- a/src/algebra/ntt/mod.rs +++ b/src/algebra/ntt/mod.rs @@ -61,42 +61,50 @@ impl type_map::Family for NttFamily { type Dyn = dyn ReedSolomon; } -/// Trait for a Reed-Solomon encoder implementation for a given field `F`. +/// Reed-Solomon encoder for a given field `F`. +/// +/// Pure-NTT abstraction: encodes polynomials, knows nothing about how callers +/// structure those polynomials (whir's IRS, for example, concatenates a +/// message and a mask into a single polynomial before calling this trait — +/// that split lives entirely on the caller side). pub trait ReedSolomon: Debug + Send + Sync { - /// Returns the next supported order equal or larger than `size`. - /// - /// The result will be an NTT-smooth number suitable for `codeword_length`. - /// - /// Returns `None` if `size` exceeds the largest supported order. + /// Smallest supported codeword length `≥ size`, or `None` if `size` + /// exceeds the engine's maximum order. The returned length is always + /// NTT-smooth for this engine. fn next_order(&self, size: usize) -> Option; + /// Generator of the multiplicative subgroup of order `codeword_length`. fn generator(&self, codeword_length: usize) -> F; - /// Returns the `index`th evaluation point. + /// Evaluation points for the requested codeword positions. /// - /// `masked_message_length`: the total message length including any mask values. + /// `result[i]` is the field point at which `codeword[indices[i]]` lives. + /// `poly_length` is the length of the polynomial whose codeword is being + /// queried — some engines (e.g. cooley_tukey) derive their internal coset + /// structure from it, so the same codeword index can map to different + /// points depending on `poly_length`. /// /// # Panics /// - /// Panics if any of the indices are `>= codeword_length` or `order` is not supported. + /// Panics if any index is `>= codeword_length` or `codeword_length` is + /// not supported. fn evaluation_points( &self, - masked_message_length: usize, + poly_length: usize, codeword_length: usize, indices: &[usize], ) -> Vec; - /// Compute a masked interleaved Reed-Solomon encoding. - /// - /// `messages` are `num_messages` slices of `message_length` elements. - /// `masks` is a `num_messages` × `mask_length` matrix of blinding coefficients. - /// `codeword_length` must be an NTT-smooth number >= `message_length + mask_length`. - /// returns an `codeword_length × num_messages` matrix. + /// Batch-encode polynomials in parallel. /// - /// Each output value is the univariate polynomial evaluation in the evaluation point - /// corresponding with the index of a coefficient list formed by concatenating message and mask. + /// All `polys[i]` must have the same length. Output is a flat buffer of + /// `polys.len() * codeword_length` elements in row-major + /// `(eval_index, poly)` layout: `result[i * polys.len() + j]` is poly + /// `j`'s value at the `i`-th evaluation point. /// - fn interleaved_encode(&self, messages: &[&[F]], masks: &[F], codeword_length: usize) -> Vec; + /// `codeword_length` must be NTT-smooth for this engine and at least the + /// polynomial length. + fn interleaved_encode(&self, polys: &[&[F]], codeword_length: usize) -> Vec; } assert_obj_safe!(ReedSolomon); @@ -108,23 +116,19 @@ pub fn next_order(size: usize) -> Option { } pub fn evaluation_points( - masked_message_length: usize, + poly_length: usize, codeword_length: usize, indices: &[usize], ) -> Vec { NTT.get::() .expect("Unsupported NTT field.") - .evaluation_points(masked_message_length, codeword_length, indices) + .evaluation_points(poly_length, codeword_length, indices) } -pub fn interleaved_rs_encode( - messages: &[&[F]], - masks: &[F], - codeword_length: usize, -) -> Vec { +pub fn interleaved_rs_encode(polys: &[&[F]], codeword_length: usize) -> Vec { NTT.get::() .expect("Unsupported NTT field.") - .interleaved_encode(messages, masks, codeword_length) + .interleaved_encode(polys, codeword_length) } pub fn generator(codeword_length: usize) -> F { @@ -145,7 +149,7 @@ mod tests { use super::*; use crate::{ algebra::{random_vector, univariate_evaluate}, - utils::{chunks_exact_or_empty, zip_strict}, + utils::zip_strict, }; fn valid_codeword_lengths(size: usize, count: usize) -> Vec { @@ -187,13 +191,17 @@ mod tests { let messages = (0..num_messages) .map(|_| random_vector(&mut rng, message_length)) .collect::>(); - let masks = random_vector(&mut rng, mask_length * num_messages); - let message_refs = messages.iter().map(|v| v.as_slice()).collect::>(); - let codeword = ntt.interleaved_encode( - &message_refs, - &masks, - codeword_length, - ); + let masks: Vec> = (0..num_messages) + .map(|_| random_vector(&mut rng, mask_length)) + .collect(); + // Build each polynomial as `message ‖ mask`. The engine takes + // unified polynomial slices; the message/mask split is purely a + // caller-side concept. + let polys: Vec> = (0..num_messages) + .map(|i| messages[i].iter().chain(masks[i].iter()).copied().collect()) + .collect(); + let poly_refs: Vec<&[F]> = polys.iter().map(Vec::as_slice).collect(); + let codeword = ntt.interleaved_encode(&poly_refs, codeword_length); // Output must be the right size. assert_eq!(codeword.len(), codeword_length * num_messages); @@ -202,8 +210,7 @@ mod tests { let mut evaluation_points = ntt.evaluation_points(message_length + mask_length, codeword_length, &sampled_indices); for (&index, &evaluation_point) in zip_strict(&sampled_indices, &evaluation_points) { let evaluations = &codeword[index * num_messages.. (index + 1) * num_messages]; - let masks = chunks_exact_or_empty(&masks, mask_length, num_messages); - for ((message, mask), value) in zip_strict(zip_strict(&messages, masks), evaluations) { + for ((message, mask), value) in zip_strict(zip_strict(&messages, &masks), evaluations) { assert_eq!(*value, univariate_evaluate(message, evaluation_point) + evaluation_point.pow([message_length as u64]) diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 61e98766..1c09c1b0 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -106,7 +106,6 @@ impl Config { assert_eq!(self.commit.num_vectors, 1); assert_eq!(self.commit.vector_size, self.sumcheck.initial_size); assert_eq!(self.sumcheck.final_size(), 1.min(self.commit.vector_size)); - debug_assert_eq!(dot(&vector, &covector), sum); if self.size() == 0 { return Opening { evaluation_points: Vec::new(), diff --git a/src/protocols/challenge_indices.rs b/src/protocols/challenge_indices.rs index e11d3bbc..53c7f7af 100644 --- a/src/protocols/challenge_indices.rs +++ b/src/protocols/challenge_indices.rs @@ -16,32 +16,37 @@ where if count == 0 { return Vec::new(); } - // TODO: This is blocking non-power-of-two support. - assert!( - num_leaves.is_power_of_two(), - "Number of leaves must be a power of two for unbiased results." - ); - if num_leaves == 1 { - // `size_bytes` would be zero, making `chunks_exact` panic. + if num_leaves <= 1 { + // `size_bytes` would be zero; short-circuit before the entropy loop. return if deduplicate { vec![0] } else { vec![0; count] }; } - // Calculate the required bytes of entropy - // TODO: Round total to bytes, instead of per index. - let size_bytes = (num_leaves.ilog2() as usize).div_ceil(8); + // Size the entropy chunk so `2^(8·size_bytes) ≥ next_pow2(num_leaves)`. + // For pow2 `num_leaves`, the entropy space is an exact multiple of + // `num_leaves` and rejection never triggers — bit-identical to the + // pre-rejection implementation. For non-pow2 `num_leaves`, rejection + // sampling eliminates the modular bias of a plain `% num_leaves`. + let bits_needed = num_leaves.next_power_of_two().ilog2() as usize; + let size_bytes = bits_needed.div_ceil(8); - // Get required entropy bits. - let entropy: Vec = (0..count * size_bytes) - .map(|_| transcript.verifier_message()) - .collect(); + // Largest multiple of `num_leaves` below `2^(8·size_bytes)`. u128 + // accommodates `size_bytes ≤ 16` without shift overflow on 64-bit hosts. + let entropy_space: u128 = 1u128 << (8 * size_bytes); + let num_leaves_u = num_leaves as u128; + let threshold: u128 = (entropy_space / num_leaves_u) * num_leaves_u; - // Convert bytes into indices - let mut indices = entropy - .chunks_exact(size_bytes) - .map(|chunk| chunk.iter().fold(0usize, |acc, &b| (acc << 8) | b as usize) % num_leaves) - .collect::>(); + let mut indices = Vec::with_capacity(count); + while indices.len() < count { + let mut candidate: u128 = 0; + for _ in 0..size_bytes { + candidate = (candidate << 8) | u128::from(transcript.verifier_message::()); + } + if candidate < threshold { + indices.push((candidate % num_leaves_u) as usize); + } + // else: candidate falls in the biased tail — reject and redraw. + } - // Sort and deduplicate indices if requested if deduplicate { indices.sort_unstable(); indices.dedup(); @@ -201,4 +206,73 @@ mod tests { "Mismatch in computed indices for deduplication test" ); } + + /// Non-pow2 `num_leaves`. With `num_leaves = 589824 = 2^16 · 9`: + /// `bits_needed = 20`, `size_bytes = 3`, entropy_space = 2^24 = 16_777_216, + /// `threshold = floor(2^24 / 589824) · 589824 = 28 · 589824 = 16_515_072`. + /// + /// The first 3-byte chunk decodes to 0xFFFFFF = 16_777_215, which is ≥ + /// threshold and must be rejected. The second chunk decodes to a small + /// value < threshold and is returned. + #[test] + fn test_challenge_indices_non_pow2_rejection_retry() { + let num_leaves: usize = 589_824; + let ds = DomainSeparator::protocol(&module_path!()) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&Empty); + let sponge = MockSponge { + absorb: None, + squeeze: &[ + 0xFF, 0xFF, 0xFF, // candidate = 16_777_215 ≥ threshold → reject + 0x00, 0x00, 0x05, // candidate = 5 < threshold → accept + ], + }; + let mut prover_state = ProverState::new(&ds, sponge); + + let result = challenge_indices(&mut prover_state, num_leaves, 1, false); + assert_eq!(result, vec![5]); + } + + /// All indices returned for a non-pow2 `num_leaves` lie in `[0, num_leaves)`, + /// and the function is deterministic (same sponge bytes → same result). + #[test] + fn test_challenge_indices_non_pow2_in_range_and_deterministic() { + let num_leaves: usize = 589_824; + let count = 5; + let bytes: &[u8] = &[ + 0x12, 0x34, 0x56, // candidate 1_193_046 < threshold + 0x78, 0x9A, 0xBC, // candidate 7_904_956 < threshold + 0xDE, 0xF0, 0x11, // candidate 14_610_449 < threshold + 0x22, 0x33, 0x44, // candidate 2_241_348 < threshold + 0x55, 0x66, 0x77, // candidate 5_596_791 < threshold + ]; + let make_state = || { + let ds = DomainSeparator::protocol(&module_path!()) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&Empty); + ProverState::new( + &ds, + MockSponge { + absorb: None, + squeeze: bytes, + }, + ) + }; + + let mut first = make_state(); + let mut second = make_state(); + let r1 = challenge_indices(&mut first, num_leaves, count, false); + let r2 = challenge_indices(&mut second, num_leaves, count, false); + + assert_eq!(r1, r2, "challenge_indices must be deterministic"); + assert_eq!(r1.len(), count); + assert!( + r1.iter().all(|&i| i < num_leaves), + "all indices must lie in [0, {num_leaves}): got {r1:?}", + ); + + // Spot-check the first index: 0x123456 = 1_193_046 < threshold 16_515_072 + // → accepted, index = 1_193_046 % 589_824 = 13_398. + assert_eq!(r1[0], 1_193_046 % num_leaves); + } } diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 84cd668f..a0d9908d 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -61,6 +61,14 @@ pub struct Witness { pub target_witness: IrsWitness, } +/// Mutable claim state threaded through `prove`/`verify`: a covector (extended +/// with `ℓ_zk` slack in ZK mode) paired with the running sum `μ` such that +/// `μ = ⟨vector, covector⟩` after each protocol step. +pub struct Claim<'a, F: Field> { + pub covector: &'a mut [F], + pub sum: &'a mut F, +} + /// Verifier output from the code-switch. pub type Commitment = IrsCommitment; @@ -193,8 +201,8 @@ impl Config { &self, prover_state: &mut ProverState, message: Vec, - witness: &IrsWitness, - covector: &mut [M::Target], + witness: IrsWitness, + claim: Claim<'_, M::Target>, folding_randomness: &[M::Target], mask: &[M::Target], ) -> Witness @@ -208,6 +216,7 @@ impl Config { U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { + let Claim { covector, sum } = claim; assert_eq!(message.len(), self.source.message_length()); assert_eq!(covector.len(), self.covector_length()); assert_eq!(mask.len(), self.message_mask_length()); @@ -227,10 +236,19 @@ impl Config { // Step 2-3: OOD challenge + answers — Construction 9.7 Steps 2-3, p.55 let ood_points: Vec = prover_state.verifier_message_vec(self.out_domain_samples); - self.maybe_send_ood_answers(prover_state, &message, mask, &ood_points); + let ood_answers = self.maybe_send_ood_answers(prover_state, &message, mask, &ood_points); // Step 4: in-domain queries — Construction 9.7 Step 4, p.55 - let source_evaluations = self.source.open(prover_state, &[witness]); + let source_evaluations = self.source.open(prover_state, &[&witness]); + // Source IRS matrix is no longer needed; release it before the trailing + // arithmetic and the caller's mask-discharge phase. + drop(witness); + let collapse_weights = eq_weights(folding_randomness); + let collapsed_values: Vec = source_evaluations + .matrix + .chunks_exact(self.source.interleaving_depth) + .map(|row| mixed_dot(self.source.embedding(), &collapse_weights, row)) + .collect(); // Step 4.1: batching — Construction 9.7 Step 4, p.55 let num_ood = self.out_domain_samples; @@ -240,6 +258,11 @@ impl Config { let (&original_sl_coeff, constraint_rlc_coeffs) = batching_coeffs.split_first().unwrap(); let (ood_rlc_coeffs, in_domain_rlc_coeffs) = constraint_rlc_coeffs.split_at(num_ood); + // Mirror verifier's sum update — Construction 9.7 Decision phase, p.55. + *sum = original_sl_coeff * *sum + + dot(ood_rlc_coeffs, &ood_answers) + + dot(in_domain_rlc_coeffs, &collapsed_values); + // Covector update — sl' from Completeness proof (p.55-56) let eval_points = lift(self.source.embedding(), &source_evaluations.points); scalar_mul(covector, original_sl_coeff); @@ -257,20 +280,23 @@ impl Config { } } - /// Send OOD answers `y_i = f(α_i) [+ α_i^ℓ · (r ‖ s)(α_i)]`. - /// In Standard mode the bracketed term is omitted. + /// Send OOD answers `y_i = f(α_i) [+ α_i^ℓ · (r ‖ s)(α_i)]` and return them + /// so the caller can reuse them for the sum update. In Standard mode the + /// bracketed term is omitted. fn maybe_send_ood_answers( &self, prover_state: &mut ProverState, message: &[M::Target], mask: &[M::Target], ood_points: &[M::Target], - ) where + ) -> Vec + where H: DuplexSpongeInterface, R: RngCore + CryptoRng, M::Target: Codec<[H::U]>, { let msg_len = message.len(); + let mut answers = Vec::with_capacity(ood_points.len()); for &point in ood_points { let f_eval = univariate_evaluate(message, point); let answer = match &self.mode { @@ -282,7 +308,9 @@ impl Config { } }; prover_state.prover_message(&answer); + answers.push(answer); } + answers } /// Accumulate OOD and in-domain weights into the covector. @@ -349,6 +377,7 @@ impl Config { &self, verifier_state: &mut VerifierState, sum: &mut M::Target, + covector: &mut [M::Target], folding_randomness: &[M::Target], commitment: &IrsCommitment, ) -> VerificationResult @@ -362,6 +391,7 @@ impl Config { Hash: ProverMessage<[H::U]>, { verify!(1 << folding_randomness.len() == self.source.interleaving_depth); + assert_eq!(covector.len(), self.covector_length()); let collapse_weights = eq_weights(folding_randomness); @@ -375,7 +405,7 @@ impl Config { // Step 2-3: OOD — Construction 9.7 Steps 2-3, p.55 // In ZK mode, ood_answers = f(α) + α^ℓ · (r,s)(α) where (r,s) is // the mask oracle message committed in the shared tree. - let _ood_points: Vec = + let ood_points: Vec = verifier_state.verifier_message_vec(self.out_domain_samples); let ood_answers: Vec = verifier_state.prover_messages_vec(self.out_domain_samples)?; @@ -399,6 +429,18 @@ impl Config { + dot(ood_rlc_coeffs, &ood_answers) + dot(in_domain_rlc_coeffs, &collapsed_values); + // Mirror prover's covector update so ` = sum` holds + // after this step (Completeness proof, p.55-56). + let eval_points = lift(self.source.embedding(), &source_evaluations.points); + scalar_mul(covector, original_sl_coeff); + self.update_covector( + covector, + ood_rlc_coeffs, + &ood_points, + in_domain_rlc_coeffs, + &eval_points, + ); + Ok(target_commitment) } } @@ -416,6 +458,27 @@ impl fmt::Display for Config { } } +/// Fold ι parallel chunks of length `chunk_len` into a single chunk. +/// +/// Uses `eq_weights(γ)` over the layout +/// `values = [chunk_0; chunk_1; ...; chunk_{ι−1}]` (each chunk of length +/// `chunk_len`) and returns `Σ_l eq_weights(γ)[l] · chunk_l`. +pub fn fold_chunks(values: &[F], chunk_len: usize, folding_randomness: &[F]) -> Vec { + let iota = 1 << folding_randomness.len(); + assert_eq!(values.len(), chunk_len * iota); + if iota == 1 { + return values.to_vec(); + } + let weights = eq_weights(folding_randomness); + (0..chunk_len) + .map(|j| { + (0..iota) + .map(|l| weights[l] * values[l * chunk_len + j]) + .sum() + }) + .collect() +} + #[cfg(test)] mod tests { use ark_std::rand::{ @@ -519,25 +582,6 @@ mod tests { } } - /// Fold ι parallel chunks of length `chunk_len` into a single chunk via - /// eq_weights(γ). Layout: values = [chunk_0; chunk_1; ...; chunk_{ι-1}], - /// each of length `chunk_len`. Returns Σ_l eq_weights(γ)[l] · chunk_l. - fn fold_chunks(values: &[F], chunk_len: usize, folding_randomness: &[F]) -> Vec { - let iota = 1 << folding_randomness.len(); - assert_eq!(values.len(), chunk_len * iota); - if iota == 1 { - return values.to_vec(); - } - let weights = eq_weights(folding_randomness); - (0..chunk_len) - .map(|j| { - (0..iota) - .map(|l| weights[l] * values[l * chunk_len + j]) - .sum() - }) - .collect() - } - /// Sample folding randomness of length log2(source.interleaving_depth). fn sample_folding_randomness( config: &Config>, @@ -565,7 +609,8 @@ mod tests { return Vec::new(); } // Lift ι parallel masks (total length source.mask_length × ι) and fold - // chunks of length source.mask_length down to a single chunk. + // chunks of length source.mask_length down to a single chunk. Masks + // are stored in whir's canonical per-poly contiguous layout. let raw = lift(config.source.embedding(), &source_witness.masks); let mut mask = fold_chunks(&raw, config.source.mask_length(), folding_randomness); // Append fresh padding s of length message_mask_length - source.mask_length. @@ -589,6 +634,8 @@ mod tests { let mut covector: Vec = random_vector(&mut rng, config.source.message_length()); covector.resize(config.covector_length(), F::ZERO); + let mut verifier_covector = covector.clone(); + let mut prover_sum = initial_sum; let instance = U64(seed); let ds = DomainSeparator::protocol(config) @@ -607,8 +654,11 @@ mod tests { let witness = config.prove( &mut prover_state, folded_message.clone(), - &source_witness, - &mut covector, + source_witness, + Claim { + covector: &mut covector, + sum: &mut prover_sum, + }, &folding_randomness, &mask_msg, ); @@ -624,12 +674,14 @@ mod tests { .verify( &mut verifier_state, &mut verifier_sum, + &mut verifier_covector, &folding_randomness, &source_commitment, ) .unwrap(); verifier_state.check_eof().unwrap(); assert_eq!(witness.message, folded_message); + assert_eq!(covector, verifier_covector); } fn test_ior_identity_config>(seed: u64, config: &Config>) @@ -642,6 +694,7 @@ mod tests { let mut covector: Vec = random_vector(&mut rng, config.source.message_length()); covector.resize(config.covector_length(), F::ZERO); + let mut verifier_covector = covector.clone(); let instance = U64(seed); let ds = DomainSeparator::protocol(config) @@ -669,12 +722,16 @@ mod tests { .collect() }; let initial_mu = dot(&h, &covector); + let mut prover_sum = initial_mu; let _witness = config.prove( &mut prover_state, folded_message, - &source_witness, - &mut covector, + source_witness, + Claim { + covector: &mut covector, + sum: &mut prover_sum, + }, &folding_randomness, &mask_msg, ); @@ -690,13 +747,15 @@ mod tests { .verify( &mut verifier_state, &mut verifier_sum, + &mut verifier_covector, &folding_randomness, &source_commitment, ) .unwrap(); verifier_state.check_eof().unwrap(); - assert_eq!(dot(&h, &covector), verifier_sum); + assert_eq!(covector, verifier_covector); + assert_eq!(dot(&h, &verifier_covector), verifier_sum); } fn test_tampered_ood_config>(seed: u64, config: &Config>) @@ -713,6 +772,7 @@ mod tests { let mut covector: Vec = random_vector(&mut rng, config.source.message_length()); covector.resize(config.covector_length(), F::ZERO); + let mut verifier_covector = covector.clone(); // Commit honest f_full, fold to get the honest post-fold message. let mut prover_state = ProverState::new_std(&ds); @@ -723,6 +783,7 @@ mod tests { // For non-ZK and source.mask_length == 0, h = folded_message and identity holds. let initial_mu = dot(&folded_message, &covector); + let mut prover_sum = initial_mu; // Tamper the post-fold message before proving. let mut tampered = folded_message.clone(); @@ -730,8 +791,11 @@ mod tests { let _witness = config.prove( &mut prover_state, tampered, - &source_witness, - &mut covector, + source_witness, + Claim { + covector: &mut covector, + sum: &mut prover_sum, + }, &folding_randomness, &[], ); @@ -747,6 +811,7 @@ mod tests { .verify( &mut verifier_state, &mut verifier_sum, + &mut verifier_covector, &folding_randomness, &source_commitment, ) @@ -754,7 +819,7 @@ mod tests { verifier_state.check_eof().unwrap(); // Sum diverges — downstream sumcheck would reject - assert_ne!(dot(&folded_message, &covector), verifier_sum); + assert_ne!(dot(&folded_message, &verifier_covector), verifier_sum); } fn test + 'static>() diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index 2c137fde..41d12910 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -138,14 +138,37 @@ impl Config { { assert!(vector_size.is_multiple_of(interleaving_depth)); assert!(rate > 0. && rate <= 1.); - let masked_message_length = vector_size / interleaving_depth + mode.mask_length(); - // `interleaved_encode` requires `codeword_length` to divide the NTT root - // order. `masked_message_length` is allowed to be arbitrary (the coset - // NTT zero-extends internally), so we only round the codeword side here. + let message_length = vector_size / interleaving_depth; + let masked_message_length = message_length + mode.mask_length(); + // Rate against the unmasked `message_length` when the mask is small + // relative to the message. The mask is `r ≈ in_domain + OOD` per + // Lemma 9.5 (~100 elements at λ=128) — typically much smaller than + // the rate slack of the NTT-smooth codeword, so it fits inside + // without forcing a `next_order` jump (e.g. avoids 524288 → 589824, + // a 12.5% jump on a 2^17-message commit). The resulting effective + // rate is `masked / codeword` ≈ requested + (mask/codeword), which + // for small-mask cases is a negligible degradation (≈1 extra + // in-domain query per round). When the mask is comparable to the + // message (tail rounds at very low rates), this optimization would + // blow the rate up and the parameter-selection fixpoint in + // `params::irs_commit::solve` wouldn't converge — fall back to rating + // the masked length there. Threshold `5·mask ≤ message` keeps the + // degradation under ~20%. #[allow(clippy::cast_sign_loss)] - let raw_codeword_length = (masked_message_length as f64 / rate).ceil() as usize; - let codeword_length = ntt::next_order::(raw_codeword_length) - .expect("codeword length exceeds NTT engine support"); + let codeword_length = { + let mask_len = mode.mask_length(); + let small_mask = 5 * mask_len <= message_length; + let unmasked_target = (message_length as f64 / rate).ceil() as usize; + let unmasked = ntt::next_order::(unmasked_target) + .expect("codeword length exceeds NTT engine support"); + if small_mask && unmasked >= masked_message_length { + unmasked + } else { + let masked_target = (masked_message_length as f64 / rate).ceil() as usize; + ntt::next_order::(masked_target) + .expect("codeword length exceeds NTT engine support") + } + }; let rate = masked_message_length as f64 / codeword_length as f64; let regime = DecodingRegimeParams::from_policy(decoding_regime, rate); @@ -265,15 +288,34 @@ impl Config { assert_eq!(vectors.len(), self.num_vectors); assert!(vectors.iter().all(|p| p.len() == self.vector_size)); - // Generate random mask - let masks = random_vector(prover_state.rng(), self.mask_length() * self.num_messages()); - - // Interleaved RS Encode the vectors - let messages = vectors - .iter() - .flat_map(|v| chunks_exact_or_empty(v, self.message_length(), self.interleaving_depth)) - .collect::>(); - let matrix = ntt::interleaved_rs_encode(&messages, &masks, self.codeword_length); + // Sample masks in whir's canonical per-polynomial-contiguous layout: + // `masks[i * mask_length + c]` is polynomial `i`'s coefficient at + // column `c`. Downstream sites (mask_proximity discharge, code_switch + // fold) read masks back via this layout. + let mask_length = self.mask_length(); + let num_polys = self.num_messages(); + let masks: Vec = random_vector(prover_state.rng(), mask_length * num_polys); + + // Engine takes unified polynomial slices (message ‖ mask) — the + // mask/message split is purely whir-side. Build a contiguous + // poly buffer to give the engine `polys: &[&[F]]` of identical + // length, then encode. + let message_length = self.message_length(); + let poly_length = message_length + mask_length; + let mut poly_buf = Vec::with_capacity(num_polys * poly_length); + let mut poly_idx = 0; + for vector in vectors { + for message in chunks_exact_or_empty(vector, message_length, self.interleaving_depth) { + poly_buf.extend_from_slice(message); + poly_buf.extend_from_slice( + &masks[poly_idx * mask_length..(poly_idx + 1) * mask_length], + ); + poly_idx += 1; + } + } + debug_assert_eq!(poly_idx, num_polys); + let polys: Vec<&[M::Source]> = poly_buf.chunks_exact(poly_length).collect(); + let matrix = ntt::interleaved_rs_encode(&polys, self.codeword_length); // Commit to the matrix let matrix_witness = self.matrix_commit.commit(prover_state, &matrix); diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 0be304db..7c09222d 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -1,4 +1,4 @@ -//! Mask proximity verification via γ-combination. +//! Mask proximity verification via γ-combination + optional inner-product discharge. //! //! Implements Construction 7.2 (p.43-44) specialized for zero-constraint mask //! oracles. Given a shared Merkle tree containing 2n vectors — n original masks @@ -10,12 +10,19 @@ //! columns n..2n-1: mask-of-masks s_1, ..., s_n //! //! Protocol: -//! 1. Verifier sends γ (combination randomness) -//! 2. Prover sends combined polynomials ξ*_i = s_i + γ·ξ_i and +//! 1. (Optional discharge) Prover sends X_i = <ξ_i, w_i> and X'_i = +//! for each mask `i` with its own covector `w_i` — Construction 7.2 step 1 +//! target-claim setup, generalized to per-mask covectors. +//! 2. Verifier sends γ (combination randomness) +//! 3. Prover sends combined polynomials ξ*_i = s_i + γ·ξ_i and //! combined IRS randomness r*_i = r'_i + γ·r_i for each mask pair -//! 3. Shared tree is opened at random positions -//! 4. Verifier checks: Enc(ξ*_i, r*_i)(y_j) = s_i(y_j) + γ·ξ_i(y_j) +//! 4. Shared tree is opened at random positions +//! 5. Verifier checks: Enc(ξ*_i, r*_i)(y_j) = s_i(y_j) + γ·ξ_i(y_j) //! at each opened position, using linearity of the RS encoding +//! 6. (Optional discharge) Verifier target check: <ξ*_i, w_i> = X'_i + γ·X_i +//! for every mask `i`. Binds each X_i to the committed ξ_i via the same +//! γ-randomness — the orchestrator uses the returned X_i values to +//! reconcile sum offsets and project to f-only. //! //! ZK safety (follows the pattern of Construction 7.2, §7.1): //! - Only the combined ξ*_i = s_i + γ·ξ_i is revealed in full. Since s_i @@ -43,9 +50,11 @@ use std::fmt; use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; use serde::{Deserialize, Serialize}; +#[cfg(feature = "tracing")] +use tracing::instrument; use crate::{ - algebra::{embedding::Identity, random_vector, scalar_mul_add_new, univariate_evaluate}, + algebra::{dot, embedding::Identity, random_vector, scalar_mul_add_new, univariate_evaluate}, hash::Hash, protocols::{ irs_commit::{Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness}, @@ -122,10 +131,11 @@ impl Config { /// /// Samples n fresh mask-of-mask polynomials, combines them with the /// provided original masks into a 2n-vector tree, and commits via IRS. + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "mask_proximity::commit", fields(num_masks = self.num_masks, vector_size = self.c_zk_commit.vector_size)))] pub fn commit( &self, prover_state: &mut ProverState, - original_msgs: &[Vec], + original_msgs: &[&[F]], ) -> Witness where F: Codec<[H::U]>, @@ -147,8 +157,8 @@ impl Config { // Tree layout: [originals..., freshes...] let all_vectors: Vec<&[F]> = original_msgs .iter() - .chain(fresh_msgs.iter()) - .map(|v| v.as_slice()) + .copied() + .chain(fresh_msgs.iter().map(Vec::as_slice)) .collect(); let mask_witness = self.c_zk_commit.commit(prover_state, &all_vectors); @@ -160,6 +170,7 @@ impl Config { } /// Receive a mask proximity commitment + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "mask_proximity::receive_commitment", fields(num_masks = self.num_masks)))] pub fn receive_commitment( &self, verifier_state: &mut VerifierState, @@ -172,12 +183,19 @@ impl Config { self.c_zk_commit.receive_commitment(verifier_state) } - /// Prove that each original mask is close to a C_zk codeword. + /// Prove that each original mask is close to a C_zk codeword. When + /// `mask_contribution_covectors` is `Some(per_mask)`, also runs the + /// Construction 7.2 target check for each mask `i` with its own covector + /// `per_mask[i]`, discharging `X_i = `. All + /// `X_i` and `X'_i` are sent before γ so γ binds them; the returned + /// values are bound by the verifier-side target checks. + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "mask_proximity::prove", fields(num_masks = self.num_masks, vector_size = self.c_zk_commit.vector_size, with_discharge = mask_contribution_covectors.is_some())))] pub fn prove( &self, prover_state: &mut ProverState, - witness: &Witness, - original_msgs: &[Vec], + mut witness: Witness, + original_msgs: &[&[F]], + mask_contribution_covectors: Option<&[&[F]]>, ) where F: Codec<[H::U]>, H: DuplexSpongeInterface, @@ -191,6 +209,23 @@ impl Config { assert_eq!(original_msgs.len(), self.num_masks); assert_eq!(witness.fresh_msgs.len(), self.num_masks); + // Construction 7.2 step 1: per-mask discharge claims X_i and X'_i must + // be sent BEFORE γ so γ-randomness binds them. + if let Some(covectors) = mask_contribution_covectors { + assert_eq!( + covectors.len(), + self.num_masks, + "one covector per mask required" + ); + for (i, &covector) in covectors.iter().enumerate() { + assert_eq!(covector.len(), self.c_zk_commit.vector_size); + let claimed_value = dot(original_msgs[i], covector); + let fresh_claim = dot(&witness.fresh_msgs[i], covector); + prover_state.prover_message(&claimed_value); + prover_state.prover_message(&fresh_claim); + } + } + // Grind the Lemma 7.4 γ-combination gap before γ is sampled. self.pow.prove(prover_state); @@ -213,29 +248,39 @@ impl Config { let combined_msg = scalar_mul_add_new(fresh_msg, gamma, orig_msg); prover_state.prover_messages(&combined_msg); - // r*_i = r'_i + γ · r_i + // r*_i = r'_i + γ · r_i — IRS stores masks per-polynomial + // contiguous, so direct slicing reconstructs each poly's mask. if irs_masks_per_vector > 0 { - let orig_r = &witness.mask_witness.masks - [i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; - let fresh_r = &witness.mask_witness.masks[(self.num_masks + i) - * irs_masks_per_vector - ..(self.num_masks + i + 1) * irs_masks_per_vector]; + let masks = &witness.mask_witness.masks; + let orig_r = &masks[i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; + let fresh_offset = (self.num_masks + i) * irs_masks_per_vector; + let fresh_r = &masks[fresh_offset..fresh_offset + irs_masks_per_vector]; let combined_r = scalar_mul_add_new(fresh_r, gamma, orig_r); prover_state.prover_messages(&combined_r); } } + // fresh_msgs is consumed; release ~num_masks · vector_size field + // elements before the (potentially slow) tree-open below. + witness.fresh_msgs = Vec::new(); + // Step 3: open the shared tree at random in-domain positions self.c_zk_commit .open(prover_state, &[&witness.mask_witness]); } - /// Verify that each original mask is close to a C_zk codeword. + /// Verify that each original mask is close to a C_zk codeword. When + /// `mask_contribution_covectors` is `Some(per_mask)`, also verifies the + /// Construction 7.2 target check for each mask `i` against `per_mask[i]` + /// and returns `Some(Vec)` — the verified claimed values + /// ``. + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "mask_proximity::verify", fields(num_masks = self.num_masks, vector_size = self.c_zk_commit.vector_size, with_discharge = mask_contribution_covectors.is_some())))] pub fn verify( &self, verifier_state: &mut VerifierState, commitment: &Commitment, - ) -> VerificationResult<()> + mask_contribution_covectors: Option<&[&[F]]>, + ) -> VerificationResult>> where F: Codec<[H::U]>, H: DuplexSpongeInterface, @@ -244,6 +289,22 @@ impl Config { U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { + // Construction 7.2 step 1: read X_i and X'_i for each mask before γ + // so γ binds them. + let mask_contribution_claims = if let Some(covectors) = mask_contribution_covectors { + verify!(covectors.len() == self.num_masks); + let mut claims = Vec::with_capacity(self.num_masks); + for &covector in covectors { + verify!(covector.len() == self.c_zk_commit.vector_size); + let claimed_value: F = verifier_state.prover_message()?; + let fresh_claim: F = verifier_state.prover_message()?; + claims.push((covector, claimed_value, fresh_claim)); + } + Some(claims) + } else { + None + }; + // Grind the Lemma 7.4 γ-combination gap before γ is sampled. self.pow.verify(verifier_state)?; @@ -291,7 +352,19 @@ impl Config { } } - Ok(()) + // Construction 7.2 target check per mask: <ξ*_i, w_i> = X'_i + γ·X_i. + mask_contribution_claims.map_or(Ok(None), |claims| { + let mut claimed_values = Vec::with_capacity(claims.len()); + for (i, (covector, claimed_value, fresh_claim)) in claims.into_iter().enumerate() { + let lhs = dot(&combined_msgs[i], covector); + let rhs = fresh_claim + gamma * claimed_value; + if lhs != rhs { + return Err(spongefish::VerificationError); + } + claimed_values.push(claimed_value); + } + Ok(Some(claimed_values)) + }) } } @@ -366,15 +439,18 @@ mod tests { let original_msgs: Vec> = (0..config.num_masks) .map(|_| random_vector(&mut rng, config.c_zk_commit.vector_size)) .collect(); + let original_refs: Vec<&[F]> = original_msgs.iter().map(Vec::as_slice).collect(); let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit(&mut prover_state, &original_msgs); - config.prove(&mut prover_state, &witness, &original_msgs); + let witness = config.commit(&mut prover_state, &original_refs); + config.prove(&mut prover_state, witness, &original_refs, None); let proof = prover_state.proof(); let mut verifier_state = VerifierState::new_std(&ds, &proof); let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - config.verify(&mut verifier_state, &commitment).unwrap(); + let _ = config + .verify(&mut verifier_state, &commitment, None) + .unwrap(); verifier_state.check_eof().unwrap(); } @@ -440,18 +516,20 @@ mod tests { let original_msgs: Vec> = (0..config.num_masks) .map(|_| random_vector(&mut rng, config.c_zk_commit.vector_size)) .collect(); + let original_refs: Vec<&[F]> = original_msgs.iter().map(Vec::as_slice).collect(); let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit(&mut prover_state, &original_msgs); + let witness = config.commit(&mut prover_state, &original_refs); let mut tampered_msgs = original_msgs; tampered_msgs[0][0] += F::ONE; - config.prove(&mut prover_state, &witness, &tampered_msgs); + let tampered_refs: Vec<&[F]> = tampered_msgs.iter().map(Vec::as_slice).collect(); + config.prove(&mut prover_state, witness, &tampered_refs, None); let proof = prover_state.proof(); let mut verifier_state = VerifierState::new_std(&ds, &proof); let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - assert_rejected(|| config.verify(&mut verifier_state, &commitment)); + assert_rejected(|| config.verify(&mut verifier_state, &commitment, None)); } /// Post-γ tamper: commit honestly, then corrupt the combined message to @@ -471,9 +549,10 @@ mod tests { let original_msgs: Vec> = (0..config.num_masks) .map(|_| random_vector(&mut rng, config.c_zk_commit.vector_size)) .collect(); + let original_refs: Vec<&[F]> = original_msgs.iter().map(Vec::as_slice).collect(); let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit(&mut prover_state, &original_msgs); + let witness = config.commit(&mut prover_state, &original_refs); let gamma: F = prover_state.verifier_message(); let irs_masks_per_vector = @@ -508,7 +587,7 @@ mod tests { let mut verifier_state = VerifierState::new_std(&ds, &proof); let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - assert_rejected(|| config.verify(&mut verifier_state, &commitment)); + assert_rejected(|| config.verify(&mut verifier_state, &commitment, None)); } fn assert_rejected(verify: impl FnOnce() -> VerificationResult) { @@ -545,4 +624,138 @@ mod tests { test_tampered_combined_msg_config(seed, &config); }); } + + /// Construction 7.2 target check: per-mask discharge round-trip. + fn test_discharge_config(seed: u64, config: &Config) + where + F: Field + Codec<[u8]> + 'static, + Standard: Distribution, + Hash: crate::transcript::ProverMessage<[u8]>, + { + let instance = U64(seed); + let ds = DomainSeparator::protocol(config) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&instance); + let mut rng = StdRng::seed_from_u64(seed); + + let original_msgs: Vec> = (0..config.num_masks) + .map(|_| random_vector(&mut rng, config.c_zk_commit.vector_size)) + .collect(); + let original_refs: Vec<&[F]> = original_msgs.iter().map(Vec::as_slice).collect(); + // One distinct covector per mask. + let covectors: Vec> = (0..config.num_masks) + .map(|_| random_vector(&mut rng, config.c_zk_commit.vector_size)) + .collect(); + let expected_claims: Vec = original_msgs + .iter() + .zip(covectors.iter()) + .map(|(m, c)| dot(m, c)) + .collect(); + let covector_refs: Vec<&[F]> = covectors.iter().map(|v| v.as_slice()).collect(); + + let mut prover_state = ProverState::new_std(&ds); + let witness = config.commit(&mut prover_state, &original_refs); + config.prove( + &mut prover_state, + witness, + &original_refs, + Some(&covector_refs), + ); + let proof = prover_state.proof(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut verifier_state).unwrap(); + let returned = config + .verify(&mut verifier_state, &commitment, Some(&covector_refs)) + .unwrap(); + assert_eq!(returned, Some(expected_claims)); + verifier_state.check_eof().unwrap(); + } + + /// Discharge with the wrong `claimed_value` — target check must reject. + fn test_discharge_wrong_claim_config(seed: u64, config: &Config) + where + F: Field + Codec<[u8]> + 'static, + Standard: Distribution, + Hash: crate::transcript::ProverMessage<[u8]>, + { + let instance = U64(seed); + let ds = DomainSeparator::protocol(config) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&instance); + let mut rng = StdRng::seed_from_u64(seed); + + let original_msgs: Vec> = (0..config.num_masks) + .map(|_| random_vector(&mut rng, config.c_zk_commit.vector_size)) + .collect(); + let original_refs: Vec<&[F]> = original_msgs.iter().map(Vec::as_slice).collect(); + let covectors: Vec> = (0..config.num_masks) + .map(|_| random_vector(&mut rng, config.c_zk_commit.vector_size)) + .collect(); + let covector_refs: Vec<&[F]> = covectors.iter().map(|v| v.as_slice()).collect(); + + let mut prover_state = ProverState::new_std(&ds); + let witness = config.commit(&mut prover_state, &original_refs); + + // Lie on mask 0's X; emit honest claims for the rest, then complete + // the protocol manually to exercise the verifier's target check. + let wrong_claim = dot(&original_msgs[0], &covectors[0]) + F::ONE; + let fresh_claim_0 = dot(&witness.fresh_msgs[0], &covectors[0]); + prover_state.prover_message(&wrong_claim); + prover_state.prover_message(&fresh_claim_0); + for i in 1..config.num_masks { + let x = dot(&original_msgs[i], &covectors[i]); + let x_prime = dot(&witness.fresh_msgs[i], &covectors[i]); + prover_state.prover_message(&x); + prover_state.prover_message(&x_prime); + } + + // Run the rest of the protocol manually (mirror prove without the + // discharge prelude we already emitted). + config.pow.prove(&mut prover_state); + let gamma: F = prover_state.verifier_message(); + let irs_masks_per_vector = + config.c_zk_commit.mask_length() * config.c_zk_commit.interleaving_depth; + for (i, (orig_msg, fresh_msg)) in original_msgs + .iter() + .zip(witness.fresh_msgs.iter()) + .enumerate() + { + let combined_msg = scalar_mul_add_new(fresh_msg, gamma, orig_msg); + prover_state.prover_messages(&combined_msg); + if irs_masks_per_vector > 0 { + let orig_r = &witness.mask_witness.masks + [i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; + let fresh_r = &witness.mask_witness.masks[(config.num_masks + i) + * irs_masks_per_vector + ..(config.num_masks + i + 1) * irs_masks_per_vector]; + let combined_r = scalar_mul_add_new(fresh_r, gamma, orig_r); + prover_state.prover_messages(&combined_r); + } + } + config + .c_zk_commit + .open(&mut prover_state, &[&witness.mask_witness]); + let proof = prover_state.proof(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut verifier_state).unwrap(); + assert_rejected(|| config.verify(&mut verifier_state, &commitment, Some(&covector_refs))); + } + + #[test] + fn test_discharge_roundtrip() { + crate::tests::init(); + proptest!(|(seed: u64, config in Config::::arbitrary())| { + test_discharge_config(seed, &config); + }); + } + + #[test] + fn test_discharge_wrong_claim_rejected() { + crate::tests::init(); + proptest!(|(seed: u64, config in Config::::arbitrary())| { + test_discharge_wrong_claim_config(seed, &config); + }); + } } diff --git a/src/protocols/mod.rs b/src/protocols/mod.rs index 64f1e0b9..65340eba 100644 --- a/src/protocols/mod.rs +++ b/src/protocols/mod.rs @@ -21,3 +21,4 @@ pub mod proof_of_work; pub mod sumcheck; pub mod whir; pub mod whir_zk; +pub mod zook; diff --git a/src/protocols/params/adaptive.rs b/src/protocols/params/adaptive.rs new file mode 100644 index 00000000..edf2b13d --- /dev/null +++ b/src/protocols/params/adaptive.rs @@ -0,0 +1,413 @@ +//! Adaptive rate planner. +//! +//! Picks per-round `target_log_inv_rate` values by enumerating candidates, +//! scoring them on a (prover_time_proxy, proof_size_proxy) pareto knee, and +//! returning the schedule chosen by [`KneeWeight`]. Driven by +//! [`super::layout::round_layout`] when [`super::spec::RateSchedule::Adaptive`] +//! is selected. + +use std::collections::HashMap; + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, + protocols::{ + irs_commit::Config as IrsConfig, + params::{ + basecase as basecase_params, + branch::{Branch, RoundBuildMode}, + build_round::{build_mask_oracle, solve_t_ood}, + code_switch as code_switch_params, + error::DeriveError, + irs_commit as irs_params, + layout::RoundShape, + protocol_config::MaskOracleInfo, + spec::{KneeWeight, Mode, OodSampleBudget, RoundContext, SecuritySpec, TuningSpec}, + sumcheck as sumcheck_params, + }, + }, +}; + +/// Hard cap on `log_inv_rate` searched. Beyond this the codeword would blow +/// past sane NTT-friendly sizes for typical witnesses. +const ADAPTIVE_MAX_LOG_INV_RATE: u32 = 20; + +/// Real per-IRS dimensions extracted from a built `IrsConfig`. Drives the +/// cost proxy with NTT-rounded codeword length and the actual decoding-regime +/// query count — no closed-form approximation. +#[derive(Clone, Copy, Debug)] +struct RoundDims { + codeword_length: usize, + in_domain_samples: usize, + interleaving_depth: usize, +} + +/// Fixed per-encode-call overhead in the cost-proxy's nominal units. +/// +/// Each `interleaved_encode` invocation pays a setup cost (allocator, +/// roots-table extension, rayon thread fan-out) that's amortized poorly for +/// small NTTs. The production trace showed sub-ms cold-cache outliers on +/// 4k-element NTTs that the pure `codeword · log codeword · interleaving` +/// model wouldn't see. This constant biases the planner against schedules +/// that fragment work into many small encodes — same direction as proof-size +/// pressure but for a different reason. +/// +/// Calibration: at ~1.4 ns / field-op (variable cost) on Apple Silicon, +/// per-call fixed overhead observed at ~10 μs ≈ 7000 "ops" worth. Round to +/// 4096 — within an order of magnitude is enough for correct pareto ordering; +/// the absolute number doesn't matter under log-scale knee normalization. +const ENCODE_FIXED_OVERHEAD: f64 = 4096.0; + +/// NTT/Merkle cost proxy for one encoded IRS. Uses real solver outputs so +/// per-round NTT smoothness rounding (`12288 = 4096·3` etc.) and per-regime +/// query counts (Johnson vs Unique vs Capacity) are accurate. +/// +/// Constants are nominal — pareto ordering is invariant under positive +/// rescaling of either axis, and the log-knee picker compounds that. +fn round_cost_from_dims(dims: RoundDims, field_bytes: f64, hash_bytes: f64) -> (f64, f64) { + let codeword = dims.codeword_length as f64; + let interleaving = dims.interleaving_depth as f64; + let queries = dims.in_domain_samples as f64; + let log_codeword = codeword.log2().max(1.0); + + let encode = ENCODE_FIXED_OVERHEAD + codeword * log_codeword * interleaving; + let proof = queries * (interleaving * field_bytes + log_codeword * hash_bytes); + (encode, proof) +} + +#[derive(Clone, Copy, Debug)] +struct Cost { + encode: f64, + proof: f64, +} + +impl Cost { + const ZERO: Self = Self { + encode: 0.0, + proof: 0.0, + }; + fn add(self, other: (f64, f64)) -> Self { + Self { + encode: self.encode + other.0, + proof: self.proof + other.1, + } + } + fn dominates(self, other: Self) -> bool { + self.encode <= other.encode + && self.proof <= other.proof + && (self.encode < other.encode || self.proof < other.proof) + } +} + +/// Bound on how aggressive Adaptive can be at a single round, expressed as a +/// multiple of the canonical WHIR per-round step `folding − 1` (the increment +/// [`super::spec::RateSchedule::Stepping`] applies, inherited from legacy +/// in-place WHIR's stepping invariant). This lets Adaptive explore schedules +/// slightly more aggressive than the legacy step, gated by the per-candidate +/// feasibility check. Pure search-space heuristic, not a correctness bound. +const ADAPTIVE_STEP_BUDGET: u32 = 2; + +/// Search per-round target rates that minimize a pareto-knee cost. The +/// skeleton is fixed (folding factors, message-length sequence); only +/// `target_log_inv_rate` per round is searched. Each candidate (round, +/// source_rate, target_rate) is checked against the per-round PoW budget via +/// the actual analytic-error formulas (memoized) — schedules whose PoW gap +/// at any round would exceed `pow_budget` are dropped before pareto. +/// +/// Returns `Err(DeriveError::AdaptiveNoFeasibleSchedule)` if no candidate +/// passes the per-slot PoW check — the spec is too tight for the planner's +/// search space. +pub(super) fn plan_adaptive_rates( + spec: &SecuritySpec, + tuning: &TuningSpec, + knee_weight: KneeWeight, + skeleton: &[RoundShape], + basecase_vector_size: usize, + mode: RoundBuildMode<'_>, +) -> Result, DeriveError> { + let mut planner: Planner<'_, M> = Planner::new(spec, skeleton, basecase_vector_size, mode); + let candidates = planner.search(tuning.starting_log_inv_rate); + // No feasible candidates means even the most conservative search choice + // (constant rate at every round) failed the per-slot PoW check. The spec + // is fundamentally too tight — surface this as a planner-level error + // rather than returning placeholder rates that would just fail downstream + // validation with a less informative message. + if candidates.is_empty() { + return Err(DeriveError::AdaptiveNoFeasibleSchedule); + } + Ok(pick_knee(pareto_frontier(candidates), knee_weight.get())) +} + +/// Adapter that returns `Some(())` when an analytic floor `bits` leaves a +/// gap small enough for `pow_budget` to close. Used via `?` so a too-low +/// floor short-circuits the surrounding `_dims` builder. +fn fits(bits: f64, deficit: f64) -> Option<()> { + (bits >= deficit).then_some(()) +} + +/// DFS state for the adaptive rate search. Owns the loop-invariants and the +/// two memoization caches so `recurse` doesn't have to thread a dozen +/// parameters down each level. +struct Planner<'a, M> { + spec: &'a SecuritySpec, + skeleton: &'a [RoundShape], + basecase_vector_size: usize, + mode: RoundBuildMode<'a>, + field_bytes: f64, + hash_bytes: f64, + /// Memoize per `(round_idx, source_rate, target_rate)`. `Some(dims)` is + /// feasible with real IRS dimensions; `None` is infeasible (PoW budget + /// can't close the analytic gap). + round_cache: HashMap<(usize, u32, u32), Option>, + /// Basecase candidates only vary by rate, so cache separately. + basecase_cache: HashMap>, + _m: std::marker::PhantomData, +} + +impl<'a, M: Embedding + Default> Planner<'a, M> { + fn new( + spec: &'a SecuritySpec, + skeleton: &'a [RoundShape], + basecase_vector_size: usize, + mode: RoundBuildMode<'a>, + ) -> Self { + Self { + spec, + skeleton, + basecase_vector_size, + mode, + field_bytes: ::field_size_bits() / 8.0, + hash_bytes: 32.0, + round_cache: HashMap::new(), + basecase_cache: HashMap::new(), + _m: std::marker::PhantomData, + } + } + + fn search(&mut self, starting_log_inv_rate: u32) -> Vec<(Vec, Cost)> { + let mut out = Vec::new(); + let mut chosen = Vec::with_capacity(self.skeleton.len()); + self.recurse(0, starting_log_inv_rate, &mut chosen, Cost::ZERO, &mut out); + out + } + + fn recurse( + &mut self, + idx: usize, + cur_rate: u32, + chosen: &mut Vec, + acc: Cost, + out: &mut Vec<(Vec, Cost)>, + ) { + // handle basecase + if idx == self.skeleton.len() { + let spec = self.spec; + let vector_size = self.basecase_vector_size; + let dims = *self + .basecase_cache + .entry(cur_rate) + .or_insert_with(|| basecase_dims::(spec, vector_size, cur_rate)); + if let Some(dims) = dims { + let (e, p) = round_cost_from_dims(dims, self.field_bytes, self.hash_bytes); + out.push((chosen.clone(), acc.add((e, p)))); + } + return; + } + let shape = self.skeleton[idx]; + // Per-round step capped at `ADAPTIVE_STEP_BUDGET · (folding − 1)` — + // the canonical WHIR per-round increment scaled by the search budget. + // Hard cap on absolute rate via `ADAPTIVE_MAX_LOG_INV_RATE`. Each + // candidate (src, tgt) is feasibility-checked below; exceeding the + // canonical increment is allowed when the PoW budget actually fits. + let max_step = shape + .source_folding_factor + .saturating_sub(1) + .saturating_mul(ADAPTIVE_STEP_BUDGET); + for delta in 0..=max_step { + let next_rate = cur_rate.saturating_add(delta); + if next_rate > ADAPTIVE_MAX_LOG_INV_RATE { + break; + } + let spec = self.spec; + let mode = self.mode; + let dims = *self + .round_cache + .entry((idx, cur_rate, next_rate)) + .or_insert_with(|| try_round_dims::(spec, &shape, cur_rate, next_rate, mode)); + let Some(dims) = dims else { continue }; + let (e, p) = round_cost_from_dims(dims, self.field_bytes, self.hash_bytes); + chosen.push(next_rate); + self.recurse(idx + 1, next_rate, chosen, acc.add((e, p)), out); + chosen.pop(); + } + } +} + +/// Drop any (schedule, cost) dominated by another. `O(|candidates|²)` — +/// fine at planner scale. +fn pareto_frontier(candidates: Vec<(Vec, Cost)>) -> Vec<(Vec, Cost)> { + let mut frontier: Vec<(Vec, Cost)> = Vec::new(); + 'outer: for cand in candidates { + for f in &frontier { + if f.1.dominates(cand.1) { + continue 'outer; + } + } + frontier.retain(|f| !cand.1.dominates(f.1)); + frontier.push(cand); + } + frontier +} + +/// Weighted log-scale pareto knee. Picks the schedule whose deficit from +/// per-axis minima — measured in log-space, weighted by `knee_weight` — is +/// smallest. Log-space normalizes the units mismatch between the +/// encode/proof proxies (ops vs bytes) and makes the picker invariant to any +/// constant rescaling of either axis. +/// +/// `knee_weight ∈ [0, 1]` is the encode-axis bias (see +/// [`super::spec::KneeWeight`]). +fn pick_knee(frontier: Vec<(Vec, Cost)>, knee_weight: f64) -> Vec { + let logs: Vec<(Vec, f64, f64)> = frontier + .into_iter() + .map(|(s, c)| (s, c.encode.max(1.0).log2(), c.proof.max(1.0).log2())) + .collect(); + let min_e = logs.iter().map(|f| f.1).fold(f64::INFINITY, f64::min); + let min_p = logs.iter().map(|f| f.2).fold(f64::INFINITY, f64::min); + let score = |le: f64, lp: f64| { + knee_weight * (le - min_e).powi(2) + (1.0 - knee_weight) * (lp - min_p).powi(2) + }; + logs.into_iter() + .min_by(|a, b| { + score(a.1, a.2) + .partial_cmp(&score(b.1, b.2)) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .expect("frontier non-empty: candidates non-empty implies frontier non-empty") + .0 +} + +/// Build the basecase IRS at `cur_rate` and return its dimensions, or `None` +/// if its analytic floors can't be closed by `pow_budget`. +fn basecase_dims( + spec: &SecuritySpec, + vector_size: usize, + log_inv_rate: u32, +) -> Option { + let ctx = RoundContext { + vector_size, + log_inv_rate, + folding_factor: 0, + }; + let commit: IrsConfig> = + irs_params::solve(spec, &ctx, OodSampleBudget::ZERO); + + let max_deficit = f64::from(spec.target_security_bits) - f64::from(spec.pow_budget.bits()); + + fits( + f64::from(sumcheck_params::analytic_error_bits( + &commit, + Option::::None, + )), + max_deficit, + )?; + if matches!(spec.mode, Mode::ZeroKnowledge) { + fits( + f64::from(basecase_params::analytic_error_bits(&commit)), + max_deficit, + )?; + } + + Some(RoundDims { + codeword_length: commit.codeword_length, + in_domain_samples: commit.in_domain_samples, + interleaving_depth: commit.interleaving_depth, + }) +} + +/// Build a round at this (source_rate, target_rate, mode) and return the +/// **target IRS's** dimensions (`codeword_length`, `in_domain_samples`, +/// `interleaving_depth`) if every per-slot analytic floor fits inside +/// `pow_budget`. Returns `None` if any floor is too low — even max grinding +/// can't close the gap. +/// +/// The returned dims drive the cost proxy, so the planner uses NTT-rounded +/// codeword sizes and per-regime query counts — no closed-form approximation. +fn try_round_dims( + spec: &SecuritySpec, + shape: &RoundShape, + source_log_inv_rate: u32, + target_log_inv_rate: u32, + mode: RoundBuildMode<'_>, +) -> Option { + let max_deficit = f64::from(spec.target_security_bits) - f64::from(spec.pow_budget.bits()); + + let src_ctx = RoundContext { + vector_size: shape.source_vector_size, + log_inv_rate: source_log_inv_rate, + folding_factor: shape.source_folding_factor, + }; + let target_log_degree = f64::from( + shape + .source_vector_size + .trailing_zeros() + .saturating_sub(shape.source_folding_factor), + ); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, f64::from(target_log_inv_rate)); + + let ood_mode = mode.map(|p| f64::from(p.c_zk_log_inv_rate.get())); + let (source, t_ood) = solve_t_ood::(spec, &src_ctx, target_list_size, ood_mode, 0).ok()?; + + let target_budget = match mode { + Branch::Standard => OodSampleBudget::ZERO, + Branch::ZeroKnowledge(_) => OodSampleBudget::new(t_ood), + }; + let tgt_ctx = RoundContext { + vector_size: source.message_length(), + log_inv_rate: target_log_inv_rate, + folding_factor: shape.target_folding_factor, + }; + let target_irs: IrsConfig> = + irs_params::solve(spec, &tgt_ctx, target_budget); + + let mask_info = match mode { + Branch::Standard => None, + Branch::ZeroKnowledge(payload) => { + let mo = build_mask_oracle::( + payload.zk_spec, + &src_ctx, + &source, + t_ood, + payload.c_zk_log_inv_rate, + shape.round_index, + ) + .ok()?; + fits(f64::from(mo.analytic_bits()), max_deficit)?; + Some(mo.info()) + } + }; + + fits( + f64::from(sumcheck_params::analytic_error_bits(&source, mask_info)), + max_deficit, + )?; + fits( + f64::from(code_switch_params::analytic_error_bits( + &source, + &target_irs, + t_ood, + mask_info, + )), + max_deficit, + )?; + + Some(RoundDims { + codeword_length: target_irs.codeword_length, + in_domain_samples: target_irs.in_domain_samples, + interleaving_depth: target_irs.interleaving_depth, + }) +} diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index fabcbb6c..89bf81f6 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -32,6 +32,17 @@ pub fn solve( folding_factor: 0, }; let commit = irs_params::solve(spec, &ctx, OodSampleBudget::ZERO); + solve_with_commit(spec, commit) +} + +/// Same as [`solve`] but with a pre-built IRS config — used by `derive` when +/// the last round's `code_switch.target` is being reused as the basecase +/// commit (Phase 2: no recommit of the folded message). +pub fn solve_with_commit( + spec: &SecuritySpec, + commit: IrsConfig>, +) -> Result, DeriveError> { + let vector_size = commit.vector_size; let sumcheck_analytic = sumcheck_params::analytic_error_bits(&commit, None); let sumcheck_pow = grind_to_at(spec, sumcheck_analytic, Pow::BasecaseSumcheck)?; diff --git a/src/protocols/params/build_round.rs b/src/protocols/params/build_round.rs index 420d3ddb..6fe8b3ba 100644 --- a/src/protocols/params/build_round.rs +++ b/src/protocols/params/build_round.rs @@ -2,7 +2,9 @@ //! //! Solves the `t_ood` fix-point, builds source/target IRS configs, and //! (in ZK) assembles the per-round mask oracle. Consumed by -//! [`super::derive`], which drives the per-round loop. +//! [`super::derive`], which drives the per-round loop. Also reachable from +//! [`super::adaptive`], which calls [`solve_t_ood`] and [`build_mask_oracle`] +//! during feasibility probing. use crate::{ algebra::{ @@ -52,19 +54,19 @@ pub(super) fn build_round_config( zk_spec, c_zk_log_inv_rate, }) => { - let num_masks = - sumcheck_params::masks_required(&ctx) + code_switch_params::masks_required(); let mask_oracle = build_mask_oracle::( zk_spec, + &ctx, &source, t_ood, - num_masks, c_zk_log_inv_rate, shape.round_index, )?; - let solve_mode = SolveMode::ZeroKnowledge(mask_oracle.info()); + let info = mask_oracle.info(); + let solve_mode = SolveMode::ZeroKnowledge(info); let round_mode = RoundMode::ZeroKnowledge { t_ood: OodSampleBudget::new(t_ood), + mask_oracle: info, }; ( OodSampleBudget::new(t_ood), @@ -89,13 +91,13 @@ pub(super) fn build_round_config( let code_switch = code_switch_params::solve(spec, source, target, t_ood, solve_mode, shape.round_index)?; - Ok(RoundConfig::new( - shape.round_index, + Ok(RoundConfig { + round_index: shape.round_index, sumcheck, code_switch, - round_mode, + mode: round_mode, mask_oracle, - )) + }) } fn solve_round_source( @@ -104,11 +106,7 @@ fn solve_round_source( ood_mode: OodMode, ) -> Result<(IrsConfig, usize), DeriveError> { let src_ctx = round_context(shape); - let target_log_inv_rate = f64::from( - shape - .source_log_inv_rate - .saturating_add(shape.source_folding_factor.saturating_sub(1)), - ); + let target_log_inv_rate = f64::from(shape.target_log_inv_rate); let target_log_degree = f64::from( shape .source_vector_size @@ -127,38 +125,69 @@ fn solve_round_source( ) } -/// ZK-only: assemble the per-round mask oracle (C_zk codeword + mask-proximity -/// check). -fn build_mask_oracle( +/// ZK-only: assemble the per-round mask oracle as **two** independent C_zk +/// trees (Zook split): +/// - `sumcheck_masks` tree: `2 · k` columns, vector size +/// `next_pow_2(sumcheck zk_mask_length)` (Lemma 6.4). Committed before +/// sumcheck. +/// - `cs_mask` tree: `2 · 1` columns, vector size `ℓ_zk = next_pow2(r + t_ood)` +/// (Lemma 9.3, Construction 9.7). Committed after sumcheck so its `r` part +/// can carry the folded source-IRS randomness. +/// +/// Both trees use the same C_zk code rate, so list sizes match and +/// `MaskOracleInfo` exposes a single shared value to downstream solvers. +pub(super) fn build_mask_oracle( zk_spec: ZkSpec<'_>, + ctx: &RoundContext, source: &IrsConfig, t_ood: usize, - num_masks: usize, c_zk_log_inv_rate: LogInvRate, round_index: usize, ) -> Result, DeriveError> { let spec = zk_spec.as_inner(); + let k = sumcheck_params::masks_required(ctx); let l_zk = compute_l_zk(source, t_ood); - let c_zk: IrsConfig> = irs_params::solve_mask_code( + + // Sumcheck-masks tree: tiny vector size (next_pow2(3) = 4), no padding to ℓ_zk. + let sumcheck_mask_vec_size = + MaskCodeMessageLen::new(sumcheck_params::zk_mask_length().next_power_of_two()); + let sumcheck_c_zk: IrsConfig> = irs_params::solve_mask_code( + zk_spec, + sumcheck_mask_vec_size, + 0, + c_zk_log_inv_rate, + MaskProximityConfig::::num_vectors_for(k), + ); + + // cs_mask tree: vector_size = ℓ_zk, holds the `(r ‖ s)` mask(s). + let cs_masks = code_switch_params::masks_required(); + let cs_c_zk: IrsConfig> = irs_params::solve_mask_code( zk_spec, l_zk, source.mask_length(), c_zk_log_inv_rate, - MaskProximityConfig::::num_vectors_for(num_masks), + MaskProximityConfig::::num_vectors_for(cs_masks), ); + let c_zk_list_size_estimate = spec.decoding_regime.list_size_estimate( (l_zk.get() as f64).log2(), f64::from(c_zk_log_inv_rate.get()), ); debug_assert!( - (c_zk.list_size() - c_zk_list_size_estimate).abs() + (cs_c_zk.list_size() - c_zk_list_size_estimate).abs() < 1e-9 * c_zk_list_size_estimate.max(1.0), - "c_zk.list_size() {} drifted from planner estimate {}", - c_zk.list_size(), + "cs_c_zk.list_size() {} drifted from planner estimate {}", + cs_c_zk.list_size(), c_zk_list_size_estimate, ); - let mask_proximity = mask_proximity_params::solve(spec, c_zk.clone(), num_masks, round_index)?; - Ok(MaskOracleConfig::new(c_zk, l_zk, mask_proximity)) + + let sumcheck_masks = mask_proximity_params::solve(spec, sumcheck_c_zk, k, round_index)?; + let cs_mask = mask_proximity_params::solve(spec, cs_c_zk, cs_masks, round_index)?; + Ok(MaskOracleConfig { + sumcheck_masks, + cs_mask, + l_zk, + }) } /// `ℓ_zk = next_pow2(r + t_ood)` (Theorem 9.6 + Lemma 9.3). diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index fec28001..144f9088 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -17,28 +17,49 @@ impl ProtocolConfig { /// Fails with [`DeriveError`] when the spec/tuning combination is /// infeasible. pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Result { - let RoundLayout { - shapes, - basecase_vector_size, - basecase_log_inv_rate, - } = round_layout(&tuning)?; - let mode: RoundBuildMode<'_> = match spec.mode { Mode::Standard => Branch::Standard, Mode::ZeroKnowledge => Branch::ZeroKnowledge(RoundBuildPayload { zk_spec: ZkSpec::try_new(&spec).expect("matched Mode::ZeroKnowledge above"), - c_zk_log_inv_rate: LogInvRate::new(tuning.starting_log_inv_rate), + // Mask code rate is fixed at 2⁻⁴ (decoupled from + // `tuning.starting_log_inv_rate`). A higher inverse rate + // tightens `|Λ(C_zk, δ_zk)|`, giving more security headroom + // for the mask-proximity discharges at the cost of larger + // mask codewords (16× message length here). + c_zk_log_inv_rate: LogInvRate::new(4), }), }; + let RoundLayout { + shapes, + basecase_vector_size, + basecase_log_inv_rate, + } = round_layout::(&spec, &tuning, mode)?; + let rounds: Vec> = shapes .iter() .map(|shape| build_round_config::(&spec, shape, mode)) .collect::>()?; - let basecase = basecase_params::solve(&spec, basecase_vector_size, basecase_log_inv_rate)?; + // When at least one round exists, the last round's `code_switch.target` + // is exactly the IRS that holds the folded basecase message — it has + // `interleaving_depth = 1` thanks to the override in `round_layout`. + // Build basecase around that same IRS so `zook::prove` can pass the + // existing witness directly to `basecase.prove` without re-encoding. + // PoW (sumcheck + γ-combination) is re-solved against the swapped + // IRS so analytic + PoW still meets the security target. + let basecase = if let Some(last) = rounds.last() { + basecase_params::solve_with_commit(&spec, last.code_switch.target.clone())? + } else { + basecase_params::solve(&spec, basecase_vector_size, basecase_log_inv_rate)? + }; - let plan = Self::new(spec, tuning, rounds, basecase); + let plan = Self { + security: spec, + tuning, + rounds, + basecase, + }; plan.validate()?; Ok(plan) } @@ -53,6 +74,7 @@ mod tests { embedding::Embedding, fields::{Field64, FieldWithSize}, }, + bits::Bits, hash, protocols::{ basecase::BasecaseMode, @@ -61,7 +83,10 @@ mod tests { error::{ChainSource, ChainTarget, DeriveError, Pow}, mask_proximity as mask_proximity_params, protocol_config::{ProtocolConfig, RoundMode}, - spec::{DecodingRegime, FoldingFactor, Mode, PowBudget, SecuritySpec, TuningSpec}, + spec::{ + DecodingRegime, FoldingFactor, KneeWeight, Mode, PowBudget, RateSchedule, + SecuritySpec, TuningSpec, + }, sumcheck as sumcheck_params, test_utils::{assert_close, assert_pow_closes_gap, TestEmbedding}, }, @@ -75,13 +100,25 @@ mod tests { FoldingFactor::ConstantFromSecondRound { initial, rest } }), ]; - (4u32..=8, 1u32..=3, folding).prop_map(|(log_size, log_inv_rate, folding_factor)| { - TuningSpec { + // Sample three meaningful regimes: unbounded stepping, a tight cap, + // and the planner-picked schedule. The cap range (4..=10) brackets + // typical production tunings without forcing all rounds to a single + // rate. + let schedule = prop_oneof![ + Just(RateSchedule::Stepping), + (4u32..=10).prop_map(|max_log_inv_rate| RateSchedule::Capped { max_log_inv_rate }), + Just(RateSchedule::Adaptive { + knee_weight: KneeWeight::DEFAULT, + }), + ]; + (4u32..=8, 1u32..=3, folding, schedule).prop_map( + |(log_size, log_inv_rate, folding_factor, rate_schedule)| TuningSpec { vector_size: 1usize << log_size, starting_log_inv_rate: log_inv_rate, folding_factor, - } - }) + rate_schedule, + }, + ) } const FIXTURE_FOLDING_FACTOR: usize = 2; @@ -95,6 +132,7 @@ mod tests { vector_size, starting_log_inv_rate: FIXTURE_LOG_INV_RATE, folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + rate_schedule: RateSchedule::Stepping, } } @@ -114,9 +152,10 @@ mod tests { fn derive_standard_with_no_rounds_uses_basecase_only() { let spec = test_spec(Mode::Standard); let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; - let plan = ProtocolConfig::::derive(spec, tuning_with(vector_size)).unwrap(); - assert!(plan.rounds().is_empty()); - assert_eq!(plan.basecase().commit.vector_size, vector_size); + let plan = + ProtocolConfig::::derive(spec, tuning_with(vector_size)).unwrap(); + assert!(plan.rounds.is_empty()); + assert_eq!(plan.basecase.commit.vector_size, vector_size); } #[test] @@ -127,11 +166,88 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), ) .unwrap(); - assert!(plan.rounds().is_empty()); - assert!(matches!( - plan.basecase().mode, - BasecaseMode::ZeroKnowledge - )); + assert!(plan.rounds.is_empty()); + assert!(matches!(plan.basecase.mode, BasecaseMode::ZeroKnowledge)); + } + + /// Adaptive plan must validate end-to-end across modes and produce + /// a strictly less aggressive rate schedule than the unbounded `Capped` + /// baseline at the tail. + #[test] + fn adaptive_plan_validates_and_caps_tail_rate() { + for mode in [Mode::Standard, Mode::ZeroKnowledge] { + let spec = test_spec(mode); + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + rate_schedule: RateSchedule::Adaptive { + knee_weight: KneeWeight::DEFAULT, + }, + }; + let plan = + ProtocolConfig::::derive(spec.clone(), tuning.clone()).unwrap(); + assert!( + plan.check_all_invariants(), + "Adaptive plan must satisfy validate() in mode={mode:?}" + ); + + // Unbounded `Capped` is the maximally-aggressive baseline (the + // canonical WHIR step-forever schedule). Adaptive should never + // propose a schedule whose basecase rate exceeds it — that's the + // whole point: Adaptive may stop growing rate earlier. + let unbounded = ProtocolConfig::::derive( + spec, + TuningSpec { + rate_schedule: RateSchedule::Stepping, + ..tuning + }, + ) + .unwrap(); + assert!( + plan.basecase.commit.rate() >= unbounded.basecase.commit.rate(), + "Adaptive basecase rate ({}) must be ≥ unbounded baseline's ({}) — \ + Adaptive should pick a *smaller* `1/ρ` at the tail", + plan.basecase.commit.rate(), + unbounded.basecase.commit.rate(), + ); + } + } + + /// When rounds exist, basecase IRS must equal the last round's + /// `code_switch.target` IRS — `zook::prove` skips re-encoding the folded + /// message at basecase entry. + #[test] + fn basecase_irs_aliases_last_round_target_when_rounds_present() { + for mode in [Mode::Standard, Mode::ZeroKnowledge] { + let spec = test_spec(mode); + for schedule in [ + RateSchedule::Stepping, + RateSchedule::Capped { + max_log_inv_rate: 6, + }, + RateSchedule::Adaptive { + knee_weight: KneeWeight::DEFAULT, + }, + ] { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + rate_schedule: schedule, + }; + let plan = ProtocolConfig::::derive(spec.clone(), tuning).unwrap(); + assert!(!plan.rounds.is_empty(), "fixture must have rounds"); + let last_target = &plan.rounds.last().unwrap().code_switch.target; + let basecase_commit = &plan.basecase.commit; + assert_eq!( + last_target, basecase_commit, + "basecase.commit must alias the last round's code_switch.target — \ + mode={mode:?} schedule={schedule:?}", + ); + assert_eq!(basecase_commit.interleaving_depth, 1); + } + } } #[test] @@ -145,7 +261,7 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - for r in plan.rounds() { + for r in &plan.rounds { let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { panic!("expected ZK round") }; @@ -164,7 +280,7 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - for r in plan.rounds() { + for r in &plan.rounds { let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { panic!("expected ZK round") }; @@ -183,11 +299,12 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - for r in plan.rounds() { + for r in &plan.rounds { let mask_oracle = r.mask_oracle().expect("ZK round has a mask oracle"); let k = r.code_switch().source.interleaving_depth.trailing_zeros() as usize; - let expected_num_masks = k + 1; - assert_eq!(mask_oracle.c_zk().num_vectors, 2 * expected_num_masks); + // Two-tree split: sumcheck_masks tree has 2·k columns, cs_mask tree has 2. + assert_eq!(mask_oracle.sumcheck_masks.c_zk_commit.num_vectors, 2 * k); + assert_eq!(mask_oracle.cs_mask.c_zk_commit.num_vectors, 2); } } @@ -202,11 +319,11 @@ mod tests { let bits: f64 = plan.analytic_bits().into(); assert!(bits.is_finite() && bits > 0.0, "bits = {bits}"); let min_round = plan - .rounds() + .rounds .iter() .map(|r| f64::from(r.analytic_bits())) .fold(f64::INFINITY, f64::min); - let expected = min_round.min(f64::from(plan.basecase().analytic_bits())); + let expected = min_round.min(f64::from(plan.basecase.analytic_bits())); assert_close(bits, expected); } @@ -220,7 +337,7 @@ mod tests { .unwrap(); let plan_bits: f64 = plan.analytic_bits().into(); let mo_floor = plan - .rounds() + .rounds .iter() .filter_map(|r| r.mask_oracle().map(|mo| f64::from(mo.analytic_bits()))) .fold(f64::INFINITY, f64::min); @@ -229,13 +346,13 @@ mod tests { "ZK plan must contribute mask-oracle bits" ); let min_round = plan - .rounds() + .rounds .iter() .map(|r| f64::from(r.analytic_bits())) .fold(f64::INFINITY, f64::min); let expected = mo_floor .min(min_round) - .min(f64::from(plan.basecase().analytic_bits())); + .min(f64::from(plan.basecase.analytic_bits())); assert_close(plan_bits, expected); } @@ -247,12 +364,9 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - assert!(matches!( - plan.basecase().mode, - BasecaseMode::ZeroKnowledge - )); - assert_eq!(plan.basecase().commit.interleaving_depth, 1); - assert_eq!(plan.basecase().sumcheck.final_size(), 1); + assert!(matches!(plan.basecase.mode, BasecaseMode::ZeroKnowledge)); + assert_eq!(plan.basecase.commit.interleaving_depth, 1); + assert_eq!(plan.basecase.sumcheck.final_size(), 1); } const LOOSE_POW_BUDGET_BITS: u32 = 60; @@ -270,7 +384,7 @@ mod tests { .unwrap(); let field_bits = ::field_size_bits(); let mut expected_total = 0.0_f64; - for r in plan.rounds() { + for r in &plan.rounds { let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { panic!("expected ZK round"); }; @@ -308,7 +422,7 @@ mod tests { #[test] fn check_pow_bits_detects_over_budget_slot() { - use crate::{bits::Bits, protocols::proof_of_work::Config as PowConfig}; + use crate::protocols::proof_of_work::Config as PowConfig; const MODERATE_POW_BUDGET_BITS: u32 = 30; let spec = SecuritySpec { pow_budget: PowBudget::per_slot(MODERATE_POW_BUDGET_BITS), @@ -326,111 +440,228 @@ mod tests { } #[test] - fn validate_round_chaining_detects_adjacent_round_mismatch() { + fn validate_security_target_met_passes_on_fresh_plan() { let spec = test_spec(Mode::ZeroKnowledge); - let mut plan = ProtocolConfig::::derive( + let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - let n = plan.rounds().len(); - assert!(n >= 2, "need ≥ 2 rounds to break a mid-chain link"); - assert!(plan.check_all_invariants(), "fresh plan must validate"); - - let bad_size = plan.rounds()[0].code_switch().target.vector_size + 1; - plan.corrupt_round_target_vector_size_for_test(0, bad_size); + plan.validate_security_target_met() + .expect("fresh plan must satisfy per-slot target check"); + } + #[test] + fn validate_security_target_met_catches_recorded_analytic_drift() { + let spec = test_spec(Mode::ZeroKnowledge); + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds.is_empty(), "need a round to corrupt"); + let recorded = plan + .rounds() + .first() + .and_then(|r| r.sumcheck().recorded_analytic) + .expect("params solver records sumcheck analytic"); + // Bump the recorded value far from the recompute → triggers drift. + plan.corrupt_round_sumcheck_recorded_analytic_for_test( + 0, + Bits::new(f64::from(recorded) + 10.0), + ); let err = plan - .validate_round_chaining() - .expect_err("adjacent-round mismatch must trip the chain check"); + .validate_security_target_met() + .expect_err("recorded vs recompute mismatch must trip drift check"); assert!( matches!( err, - DeriveError::RoundChainBroken { - from: ChainSource::Round(0), - to: ChainTarget::NextRound(1), + DeriveError::AnalyticDrift { + pow: Pow::RoundSumcheck { index: 0 }, .. } ), "got {err:?}", ); - assert!(!plan.check_all_invariants()); } #[test] - fn validate_round_chaining_detects_basecase_mismatch() { + fn validate_security_target_met_catches_zeroed_basecase_pow() { + // ZK basecase has a γ-combination PoW slot whose analytic floor is + // below the security target under the test fixture. Wiping the PoW + // must trip `validate_security_target_met`. let spec = test_spec(Mode::ZeroKnowledge); let mut plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - let n = plan.rounds().len(); - assert!(n >= 2, "need ≥ 2 rounds to break the chain by truncation"); - assert!(plan.check_all_invariants(), "fresh plan must validate"); + assert!( + plan.check_all_invariants(), + "fresh plan must validate including soundness" + ); - plan.truncate_rounds_for_test(n - 1); + // Skip the test if this fixture's basecase γ-combination doesn't actually + // require PoW (i.e., analytic ≥ target already). Otherwise force a violation. + let pre_pow = plan.basecase.pow.difficulty(); + if f64::from(pre_pow) == 0.0 { + return; + } + + plan.zero_basecase_pow_for_test(); let err = plan - .validate_round_chaining() - .expect_err("truncated tail breaks basecase chaining"); + .validate_security_target_met() + .expect_err("zeroed γ-combination PoW must trip soundness validation"); assert!( matches!( err, - DeriveError::RoundChainBroken { - to: ChainTarget::Basecase, + DeriveError::SecurityTargetNotMet { + pow: Pow::BasecaseGammaCombination, .. } ), "got {err:?}", ); - assert!(!plan.check_all_invariants()); } #[test] - fn validate_security_target_met_passes_on_fresh_plan() { + fn derive_with_capped_rate_shrinks_basecase_codeword() { + // Same shape that historically inflates basecase via rate stepping: + // ~2^20 witness, folding 3, 5 rounds → stepped log_inv_rate ≈ 12. + const LOG_VECTOR_SIZE: u32 = 12; let spec = test_spec(Mode::ZeroKnowledge); - let plan = ProtocolConfig::::derive( + let folding = FoldingFactor::Constant(3); + + let stepped_tuning = TuningSpec { + vector_size: 1usize << LOG_VECTOR_SIZE, + starting_log_inv_rate: 2, + folding_factor: folding, + rate_schedule: RateSchedule::Stepping, + }; + let stepped_plan = + ProtocolConfig::::derive(spec.clone(), stepped_tuning).unwrap(); + + let capped_tuning = TuningSpec { + vector_size: 1usize << LOG_VECTOR_SIZE, + starting_log_inv_rate: 2, + folding_factor: folding, + rate_schedule: RateSchedule::Capped { + max_log_inv_rate: 4, + }, + }; + let capped_plan = ProtocolConfig::::derive(spec, capped_tuning).unwrap(); + + // Same number of rounds — cap only affects per-round rate, not layout shape. + assert_eq!(stepped_plan.rounds.len(), capped_plan.rounds.len()); + // Cap forces a strictly smaller basecase codeword (the saving we're after). + let stepped_codeword = stepped_plan.basecase.commit.codeword_length; + let capped_codeword = capped_plan.basecase.commit.codeword_length; + assert!( + capped_codeword < stepped_codeword, + "capped basecase codeword ({capped_codeword}) should be smaller than stepped \ + ({stepped_codeword})", + ); + // For this fixture the cap should produce at least a 4× reduction. + assert!( + capped_codeword * 4 <= stepped_codeword, + "expected ≥4× reduction; got stepped={stepped_codeword}, capped={capped_codeword}", + ); + } + + /// Adaptive must not be worse than the unbounded-`Capped` baseline on + /// basecase NTT work for the folding-3 / 2^12 ZK fixture that historically + /// inflates basecase via unbounded rate stepping. At minimum, Adaptive's + /// basecase codeword must be ≤ the baseline's. + #[test] + fn adaptive_basecase_no_worse_than_unbounded_on_inflating_fixture() { + const LOG_VECTOR_SIZE: u32 = 12; + let spec = test_spec(Mode::ZeroKnowledge); + let base = TuningSpec { + vector_size: 1usize << LOG_VECTOR_SIZE, + starting_log_inv_rate: 2, + folding_factor: FoldingFactor::Constant(3), + rate_schedule: RateSchedule::Stepping, + }; + let unbounded = + ProtocolConfig::::derive(spec.clone(), base.clone()).unwrap(); + let adaptive = ProtocolConfig::::derive( spec, - tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + TuningSpec { + rate_schedule: RateSchedule::Adaptive { + knee_weight: KneeWeight::DEFAULT, + }, + ..base + }, ) .unwrap(); - plan.validate_security_target_met() - .expect("fresh plan must satisfy per-slot target check"); + let unbounded_basecode = unbounded.basecase.commit.codeword_length; + let adaptive_basecode = adaptive.basecase.commit.codeword_length; + assert!( + adaptive_basecode <= unbounded_basecode, + "Adaptive basecase codeword ({adaptive_basecode}) must be ≤ unbounded \ + baseline's ({unbounded_basecode})", + ); } #[test] - fn validate_security_target_met_catches_recorded_analytic_drift() { - use crate::bits::Bits; + fn validate_round_chaining_detects_adjacent_round_mismatch() { let spec = test_spec(Mode::ZeroKnowledge); let mut plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - assert!(!plan.rounds().is_empty(), "need a round to corrupt"); - let recorded = plan - .rounds() - .first() - .and_then(|r| r.sumcheck().recorded_analytic) - .expect("params solver records sumcheck analytic"); - // Bump the recorded value far from the recompute → triggers drift. - plan.corrupt_round_sumcheck_recorded_analytic_for_test( - 0, - Bits::new(f64::from(recorded) + 10.0), + let n = plan.rounds.len(); + assert!(n >= 2, "need ≥ 2 rounds to break a mid-chain link"); + assert!(plan.check_all_invariants(), "fresh plan must validate"); + + let bad_size = plan.rounds[0].code_switch().target.vector_size + 1; + plan.corrupt_round_target_vector_size_for_test(0, bad_size); + + let err = plan + .validate_round_chaining() + .expect_err("adjacent-round mismatch must trip the chain check"); + assert!( + matches!( + err, + DeriveError::RoundChainBroken { + from: ChainSource::Round(0), + to: ChainTarget::NextRound(1), + .. + } + ), + "got {err:?}", ); + assert!(!plan.check_all_invariants()); + } + + #[test] + fn validate_round_chaining_detects_basecase_mismatch() { + let spec = test_spec(Mode::ZeroKnowledge); + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + let n = plan.rounds.len(); + assert!(n >= 2, "need ≥ 2 rounds to break the chain by truncation"); + assert!(plan.check_all_invariants(), "fresh plan must validate"); + + plan.truncate_rounds_for_test(n - 1); let err = plan - .validate_security_target_met() - .expect_err("recorded vs recompute mismatch must trip drift check"); + .validate_round_chaining() + .expect_err("truncated tail breaks basecase chaining"); assert!( matches!( err, - DeriveError::AnalyticDrift { - pow: Pow::RoundSumcheck { index: 0 }, + DeriveError::RoundChainBroken { + to: ChainTarget::Basecase, .. } ), "got {err:?}", ); + assert!(!plan.check_all_invariants()); } #[test] @@ -480,8 +711,8 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), ) .unwrap(); - assert!(plan.rounds().is_empty()); - assert!(plan.basecase().commit.unique_decoding()); + assert!(plan.rounds.is_empty()); + assert!(plan.basecase.commit.unique_decoding()); } #[test] @@ -495,8 +726,8 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), ) .unwrap(); - assert!(plan.rounds().is_empty()); - assert!(plan.basecase().commit.unique_decoding()); + assert!(plan.rounds.is_empty()); + assert!(plan.basecase.commit.unique_decoding()); } #[test] @@ -510,14 +741,14 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - assert!(!plan.rounds().is_empty(), "expected multi-round plan"); - for r in plan.rounds() { + assert!(!plan.rounds.is_empty(), "expected multi-round plan"); + for r in &plan.rounds { let cs = r.code_switch(); assert!(cs.source.unique_decoding()); assert!(cs.target.unique_decoding()); assert!(cs.out_domain_samples >= 1); } - assert!(plan.basecase().commit.unique_decoding()); + assert!(plan.basecase.commit.unique_decoding()); } #[test] @@ -531,14 +762,14 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - assert!(!plan.rounds().is_empty(), "expected multi-round plan"); - for r in plan.rounds() { + assert!(!plan.rounds.is_empty(), "expected multi-round plan"); + for r in &plan.rounds { let mo = r.mask_oracle().expect("ZK round must own a mask oracle"); - assert!(mo.c_zk().unique_decoding()); + assert!(mo.cs_mask.c_zk_commit.unique_decoding()); assert!(r.code_switch().source.unique_decoding()); assert!(r.code_switch().out_domain_samples >= 1); } - assert!(plan.basecase().commit.unique_decoding()); + assert!(plan.basecase.commit.unique_decoding()); } #[test] @@ -552,8 +783,8 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - assert!(!plan.rounds().is_empty(), "expected multi-round plan"); - for r in plan.rounds() { + assert!(!plan.rounds.is_empty(), "expected multi-round plan"); + for r in &plan.rounds { assert!(r.code_switch().out_domain_samples >= 1); } } @@ -569,8 +800,8 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - assert!(!plan.rounds().is_empty(), "expected multi-round plan"); - for r in plan.rounds() { + assert!(!plan.rounds.is_empty(), "expected multi-round plan"); + for r in &plan.rounds { r.mask_oracle().expect("ZK round must own a mask oracle"); assert!(r.code_switch().out_domain_samples >= 1); } @@ -580,8 +811,8 @@ mod tests { spec: &SecuritySpec, plan: &ProtocolConfig, ) { - for r in plan.rounds() { - let mask_info = r.mask_oracle_info(); + for r in &plan.rounds { + let mask_info = r.mode().mask_oracle(); let cs = r.code_switch(); assert_pow_closes_gap( spec, @@ -599,27 +830,25 @@ mod tests { &cs.pow, ); if let Some(mo) = r.mask_oracle() { - let mp = mo.mask_proximity(); - assert_pow_closes_gap( - spec, - mask_proximity_params::analytic_error_bits(&mp.c_zk_commit, mp.num_masks), - &mp.pow, - ); + for mp in [&mo.sumcheck_masks, &mo.cs_mask] { + assert_pow_closes_gap( + spec, + mask_proximity_params::analytic_error_bits(&mp.c_zk_commit, mp.num_masks), + &mp.pow, + ); + } } } assert_pow_closes_gap( spec, - sumcheck_params::analytic_error_bits(&plan.basecase().commit, None), - &plan.basecase().sumcheck.round_pow, + sumcheck_params::analytic_error_bits(&plan.basecase.commit, None), + &plan.basecase.sumcheck.round_pow, ); - if matches!( - plan.basecase().mode, - BasecaseMode::ZeroKnowledge - ) { + if matches!(plan.basecase.mode, BasecaseMode::ZeroKnowledge) { assert_pow_closes_gap( spec, - basecase_params::analytic_error_bits(&plan.basecase().commit), - &plan.basecase().pow, + basecase_params::analytic_error_bits(&plan.basecase.commit), + &plan.basecase.pow, ); } } @@ -646,15 +875,12 @@ mod tests { fn derive_standard_succeeds_over_tunings(tuning in arb_tuning()) { let spec = test_spec(Mode::Standard); let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); - for r in plan.rounds() { + for r in &plan.rounds { prop_assert!(matches!(r.mode(), RoundMode::Standard)); prop_assert!(r.mask_oracle().is_none()); } - prop_assert!(matches!( - plan.basecase().mode, - BasecaseMode::Standard - )); - prop_assert_eq!(plan.basecase().commit.interleaving_depth, 1); + prop_assert!(matches!(plan.basecase.mode, BasecaseMode::Standard)); + prop_assert_eq!(plan.basecase.commit.interleaving_depth, 1); } #[test] @@ -665,7 +891,7 @@ mod tests { let spec = test_spec(Mode::ZeroKnowledge); let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); - for r in plan.rounds() { + for r in &plan.rounds { let mask_oracle = r .mask_oracle() .expect("ZK round must have a mask oracle"); @@ -674,14 +900,13 @@ mod tests { }; let cs = r.code_switch(); let k = cs.source.interleaving_depth.trailing_zeros() as usize; - let num_masks = k + 1; - prop_assert_eq!(mask_oracle.c_zk().num_vectors, 2 * num_masks); - prop_assert_eq!(mask_oracle.mask_proximity().num_masks, num_masks); + prop_assert_eq!(mask_oracle.sumcheck_masks.c_zk_commit.num_vectors, 2 * k); + prop_assert_eq!(mask_oracle.cs_mask.c_zk_commit.num_vectors, 2); let source_mask = cs.source.mask_length(); - prop_assert!(mask_oracle.l_zk().get() >= source_mask + t_ood.get()); + prop_assert!(mask_oracle.l_zk.get() >= source_mask + t_ood.get()); } prop_assert!(matches!( - plan.basecase().mode, + plan.basecase.mode, BasecaseMode::ZeroKnowledge )); } diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index 1e54af19..37844ec1 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -145,6 +145,16 @@ pub enum DeriveError { /// `tuning.folding_factor` must yield at least 1 at every round. #[error("tuning.folding_factor min ({min}) must be ≥ 1")] TuningFoldingFactorBelowOne { min: usize }, + + /// `RateSchedule::Adaptive`'s search exhausted with no feasible schedule. + /// Every per-round (source_rate, target_rate) candidate within + /// `ADAPTIVE_STEP_BUDGET · (folding − 1)` failed at least one per-slot + /// analytic floor against `spec.pow_budget`. The spec is fundamentally + /// too tight for any schedule the planner can express — pick a looser + /// `pow_budget`, a smaller `target_security_bits`, or a different + /// witness shape. + #[error("Adaptive planner found no feasible schedule under spec.pow_budget")] + AdaptiveNoFeasibleSchedule, } /// Lift `Result` into `Result` by attaching a diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 9e0c8ea6..c43caba9 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -17,6 +17,9 @@ use crate::{ }, }; +/// Convergence cap for the mask-length / effective-rate fixpoint in [`solve`]. +const SOLVE_MAX_ITER: usize = 4; + pub fn solve( spec: &SecuritySpec, ctx: &RoundContext, @@ -26,25 +29,39 @@ pub fn solve( let rate = rate(f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; - let mode = match spec.mode { - Mode::Standard => IrsMode::Standard, - Mode::ZeroKnowledge => { - let mask_length = num_in_domain_queries(spec.decoding_regime, security_target, rate) - .saturating_add(out_domain_samples.get()); - IrsMode::ZeroKnowledge { mask_length } - } + let build = |mode: IrsMode| -> IrsConfig { + IrsConfig::new( + security_target, + spec.decoding_regime, + spec.hash_id, + 1, + ctx.vector_size, + interleaving_depth, + rate, + mode, + ) }; - IrsConfig::new( - security_target, - spec.decoding_regime, - spec.hash_id, - 1, - ctx.vector_size, - interleaving_depth, - rate, - mode, - ) + if matches!(spec.mode, Mode::Standard) { + return build(IrsMode::Standard); + } + + // ZK mode: Lemma 9.5 requires `r ≥ in-domain + OOD`. `in_domain_samples` + // is computed inside `IrsConfig::new` from the codeword's *effective* + // rate (which depends on `mask_length` via `masked_message_length`). + // Iterate to align the two: pick a mask, build the config, check whether + // the effective in_domain exceeds the assumed query count, and grow if + // needed. Converges in 1-2 iterations because rate shifts are small. + let mut q = num_in_domain_queries(spec.decoding_regime, security_target, rate); + for _ in 0..SOLVE_MAX_ITER { + let mask_length = q.saturating_add(out_domain_samples.get()); + let cfg = build(IrsMode::ZeroKnowledge { mask_length }); + if cfg.in_domain_samples <= q.get() { + return cfg; + } + q = std::num::NonZeroUsize::new(cfg.in_domain_samples).expect("in_domain ≥ 1"); + } + panic!("mask / effective-rate fixpoint did not converge in {SOLVE_MAX_ITER} iterations"); } /// Shared C_zk IRS config for mask polynomials. diff --git a/src/protocols/params/layout.rs b/src/protocols/params/layout.rs index 79df2f0c..69443fc4 100644 --- a/src/protocols/params/layout.rs +++ b/src/protocols/params/layout.rs @@ -4,14 +4,21 @@ //! independent of [`SecuritySpec`] and IRS solving. Consumed by //! [`super::build_round`] to instantiate per-round configs and by //! [`super::derive`] to drive the round/basecase split. +//! +//! When [`super::spec::RateSchedule::Adaptive`] is selected, this layer also +//! invokes [`super::adaptive::plan_adaptive_rates`] to overwrite the +//! per-round `target_log_inv_rate` chain. Adaptive runs only after the +//! skeleton is built (it needs the shape sequence to enumerate candidates). use crate::{ algebra::embedding::Embedding, protocols::{ irs_commit::Config as IrsConfig, params::{ + adaptive::plan_adaptive_rates, + branch::RoundBuildMode, error::DeriveError, - spec::{RoundContext, TuningSpec}, + spec::{RateSchedule, RoundContext, SecuritySpec, TuningSpec}, }, }, }; @@ -23,6 +30,10 @@ pub(super) struct RoundShape { pub(super) source_log_inv_rate: u32, pub(super) source_folding_factor: u32, pub(super) target_folding_factor: u32, + /// Inverse rate for this round's code-switch target (= next round's + /// source). Pre-computed by [`round_layout`] under the schedule so + /// [`target_context`] doesn't have to re-derive it. + pub(super) target_log_inv_rate: u32, } #[derive(Debug)] @@ -32,7 +43,11 @@ pub(super) struct RoundLayout { pub(super) basecase_log_inv_rate: u32, } -pub(super) fn round_layout(tuning: &TuningSpec) -> Result { +pub(super) fn round_layout( + spec: &SecuritySpec, + tuning: &TuningSpec, + mode: RoundBuildMode<'_>, +) -> Result { if !tuning.vector_size.is_power_of_two() { return Err(DeriveError::TuningVectorSizeNotPowerOfTwo { vector_size: tuning.vector_size, @@ -54,21 +69,56 @@ pub(super) fn round_layout(tuning: &TuningSpec) -> Result { + let new_rates = plan_adaptive_rates::( + spec, + tuning, + knee_weight, + &shapes, + basecase_vector_size, + mode, + )?; + for (shape, &t) in shapes.iter_mut().zip(&new_rates) { + shape.target_log_inv_rate = t; + } + for i in 1..shapes.len() { + shapes[i].source_log_inv_rate = shapes[i - 1].target_log_inv_rate; + } + *new_rates.last().expect("non-empty") + } + _ => log_inv_rate, + }; + + // Force the last round's target IRS to interleaving = 1 so it can serve + // directly as the basecase commit — eliminates the recommit of the folded + // message in `zook::prove`. `target_folding_factor` only affects the + // target IRS shape (it's a hint about the next round's source-folding, + // and there is no next round here). + if let Some(last) = shapes.last_mut() { + last.target_folding_factor = 0; } Ok(RoundLayout { shapes, - basecase_vector_size: 1usize << num_vars, - basecase_log_inv_rate: log_inv_rate, + basecase_vector_size, + basecase_log_inv_rate, }) } @@ -86,9 +136,7 @@ pub(super) fn target_context( ) -> RoundContext { RoundContext { vector_size: source.message_length(), - log_inv_rate: shape - .source_log_inv_rate - .saturating_add(shape.source_folding_factor.saturating_sub(1)), + log_inv_rate: shape.target_log_inv_rate, folding_factor: shape.target_folding_factor, } } @@ -96,7 +144,14 @@ pub(super) fn target_context( #[cfg(test)] mod tests { use super::*; - use crate::protocols::params::spec::FoldingFactor; + use crate::{ + hash, + protocols::params::{ + branch::Branch, + spec::{DecodingRegime, FoldingFactor, Mode, PowBudget}, + test_utils::TestEmbedding, + }, + }; const FIXTURE_FOLDING_FACTOR: usize = 2; const FIXTURE_LOG_INV_RATE: u32 = 1; @@ -110,14 +165,29 @@ mod tests { const RATE_STEPPING_STARTING_LOG_INV_RATE: u32 = 2; const MIN_ROUNDS_FOR_CHAINING_TEST: usize = 2; + fn fixture_spec() -> SecuritySpec { + SecuritySpec { + mode: Mode::Standard, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: 40, + pow_budget: PowBudget::per_slot(60), + hash_id: hash::BLAKE3, + } + } + fn tuning_with(vector_size: usize) -> TuningSpec { TuningSpec { vector_size, starting_log_inv_rate: FIXTURE_LOG_INV_RATE, folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + rate_schedule: RateSchedule::Stepping, } } + fn run_layout(tuning: &TuningSpec) -> Result { + round_layout::(&fixture_spec(), tuning, Branch::Standard) + } + #[test] fn round_layout_rate_steps_up_by_folding_minus_one() { let tuning = TuningSpec { @@ -127,8 +197,9 @@ mod tests { initial: VARIED_INITIAL_FOLDING, rest: VARIED_STEADY_FOLDING, }, + rate_schedule: RateSchedule::Stepping, }; - let layout = round_layout(&tuning).unwrap(); + let layout = run_layout(&tuning).unwrap(); let mut expected_log_inv_rate = RATE_STEPPING_STARTING_LOG_INV_RATE; for shape in &layout.shapes { @@ -147,13 +218,17 @@ mod tests { initial: VARIED_INITIAL_FOLDING, rest: VARIED_STEADY_FOLDING, }, + rate_schedule: RateSchedule::Stepping, }; - let layout = round_layout(&tuning).unwrap(); + let layout = run_layout(&tuning).unwrap(); assert!( layout.shapes.len() >= MIN_ROUNDS_FOR_CHAINING_TEST, "need ≥ {MIN_ROUNDS_FOR_CHAINING_TEST} rounds to test chaining", ); - for window in layout.shapes.windows(2) { + // Last round's target_folding_factor is force-zeroed by round_layout to + // align with the basecase commit's interleaving = 1. Only window over + // round pairs that aren't terminated by this zeroing. + for window in layout.shapes[..layout.shapes.len() - 1].windows(2) { assert_eq!( window[0].target_folding_factor, window[1].source_folding_factor @@ -164,7 +239,7 @@ mod tests { #[test] fn round_layout_basecase_size_consumes_remaining_num_vars() { let tuning = tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND); - let layout = round_layout(&tuning).unwrap(); + let layout = run_layout(&tuning).unwrap(); let consumed: u32 = layout.shapes.iter().map(|s| s.source_folding_factor).sum(); let initial_num_vars = tuning.vector_size.trailing_zeros(); let remaining = initial_num_vars - consumed; @@ -175,7 +250,7 @@ mod tests { fn round_layout_stops_when_no_room_for_source_plus_target() { let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; let tuning = tuning_with(vector_size); - let layout = round_layout(&tuning).unwrap(); + let layout = run_layout(&tuning).unwrap(); assert!(layout.shapes.is_empty()); assert_eq!(layout.basecase_vector_size, vector_size); assert_eq!(layout.basecase_log_inv_rate, FIXTURE_LOG_INV_RATE); @@ -187,8 +262,9 @@ mod tests { vector_size: 12, starting_log_inv_rate: FIXTURE_LOG_INV_RATE, folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + rate_schedule: RateSchedule::Stepping, }; - let err = round_layout(&tuning).expect_err("non-pow2 vector_size must fail"); + let err = run_layout(&tuning).expect_err("non-pow2 vector_size must fail"); assert!( matches!( err, @@ -204,8 +280,9 @@ mod tests { vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, starting_log_inv_rate: FIXTURE_LOG_INV_RATE, folding_factor: FoldingFactor::Constant(0), + rate_schedule: RateSchedule::Stepping, }; - let err = round_layout(&tuning).expect_err("folding_factor = 0 must fail"); + let err = run_layout(&tuning).expect_err("folding_factor = 0 must fail"); assert!( matches!(err, DeriveError::TuningFoldingFactorBelowOne { min: 0 }), "got {err:?}", diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index 6f7ddf7a..d4d62fc1 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -4,6 +4,7 @@ //! "the bounds doc, §N") live at //! . +pub(crate) mod adaptive; pub(crate) mod basecase; pub(crate) mod bounds; pub(crate) mod branch; @@ -28,6 +29,6 @@ pub use protocol_config::{ MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, }; pub use spec::{ - DecodingRegime, FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, - PowBudget, RoundContext, SecuritySpec, TuningSpec, ZkSpec, + DecodingRegime, FoldingFactor, KneeWeight, ListSize, LogInvRate, MaskCodeMessageLen, Mode, + OodSampleBudget, PowBudget, RateSchedule, RoundContext, SecuritySpec, TuningSpec, ZkSpec, }; diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 41c8fb4b..a258f9b3 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -1,17 +1,19 @@ //! Output of [`super::derive`]: the assembled per-round and basecase configs. +//! +//! Each ZK round owns its mask oracle: a per-round C_zk codeword (sized for +//! `2·(k+1)` columns — `k` sumcheck masks + 1 code-switch `(r ‖ s)` mask, all +//! doubled by Construction 7.2's originals + fresh pairs) plus a per-round +//! mask-proximity check. Standard rounds carry no mask oracle. use ark_ff::Field; +use serde::{Deserialize, Serialize}; use crate::{ - algebra::{ - embedding::{Embedding, Identity}, - fields::FieldWithSize, - }, + algebra::{embedding::Embedding, fields::FieldWithSize}, bits::Bits, protocols::{ basecase::Config as BasecaseConfig, code_switch::Config as CodeSwitchConfig, - irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, params::{ basecase as basecase_params, @@ -27,29 +29,16 @@ use crate::{ }, }; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "")] pub struct ProtocolConfig { - security: SecuritySpec, - tuning: TuningSpec, - rounds: Vec>, - basecase: BasecaseConfig, + pub security: SecuritySpec, + pub tuning: TuningSpec, + pub rounds: Vec>, + pub basecase: BasecaseConfig, } impl ProtocolConfig { - pub(crate) const fn new( - security: SecuritySpec, - tuning: TuningSpec, - rounds: Vec>, - basecase: BasecaseConfig, - ) -> Self { - Self { - security, - tuning, - rounds, - basecase, - } - } - pub const fn security(&self) -> &SecuritySpec { &self.security } @@ -84,17 +73,16 @@ impl ProtocolConfig { Ok(()) } - /// For each PoW slot: verify (a) the analytic-bits floor recorded at - /// solve time still matches a fresh recompute from the config's current - /// state, and (b) `recorded_analytic + pow.difficulty() ≥ target_security_bits`. - /// - /// `grind_to_at` guarantees (b) at solve time. If (a) holds, (b) holds - /// trivially. If (a) drifts, (b) may fail — most often because a planner - /// regression overwrote an IRS field after the solver consumed it. + /// Re-verify that `analytic + pow.difficulty() ≥ target_security_bits` + /// for every PoW slot. `grind_to_at` guarantees this at solve time, so + /// failure here indicates the analytic formula returns a *different* + /// value at validate-time than the solver consumed — most often because + /// a planner-input change (rate schedule, mask sizing, …) flowed into + /// the IRS configs but not into the formula's inputs. /// - /// `EPS` matches the `assert_pow_closes_gap` slack used by the per-slot - /// proptest helper, so validation stays consistent with test-time - /// assertions. + /// `EPS = 1e-3` matches the `assert_pow_closes_gap` slack used by the + /// per-slot proptest helper, so `validate` stays consistent with the + /// test-time assertion. pub fn validate_security_target_met(&self) -> Result<(), DeriveError> { const EPS: f64 = 1e-3; let target = Bits::new(f64::from(self.security.target_security_bits)); @@ -126,7 +114,7 @@ impl ProtocolConfig { Ok(()) }; for r in &self.rounds { - let mask_info = r.mask_oracle_info(); + let mask_info = r.mode.mask_oracle(); check( Pow::RoundSumcheck { index: r.round_index, @@ -148,18 +136,20 @@ impl ProtocolConfig { ), &r.code_switch.pow, )?; - if let Some(mo) = r.mask_oracle() { - check( - Pow::RoundMaskProximity { - index: r.round_index, - }, - mo.mask_proximity.recorded_analytic, - mask_proximity_params::analytic_error_bits( - &mo.mask_proximity.c_zk_commit, - mo.mask_proximity.num_masks, - ), - &mo.mask_proximity.pow, - )?; + if let Some(mo) = &r.mask_oracle { + for tree in [&mo.sumcheck_masks, &mo.cs_mask] { + check( + Pow::RoundMaskProximity { + index: r.round_index, + }, + tree.recorded_analytic, + mask_proximity_params::analytic_error_bits( + &tree.c_zk_commit, + tree.num_masks, + ), + &tree.pow, + )?; + } } } check( @@ -203,12 +193,18 @@ impl ProtocolConfig { }, &r.code_switch.pow, )?; - if let Some(mo) = r.mask_oracle() { + if let Some(mo) = &r.mask_oracle { + check( + Pow::RoundMaskProximity { + index: r.round_index, + }, + &mo.sumcheck_masks.pow, + )?; check( Pow::RoundMaskProximity { index: r.round_index, }, - &mo.mask_proximity.pow, + &mo.cs_mask.pow, )?; } } @@ -259,12 +255,16 @@ impl ProtocolConfig { /// HVZK privacy error in bits, summed across ZK rounds: /// `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` (bounds doc, §5.3 + §5.7). + /// Standard-mode plans return `target_security_bits` as a sentinel — + /// HVZK isn't claimed when there are no ZK rounds. pub fn privacy_error_bits(&self) -> Bits { let field_bits = ::field_size_bits(); let mut total_error = 0.0_f64; for r in &self.rounds { - if let RoundMode::ZeroKnowledge { t_ood, .. } = &r.mode { + if let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode { let t = usize_to_f64(t_ood.get()); + // ζ_ze ≤ (t_ood² + t_ood) / (2|F|). Compute in log space to + // stay numerically stable for large field_bits. let log_err = f64::midpoint(t * t, t).log2() - field_bits; total_error += 2_f64.powf(log_err); } @@ -274,14 +274,15 @@ impl ProtocolConfig { } Bits::new((-total_error.log2()).max(0.0)) } -} -impl ProtocolConfig { /// Analytic soundness bits (excluding PoW). pub fn analytic_bits(&self) -> Bits { let mut min_bits = f64::from(self.basecase.analytic_bits()); for round in &self.rounds { min_bits = min_bits.min(f64::from(round.analytic_bits())); + if let Some(mo) = &round.mask_oracle { + min_bits = min_bits.min(f64::from(mo.analytic_bits())); + } } Bits::new(min_bits.max(0.0)) } @@ -305,6 +306,12 @@ impl ProtocolConfig { self.rounds[round_idx].code_switch.target.vector_size = new_size; } + /// Drop the basecase γ-combination PoW to `0`, simulating a planner bug + /// that under-counts the analytic-to-target gap. + pub(crate) fn zero_basecase_pow_for_test(&mut self) { + self.basecase.pow = PowConfig::none(); + } + pub(crate) fn corrupt_round_sumcheck_recorded_analytic_for_test( &mut self, round_idx: usize, @@ -314,33 +321,19 @@ impl ProtocolConfig { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "")] pub struct RoundConfig { - round_index: usize, - sumcheck: SumcheckConfig, - code_switch: CodeSwitchConfig, - mode: RoundMode, - /// `Some` iff `mode.is_zk()`. Sized for this round's `k + 1` masks. - mask_oracle: Option>, + pub round_index: usize, + pub sumcheck: SumcheckConfig, + pub code_switch: CodeSwitchConfig, + pub mode: RoundMode, + /// `Some` iff this is a ZK round. Sized for this round's `k + 1` masks + /// (k sumcheck + 1 code-switch). + pub mask_oracle: Option>, } impl RoundConfig { - pub(crate) const fn new( - round_index: usize, - sumcheck: SumcheckConfig, - code_switch: CodeSwitchConfig, - mode: RoundMode, - mask_oracle: Option>, - ) -> Self { - Self { - round_index, - sumcheck, - code_switch, - mode, - mask_oracle, - } - } - pub const fn round_index(&self) -> usize { self.round_index } @@ -357,43 +350,16 @@ impl RoundConfig { &self.mode } - /// Borrow the round's mask oracle if this is a ZK round. pub const fn mask_oracle(&self) -> Option<&MaskOracleConfig> { self.mask_oracle.as_ref() } - /// Slim mask-oracle view derived from `mask_oracle()`. - pub fn mask_oracle_info(&self) -> Option { - self.mask_oracle().map(MaskOracleConfig::info) - } -} - -/// Standard vs. ZK round. -/// -/// Non-generic — the per-round `MaskOracleConfig` lives on -/// [`RoundConfig`] as a sibling field. -#[derive(Clone, Copy, Debug)] -pub enum RoundMode { - Standard, - ZeroKnowledge { - /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). - t_ood: OodSampleBudget, - }, -} - -impl RoundMode { - pub const fn is_zk(&self) -> bool { - matches!(self, Self::ZeroKnowledge { .. }) - } -} - -impl RoundConfig { /// Round-level analytic floor: the smallest of `sumcheck`, `code_switch`, /// and (when present) the per-round mask-oracle proximity check. pub fn analytic_bits(&self) -> Bits { let source = &self.code_switch.source; let target = &self.code_switch.target; - let mask_info = self.mask_oracle_info(); + let mask_info = self.mode.mask_oracle(); let sumcheck_term = f64::from(sumcheck_params::analytic_error_bits(source, mask_info)); let code_switch_term = f64::from(code_switch_params::analytic_error_bits( @@ -402,71 +368,86 @@ impl RoundConfig { self.code_switch.out_domain_samples, mask_info, )); - let mask_oracle_term = self - .mask_oracle() - .map_or(f64::INFINITY, |mo| f64::from(mo.analytic_bits())); - - Bits::new( - sumcheck_term - .min(code_switch_term) - .min(mask_oracle_term) - .max(0.0), - ) + + Bits::new(sumcheck_term.min(code_switch_term).max(0.0)) } } -/// One round's mask oracle: a C_zk codeword + ℓ_zk + mask-proximity check. -#[derive(Clone, Debug)] -pub struct MaskOracleConfig { - c_zk: IrsConfig>, - /// `next_pow2(r + t_ood)` (Theorem 9.6 + Lemma 9.3). - l_zk: MaskCodeMessageLen, - mask_proximity: MaskProximityConfig, +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum RoundMode { + Standard, + ZeroKnowledge { + /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). + t_ood: OodSampleBudget, + /// Slim view of this round's [`MaskOracleConfig`] (C_zk's list size + + /// ℓ_zk) — denormalized so soundness routines can read it without + /// chasing through `mask_oracle`. + mask_oracle: MaskOracleInfo, + }, } -impl MaskOracleConfig { - pub(crate) const fn new( - c_zk: IrsConfig>, - l_zk: MaskCodeMessageLen, - mask_proximity: MaskProximityConfig, - ) -> Self { - Self { - c_zk, - l_zk, - mask_proximity, - } - } - - pub const fn c_zk(&self) -> &IrsConfig> { - &self.c_zk +impl RoundMode { + pub const fn is_zk(&self) -> bool { + matches!(self, Self::ZeroKnowledge { .. }) } - pub const fn l_zk(&self) -> MaskCodeMessageLen { - self.l_zk + pub const fn mask_oracle(&self) -> Option { + match self { + Self::Standard => None, + Self::ZeroKnowledge { mask_oracle, .. } => Some(*mask_oracle), + } } +} - pub const fn mask_proximity(&self) -> &MaskProximityConfig { - &self.mask_proximity - } +/// One round's mask oracle, split across two independent C_zk trees: +/// - `sumcheck_masks`: the `k` sumcheck masks (Lemma 6.4), each of length +/// `next_pow_2(mask_length)`. Committed BEFORE sumcheck. +/// - `cs_mask`: the single `(r ‖ s)` code-switch mask (Construction 9.7), +/// length `ℓ_zk`. Committed AFTER sumcheck so its `r` part can carry the +/// folded source-IRS randomness. +/// +/// Each sub-tree carries its own `mask_proximity` instance (with its own +/// `c_zk_commit` IRS config). Both trees use the same C_zk code rate, so +/// `info()` exposes a single shared list-size to downstream solvers. +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct MaskOracleConfig { + /// Mask-proximity over the `k` sumcheck masks; `num_vectors = 2 · k`. + /// Vector size is `next_pow_2(sumcheck.mask_length)`, much smaller than ℓ_zk. + pub sumcheck_masks: MaskProximityConfig, + /// Mask-proximity over the single cs_mask; `num_vectors = 2`. Vector + /// size is ℓ_zk. + pub cs_mask: MaskProximityConfig, + /// `next_pow2(r + t_ood)` for this round (Lemma 9.3). + pub l_zk: MaskCodeMessageLen, +} +impl MaskOracleConfig { pub fn info(&self) -> MaskOracleInfo { + // Both sub-trees use the same C_zk rate, so either's list size is the + // shared C_zk list size. MaskOracleInfo { - c_zk_list_size: ListSize::new(self.c_zk.list_size()), + c_zk_list_size: ListSize::new(self.cs_mask.c_zk_commit.list_size()), l_zk: self.l_zk, } } + + /// Analytic soundness bits (excluding PoW) for this round's mask oracle. + /// Both trees must hold; the weaker one dominates. + pub fn analytic_bits(&self) -> Bits { + let a = self.sumcheck_masks.analytic_bits(); + let b = self.cs_mask.analytic_bits(); + if f64::from(a) < f64::from(b) { + a + } else { + b + } + } } /// Slim mask-oracle view (C_zk's list size + ℓ_zk). -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct MaskOracleInfo { pub c_zk_list_size: ListSize, pub l_zk: MaskCodeMessageLen, } - -impl MaskOracleConfig { - /// Analytic soundness bits (excluding PoW) for this round's mask oracle. - pub fn analytic_bits(&self) -> Bits { - self.mask_proximity.analytic_bits() - } -} diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 42ff3104..663dd75e 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -17,7 +17,7 @@ use crate::{bits::Bits, engines::EngineId}; /// - **Planning credit**: subtracted from `target_security_bits` so solvers /// know the analytic floor they must reach. /// - **Validation cap**: rejects any per-slot PoW that exceeds `bits`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum PowBudget { Forbidden, PerSlot { bits: NonZeroU32 }, @@ -41,8 +41,9 @@ impl PowBudget { } } -/// Phantom-typed newtype. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +/// Phantom-typed newtype — `Tagged` and `Tagged` are distinct types. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(bound = "T: Serialize + for<'a> Deserialize<'a>")] pub struct Tagged(T, PhantomData); impl Tagged { @@ -55,7 +56,7 @@ impl Tagged { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct SecuritySpec { pub mode: Mode, pub decoding_regime: DecodingRegime, @@ -72,7 +73,7 @@ impl SecuritySpec { } /// Per-round folding strategy. `at_round(i)` returns the factor for round `i`. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum FoldingFactor { /// Same folding factor across all rounds. Constant(usize), @@ -110,11 +111,104 @@ impl FoldingFactor { } /// Proof-size / prover-time / soundness-margin tradeoffs. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TuningSpec { pub vector_size: usize, pub starting_log_inv_rate: u32, pub folding_factor: FoldingFactor, + /// Per-round inverse-rate schedule (see [`RateSchedule`]). + pub rate_schedule: RateSchedule, +} + +/// Pareto-knee bias for [`RateSchedule::Adaptive`]'s planner. Controls the +/// trade-off between encode (NTT) cost and proof-byte cost when picking +/// among pareto-frontier candidates. +/// +/// `α ∈ [0, 1]`: +/// - `α = 0.5` (default): pure geometric knee, balanced across both axes. +/// - `α > 0.5`: prover-biased — accepts more proof bytes for less NTT. +/// - `α < 0.5`: proof-biased — accepts more NTT for fewer proof bytes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct KneeWeight(OrderedFloat); + +impl KneeWeight { + /// Pure geometric knee — equal weighting on both axes. + pub const DEFAULT: Self = Self(OrderedFloat(0.5)); + + /// Build a `KneeWeight`, clamping `value` to `[0, 1]`. Out-of-range + /// inputs (including NaN) are silently corrected — the planner has no + /// meaningful behavior outside `[0, 1]`, so reject-by-clamp at the API + /// boundary rather than propagating a `Result`. + pub const fn new(value: f64) -> Self { + let clamped = if value.is_nan() { + 0.5 + } else { + value.clamp(0.0, 1.0) + }; + Self(OrderedFloat(clamped)) + } + + pub const fn get(self) -> f64 { + self.0 .0 + } +} + +impl Default for KneeWeight { + fn default() -> Self { + Self::DEFAULT + } +} + +/// Per-round inverse-rate schedule. +/// +/// The Section 10 code-switching IOPP (Theorem 10.2) commits a fresh codeword +/// each round, so the per-round rate is a free parameter — the only structural +/// constraint is on message lengths (`2^{k_{i+1}} · ℓ_{i+1} ≥ ℓ_i`). +/// +/// Capping the inverse rate trades a small increase in per-round query count +/// (cheap on tiny late-round Merkle trees) for dramatically smaller late-round +/// and basecase NTTs. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum RateSchedule { + /// Step `log_inv_rate += folding − 1` per round, unbounded — the + /// canonical legacy in-place WHIR behavior. Late-round and basecase NTTs + /// grow with depth. + Stepping, + /// Step like [`Self::Stepping`], then clamp at `max_log_inv_rate`. Set + /// `max_log_inv_rate == tuning.starting_log_inv_rate` for a constant-rate + /// schedule. + Capped { max_log_inv_rate: u32 }, + /// Per-round rates chosen at `derive` time by a planner that minimises a + /// `(prover_time_proxy, proof_size_proxy)` pareto knee under the security + /// target. No user-supplied budget; the knee is scale-free. The + /// [`KneeWeight`] biases the picker between the two axes. See + /// [`super::derive::plan_adaptive_rates`]. + Adaptive { knee_weight: KneeWeight }, +} + +impl RateSchedule { + /// Compute the next round's inverse rate from the current one and the + /// current round's folding factor. + /// + /// `Adaptive` returns the unbounded step here; the planner runs after the + /// skeleton is built and overwrites every per-round rate. + pub const fn step(self, current_log_inv_rate: u32, folding_factor: u32) -> u32 { + let stepped = current_log_inv_rate.saturating_add(folding_factor.saturating_sub(1)); + match self { + Self::Stepping | Self::Adaptive { .. } => stepped, + Self::Capped { max_log_inv_rate } => { + if stepped < max_log_inv_rate { + stepped + } else { + max_log_inv_rate + } + } + } + } + + pub const fn is_adaptive(self) -> bool { + matches!(self, Self::Adaptive { .. }) + } } /// Per-round context handed to a sub-protocol builder. @@ -126,7 +220,7 @@ pub struct RoundContext { } /// Standard vs. zero-knowledge selection. Orthogonal to [`DecodingRegime`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum Mode { Standard, ZeroKnowledge, @@ -243,7 +337,7 @@ pub type MaskCodeMessageLen = Tagged; pub type LogInvRate = Tagged; /// Reed–Solomon list-decoding ball size `|Λ(C, δ)|`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct ListSize(OrderedFloat); impl ListSize { @@ -312,4 +406,37 @@ mod tests { Bits::new(0.0), ); } + + #[test] + fn rate_schedule_stepping_unbounded() { + // Stepping adds `folding − 1` per round regardless of magnitude + // (legacy unbounded-stepping behavior). + assert_eq!(RateSchedule::Stepping.step(2, 3), 4); + assert_eq!(RateSchedule::Stepping.step(100, 5), 104); + } + + #[test] + fn rate_schedule_capped_clamps_above_cap() { + let cap = RateSchedule::Capped { + max_log_inv_rate: 5, + }; + assert_eq!(cap.step(2, 3), 4); // below cap → step normally + assert_eq!(cap.step(4, 3), 5); // would step to 6 → clamp to 5 + assert_eq!(cap.step(5, 3), 5); // already at cap → stays + assert_eq!(cap.step(10, 3), 5); // above cap → snaps back to cap + } + + #[test] + fn rate_schedule_folding_factor_one_never_steps() { + // folding == 1 means rate step is `1 − 1 = 0`. The IRS layer disallows + // folding < 1, but the math here should still be consistent. + assert_eq!(RateSchedule::Stepping.step(2, 1), 2); + assert_eq!( + RateSchedule::Capped { + max_log_inv_rate: 5 + } + .step(2, 1), + 2, + ); + } } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 68b263bf..ce795a6a 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -30,7 +30,7 @@ pub fn solve( SolveMode::ZeroKnowledge(mask_oracle) => ( Some(mask_oracle), sumcheck::SumcheckMode::ZeroKnowledge { - mask_length: zk_mask_length(), + mask_length: SumcheckMaskLen::new(zk_mask_length()), }, ), }; @@ -79,8 +79,8 @@ const fn num_sumcheck_rounds(ctx: &RoundContext) -> usize { /// Construction 6.3 step 4(a) sends `h_j ∈ F^{ SumcheckMaskLen { - SumcheckMaskLen::new(3) +pub const fn zk_mask_length() -> usize { + 3 } #[cfg(test)] @@ -105,7 +105,11 @@ mod tests { irs_params::solve(spec, ctx, OodSampleBudget::ZERO) } - const FIXTURE_LOG_VECTOR_SIZE: u32 = 4; + /// Smallest pow2 shape that still produces a non-degenerate IRS. Increased + /// from 2^4 to 2^8 so message_length is large enough that the Lemma 9.5 mask + /// (now sized to actual query count, not pow2-padded) doesn't shift the rate + /// enough to push the PoW gap past the 60-bit cap. + const FIXTURE_LOG_VECTOR_SIZE: u32 = 8; const FIXTURE_LOG_INV_RATE: u32 = 1; const FIXTURE_FOLDING_FACTOR: u32 = 2; diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index b6772baa..d149ff42 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -10,7 +10,6 @@ use tracing::instrument; use crate::{ algebra::{ - dot, sumcheck::{compute_sumcheck_polynomial, fold, fold_and_compute_polynomial}, univariate_evaluate, }, @@ -157,7 +156,6 @@ impl Config { ); assert_eq!(a.len(), self.initial_size); assert_eq!(b.len(), self.initial_size); - debug_assert_eq!(dot(a, b), *sum); assert_eq!(masks.len(), self.num_rounds * self.mask_length()); let half = F::from(2).inverse().unwrap(); @@ -357,6 +355,7 @@ mod tests { use super::*; use crate::{ algebra::{ + dot, fields::{self, Field64}, multilinear_extend, random_vector, }, diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs index 18dacc42..32093048 100644 --- a/src/protocols/whir_zk/mod.rs +++ b/src/protocols/whir_zk/mod.rs @@ -536,6 +536,86 @@ mod tests { } } + /// Mirror of `zook::tests::roundtrip_2_pow_20_three_claims_zk`: 1 polynomial of + /// size 2^20 with 3 multilinear-extension claims on BN254, target_security 128 + /// with a 10-bit PoW budget. Used to compare prove/verify wall-clock against the + /// zook orchestrator under the same workload shape. + /// + /// Run with: `cargo test --release --features tracing,rs_in_order --lib \ + /// protocols::whir_zk::tests::roundtrip_2_pow_20_three_claims_whir_zk -- --nocapture` + #[test] + fn roundtrip_2_pow_20_three_claims_whir_zk() { + use crate::algebra::fields::Field256; + type FB = Field256; + + const NV: usize = 19; + const NC: usize = 1 << NV; + + crate::tests::init(); + let mut rng = ark_std::test_rng(); + + let whir_params = ProtocolParameters { + decoding_regime: DecodingRegime::Johnson, + security_level: 128, + pow_bits: 10, + initial_folding_factor: 3, + folding_factor: 3, + starting_log_inv_rate: 2, + batch_size: 1, + hash_id: hash::SHA2, + }; + let params = Config::::new(1 << NV, &whir_params, 1); + + let vector: Vec = random_vector(&mut rng, NC); + let f0 = MultilinearExtension { + point: random_vector::(&mut rng, NV), + }; + let f1 = MultilinearExtension { + point: random_vector::(&mut rng, NV), + }; + let f2 = MultilinearExtension { + point: random_vector::(&mut rng, NV), + }; + let embedding = params.blinded_commitment.embedding(); + let evaluations = vec![ + f0.evaluate(embedding, &vector), + f1.evaluate(embedding, &vector), + f2.evaluate(embedding, &vector), + ]; + + let forms: Vec>> = vec![Box::new(f0), Box::new(f1), Box::new(f2)]; + let refs: Vec<&dyn LinearForm> = forms.iter().map(|w| w.as_ref()).collect(); + let prove_forms: Vec>> = forms + .iter() + .map(|f| { + let mut cv = vec![FB::ZERO; params.blinded_commitment.initial_size()]; + f.accumulate(&mut cv, FB::ONE); + Box::new(Covector { vector: cv }) as Box> + }) + .collect(); + + let ds = DomainSeparator::protocol(¶ms) + .session(&format!("whir-zk-bench-2^20 {}:{}", file!(), line!())) + .instance(&Empty); + let mut prover_state = ProverState::new_std(&ds); + let witness = params.commit(&mut prover_state, &[&vector[..]]); + let _ = params.prove( + &mut prover_state, + vec![Cow::Borrowed(&vector[..])], + witness, + prove_forms, + Cow::Borrowed(&evaluations), + ); + let proof = prover_state.proof(); + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let commitment = params + .receive_commitments(&mut verifier_state, 1) + .expect("receive_commitments"); + params + .verify(&mut verifier_state, &refs, &evaluations, &commitment) + .expect("verify"); + } + /// Soundness exploit: malicious prover generates proof for WRONG evaluation. /// A sound PCS must reject; if verify() returns Ok, g_eval freedom lets /// the prover forge arbitrary evaluation claims. diff --git a/src/protocols/zook/commit.rs b/src/protocols/zook/commit.rs new file mode 100644 index 00000000..99a5fce8 --- /dev/null +++ b/src/protocols/zook/commit.rs @@ -0,0 +1,239 @@ +//! Initial witness commitment for zook. +//! +//! Goes through `rounds[0].code_switch.source` when the plan has rounds; +//! through `basecase.commit` when it doesn't. + +use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; +#[cfg(feature = "tracing")] +use tracing::instrument; + +use crate::{ + algebra::{embedding::Embedding, lift}, + hash::Hash, + protocols::{ + irs_commit::{Commitment as IrsCommitment, Witness as IrsWitness}, + params::protocol_config::ProtocolConfig, + }, + transcript::{ + Codec, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, VerifierState, + }, +}; + +/// Prover handle from [`ProtocolConfig::commit`]; consumed by `prove`. +#[must_use] +#[derive(Clone, Debug)] +pub struct CommittedWitness { + pub(crate) state: CommittedState, +} + +/// Internal: the two branches differ in their IRS witness field type. +#[derive(Clone, Debug)] +pub(crate) enum CommittedState { + /// Plan has ≥ 1 round; committed through `rounds[0].code_switch.source`. + Round { + message: Vec, + irs_witness: IrsWitness, + }, + /// Basecase-only plan; witness was lifted into `M::Target` first. + Basecase { + message: Vec, + irs_witness: IrsWitness, + }, +} + +/// Verifier handle from [`ProtocolConfig::receive_commitment`]; consumed by `verify`. +#[must_use] +#[derive(Clone, Debug)] +pub struct Commitment { + pub(crate) irs_commitment: IrsCommitment, +} + +impl ProtocolConfig { + /// Commit the initial witness to the protocol's first IRS codeword. + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::commit", fields(vector_size = self.tuning.vector_size, num_rounds = self.rounds.len())))] + pub fn commit( + &self, + ps: &mut ProverState, + witness: &[M::Source], + ) -> CommittedWitness + where + Standard: Distribution + Distribution, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + M::Target: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + assert_eq!( + witness.len(), + self.tuning.vector_size, + "zook witness length", + ); + + let state = if let Some(round) = self.rounds.first() { + let irs_witness = round.code_switch.source.commit(ps, &[witness]); + let message = lift(round.code_switch.source.embedding(), witness); + CommittedState::Round { + message, + irs_witness, + } + } else { + // Basecase IRS is over `M::Target`; lift before committing. + let embedding = M::default(); + let message = lift(&embedding, witness); + let irs_witness = self.basecase.commit.commit(ps, &[&message]); + CommittedState::Basecase { + message, + irs_witness, + } + }; + CommittedWitness { state } + } + + /// Verifier mirror of [`Self::commit`]. + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::receive_commitment", fields(vector_size = self.tuning.vector_size, num_rounds = self.rounds.len())))] + pub fn receive_commitment(&self, vs: &mut VerifierState) -> VerificationResult + where + H: DuplexSpongeInterface, + M::Target: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let irs_commitment = match self.rounds.first() { + Some(round) => round.code_switch.source.receive_commitment(vs)?, + None => self.basecase.commit.receive_commitment(vs)?, + }; + Ok(Commitment { irs_commitment }) + } +} + +#[cfg(test)] +mod tests { + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + use super::*; + use crate::{ + algebra::random_vector, + hash, + protocols::params::{ + spec::{ + DecodingRegime, FoldingFactor, Mode, PowBudget, RateSchedule, SecuritySpec, + TuningSpec, + }, + test_utils::TestEmbedding, + }, + transcript::{codecs::Empty, DomainSeparator}, + }; + + type F = ::Source; + + /// Keep PoW below the 60-bit cap during `derive` for small test tunings. + const TEST_TARGET_BITS: u32 = 40; + + fn test_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: TEST_TARGET_BITS, + pow_budget: PowBudget::per_slot(10), + hash_id: hash::BLAKE3, + } + } + + /// Vector size large enough for ≥ 1 round under folding_factor 2. + fn tuning_with_rounds() -> TuningSpec { + TuningSpec { + vector_size: 1 << 8, + starting_log_inv_rate: 1, + folding_factor: FoldingFactor::Constant(2), + rate_schedule: RateSchedule::Stepping, + } + } + + /// Vector size below `2 · folding_factor`; `round_layout` admits 0 rounds. + fn tuning_basecase_only() -> TuningSpec { + TuningSpec { + vector_size: 1 << 3, + starting_log_inv_rate: 1, + folding_factor: FoldingFactor::Constant(2), + rate_schedule: RateSchedule::Stepping, + } + } + + fn roundtrip( + config: &ProtocolConfig, + seed: u64, + ) -> CommittedWitness { + let mut rng = StdRng::seed_from_u64(seed); + let witness = random_vector::(&mut rng, config.tuning.vector_size); + + let ds = DomainSeparator::protocol(&"zook-commit-test") + .session(&format!("commit roundtrip {}:{}", file!(), line!())) + .instance(&Empty); + + let mut prover_state = ProverState::new_std(&ds); + let committed = config.commit(&mut prover_state, &witness); + let proof = prover_state.proof(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let _ = config + .receive_commitment(&mut verifier_state) + .expect("receive_commitment"); + verifier_state + .check_eof() + .expect("transcript fully consumed"); + + committed + } + + #[test] + fn commit_receive_roundtrip_with_rounds_zk() { + let config = ProtocolConfig::::derive( + test_spec(Mode::ZeroKnowledge), + tuning_with_rounds(), + ) + .unwrap(); + assert!(!config.rounds.is_empty()); + let committed = roundtrip(&config, 0); + assert!(matches!(committed.state, CommittedState::Round { .. })); + } + + #[test] + fn commit_receive_roundtrip_with_rounds_standard() { + let config = ProtocolConfig::::derive( + test_spec(Mode::Standard), + tuning_with_rounds(), + ) + .unwrap(); + let committed = roundtrip(&config, 1); + assert!(matches!(committed.state, CommittedState::Round { .. })); + } + + #[test] + fn commit_receive_roundtrip_basecase_only() { + let config = ProtocolConfig::::derive( + test_spec(Mode::ZeroKnowledge), + tuning_basecase_only(), + ) + .unwrap(); + assert!(config.rounds.is_empty()); + let committed = roundtrip(&config, 2); + assert!(matches!(committed.state, CommittedState::Basecase { .. })); + } + + #[test] + #[should_panic(expected = "zook witness length")] + fn commit_rejects_wrong_witness_size() { + let config = ProtocolConfig::::derive( + test_spec(Mode::ZeroKnowledge), + tuning_with_rounds(), + ) + .unwrap(); + let mut rng = StdRng::seed_from_u64(3); + let too_short = random_vector::(&mut rng, config.tuning.vector_size - 1); + + let ds = DomainSeparator::protocol(&"zook-commit-test") + .session(&format!("wrong size {}:{}", file!(), line!())) + .instance(&Empty); + let mut prover_state = ProverState::new_std(&ds); + let _ = config.commit(&mut prover_state, &too_short); + } +} diff --git a/src/protocols/zook/mod.rs b/src/protocols/zook/mod.rs new file mode 100644 index 00000000..cecc100b --- /dev/null +++ b/src/protocols/zook/mod.rs @@ -0,0 +1,120 @@ +//! Zook ZK protocol — Construction 9.7. + +pub mod commit; +pub mod prover; +pub mod verifier; + +pub use commit::{Commitment, CommittedWitness}; + +pub use crate::protocols::params::protocol_config::ProtocolConfig; + +#[cfg(test)] +mod tests { + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + use super::ProtocolConfig; + use crate::{ + algebra::{ + embedding::Identity, + fields::Field256, + linear_form::{Evaluate, LinearForm, MultilinearExtension}, + random_vector, + }, + hash, + protocols::params::spec::{ + DecodingRegime, FoldingFactor, Mode, PowBudget, RateSchedule, SecuritySpec, TuningSpec, + }, + transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, + }; + + type F = Field256; + type Embed = Identity; + + /// Keep PoW below the 60-bit cap during `derive` for the test tuning. + const TEST_TARGET_BITS: u32 = 128; + + fn test_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: TEST_TARGET_BITS, + pow_budget: PowBudget::per_slot(10), + hash_id: hash::SHA2, + } + } + + /// 2^20-sized witness; folding_factor 4 ⇒ 5 rounds (keeps round count + /// modest at this size). + fn tuning_2_pow_20() -> TuningSpec { + TuningSpec { + vector_size: 1 << 19, + starting_log_inv_rate: 2, + folding_factor: FoldingFactor::ConstantFromSecondRound { + initial: 3, + rest: 3, + }, + rate_schedule: RateSchedule::Stepping, + } + } + + /// Commit → prove → verify roundtrip on a 2^20 witness with 3 multilinear + /// claims. Shared by the ZK and Standard tests below. + fn roundtrip_three_claims(mode: Mode, seed: u64) { + crate::tests::init(); + let config = ProtocolConfig::::derive(test_spec(mode), tuning_2_pow_20()).unwrap(); + + let mut rng = StdRng::seed_from_u64(seed); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let mu = config.tuning.vector_size.trailing_zeros() as usize; + + // 3 random multilinear claims and their true evaluations against the + // witness. + let embedding = ::default(); + let forms: Vec> = (0..3) + .map(|_| MultilinearExtension { + point: random_vector::(&mut rng, mu), + }) + .collect(); + let values: Vec = forms + .iter() + .map(|f| f.evaluate(&embedding, &witness)) + .collect(); + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + + let ds = DomainSeparator::protocol(&"zook-mod-test") + .session(&format!( + "three-claims 2^20 mode={:?} {}:{}", + mode, + file!(), + line!() + )) + .instance(&Empty); + + let mut prover_state = ProverState::new_std(&ds); + let committed = config.commit(&mut prover_state, &witness); + config.prove(&mut prover_state, committed, &form_refs, &values); + let proof = prover_state.proof(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let commitment = config + .receive_commitment(&mut verifier_state) + .expect("receive_commitment"); + config + .verify(&mut verifier_state, commitment, &form_refs, &values) + .expect("verify"); + verifier_state + .check_eof() + .expect("transcript fully consumed"); + } + + #[test] + fn roundtrip_2_pow_20_three_claims_zk() { + roundtrip_three_claims(Mode::ZeroKnowledge, 0); + } + + #[test] + fn roundtrip_2_pow_20_three_claims_standard() { + roundtrip_three_claims(Mode::Standard, 1); + } +} diff --git a/src/protocols/zook/prover.rs b/src/protocols/zook/prover.rs new file mode 100644 index 00000000..bc6a6eb5 --- /dev/null +++ b/src/protocols/zook/prover.rs @@ -0,0 +1,586 @@ +//! Zook prover — Construction 9.7 (ZK Code-Switching). +//! +//! Per ZK round, the orchestrator wraps three sub-protocols: +//! - **ZK sumcheck** (Lemma 6.5): reduces ` = sum`, +//! consuming `k` short (degree-2) mask polynomials for hiding. +//! - **Code-switch** (Construction 9.7): reduces source IRS to a target +//! IRS, consuming one `cs_mask = (r ‖ s)` mask oracle of length ℓ_zk. +//! - **Mask proximity** (Construction 7.2): proves the committed masks are +//! close to C_zk codewords. +//! +//! The mask oracle is split across **two** C_zk trees: +//! - `sumcheck_masks` tree: k masks at vector_size = next_pow_2(mask_length) +//! (= 4 for degree-2 round polys). Committed BEFORE sumcheck — standard +//! ZK-sumcheck discipline. +//! - `cs_mask` tree: 1 mask at vector_size = ℓ_zk. Committed AFTER sumcheck +//! so cs_mask's `r` part can carry `fold(source.masks, folding_randomness)` +//! (Construction 9.7 `(r ‖ s)` structure). +//! +//! Splitting avoids padding the tiny sumcheck masks up to ℓ_zk; the NTT and +//! Merkle work at C_zk's rate drops by ~4× in typical configurations. +//! +//! Per-round flow: +//! 1. Sample sumcheck masks `M_0..M_{k−1}` (length `mask_length` each) and +//! `s_fresh` of length `ℓ_zk − source.mask_length()`. +//! 2. Commit `sumcheck_masks` tree (k masks padded to next_pow_2(mask_length)). +//! 3. Run ZK sumcheck. Post-sumcheck `state.sum = δ + γ_sumcheck · dot`, +//! where `δ = Σ M_i(r_i)`. +//! 4. Derive `r_folded = fold(lift(source.masks), r_0..r_{k−1})`, assemble +//! `cs_mask = (r_folded ‖ s_fresh)`, commit `cs_mask` tree. +//! 5. Send `δ` cleartext and reconcile +//! `state.sum := (sum − δ) · γ_sumcheck⁻¹` so the claim entering +//! code-switch is the unmasked `dot(folded_a, post_sumcheck_cov)`. +//! 6. Code-switch updates `state.{message, irs_witness, covector, sum}`. +//! 7. Prove sumcheck masks at `[1, r_i, …, r_i^{sumcheck_vec_size−1}]` +//! (gives X_i = M_i(r_i)); prove cs_mask at the post-cs covector mask +//! region (gives X_cs). The verifier checks `Σ X_{i ProtocolConfig> { + /// Prove `f(witness) == values[j]` for every form `f = forms[j]` against + /// the committed witness. Consumes `committed`. + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::prove", fields(vector_size = self.tuning.vector_size, num_rounds = self.rounds.len(), num_claims = forms.len())))] + pub fn prove( + &self, + ps: &mut ProverState, + committed: CommittedWitness>, + forms: &[&dyn LinearForm], + values: &[F], + ) where + Standard: Distribution, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + assert_eq!(forms.len(), values.len(), "forms.len() != values.len()"); + assert!(!forms.is_empty(), "zook requires ≥ 1 (form, value) pair"); + + // RLC challenge binds the form/value set to the commitment. + let gamma: F = ps.verifier_message(); + let gamma_powers = geometric_sequence(gamma, forms.len()); + + // Materialize the combined covector = Σ γ^j · form_j and combined value. + let mut covector = vec![F::ZERO; self.tuning.vector_size]; + for (form, &g) in forms.iter().zip(&gamma_powers) { + form.accumulate(&mut covector, g); + } + let combined_value: F = values.iter().zip(&gamma_powers).map(|(v, g)| *v * g).sum(); + + // Reduce to basecase inputs `(message, witness, covector, sum)`. The + // two arms differ only in how those are obtained. + let (message, basecase_witness, covector, sum) = match committed.state { + // Basecase-only plan: use the committed witness directly. + CommittedState::Basecase { + message, + irs_witness, + } => (message, irs_witness, covector, combined_value), + CommittedState::Round { + message, + irs_witness, + } => { + let mut state = ProverRoundState { + message, + irs_witness, + covector, + sum: combined_value, + }; + for round in &self.rounds { + state = prove_round(round, state, ps); + } + // After per-round reconciliation, state.sum is bound to + // dot(state.message, state.covector) — no extra transcript + // send needed entering basecase. + (state.message, state.irs_witness, state.covector, state.sum) + } + }; + + let _ = self + .basecase + .prove(ps, message, &basecase_witness, covector, sum); + } +} + +/// Per-round transient state. `irs_witness` is `IrsWitness` throughout +/// because we restrict to `Identity` embeddings (M::Source = M::Target = F). +struct ProverRoundState { + message: Vec, + irs_witness: IrsWitness, + covector: Vec, + sum: F, +} + +#[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::prove_round", fields(msg_len = round.code_switch.source.message_length(), message_len = state.message.len())))] +fn prove_round( + round: &RoundConfig>, + mut state: ProverRoundState, + ps: &mut ProverState, +) -> ProverRoundState +where + F: Field + Default + Zeroize + Codec<[H::U]>, + Standard: Distribution, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, +{ + let msg_len = round.code_switch.source.message_length(); + + // (1+2) Sample sumcheck masks (and pre-sample cs_mask's s_fresh) and + // pre-commit the sumcheck-masks tree. + let mut sumcheck_masks_state = round + .mask_oracle + .as_ref() + .map(|mo| SumcheckMasks::sample_and_commit(mo, round, ps)); + let sumcheck_masks_flat: &[F] = sumcheck_masks_state + .as_ref() + .map_or(&[][..], SumcheckMasks::flat); + + // (3) ZK sumcheck. Post-sumcheck `state.sum = δ + γ_sumcheck · dot`. + let opening = round.sumcheck.prove( + ps, + &mut state.message, + &mut state.covector, + &mut state.sum, + sumcheck_masks_flat, + ); + + // (4+5) Build cs_mask = (r_folded ‖ s_fresh), commit the cs_mask tree, + // send δ and reconcile state.sum to the unmasked dot. + let cs_mask_state = sumcheck_masks_state.as_mut().map(|sm| { + CsMask::commit_and_reconcile(sm, round, &state.irs_witness, &opening, &mut state.sum, ps) + }); + + // (6) Extend covector and run code-switch. Hand the source IRS witness over + // by value so code_switch can drop the source matrix immediately after + // `source.open` — frees ~268 MB in round 1 before cs_mask::prove_at. + if let Some(cm) = cs_mask_state.as_ref() { + state.covector.resize(msg_len + cm.l_zk, F::ZERO); + } + let cs_witness = round.code_switch.prove( + ps, + state.message, + std::mem::take(&mut state.irs_witness), + code_switch::Claim { + covector: &mut state.covector, + sum: &mut state.sum, + }, + &opening.round_challenges, + cs_mask_state.as_ref().map_or(&[][..], CsMask::coefficients), + ); + + // (7) Prove sumcheck masks at the round challenges, prove cs_mask at the + // post-cs covector, then subtract X_cs. + if let (Some(sm), Some(cm)) = (sumcheck_masks_state, cs_mask_state) { + sm.prove_at(&opening.round_challenges, ps); + let x_cs = cm.prove_at(&state.covector[msg_len..], ps); + state.sum -= x_cs; + } + // opening's round_challenges/mask_rlc are fully consumed; release before + // the next round's allocations begin. + drop(opening); + + state.message = cs_witness.message; + state.irs_witness = cs_witness.target_witness; + state.covector.truncate(state.message.len()); + + state +} + +/// Sumcheck-masks tree: k masks committed BEFORE sumcheck so each round-poly +/// is bound to its mask. Vector size is `next_pow_2(sumcheck.mask_length)` — +/// no padding to ℓ_zk. +/// +/// Also carries the `s_fresh` padding for cs_mask — sampled pre-sumcheck +/// (doesn't depend on folding randomness) and consumed by [`CsMask`]. +struct SumcheckMasks<'a, F: Field> { + mask_oracle: &'a MaskOracleConfig, + /// k padded sumcheck masks, each of length `vec_size`. `originals[i][0..mask_len]` + /// is the raw mask polynomial; the rest is zero padding for NTT. + originals: Vec>, + /// k * mask_len concatenated raw mask coefficients (sumcheck.prove input). + flat: Vec, + witness: mask_proximity::Witness, + /// Fresh padding for cs_mask. Length `ℓ_zk − source.mask_length()`. Sampled + /// here so the RNG draws happen during pre-sumcheck setup; consumed by + /// `CsMask::commit_and_reconcile`. + s_fresh: Vec, + /// Vector size of the sumcheck-masks tree (= `next_pow_2(mask_len)`). + vec_size: usize, + k: usize, +} + +impl<'a, F: Field + Zeroize> SumcheckMasks<'a, F> { + fn sample_and_commit( + mask_oracle: &'a MaskOracleConfig, + round: &RoundConfig>, + ps: &mut ProverState, + ) -> Self + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Standard: Distribution, + Hash: ProverMessage<[H::U]>, + { + let k = round.sumcheck.num_rounds; + let mask_len = match round.sumcheck.mode { + SumcheckMode::Standard => 0, + SumcheckMode::ZeroKnowledge { mask_length } => mask_length.get(), + }; + let vec_size = mask_oracle.sumcheck_masks.c_zk_commit.vector_size; + debug_assert!(vec_size >= mask_len); + + let l_zk = mask_oracle.l_zk.get(); + let source_mask_len = round.code_switch.source.mask_length(); + assert!(source_mask_len <= l_zk); + + let flat: Vec = random_vector(ps.rng(), k * mask_len); + let originals: Vec> = (0..k) + .map(|i| { + let mut padded = vec![F::ZERO; vec_size]; + padded[..mask_len].copy_from_slice(&flat[i * mask_len..(i + 1) * mask_len]); + padded + }) + .collect(); + // s_fresh doesn't depend on folding randomness — sample it here. + let s_fresh: Vec = random_vector(ps.rng(), l_zk - source_mask_len); + let originals_refs: Vec<&[F]> = originals.iter().map(Vec::as_slice).collect(); + let witness = mask_oracle.sumcheck_masks.commit(ps, &originals_refs); + Self { + mask_oracle, + originals, + flat, + witness, + s_fresh, + vec_size, + k, + } + } + + fn flat(&self) -> &[F] { + &self.flat + } + + /// δ = Σ_i M_i(r_i). The orchestrator sends δ cleartext + reconciles + /// `state.sum` before code-switch. + fn delta(&self, round_challenges: &[F]) -> F { + round_challenges + .iter() + .zip(self.originals.iter()) + .map(|(&r, mask_padded)| univariate_evaluate(mask_padded, r)) + .sum() + } + + /// Prove each sumcheck mask at `cov_i = [1, r_i, …, r_i^{vec_size−1}]`, + /// emitting `X_i = M_i(r_i)` to the verifier. Soundness: the verifier + /// checks `Σ X_i == δ` against the cleartext δ sent earlier. + fn prove_at(mut self, round_challenges: &[F], ps: &mut ProverState) + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Standard: Distribution, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + debug_assert_eq!(round_challenges.len(), self.k); + let covectors: Vec> = round_challenges + .iter() + .map(|&r| geometric_sequence(r, self.vec_size)) + .collect(); + let covector_refs: Vec<&[F]> = covectors.iter().map(|v| v.as_slice()).collect(); + let originals_refs: Vec<&[F]> = self.originals.iter().map(Vec::as_slice).collect(); + self.mask_oracle.sumcheck_masks.prove( + ps, + self.witness, + &originals_refs, + Some(&covector_refs), + ); + // Wipe secret mask material before this struct drops, so it doesn't + // linger in the freed heap region. + drop(originals_refs); + self.flat.zeroize(); + for o in &mut self.originals { + o.zeroize(); + } + } +} + +/// cs_mask tree: 1 mask of length ℓ_zk, committed AFTER sumcheck so the `r` +/// part can carry `fold(source.masks, folding_randomness)`. +struct CsMask<'a, F: Field> { + mask_oracle: &'a MaskOracleConfig, + /// The mask polynomial coefficients, structured as `(r_folded ‖ s_fresh)`. + coefficients: Vec, + witness: mask_proximity::Witness, + l_zk: usize, +} + +impl<'a, F: Field + Zeroize> CsMask<'a, F> { + /// Build `cs_mask = (r_folded ‖ s_fresh)` (s_fresh was pre-sampled in + /// [`SumcheckMasks::sample_and_commit`] and is consumed via `mem::take` + /// here), commit the cs_mask tree, send δ cleartext, and reconcile + /// `*sum := (sum − δ) · γ_sumcheck⁻¹`. + fn commit_and_reconcile( + sumcheck_masks: &mut SumcheckMasks<'a, F>, + round: &RoundConfig>, + irs_witness: &IrsWitness, + opening: &SumcheckOpening, + sum: &mut F, + ps: &mut ProverState, + ) -> Self + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Standard: Distribution, + Hash: ProverMessage<[H::U]>, + { + let mask_oracle = sumcheck_masks.mask_oracle; + let l_zk = mask_oracle.l_zk.get(); + let source_mask_len = round.code_switch.source.mask_length(); + debug_assert!(source_mask_len <= l_zk); + debug_assert_eq!(sumcheck_masks.s_fresh.len(), l_zk - source_mask_len); + + // r_folded = collapse of source IRS randomness by sumcheck challenges. + // Masks live in whir's canonical per-poly contiguous layout. + let raw = lift(round.code_switch.source.embedding(), &irs_witness.masks); + let r_folded = fold_chunks(&raw, source_mask_len, &opening.round_challenges); + debug_assert_eq!(r_folded.len(), source_mask_len); + + // Pre-size to l_zk so the `append` below doesn't reallocate. + let mut cs_mask = Vec::with_capacity(l_zk); + cs_mask.extend_from_slice(&r_folded); + // s_fresh is moved out — only consumed here, so no clone needed. + cs_mask.append(&mut sumcheck_masks.s_fresh); + debug_assert_eq!(cs_mask.len(), l_zk); + + let witness = mask_oracle.cs_mask.commit(ps, &[&cs_mask[..]]); + + // δ = Σ M_i(r_i) from the sumcheck-masks tree. Sent cleartext; bound + // by the verifier's `Σ X_{i &[F] { + &self.coefficients + } + + /// Prove cs_mask at the post-cs covector mask region. Returns + /// `X_cs = ` for the f-only projection. + fn prove_at(mut self, cs_mask_covector: &[F], ps: &mut ProverState) -> F + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Standard: Distribution, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let covector_refs: [&[F]; 1] = [cs_mask_covector]; + let x_cs = dot(&self.coefficients, cs_mask_covector); + self.mask_oracle.cs_mask.prove( + ps, + self.witness, + &[&self.coefficients[..]], + Some(&covector_refs), + ); + // cs_mask is uniformly random secret randomness; wipe before drop. + self.coefficients.zeroize(); + x_cs + } +} + +#[cfg(test)] +mod tests { + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + use super::*; + use crate::{ + algebra::{ + fields::Field64, + linear_form::{Evaluate, MultilinearExtension}, + }, + hash, + protocols::params::{ + spec::{ + DecodingRegime, FoldingFactor, Mode, PowBudget, RateSchedule, SecuritySpec, + TuningSpec, + }, + test_utils::TestEmbedding, + }, + transcript::{codecs::Empty, DomainSeparator}, + }; + + type F = Field64; + + const TEST_TARGET_BITS: u32 = 40; + + fn test_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: TEST_TARGET_BITS, + pow_budget: PowBudget::per_slot(10), + hash_id: hash::BLAKE3, + } + } + + fn tuning_with_rounds() -> TuningSpec { + TuningSpec { + vector_size: 1 << 8, + starting_log_inv_rate: 1, + folding_factor: FoldingFactor::Constant(2), + rate_schedule: RateSchedule::Stepping, + } + } + + fn tuning_basecase_only() -> TuningSpec { + TuningSpec { + vector_size: 1 << 3, + starting_log_inv_rate: 1, + folding_factor: FoldingFactor::Constant(2), + rate_schedule: RateSchedule::Stepping, + } + } + + /// Domain separator + ProverState setup shared by every test. + macro_rules! test_session { + ($tag:expr) => {{ + let ds = DomainSeparator::protocol(&"zook-prover-test") + .session(&format!("{} {}:{}", $tag, file!(), line!())) + .instance(&Empty); + ProverState::new_std(&ds) + }}; + } + + fn run_prove(config: &ProtocolConfig, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let mu = config.tuning.vector_size.trailing_zeros() as usize; + let form = MultilinearExtension { + point: random_vector::(&mut rng, mu), + }; + let embedding = as Default>::default(); + let value = form.evaluate(&embedding, &witness); + + let mut prover_state = test_session!("smoke"); + let committed = config.commit(&mut prover_state, &witness); + config.prove(&mut prover_state, committed, &[&form], &[value]); + let _ = prover_state.proof(); + } + + fn smoke(mode: Mode, tuning: TuningSpec, seed: u64, expect_rounds: bool) { + let config = ProtocolConfig::::derive(test_spec(mode), tuning).unwrap(); + assert_eq!(!config.rounds.is_empty(), expect_rounds); + run_prove(&config, seed); + } + + #[test] + fn prove_completes_with_rounds_zk() { + smoke(Mode::ZeroKnowledge, tuning_with_rounds(), 0, true); + } + + #[test] + fn prove_completes_with_rounds_standard() { + smoke(Mode::Standard, tuning_with_rounds(), 1, true); + } + + #[test] + fn prove_completes_basecase_only_zk() { + smoke(Mode::ZeroKnowledge, tuning_basecase_only(), 2, false); + } + + #[test] + fn prove_completes_basecase_only_standard() { + smoke(Mode::Standard, tuning_basecase_only(), 3, false); + } + + #[test] + #[should_panic(expected = "forms.len() != values.len()")] + fn prove_rejects_count_mismatch() { + let config = ProtocolConfig::::derive( + test_spec(Mode::Standard), + tuning_with_rounds(), + ) + .unwrap(); + let mut rng = StdRng::seed_from_u64(4); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let mu = config.tuning.vector_size.trailing_zeros() as usize; + let form = MultilinearExtension { + point: random_vector::(&mut rng, mu), + }; + let mut prover_state = test_session!("count-mismatch"); + let committed = config.commit(&mut prover_state, &witness); + // 1 form, 2 values → mismatch. + config.prove(&mut prover_state, committed, &[&form], &[F::ONE, F::ONE]); + } + + #[test] + #[should_panic(expected = "zook requires ≥ 1")] + fn prove_rejects_empty() { + let config = ProtocolConfig::::derive( + test_spec(Mode::Standard), + tuning_with_rounds(), + ) + .unwrap(); + let mut rng = StdRng::seed_from_u64(5); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let mut prover_state = test_session!("empty"); + let committed = config.commit(&mut prover_state, &witness); + config.prove(&mut prover_state, committed, &[], &[]); + } +} diff --git a/src/protocols/zook/verifier.rs b/src/protocols/zook/verifier.rs new file mode 100644 index 00000000..4fa37326 --- /dev/null +++ b/src/protocols/zook/verifier.rs @@ -0,0 +1,317 @@ +//! Zook verifier — mirror of [`super::prover`] (Construction 9.7). +//! +//! Per ZK round, the verifier: +//! 1. Receives the **sumcheck-masks tree** commitment (pre-sumcheck — binds +//! each round poly's mask via Fiat-Shamir). +//! 2. Runs `sumcheck.verify` (orchestrator folds its public covector). +//! Post-sumcheck `state.sum = δ + γ_sumcheck · dot`. +//! 3. Receives the **cs_mask tree** commitment (post-sumcheck — cs_mask +//! carries `r_folded` from source IRS randomness) and reads `δ` +//! cleartext, reconciling `state.sum := (sum − δ) · γ_sumcheck⁻¹`. +//! 4. Extends the covector and runs `code_switch.verify`. +//! 5. Verifies sumcheck masks at `[1, r_i, …, r_i^{vec_size_A−1}]` (gives +//! X_i = M_i(r_i)); verifies cs_mask at the post-cs covector mask +//! region (gives X_cs). Checks `Σ X_i == δ` to bind δ; subtracts X_cs +//! to project to f-only. +//! +//! After all rounds: receives the basecase IRS commitment, runs +//! `basecase.verify`, and ties the result back to the public covector via a +//! final MLE check `multilinear_extend(covector, point) == +//! basecase.linear_form_evaluation`. + +use ark_ff::Field; +use ark_std::rand::{distributions::Standard, prelude::Distribution}; +#[cfg(feature = "tracing")] +use tracing::instrument; + +use crate::{ + algebra::{ + embedding::Identity, geometric_sequence, linear_form::LinearForm, multilinear_extend, + sumcheck::fold, + }, + hash::Hash, + protocols::{ + irs_commit::Commitment as IrsCommitment, + params::protocol_config::{ProtocolConfig, RoundConfig}, + zook::commit::Commitment, + }, + transcript::{ + codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, VerificationResult, + VerifierMessage, VerifierState, + }, + verify, +}; + +impl ProtocolConfig> { + /// Verify `f(witness) == values[j]` for every form `f = forms[j]` against + /// the received commitment. + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::verify", fields(vector_size = self.tuning.vector_size, num_rounds = self.rounds.len(), num_claims = forms.len())))] + pub fn verify( + &self, + vs: &mut VerifierState, + commitment: Commitment, + forms: &[&dyn LinearForm], + values: &[F], + ) -> VerificationResult<()> + where + Standard: Distribution, + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + assert_eq!(forms.len(), values.len(), "forms.len() != values.len()"); + assert!(!forms.is_empty(), "zook requires ≥ 1 (form, value) pair"); + + // RLC challenge binds the form/value set to the commitment. + let gamma: F = vs.verifier_message(); + let gamma_powers = geometric_sequence(gamma, forms.len()); + + // Materialize the combined covector = Σ γ^j · form_j and combined value. + let mut covector = vec![F::ZERO; self.tuning.vector_size]; + for (form, &g) in forms.iter().zip(&gamma_powers) { + form.accumulate(&mut covector, g); + } + let combined_value: F = values.iter().zip(&gamma_powers).map(|(v, g)| *v * g).sum(); + + // Reduce to basecase inputs `(irs_commitment, covector, sum)`. After + // any rounds, the last round's target commitment IS the basecase + // commitment (`derive` makes `basecase.commit == last.code_switch.target`), + // so we reuse `state.irs_commitment` directly — no separate receive. + let (irs_commitment, covector, sum) = if self.rounds.is_empty() { + (commitment.irs_commitment, covector, combined_value) + } else { + let mut state = VerifierRoundState { + irs_commitment: commitment.irs_commitment, + covector, + sum: combined_value, + }; + for round in &self.rounds { + state = verify_round(round, state, vs)?; + } + (state.irs_commitment, state.covector, state.sum) + }; + + let opening = self.basecase.verify(vs, &irs_commitment, sum)?; + + // Final consistency: the implicit linear form revealed by basecase + // must equal the MLE of our covector at the sumcheck challenge point. + let expected = multilinear_extend(&covector, &opening.evaluation_points); + verify!(expected == opening.linear_form_evaluation); + Ok(()) + } +} + +struct VerifierRoundState { + irs_commitment: IrsCommitment, + covector: Vec, + sum: F, +} + +#[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::verify_round", fields(msg_len = round.code_switch.source.message_length(), covector_len = state.covector.len())))] +fn verify_round( + round: &RoundConfig>, + mut state: VerifierRoundState, + vs: &mut VerifierState, +) -> VerificationResult> +where + F: Field + Default + Codec<[H::U]>, + Standard: Distribution, + H: DuplexSpongeInterface, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, +{ + let msg_len = round.code_switch.source.message_length(); + + // (1) Receive the sumcheck-masks tree commitment (pre-sumcheck) so each + // sumcheck round poly is bound to its mask. + let sumcheck_commitment = match round.mask_oracle.as_ref() { + Some(mo) => Some(mo.sumcheck_masks.receive_commitment(vs)?), + None => None, + }; + + // (2) ZK sumcheck. Mutates sum to `δ + γ_sumcheck · dot`; orchestrator + // folds its public covector by the round challenges. + let opening = round.sumcheck.verify(vs, &mut state.sum)?; + for &r in &opening.round_challenges { + fold(&mut state.covector, r); + } + + // (3) Receive cs_mask tree commitment (post-sumcheck) + δ cleartext. + // Reconcile sum to the unmasked dot. + let cs_mask_commitment = match round.mask_oracle.as_ref() { + Some(mo) => Some(mo.cs_mask.receive_commitment(vs)?), + None => None, + }; + let delta: Option = if round.mask_oracle.is_some() { + let d: F = vs.prover_message()?; + let gamma_inv = opening + .mask_rlc + .inverse() + .expect("sumcheck mask_rlc must be non-zero"); + state.sum = (state.sum - d) * gamma_inv; + Some(d) + } else { + None + }; + + // (4) Extend covector and run code-switch. + if let Some(mo) = round.mask_oracle.as_ref() { + state.covector.resize(msg_len + mo.l_zk.get(), F::ZERO); + } + let target_commitment = round.code_switch.verify( + vs, + &mut state.sum, + &mut state.covector, + &opening.round_challenges, + &state.irs_commitment, + )?; + + // (5) Discharge sumcheck masks (bind δ) and cs_mask (subtract X_cs). + if let (Some(mo), Some(scm), Some(cmc), Some(d)) = ( + round.mask_oracle.as_ref(), + sumcheck_commitment, + cs_mask_commitment, + delta, + ) { + // Sumcheck-masks tree: per-mask covectors [1, r_i, …, r_i^{vec_size−1}]. + let sumcheck_vec_size = mo.sumcheck_masks.c_zk_commit.vector_size; + let sm_covectors: Vec> = opening + .round_challenges + .iter() + .map(|&r| geometric_sequence(r, sumcheck_vec_size)) + .collect(); + let sm_refs: Vec<&[F]> = sm_covectors.iter().map(|v| v.as_slice()).collect(); + let sm_x_values = mo + .sumcheck_masks + .verify(vs, &scm, Some(&sm_refs))? + .expect("sumcheck-mask values always returned when covectors passed"); + let sumcheck_x_sum: F = sm_x_values.iter().copied().sum(); + verify!(sumcheck_x_sum == d); + + // cs_mask tree: single covector at the post-cs covector mask region. + let cs_cov: [&[F]; 1] = [&state.covector[msg_len..]]; + let cs_x_values = mo + .cs_mask + .verify(vs, &cmc, Some(&cs_cov))? + .expect("cs_mask value always returned when covector passed"); + state.sum -= cs_x_values[0]; + } + + state.covector.truncate(msg_len); + state.irs_commitment = target_commitment; + + Ok(state) +} + +#[cfg(test)] +mod tests { + use ark_std::rand::{rngs::StdRng, SeedableRng}; + + use super::*; + use crate::{ + algebra::{ + fields::Field64, + linear_form::{Evaluate, MultilinearExtension}, + random_vector, + }, + hash, + protocols::params::{ + spec::{ + DecodingRegime, FoldingFactor, Mode, PowBudget, RateSchedule, SecuritySpec, + TuningSpec, + }, + test_utils::TestEmbedding, + }, + transcript::{codecs::Empty, DomainSeparator, ProverState}, + }; + + type F = Field64; + + const TEST_TARGET_BITS: u32 = 40; + + fn test_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: TEST_TARGET_BITS, + pow_budget: PowBudget::per_slot(10), + hash_id: hash::BLAKE3, + } + } + + fn tuning_with_rounds() -> TuningSpec { + TuningSpec { + vector_size: 1 << 8, + starting_log_inv_rate: 1, + folding_factor: FoldingFactor::Constant(2), + rate_schedule: RateSchedule::Stepping, + } + } + + fn tuning_basecase_only() -> TuningSpec { + TuningSpec { + vector_size: 1 << 3, + starting_log_inv_rate: 1, + folding_factor: FoldingFactor::Constant(2), + rate_schedule: RateSchedule::Stepping, + } + } + + fn roundtrip(config: &ProtocolConfig, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let mu = config.tuning.vector_size.trailing_zeros() as usize; + let form = MultilinearExtension { + point: random_vector::(&mut rng, mu), + }; + let embedding = as Default>::default(); + let value = form.evaluate(&embedding, &witness); + + let ds = DomainSeparator::protocol(&"zook-verifier-test") + .session(&format!("roundtrip {}:{}", file!(), line!())) + .instance(&Empty); + + let mut prover_state = ProverState::new_std(&ds); + let committed = config.commit(&mut prover_state, &witness); + config.prove(&mut prover_state, committed, &[&form], &[value]); + let proof = prover_state.proof(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut verifier_state).unwrap(); + config + .verify(&mut verifier_state, commitment, &[&form], &[value]) + .unwrap(); + verifier_state.check_eof().unwrap(); + } + + fn smoke(mode: Mode, tuning: TuningSpec, seed: u64, expect_rounds: bool) { + let config = ProtocolConfig::::derive(test_spec(mode), tuning).unwrap(); + assert_eq!(!config.rounds.is_empty(), expect_rounds); + roundtrip(&config, seed); + } + + #[test] + fn verify_completes_with_rounds_zk() { + smoke(Mode::ZeroKnowledge, tuning_with_rounds(), 0, true); + } + + #[test] + fn verify_completes_with_rounds_standard() { + smoke(Mode::Standard, tuning_with_rounds(), 1, true); + } + + #[test] + fn verify_completes_basecase_only_zk() { + smoke(Mode::ZeroKnowledge, tuning_basecase_only(), 2, false); + } + + #[test] + fn verify_completes_basecase_only_standard() { + smoke(Mode::Standard, tuning_basecase_only(), 3, false); + } +} From 3b48fe896c4275ac92c5c76a7d421ca69e477986 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 30 May 2026 09:40:23 +0530 Subject: [PATCH 2/6] feat : updated structure for derive --- benches/zook_vs_whir.rs | 5 ++++- src/protocols/params/derive.rs | 3 +-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/benches/zook_vs_whir.rs b/benches/zook_vs_whir.rs index a44e307e..6a14190e 100644 --- a/benches/zook_vs_whir.rs +++ b/benches/zook_vs_whir.rs @@ -279,7 +279,10 @@ fn record_cell( .expect("write CSV row"); } -#[allow(clippy::too_many_lines, reason = "bench harness — sweep + table format is naturally long")] +#[allow( + clippy::too_many_lines, + reason = "bench harness — sweep + table format is naturally long" +)] fn main() { let mut csv = File::create("bench_results.csv").expect("create bench_results.csv"); writeln!( diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 144f9088..0ab58e63 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -152,8 +152,7 @@ mod tests { fn derive_standard_with_no_rounds_uses_basecase_only() { let spec = test_spec(Mode::Standard); let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; - let plan = - ProtocolConfig::::derive(spec, tuning_with(vector_size)).unwrap(); + let plan = ProtocolConfig::::derive(spec, tuning_with(vector_size)).unwrap(); assert!(plan.rounds.is_empty()); assert_eq!(plan.basecase.commit.vector_size, vector_size); } From 1304c2439347e7a5d0eb967787c8b4c318c77788 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 30 May 2026 11:18:43 +0530 Subject: [PATCH 3/6] refactor : invariants --- src/protocols/params/build_round.rs | 1 - src/protocols/params/derive.rs | 70 +++++++------------------ src/protocols/params/protocol_config.rs | 67 ++++++----------------- 3 files changed, 37 insertions(+), 101 deletions(-) diff --git a/src/protocols/params/build_round.rs b/src/protocols/params/build_round.rs index 6fe8b3ba..e2e632fe 100644 --- a/src/protocols/params/build_round.rs +++ b/src/protocols/params/build_round.rs @@ -66,7 +66,6 @@ pub(super) fn build_round_config( let solve_mode = SolveMode::ZeroKnowledge(info); let round_mode = RoundMode::ZeroKnowledge { t_ood: OodSampleBudget::new(t_ood), - mask_oracle: info, }; ( OodSampleBudget::new(t_ood), diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 0ab58e63..9800ff8d 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -261,7 +261,7 @@ mod tests { ) .unwrap(); for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { panic!("expected ZK round") }; assert!(t_ood.get() >= 1); @@ -280,7 +280,7 @@ mod tests { ) .unwrap(); for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { panic!("expected ZK round") }; assert_eq!(t_ood.get(), 1); @@ -300,7 +300,7 @@ mod tests { .unwrap(); for r in &plan.rounds { let mask_oracle = r.mask_oracle().expect("ZK round has a mask oracle"); - let k = r.code_switch().source.interleaving_depth.trailing_zeros() as usize; + let k = r.code_switch.source.interleaving_depth.trailing_zeros() as usize; // Two-tree split: sumcheck_masks tree has 2·k columns, cs_mask tree has 2. assert_eq!(mask_oracle.sumcheck_masks.c_zk_commit.num_vectors, 2 * k); assert_eq!(mask_oracle.cs_mask.c_zk_commit.num_vectors, 2); @@ -326,35 +326,6 @@ mod tests { assert_close(bits, expected); } - #[test] - fn analytic_bits_includes_mask_oracle_in_zk() { - let spec = test_spec(Mode::ZeroKnowledge); - let plan = ProtocolConfig::::derive( - spec, - tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ) - .unwrap(); - let plan_bits: f64 = plan.analytic_bits().into(); - let mo_floor = plan - .rounds - .iter() - .filter_map(|r| r.mask_oracle().map(|mo| f64::from(mo.analytic_bits()))) - .fold(f64::INFINITY, f64::min); - assert!( - mo_floor.is_finite(), - "ZK plan must contribute mask-oracle bits" - ); - let min_round = plan - .rounds - .iter() - .map(|r| f64::from(r.analytic_bits())) - .fold(f64::INFINITY, f64::min); - let expected = mo_floor - .min(min_round) - .min(f64::from(plan.basecase.analytic_bits())); - assert_close(plan_bits, expected); - } - #[test] fn derive_plans_basecase() { let spec = test_spec(Mode::ZeroKnowledge); @@ -384,7 +355,7 @@ mod tests { let field_bits = ::field_size_bits(); let mut expected_total = 0.0_f64; for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { panic!("expected ZK round"); }; let t = t_ood.get() as f64; @@ -460,9 +431,9 @@ mod tests { .unwrap(); assert!(!plan.rounds.is_empty(), "need a round to corrupt"); let recorded = plan - .rounds() + .rounds .first() - .and_then(|r| r.sumcheck().recorded_analytic) + .and_then(|r| r.sumcheck.recorded_analytic) .expect("params solver records sumcheck analytic"); // Bump the recorded value far from the recompute → triggers drift. plan.corrupt_round_sumcheck_recorded_analytic_for_test( @@ -614,7 +585,7 @@ mod tests { assert!(n >= 2, "need ≥ 2 rounds to break a mid-chain link"); assert!(plan.check_all_invariants(), "fresh plan must validate"); - let bad_size = plan.rounds[0].code_switch().target.vector_size + 1; + let bad_size = plan.rounds[0].code_switch.target.vector_size + 1; plan.corrupt_round_target_vector_size_for_test(0, bad_size); let err = plan @@ -742,10 +713,9 @@ mod tests { .unwrap(); assert!(!plan.rounds.is_empty(), "expected multi-round plan"); for r in &plan.rounds { - let cs = r.code_switch(); - assert!(cs.source.unique_decoding()); - assert!(cs.target.unique_decoding()); - assert!(cs.out_domain_samples >= 1); + assert!(r.code_switch.source.unique_decoding()); + assert!(r.code_switch.target.unique_decoding()); + assert!(r.code_switch.out_domain_samples >= 1); } assert!(plan.basecase.commit.unique_decoding()); } @@ -765,8 +735,8 @@ mod tests { for r in &plan.rounds { let mo = r.mask_oracle().expect("ZK round must own a mask oracle"); assert!(mo.cs_mask.c_zk_commit.unique_decoding()); - assert!(r.code_switch().source.unique_decoding()); - assert!(r.code_switch().out_domain_samples >= 1); + assert!(r.code_switch.source.unique_decoding()); + assert!(r.code_switch.out_domain_samples >= 1); } assert!(plan.basecase.commit.unique_decoding()); } @@ -784,7 +754,7 @@ mod tests { .unwrap(); assert!(!plan.rounds.is_empty(), "expected multi-round plan"); for r in &plan.rounds { - assert!(r.code_switch().out_domain_samples >= 1); + assert!(r.code_switch.out_domain_samples >= 1); } } @@ -802,7 +772,7 @@ mod tests { assert!(!plan.rounds.is_empty(), "expected multi-round plan"); for r in &plan.rounds { r.mask_oracle().expect("ZK round must own a mask oracle"); - assert!(r.code_switch().out_domain_samples >= 1); + assert!(r.code_switch.out_domain_samples >= 1); } } @@ -811,12 +781,12 @@ mod tests { plan: &ProtocolConfig, ) { for r in &plan.rounds { - let mask_info = r.mode().mask_oracle(); - let cs = r.code_switch(); + let mask_info = r.mask_oracle_info(); + let cs = &r.code_switch; assert_pow_closes_gap( spec, sumcheck_params::analytic_error_bits(&cs.source, mask_info), - &r.sumcheck().round_pow, + &r.sumcheck.round_pow, ); assert_pow_closes_gap( spec, @@ -875,7 +845,7 @@ mod tests { let spec = test_spec(Mode::Standard); let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); for r in &plan.rounds { - prop_assert!(matches!(r.mode(), RoundMode::Standard)); + prop_assert!(matches!(r.mode, RoundMode::Standard)); prop_assert!(r.mask_oracle().is_none()); } prop_assert!(matches!(plan.basecase.mode, BasecaseMode::Standard)); @@ -894,10 +864,10 @@ mod tests { let mask_oracle = r .mask_oracle() .expect("ZK round must have a mask oracle"); - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { panic!("expected ZK round"); }; - let cs = r.code_switch(); + let cs = &r.code_switch; let k = cs.source.interleaving_depth.trailing_zeros() as usize; prop_assert_eq!(mask_oracle.sumcheck_masks.c_zk_commit.num_vectors, 2 * k); prop_assert_eq!(mask_oracle.cs_mask.c_zk_commit.num_vectors, 2); diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index a258f9b3..4e18a83c 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -39,22 +39,6 @@ pub struct ProtocolConfig { } impl ProtocolConfig { - pub const fn security(&self) -> &SecuritySpec { - &self.security - } - - pub const fn tuning(&self) -> &TuningSpec { - &self.tuning - } - - pub fn rounds(&self) -> &[RoundConfig] { - &self.rounds - } - - pub const fn basecase(&self) -> &BasecaseConfig { - &self.basecase - } - /// `true` if every PoW slot's difficulty fits within `security.pow_budget`. pub fn check_pow_bits(&self) -> bool { self.validate_pow_budget().is_ok() @@ -114,7 +98,7 @@ impl ProtocolConfig { Ok(()) }; for r in &self.rounds { - let mask_info = r.mode.mask_oracle(); + let mask_info = r.mask_oracle_info(); check( Pow::RoundSumcheck { index: r.round_index, @@ -280,9 +264,6 @@ impl ProtocolConfig { let mut min_bits = f64::from(self.basecase.analytic_bits()); for round in &self.rounds { min_bits = min_bits.min(f64::from(round.analytic_bits())); - if let Some(mo) = &round.mask_oracle { - min_bits = min_bits.min(f64::from(mo.analytic_bits())); - } } Bits::new(min_bits.max(0.0)) } @@ -334,32 +315,20 @@ pub struct RoundConfig { } impl RoundConfig { - pub const fn round_index(&self) -> usize { - self.round_index - } - - pub const fn sumcheck(&self) -> &SumcheckConfig { - &self.sumcheck - } - - pub const fn code_switch(&self) -> &CodeSwitchConfig { - &self.code_switch - } - - pub const fn mode(&self) -> &RoundMode { - &self.mode - } - pub const fn mask_oracle(&self) -> Option<&MaskOracleConfig> { self.mask_oracle.as_ref() } + pub fn mask_oracle_info(&self) -> Option { + self.mask_oracle.as_ref().map(MaskOracleConfig::info) + } + /// Round-level analytic floor: the smallest of `sumcheck`, `code_switch`, /// and (when present) the per-round mask-oracle proximity check. pub fn analytic_bits(&self) -> Bits { let source = &self.code_switch.source; let target = &self.code_switch.target; - let mask_info = self.mode.mask_oracle(); + let mask_info = self.mask_oracle_info(); let sumcheck_term = f64::from(sumcheck_params::analytic_error_bits(source, mask_info)); let code_switch_term = f64::from(code_switch_params::analytic_error_bits( @@ -368,8 +337,17 @@ impl RoundConfig { self.code_switch.out_domain_samples, mask_info, )); - - Bits::new(sumcheck_term.min(code_switch_term).max(0.0)) + let mask_oracle_term = self + .mask_oracle + .as_ref() + .map_or(f64::INFINITY, |mo| f64::from(mo.analytic_bits())); + + Bits::new( + sumcheck_term + .min(code_switch_term) + .min(mask_oracle_term) + .max(0.0), + ) } } @@ -379,10 +357,6 @@ pub enum RoundMode { ZeroKnowledge { /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). t_ood: OodSampleBudget, - /// Slim view of this round's [`MaskOracleConfig`] (C_zk's list size + - /// ℓ_zk) — denormalized so soundness routines can read it without - /// chasing through `mask_oracle`. - mask_oracle: MaskOracleInfo, }, } @@ -390,13 +364,6 @@ impl RoundMode { pub const fn is_zk(&self) -> bool { matches!(self, Self::ZeroKnowledge { .. }) } - - pub const fn mask_oracle(&self) -> Option { - match self { - Self::Standard => None, - Self::ZeroKnowledge { mask_oracle, .. } => Some(*mask_oracle), - } - } } /// One round's mask oracle, split across two independent C_zk trees: From ed095c0a43f6f8e1b16bc971da1faad64950ec6d Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Mon, 1 Jun 2026 13:29:26 +0530 Subject: [PATCH 4/6] feat : added wiring and refactor for safety in zook --- benches/zook_vs_whir.rs | 4 +- src/protocols/code_switch.rs | 127 ++++-- src/protocols/mask_proximity.rs | 24 +- src/protocols/params/build_round.rs | 37 +- src/protocols/params/code_switch.rs | 28 +- src/protocols/params/derive.rs | 21 +- src/protocols/params/irs_commit.rs | 17 +- src/protocols/params/mod.rs | 4 +- src/protocols/params/protocol_config.rs | 28 +- src/protocols/params/sumcheck.rs | 8 +- src/protocols/params/test_utils.rs | 10 +- src/protocols/zook/mod.rs | 56 ++- src/protocols/zook/prover.rs | 525 +++++++++++++++--------- src/protocols/zook/verifier.rs | 508 +++++++++++++++++------ 14 files changed, 958 insertions(+), 439 deletions(-) diff --git a/benches/zook_vs_whir.rs b/benches/zook_vs_whir.rs index 6a14190e..d750c3fa 100644 --- a/benches/zook_vs_whir.rs +++ b/benches/zook_vs_whir.rs @@ -54,7 +54,7 @@ const LOG_SIZES: &[u32] = &[ /// Number of measured iterations per (size, protocol). User-chosen tight /// budget — see module note about first-iteration cold-cache noise. -const ITERATIONS: usize = 3; +const ITERATIONS: usize = 25; /// Shared protocol params — held constant across sizes so we measure scaling /// in vector size only. @@ -228,7 +228,7 @@ fn bench_zook( let commitment = config .receive_commitment(&mut vs) .map_err(|e| format!("{label} receive_commitment: {e:?}"))?; - config + let _ = config .verify(&mut vs, commitment, &[form_ref], &[evaluation]) .map_err(|e| format!("{label} verify: {e:?}"))?; samples.verify.push(t.elapsed()); diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index a0d9908d..08e78ca8 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -72,6 +72,24 @@ pub struct Claim<'a, F: Field> { /// Verifier output from the code-switch. pub type Commitment = IrsCommitment; +/// Code-switch verification output for implicit-covector callers. +/// +/// Returned by [`Config::verify_for_implicit`]. The caller accumulates +/// constraint terms from these instead of updating an explicit `Vec`. +#[must_use] +pub struct CovectorUpdateParams { + /// The `original_sl_coeff` that would have scaled the covector. + pub original_sl_coeff: F, + /// RLC coefficients for each OOD constraint. + pub ood_rlc_coeffs: Vec, + /// OOD evaluation points (alpha_i). + pub ood_eval_points: Vec, + /// RLC coefficients for each in-domain constraint. + pub in_domain_rlc_coeffs: Vec, + /// In-domain evaluation points (omega_j, already lifted via embedding). + pub in_domain_eval_points: Vec, +} + impl Config { /// Create a code-switch config. pub fn new( @@ -115,15 +133,15 @@ impl Config { message_mask_length, } = &mode { - let l_zk = message_mask_length.get(); + let mask_oracle_len = message_mask_length.get(); // Theorem 9.6: ℓ_zk ≥ r (mask oracle must cover source randomness). assert!( - l_zk >= source_config.mask_length(), - "message_mask_length ({l_zk}) must be >= source randomness length ({})", + mask_oracle_len >= source_config.mask_length(), + "message_mask_length ({mask_oracle_len}) must be >= source randomness length ({})", source_config.mask_length(), ); assert!( - l_zk - source_config.mask_length() >= out_domain_samples, + mask_oracle_len - source_config.mask_length() >= out_domain_samples, "sampled randomness (s) length must cover all out-of-domain sample requests" ); // t' = target in-domain queries + OOD queries (Construction 9.7 step 4). @@ -372,15 +390,16 @@ impl Config { /// per-round mask tree containing `s` and is responsible for /// running `mask_proximity::verify` on that same tree before /// accepting the round. - #[cfg_attr(feature = "tracing", instrument(skip_all))] - pub fn verify( + /// Shared transcript work: receive target commitment, verify source opening, + /// sample batching coefficients, update `sum`. Returns the target commitment + /// and the parameters needed to update a covector (explicitly or implicitly). + fn verify_inner( &self, verifier_state: &mut VerifierState, sum: &mut M::Target, - covector: &mut [M::Target], folding_randomness: &[M::Target], commitment: &IrsCommitment, - ) -> VerificationResult + ) -> VerificationResult<(Commitment, CovectorUpdateParams)> where H: DuplexSpongeInterface, Standard: Distribution, @@ -390,27 +409,16 @@ impl Config { U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - verify!(1 << folding_randomness.len() == self.source.interleaving_depth); - assert_eq!(covector.len(), self.covector_length()); - let collapse_weights = eq_weights(folding_randomness); - // Step 1: target commitment — Construction 9.7 Step 1, p.55 - // Mask oracle is committed in the shared mask tree by the orchestrator. let target_commitment = self.target.receive_commitment(verifier_state)?; - - // Grind Lemma 9.9 OOD gap before α is sampled. self.pow.verify(verifier_state)?; - // Step 2-3: OOD — Construction 9.7 Steps 2-3, p.55 - // In ZK mode, ood_answers = f(α) + α^ℓ · (r,s)(α) where (r,s) is - // the mask oracle message committed in the shared tree. - let ood_points: Vec = + let ood_eval_points: Vec = verifier_state.verifier_message_vec(self.out_domain_samples); let ood_answers: Vec = verifier_state.prover_messages_vec(self.out_domain_samples)?; - // Step 4: source opening — Construction 9.7 Step 4, p.55 let source_evaluations = self.source.verify(verifier_state, &[commitment])?; let collapsed_values: Vec = source_evaluations .matrix @@ -418,7 +426,6 @@ impl Config { .map(|row| mixed_dot(self.source.embedding(), &collapse_weights, row)) .collect(); - // Step 4.1: batching + μ' — Construction 9.7 Decision phase, p.55 let num_ood = self.out_domain_samples; let num_in_domain = source_evaluations.points.len(); let coeffs = geometric_challenge(verifier_state, 1 + num_ood + num_in_domain); @@ -429,20 +436,80 @@ impl Config { + dot(ood_rlc_coeffs, &ood_answers) + dot(in_domain_rlc_coeffs, &collapsed_values); - // Mirror prover's covector update so ` = sum` holds - // after this step (Completeness proof, p.55-56). - let eval_points = lift(self.source.embedding(), &source_evaluations.points); - scalar_mul(covector, original_sl_coeff); + let in_domain_eval_points = lift(self.source.embedding(), &source_evaluations.points); + + Ok(( + target_commitment, + CovectorUpdateParams { + original_sl_coeff, + ood_rlc_coeffs: ood_rlc_coeffs.to_vec(), + ood_eval_points, + in_domain_rlc_coeffs: in_domain_rlc_coeffs.to_vec(), + in_domain_eval_points, + }, + )) + } + + #[cfg_attr(feature = "tracing", instrument(skip_all))] + pub fn verify( + &self, + verifier_state: &mut VerifierState, + sum: &mut M::Target, + covector: &mut [M::Target], + folding_randomness: &[M::Target], + commitment: &IrsCommitment, + ) -> VerificationResult + where + H: DuplexSpongeInterface, + Standard: Distribution, + M::Target: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + verify!(1 << folding_randomness.len() == self.source.interleaving_depth); + assert_eq!(covector.len(), self.covector_length()); + + let (target_commitment, params) = + self.verify_inner(verifier_state, sum, folding_randomness, commitment)?; + + scalar_mul(covector, params.original_sl_coeff); self.update_covector( covector, - ood_rlc_coeffs, - &ood_points, - in_domain_rlc_coeffs, - &eval_points, + ¶ms.ood_rlc_coeffs, + ¶ms.ood_eval_points, + ¶ms.in_domain_rlc_coeffs, + ¶ms.in_domain_eval_points, ); Ok(target_commitment) } + + /// Like [`verify`] but does NOT update an explicit covector. Instead it + /// returns [`CovectorUpdateParams`] so the caller can accumulate constraint + /// terms implicitly. Also does NOT assert `covector.len() == covector_length()` + /// since there is no covector. The `sum` is still updated as normal. + #[cfg_attr(feature = "tracing", instrument(skip_all))] + pub fn verify_for_implicit( + &self, + verifier_state: &mut VerifierState, + sum: &mut M::Target, + folding_randomness: &[M::Target], + commitment: &IrsCommitment, + ) -> VerificationResult<(Commitment, CovectorUpdateParams)> + where + H: DuplexSpongeInterface, + Standard: Distribution, + M::Target: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + verify!(1 << folding_randomness.len() == self.source.interleaving_depth); + self.verify_inner(verifier_state, sum, folding_randomness, commitment) + } } impl fmt::Display for Config { @@ -711,7 +778,7 @@ mod tests { // h is the post-fold polynomial whose inner product with covector // should equal the verifier sum: // - non-ZK: h = folded_message (length message_length) - // - ZK: h = [folded_message; mask_msg] (length message_length + l_zk) + // - ZK: h = [folded_message; mask_msg] (length message_length + mask_oracle_len) let h: Vec = if mask_msg.is_empty() { folded_message.clone() } else { diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 7c09222d..b10e9e18 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -40,7 +40,13 @@ //! is that the simulator can simulate all 2n values at each position //! independently (each pair (ξ_i, s_i) is simulatable from C_zk's ZK //! property, and the pairs are independent across i). Formal derivation -//! for the shared-tree case is pending. +//! for the shared-tree case: Bound 6 (§5.6) of the parameter +//! derivation document establishes SD ≤ ζ_C + 2n·ζ_{C_zk} for the shared tree +//! (vs Lemma 7.3's ζ_C + n·ζ_{C_zk} for separate oracles) by P8 (independent +//! sampling preserved by the Merkle authentication layer). Under perfect-ZK +//! encoder conditions (t_zk ≤ r_zk), both terms are 0 and the shared-tree +//! model is perfectly ZK. The 2n multiplier vs paper's n is the only permanent +//! artifact and is accounted for in the per-round mask-proximity error budget. //! //! Soundness: if ξ_i is far from C_zk, the spot-check fails with high //! probability over γ (Lemma 7.4, p.45). @@ -260,13 +266,25 @@ impl Config { } } - // fresh_msgs is consumed; release ~num_masks · vector_size field - // elements before the (potentially slow) tree-open below. + // Zeroize secret mask-of-masks polynomials before freeing. + for msg in &mut witness.fresh_msgs { + for elem in msg.iter_mut() { + elem.zeroize(); + } + } witness.fresh_msgs = Vec::new(); // Step 3: open the shared tree at random in-domain positions self.c_zk_commit .open(prover_state, &[&witness.mask_witness]); + + // Zeroize IRS mask secret material after the tree open. + for elem in &mut witness.mask_witness.masks { + elem.zeroize(); + } + for elem in &mut witness.mask_witness.matrix { + elem.zeroize(); + } } /// Verify that each original mask is close to a C_zk codeword. When diff --git a/src/protocols/params/build_round.rs b/src/protocols/params/build_round.rs index e2e632fe..6e10ad97 100644 --- a/src/protocols/params/build_round.rs +++ b/src/protocols/params/build_round.rs @@ -22,7 +22,7 @@ use crate::{ irs_commit as irs_params, layout::{round_context, target_context, RoundShape}, mask_proximity as mask_proximity_params, - protocol_config::{MaskOracleConfig, RoundConfig, RoundMode}, + protocol_config::{MaskOracleConfig, RoundConfig}, spec::{ DecodingRegime, LogInvRate, MaskCodeMessageLen, OodSampleBudget, RoundContext, SecuritySpec, ZkSpec, @@ -43,13 +43,8 @@ pub(super) fn build_round_config( let ood_mode = mode.map(|p| f64::from(p.c_zk_log_inv_rate.get())); let (source, t_ood) = solve_round_source::(spec, shape, ood_mode)?; - let (target_budget, solve_mode, round_mode, mask_oracle) = match mode { - Branch::Standard => ( - OodSampleBudget::ZERO, - SolveMode::Standard, - RoundMode::Standard, - None, - ), + let (target_budget, solve_mode, mask_oracle) = match mode { + Branch::Standard => (OodSampleBudget::ZERO, SolveMode::Standard, None), Branch::ZeroKnowledge(RoundBuildPayload { zk_spec, c_zk_log_inv_rate, @@ -64,15 +59,7 @@ pub(super) fn build_round_config( )?; let info = mask_oracle.info(); let solve_mode = SolveMode::ZeroKnowledge(info); - let round_mode = RoundMode::ZeroKnowledge { - t_ood: OodSampleBudget::new(t_ood), - }; - ( - OodSampleBudget::new(t_ood), - solve_mode, - round_mode, - Some(mask_oracle), - ) + (OodSampleBudget::new(t_ood), solve_mode, Some(mask_oracle)) } }; @@ -94,7 +81,6 @@ pub(super) fn build_round_config( round_index: shape.round_index, sumcheck, code_switch, - mode: round_mode, mask_oracle, }) } @@ -145,7 +131,7 @@ pub(super) fn build_mask_oracle( ) -> Result, DeriveError> { let spec = zk_spec.as_inner(); let k = sumcheck_params::masks_required(ctx); - let l_zk = compute_l_zk(source, t_ood); + let mask_oracle_len = compute_l_zk(source, t_ood); // Sumcheck-masks tree: tiny vector size (next_pow2(3) = 4), no padding to ℓ_zk. let sumcheck_mask_vec_size = @@ -162,14 +148,14 @@ pub(super) fn build_mask_oracle( let cs_masks = code_switch_params::masks_required(); let cs_c_zk: IrsConfig> = irs_params::solve_mask_code( zk_spec, - l_zk, + mask_oracle_len, source.mask_length(), c_zk_log_inv_rate, MaskProximityConfig::::num_vectors_for(cs_masks), ); let c_zk_list_size_estimate = spec.decoding_regime.list_size_estimate( - (l_zk.get() as f64).log2(), + (mask_oracle_len.get() as f64).log2(), f64::from(c_zk_log_inv_rate.get()), ); debug_assert!( @@ -185,7 +171,8 @@ pub(super) fn build_mask_oracle( Ok(MaskOracleConfig { sumcheck_masks, cs_mask, - l_zk, + mask_oracle_len, + t_ood: OodSampleBudget::new(t_ood), }) } @@ -250,15 +237,15 @@ fn ood_security_bits_at( target_list_size.log2(), ), Branch::ZeroKnowledge(c_zk_log_inv_rate) => { - let l_zk = source + let mask_oracle_len = source .mask_length() .saturating_add(t_ood) .next_power_of_two(); let c_zk_list = spec .decoding_regime - .list_size_estimate(usize_to_f64(l_zk).log2(), c_zk_log_inv_rate); + .list_size_estimate(usize_to_f64(mask_oracle_len).log2(), c_zk_log_inv_rate); ( - usize_to_f64(source.message_length().saturating_add(l_zk)).log2(), + usize_to_f64(source.message_length().saturating_add(mask_oracle_len)).log2(), (target_list_size * c_zk_list).log2(), ) } diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index b18f3cca..0c9e4f08 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -33,17 +33,17 @@ pub fn solve( let (mask_oracle, output_mode) = match mode { SolveMode::Standard => (None, code_switch::CodeSwitchMode::Standard), SolveMode::ZeroKnowledge(mask_oracle) => { - let l_zk = mask_oracle.l_zk.get(); + let mask_oracle_len = mask_oracle.mask_oracle_len.get(); assert!( - l_zk >= source.mask_length().saturating_add(t_ood), - "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Theorem 9.6 witness sizing", + mask_oracle_len >= source.mask_length().saturating_add(t_ood), + "ℓ_zk ({mask_oracle_len}) < r + t_ood ({} + {}) — violates Theorem 9.6 witness sizing", source.mask_length(), t_ood, ); ( Some(mask_oracle), code_switch::CodeSwitchMode::ZeroKnowledge { - message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), + message_mask_length: NonZeroUsize::new(mask_oracle_len).expect("ℓ_zk > 0"), }, ) } @@ -75,7 +75,11 @@ pub fn analytic_error_bits( // `ℓ` (Standard). let degree = mask_oracle.map_or_else( || source.message_length(), - |info| source.message_length().saturating_add(info.l_zk.get()), + |info| { + source + .message_length() + .saturating_add(info.mask_oracle_len.get()) + }, ); let t_ood_f = usize_to_f64(t_ood); @@ -187,7 +191,7 @@ mod tests { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(C_ZK_LIST_SIZE), - l_zk: MaskCodeMessageLen::new(L_ZK_USIZE), + mask_oracle_len: MaskCodeMessageLen::new(L_ZK_USIZE), }; let (source, target, t_ood) = build_round_io::( &spec, @@ -280,18 +284,18 @@ mod tests { &spec, log_inv_rate, folding_factor, num_vars, Some(log_inv_rate), ); let r = source.mask_length(); - let l_zk = compute_l_zk(&source, t_ood); + let mask_oracle_len = compute_l_zk(&source, t_ood); let zk_spec = ZkSpec::try_new(&spec).expect("arb_zk_spec"); let c_zk = irs_params::solve_mask_code::( zk_spec, - l_zk, + mask_oracle_len, r, LogInvRate::new(log_inv_rate), 2, ); let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(c_zk.list_size()), - l_zk, + mask_oracle_len, }; let config = solve( &spec, @@ -353,7 +357,7 @@ mod tests { let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), - l_zk: MaskCodeMessageLen::new(TOO_SMALL_L_ZK), + mask_oracle_len: MaskCodeMessageLen::new(TOO_SMALL_L_ZK), }; let _ = solve( &spec, @@ -419,7 +423,9 @@ mod tests { let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), - l_zk: MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()), + mask_oracle_len: MaskCodeMessageLen::new( + (source.mask_length() + t_ood).next_power_of_two(), + ), }; let config = solve( &spec, diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 9800ff8d..4d6e668d 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -82,7 +82,7 @@ mod tests { basecase as basecase_params, code_switch as code_switch_params, error::{ChainSource, ChainTarget, DeriveError, Pow}, mask_proximity as mask_proximity_params, - protocol_config::{ProtocolConfig, RoundMode}, + protocol_config::ProtocolConfig, spec::{ DecodingRegime, FoldingFactor, KneeWeight, Mode, PowBudget, RateSchedule, SecuritySpec, TuningSpec, @@ -261,9 +261,7 @@ mod tests { ) .unwrap(); for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { - panic!("expected ZK round") - }; + let t_ood = r.mask_oracle().expect("expected ZK round").t_ood; assert!(t_ood.get() >= 1); } } @@ -280,9 +278,7 @@ mod tests { ) .unwrap(); for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { - panic!("expected ZK round") - }; + let t_ood = r.mask_oracle().expect("expected ZK round").t_ood; assert_eq!(t_ood.get(), 1); } } @@ -355,9 +351,7 @@ mod tests { let field_bits = ::field_size_bits(); let mut expected_total = 0.0_f64; for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { - panic!("expected ZK round"); - }; + let t_ood = r.mask_oracle().expect("expected ZK round").t_ood; let t = t_ood.get() as f64; expected_total += 2_f64.powf(f64::midpoint(t * t, t).log2() - field_bits); } @@ -845,7 +839,6 @@ mod tests { let spec = test_spec(Mode::Standard); let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); for r in &plan.rounds { - prop_assert!(matches!(r.mode, RoundMode::Standard)); prop_assert!(r.mask_oracle().is_none()); } prop_assert!(matches!(plan.basecase.mode, BasecaseMode::Standard)); @@ -864,15 +857,13 @@ mod tests { let mask_oracle = r .mask_oracle() .expect("ZK round must have a mask oracle"); - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { - panic!("expected ZK round"); - }; + let t_ood = mask_oracle.t_ood; let cs = &r.code_switch; let k = cs.source.interleaving_depth.trailing_zeros() as usize; prop_assert_eq!(mask_oracle.sumcheck_masks.c_zk_commit.num_vectors, 2 * k); prop_assert_eq!(mask_oracle.cs_mask.c_zk_commit.num_vectors, 2); let source_mask = cs.source.mask_length(); - prop_assert!(mask_oracle.l_zk.get() >= source_mask + t_ood.get()); + prop_assert!(mask_oracle.mask_oracle_len.get() >= source_mask + t_ood.get()); } prop_assert!(matches!( plan.basecase.mode, diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index c43caba9..e680cb88 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -66,22 +66,25 @@ pub fn solve( /// Shared C_zk IRS config for mask polynomials. /// -/// - `l_zk`: message length, must be a power of 2. +/// - `mask_oracle_len`: message length, must be a power of 2. /// - `source_mask_length`: `r` from Theorem 9.6. /// - `num_vectors`: `2 * num_masks` (Construction 7.2: originals + fresh). pub fn solve_mask_code( spec: ZkSpec<'_>, - l_zk: MaskCodeMessageLen, + mask_oracle_len: MaskCodeMessageLen, source_mask_length: usize, log_inv_rate: LogInvRate, num_vectors: usize, ) -> IrsConfig { - let l_zk = l_zk.get(); + let mask_oracle_len = mask_oracle_len.get(); assert!( - l_zk >= source_mask_length, - "Theorem 9.6: ℓ_zk ({l_zk}) ≥ source mask length ({source_mask_length})", + mask_oracle_len >= source_mask_length, + "Theorem 9.6: ℓ_zk ({mask_oracle_len}) ≥ source mask length ({source_mask_length})", + ); + assert!( + mask_oracle_len.is_power_of_two(), + "ℓ_zk ({mask_oracle_len}) must be a power of 2" ); - assert!(l_zk.is_power_of_two(), "ℓ_zk ({l_zk}) must be a power of 2"); assert!( num_vectors.is_multiple_of(2), "num_vectors ({num_vectors}) must be even (mask-proximity original/fresh pairs)", @@ -95,7 +98,7 @@ pub fn solve_mask_code( spec.decoding_regime, spec.hash_id, num_vectors, - l_zk, + mask_oracle_len, 1, rate, IrsMode::Standard, diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index d4d62fc1..a9f95207 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -25,9 +25,7 @@ pub(crate) mod test_utils; pub use branch::{Branch, SolveMode}; pub use error::{ChainSource, ChainTarget, DeriveError, Pow}; -pub use protocol_config::{ - MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, -}; +pub use protocol_config::{MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig}; pub use spec::{ DecodingRegime, FoldingFactor, KneeWeight, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, RateSchedule, RoundContext, SecuritySpec, TuningSpec, ZkSpec, diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 4e18a83c..19c0bef8 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -245,8 +245,8 @@ impl ProtocolConfig { let field_bits = ::field_size_bits(); let mut total_error = 0.0_f64; for r in &self.rounds { - if let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode { - let t = usize_to_f64(t_ood.get()); + if let Some(mo) = &r.mask_oracle { + let t = usize_to_f64(mo.t_ood.get()); // ζ_ze ≤ (t_ood² + t_ood) / (2|F|). Compute in log space to // stay numerically stable for large field_bits. let log_err = f64::midpoint(t * t, t).log2() - field_bits; @@ -308,7 +308,6 @@ pub struct RoundConfig { pub round_index: usize, pub sumcheck: SumcheckConfig, pub code_switch: CodeSwitchConfig, - pub mode: RoundMode, /// `Some` iff this is a ZK round. Sized for this round's `k + 1` masks /// (k sumcheck + 1 code-switch). pub mask_oracle: Option>, @@ -351,21 +350,6 @@ impl RoundConfig { } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum RoundMode { - Standard, - ZeroKnowledge { - /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). - t_ood: OodSampleBudget, - }, -} - -impl RoundMode { - pub const fn is_zk(&self) -> bool { - matches!(self, Self::ZeroKnowledge { .. }) - } -} - /// One round's mask oracle, split across two independent C_zk trees: /// - `sumcheck_masks`: the `k` sumcheck masks (Lemma 6.4), each of length /// `next_pow_2(mask_length)`. Committed BEFORE sumcheck. @@ -386,7 +370,9 @@ pub struct MaskOracleConfig { /// size is ℓ_zk. pub cs_mask: MaskProximityConfig, /// `next_pow2(r + t_ood)` for this round (Lemma 9.3). - pub l_zk: MaskCodeMessageLen, + pub mask_oracle_len: MaskCodeMessageLen, + /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). + pub t_ood: OodSampleBudget, } impl MaskOracleConfig { @@ -395,7 +381,7 @@ impl MaskOracleConfig { // shared C_zk list size. MaskOracleInfo { c_zk_list_size: ListSize::new(self.cs_mask.c_zk_commit.list_size()), - l_zk: self.l_zk, + mask_oracle_len: self.mask_oracle_len, } } @@ -416,5 +402,5 @@ impl MaskOracleConfig { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct MaskOracleInfo { pub c_zk_list_size: ListSize, - pub l_zk: MaskCodeMessageLen, + pub mask_oracle_len: MaskCodeMessageLen, } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index ce795a6a..e6d4a4e0 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -59,7 +59,7 @@ pub fn analytic_error_bits( let poly_id = mask_oracle.map_or(field_bits - log_list_size - 1.0, |info| { let log_list_size_c_zk = info.c_zk_list_size.get().log2(); - let log_l_zk = usize_to_f64(info.l_zk.get()).log2(); + let log_l_zk = usize_to_f64(info.mask_oracle_len.get()).log2(); field_bits - log_list_size - log_list_size_c_zk - log_l_zk }); @@ -170,7 +170,7 @@ mod tests { let irs = build_source_irs(&spec, &ctx); let info = MaskOracleInfo { c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), - l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), + mask_oracle_len: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; let got = f64::from(analytic_error_bits::(&irs, Some(info))); @@ -195,7 +195,7 @@ mod tests { let irs = build_source_irs(&spec, &ctx); let huge = MaskOracleInfo { c_zk_list_size: ListSize::new(2_f64.powi(OVERSIZED_LOG_C_ZK_LIST)), - l_zk: MaskCodeMessageLen::new(1 << OVERSIZED_LOG_L_ZK), + mask_oracle_len: MaskCodeMessageLen::new(1 << OVERSIZED_LOG_L_ZK), }; let bits = f64::from(analytic_error_bits::(&irs, Some(huge))); assert_close(bits, 0.0); @@ -267,7 +267,7 @@ mod tests { irs_params::solve(&spec, &ctx, OodSampleBudget::ZERO); let info = MaskOracleInfo { c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), - l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), + mask_oracle_len: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; let config = solve( &spec, diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index ae3a0604..9c628215 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -94,12 +94,12 @@ pub fn arb_round_ctx() -> impl Strategy { /// `None` in Standard; `Some(ℓ_zk=2, c_zk rate 1/2)` in ZK. pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option { let zk_spec = ZkSpec::try_new(spec)?; - let l_zk = MaskCodeMessageLen::new(2); + let mask_oracle_len = MaskCodeMessageLen::new(2); let c_zk: IrsConfig = - irs_params::solve_mask_code(zk_spec, l_zk, 0, LogInvRate::new(1), 2); + irs_params::solve_mask_code(zk_spec, mask_oracle_len, 0, LogInvRate::new(1), 2); Some(MaskOracleInfo { c_zk_list_size: ListSize::new(c_zk.list_size()), - l_zk, + mask_oracle_len, }) } @@ -125,14 +125,14 @@ pub fn assert_close(got: f64, expected: f64) { /// C_zk fixture for `mask_proximity` tests. pub fn build_test_c_zk( spec: &SecuritySpec, - l_zk: usize, + mask_oracle_len: usize, log_inv_rate: u32, num_masks: usize, ) -> IrsConfig { let zk_spec = ZkSpec::try_new(spec).expect("build_test_c_zk requires a ZK spec"); irs_params::solve_mask_code( zk_spec, - MaskCodeMessageLen::new(l_zk), + MaskCodeMessageLen::new(mask_oracle_len), 0, LogInvRate::new(log_inv_rate), MaskProximityConfig::::num_vectors_for(num_masks), diff --git a/src/protocols/zook/mod.rs b/src/protocols/zook/mod.rs index cecc100b..9c1bb1c2 100644 --- a/src/protocols/zook/mod.rs +++ b/src/protocols/zook/mod.rs @@ -7,6 +7,57 @@ pub mod verifier; pub use commit::{Commitment, CommittedWitness}; pub use crate::protocols::params::protocol_config::ProtocolConfig; +use crate::{algebra::linear_form::LinearForm, transcript::VerificationResult, verify}; + +/// Output of [`ProtocolConfig::verify`]. +/// +/// The verifier has completed all round checks. The caller must finish +/// verification by checking that the input forms evaluate to the claimed value: +/// +/// ```text +/// initial_claim_scale × Σ_j rlc_coefficients[j] × form_j.mle_evaluate(evaluation_point) == linear_forms_contribution +/// ``` +#[must_use] +#[derive(Clone, Debug)] +pub struct FinalClaim { + /// All sumcheck challenges from all rounds concatenated with basecase + /// evaluation points. Length = log2(vector_size). + pub evaluation_point: Vec, + /// Cumulative product of all code-switch `original_sl_coeff` values. + pub initial_claim_scale: F, + /// `opening.linear_form_evaluation − Σ_c constraint_contributions`. + /// The portion of the protocol sum attributable to the input linear forms. + pub linear_forms_contribution: F, + /// Fiat-Shamir RLC coefficients for each input form. + /// `rlc_coefficients[0] == F::ONE` always. + pub rlc_coefficients: Vec, +} + +impl FinalClaim { + /// Complete the verification started by [`ProtocolConfig::verify`]. + /// + /// Checks `initial_claim_scale × Σ_j rlc_coefficients[j] × form_j.mle_evaluate(evaluation_point) == linear_forms_contribution`. + /// + /// For `MultilinearExtension` forms this runs in O(num_forms × log N). + /// For `Covector` forms the default `mle_evaluate` is O(N). + pub fn verify(&self, linear_forms: &[&dyn LinearForm]) -> VerificationResult<()> + where + F: Default, + { + assert_eq!( + linear_forms.len(), + self.rlc_coefficients.len(), + "linear_forms.len() must match rlc_coefficients.len()" + ); + let form_mle_sum: F = linear_forms + .iter() + .zip(&self.rlc_coefficients) + .map(|(form, &g)| g * form.mle_evaluate(&self.evaluation_point)) + .sum(); + verify!(self.initial_claim_scale * form_mle_sum == self.linear_forms_contribution); + Ok(()) + } +} #[cfg(test)] mod tests { @@ -100,9 +151,12 @@ mod tests { let commitment = config .receive_commitment(&mut verifier_state) .expect("receive_commitment"); - config + let claim = config .verify(&mut verifier_state, commitment, &form_refs, &values) .expect("verify"); + // Final claim check for 3 MultilinearExtension forms: + // initial_claim_scale × Σ_j γ^j × form_j.mle_evaluate(z_full) == linear_forms_contribution + claim.verify(&form_refs).expect("FinalClaim::verify failed"); verifier_state .check_eof() .expect("transcript fully consumed"); diff --git a/src/protocols/zook/prover.rs b/src/protocols/zook/prover.rs index bc6a6eb5..a0eea32f 100644 --- a/src/protocols/zook/prover.rs +++ b/src/protocols/zook/prover.rs @@ -21,12 +21,13 @@ //! //! Per-round flow: //! 1. Sample sumcheck masks `M_0..M_{k−1}` (length `mask_length` each) and -//! `s_fresh` of length `ℓ_zk − source.mask_length()`. +//! `cs_fresh_padding` of length `mask_oracle_len − source.mask_length()`. //! 2. Commit `sumcheck_masks` tree (k masks padded to next_pow_2(mask_length)). //! 3. Run ZK sumcheck. Post-sumcheck `state.sum = δ + γ_sumcheck · dot`, //! where `δ = Σ M_i(r_i)`. -//! 4. Derive `r_folded = fold(lift(source.masks), r_0..r_{k−1})`, assemble -//! `cs_mask = (r_folded ‖ s_fresh)`, commit `cs_mask` tree. +//! 4. Derive `folded_irs_masks = fold(source.masks, r_0..r_{k−1})` (Identity: +//! no lift needed), assemble `cs_mask = (folded_irs_masks ‖ cs_fresh_padding)`, +//! commit `cs_mask` tree. //! 5. Send `δ` cleartext and reconcile //! `state.sum := (sum − δ) · γ_sumcheck⁻¹` so the claim entering //! code-switch is the unmasked `dot(folded_a, post_sumcheck_cov)`. @@ -47,7 +48,7 @@ use zeroize::Zeroize; use crate::{ algebra::{ - dot, embedding::Identity, geometric_sequence, lift, linear_form::LinearForm, random_vector, + dot, embedding::Identity, geometric_sequence, linear_form::LinearForm, random_vector, univariate_evaluate, }, hash::Hash, @@ -55,6 +56,7 @@ use crate::{ code_switch::{self, fold_chunks}, irs_commit::Witness as IrsWitness, mask_proximity, + mask_proximity::Config as MaskProximityConfig, params::protocol_config::{MaskOracleConfig, ProtocolConfig, RoundConfig}, sumcheck::{SumcheckMode, SumcheckOpening}, zook::commit::{CommittedState, CommittedWitness}, @@ -66,15 +68,15 @@ use crate::{ }; impl ProtocolConfig> { - /// Prove `f(witness) == values[j]` for every form `f = forms[j]` against + /// Prove `f(witness) == evaluations[j]` for every linear_form `f = linear_forms[j]` against /// the committed witness. Consumes `committed`. - #[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::prove", fields(vector_size = self.tuning.vector_size, num_rounds = self.rounds.len(), num_claims = forms.len())))] + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::prove", fields(vector_size = self.tuning.vector_size, num_rounds = self.rounds.len(), num_claims = linear_forms.len())))] pub fn prove( &self, ps: &mut ProverState, committed: CommittedWitness>, - forms: &[&dyn LinearForm], - values: &[F], + linear_forms: &[&dyn LinearForm], + evaluations: &[F], ) where Standard: Distribution, H: DuplexSpongeInterface, @@ -85,19 +87,30 @@ impl ProtocolConfig> { U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - assert_eq!(forms.len(), values.len(), "forms.len() != values.len()"); - assert!(!forms.is_empty(), "zook requires ≥ 1 (form, value) pair"); + assert_eq!( + linear_forms.len(), + evaluations.len(), + "linear_forms.len() != evaluations.len()" + ); + assert!( + !linear_forms.is_empty(), + "zook requires ≥ 1 (form, value) pair" + ); // RLC challenge binds the form/value set to the commitment. - let gamma: F = ps.verifier_message(); - let gamma_powers = geometric_sequence(gamma, forms.len()); + let batching_challenge: F = ps.verifier_message(); + let claim_weights = geometric_sequence(batching_challenge, linear_forms.len()); // Materialize the combined covector = Σ γ^j · form_j and combined value. let mut covector = vec![F::ZERO; self.tuning.vector_size]; - for (form, &g) in forms.iter().zip(&gamma_powers) { - form.accumulate(&mut covector, g); + for (form, &weight) in linear_forms.iter().zip(&claim_weights) { + form.accumulate(&mut covector, weight); } - let combined_value: F = values.iter().zip(&gamma_powers).map(|(v, g)| *v * g).sum(); + let batched_evaluation: F = evaluations + .iter() + .zip(&claim_weights) + .map(|(v, weight)| *v * weight) + .sum(); // Reduce to basecase inputs `(message, witness, covector, sum)`. The // two arms differ only in how those are obtained. @@ -106,7 +119,7 @@ impl ProtocolConfig> { CommittedState::Basecase { message, irs_witness, - } => (message, irs_witness, covector, combined_value), + } => (message, irs_witness, covector, batched_evaluation), CommittedState::Round { message, irs_witness, @@ -115,7 +128,7 @@ impl ProtocolConfig> { message, irs_witness, covector, - sum: combined_value, + sum: batched_evaluation, }; for round in &self.rounds { state = prove_round(round, state, ps); @@ -127,6 +140,9 @@ impl ProtocolConfig> { } }; + // Standard mode (BasecaseMode::Standard) sends the full witness vector + // and IRS randomness cleartext. Only call with Mode::ZeroKnowledge if + // end-to-end hiding is required. let _ = self .basecase .prove(ps, message, &basecase_witness, covector, sum); @@ -160,37 +176,38 @@ where { let msg_len = round.code_switch.source.message_length(); - // (1+2) Sample sumcheck masks (and pre-sample cs_mask's s_fresh) and - // pre-commit the sumcheck-masks tree. - let mut sumcheck_masks_state = round - .mask_oracle - .as_ref() - .map(|mo| SumcheckMasks::sample_and_commit(mo, round, ps)); - let sumcheck_masks_flat: &[F] = sumcheck_masks_state - .as_ref() - .map_or(&[][..], SumcheckMasks::flat); - - // (3) ZK sumcheck. Post-sumcheck `state.sum = δ + γ_sumcheck · dot`. + debug_assert_eq!( + dot(&state.message, &state.covector), + state.sum, + "prove_round entry: dot(message, covector) must equal sum" + ); + + // Samples and commits the sumcheck-masks tree (ZK) or is a no-op (Standard). + // cs_fresh_padding is pre-sampled here because it does not depend on folding randomness. + let mut masker = RoundMaskOracle::begin(round, ps); + let opening = round.sumcheck.prove( ps, &mut state.message, &mut state.covector, &mut state.sum, - sumcheck_masks_flat, + masker.sumcheck_blinding(), ); - // (4+5) Build cs_mask = (r_folded ‖ s_fresh), commit the cs_mask tree, - // send δ and reconcile state.sum to the unmasked dot. - let cs_mask_state = sumcheck_masks_state.as_mut().map(|sm| { - CsMask::commit_and_reconcile(sm, round, &state.irs_witness, &opening, &mut state.sum, ps) - }); - - // (6) Extend covector and run code-switch. Hand the source IRS witness over - // by value so code_switch can drop the source matrix immediately after - // `source.open` — frees ~268 MB in round 1 before cs_mask::prove_at. - if let Some(cm) = cs_mask_state.as_ref() { - state.covector.resize(msg_len + cm.l_zk, F::ZERO); - } + // Build cs_mask = (folded_irs_masks ‖ cs_fresh_padding), commit its tree, + // send mask_eval_sum cleartext, reconcile sum to the unmasked dot. + masker.bind_code_switch_mask(&state.irs_witness, &opening, &mut state.sum, ps); + + debug_assert_eq!( + dot(&state.message, &state.covector), + state.sum, + "post-reconcile: dot(message, covector) must equal sum" + ); + + // Extend covector for ZK mask region; +0 in Standard mode. + state + .covector + .resize(msg_len + masker.covector_extension(), F::ZERO); let cs_witness = round.code_switch.prove( ps, state.message, @@ -200,54 +217,46 @@ where sum: &mut state.sum, }, &opening.round_challenges, - cs_mask_state.as_ref().map_or(&[][..], CsMask::coefficients), + masker.code_switch_blinding(), ); - // (7) Prove sumcheck masks at the round challenges, prove cs_mask at the - // post-cs covector, then subtract X_cs. - if let (Some(sm), Some(cm)) = (sumcheck_masks_state, cs_mask_state) { - sm.prove_at(&opening.round_challenges, ps); - let x_cs = cm.prove_at(&state.covector[msg_len..], ps); - state.sum -= x_cs; - } - // opening's round_challenges/mask_rlc are fully consumed; release before - // the next round's allocations begin. + // Prove both mask trees; subtract cs_mask contribution to project sum to f-only. + masker.finish( + &opening.round_challenges, + &state.covector[msg_len..], + &mut state.sum, + ps, + ); drop(opening); state.message = cs_witness.message; state.irs_witness = cs_witness.target_witness; state.covector.truncate(state.message.len()); + debug_assert_eq!( + dot(&state.message, &state.covector), + state.sum, + "prove_round exit: dot(message, covector) must equal sum" + ); + state } -/// Sumcheck-masks tree: k masks committed BEFORE sumcheck so each round-poly -/// is bound to its mask. Vector size is `next_pow_2(sumcheck.mask_length)` — -/// no padding to ℓ_zk. -/// -/// Also carries the `s_fresh` padding for cs_mask — sampled pre-sumcheck -/// (doesn't depend on folding randomness) and consumed by [`CsMask`]. -struct SumcheckMasks<'a, F: Field> { - mask_oracle: &'a MaskOracleConfig, - /// k padded sumcheck masks, each of length `vec_size`. `originals[i][0..mask_len]` - /// is the raw mask polynomial; the rest is zero padding for NTT. - originals: Vec>, - /// k * mask_len concatenated raw mask coefficients (sumcheck.prove input). - flat: Vec, - witness: mask_proximity::Witness, - /// Fresh padding for cs_mask. Length `ℓ_zk − source.mask_length()`. Sampled - /// here so the RNG draws happen during pre-sumcheck setup; consumed by - /// `CsMask::commit_and_reconcile`. - s_fresh: Vec, - /// Vector size of the sumcheck-masks tree (= `next_pow_2(mask_len)`). - vec_size: usize, - k: usize, +/// The committed mask tree for the sumcheck sub-protocol. +/// Holds k blinding polynomials sampled before sumcheck and opened after code-switch. +struct SumcheckMaskTree<'a, F: Field> { + cfg: &'a MaskProximityConfig, + padded_masks: Vec>, + flat_coefficients: Vec, + tree_witness: mask_proximity::Witness, + padded_vec_size: usize, } -impl<'a, F: Field + Zeroize> SumcheckMasks<'a, F> { +impl<'a, F: Field + Zeroize> SumcheckMaskTree<'a, F> { fn sample_and_commit( - mask_oracle: &'a MaskOracleConfig, - round: &RoundConfig>, + cfg: &'a MaskProximityConfig, + num_masks: usize, + mask_poly_len: usize, ps: &mut ProverState, ) -> Self where @@ -257,59 +266,50 @@ impl<'a, F: Field + Zeroize> SumcheckMasks<'a, F> { Standard: Distribution, Hash: ProverMessage<[H::U]>, { - let k = round.sumcheck.num_rounds; - let mask_len = match round.sumcheck.mode { - SumcheckMode::Standard => 0, - SumcheckMode::ZeroKnowledge { mask_length } => mask_length.get(), - }; - let vec_size = mask_oracle.sumcheck_masks.c_zk_commit.vector_size; - debug_assert!(vec_size >= mask_len); - - let l_zk = mask_oracle.l_zk.get(); - let source_mask_len = round.code_switch.source.mask_length(); - assert!(source_mask_len <= l_zk); - - let flat: Vec = random_vector(ps.rng(), k * mask_len); - let originals: Vec> = (0..k) + let padded_vec_size = cfg.c_zk_commit.vector_size; + let flat_coefficients: Vec = random_vector(ps.rng(), num_masks * mask_poly_len); + let padded_masks: Vec> = (0..num_masks) .map(|i| { - let mut padded = vec![F::ZERO; vec_size]; - padded[..mask_len].copy_from_slice(&flat[i * mask_len..(i + 1) * mask_len]); + let mut padded = vec![F::ZERO; padded_vec_size]; + padded[..mask_poly_len].copy_from_slice( + &flat_coefficients[i * mask_poly_len..(i + 1) * mask_poly_len], + ); padded }) .collect(); - // s_fresh doesn't depend on folding randomness — sample it here. - let s_fresh: Vec = random_vector(ps.rng(), l_zk - source_mask_len); - let originals_refs: Vec<&[F]> = originals.iter().map(Vec::as_slice).collect(); - let witness = mask_oracle.sumcheck_masks.commit(ps, &originals_refs); + let padded_mask_refs: Vec<&[F]> = padded_masks.iter().map(Vec::as_slice).collect(); + let tree_witness = cfg.commit(ps, &padded_mask_refs); Self { - mask_oracle, - originals, - flat, - witness, - s_fresh, - vec_size, - k, + cfg, + padded_masks, + flat_coefficients, + tree_witness, + padded_vec_size, } } - fn flat(&self) -> &[F] { - &self.flat + /// Input to sumcheck.prove(): the flat blinding coefficients. + fn blinding(&self) -> &[F] { + &self.flat_coefficients + } + + /// Zeroize the flat blinding coefficients once sumcheck no longer needs them. + fn wipe_blinding(&mut self) { + self.flat_coefficients.zeroize(); } - /// δ = Σ_i M_i(r_i). The orchestrator sends δ cleartext + reconciles - /// `state.sum` before code-switch. - fn delta(&self, round_challenges: &[F]) -> F { - round_challenges + /// δ = Σ_i evaluate(padded_masks[i], challenges[i]) + fn eval_sum(&self, challenges: &[F]) -> F { + challenges .iter() - .zip(self.originals.iter()) - .map(|(&r, mask_padded)| univariate_evaluate(mask_padded, r)) + .zip(self.padded_masks.iter()) + .map(|(&c, mask)| univariate_evaluate(mask, c)) .sum() } - /// Prove each sumcheck mask at `cov_i = [1, r_i, …, r_i^{vec_size−1}]`, - /// emitting `X_i = M_i(r_i)` to the verifier. Soundness: the verifier - /// checks `Σ X_i == δ` against the cleartext δ sent earlier. - fn prove_at(mut self, round_challenges: &[F], ps: &mut ProverState) + /// Open the sumcheck-masks tree at the per-mask geometric covectors. + /// Zeroizes padded_masks before returning. + fn prove(mut self, round_challenges: &[F], ps: &mut ProverState) where F: Codec<[H::U]>, H: DuplexSpongeInterface, @@ -320,107 +320,247 @@ impl<'a, F: Field + Zeroize> SumcheckMasks<'a, F> { U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - debug_assert_eq!(round_challenges.len(), self.k); - let covectors: Vec> = round_challenges + let evaluation_covectors: Vec> = round_challenges .iter() - .map(|&r| geometric_sequence(r, self.vec_size)) + .map(|&c| geometric_sequence(c, self.padded_vec_size)) .collect(); - let covector_refs: Vec<&[F]> = covectors.iter().map(|v| v.as_slice()).collect(); - let originals_refs: Vec<&[F]> = self.originals.iter().map(Vec::as_slice).collect(); - self.mask_oracle.sumcheck_masks.prove( + let covector_refs: Vec<&[F]> = evaluation_covectors.iter().map(Vec::as_slice).collect(); + let padded_mask_refs: Vec<&[F]> = self.padded_masks.iter().map(Vec::as_slice).collect(); + self.cfg.prove( ps, - self.witness, - &originals_refs, + self.tree_witness, + &padded_mask_refs, Some(&covector_refs), ); - // Wipe secret mask material before this struct drops, so it doesn't - // linger in the freed heap region. - drop(originals_refs); - self.flat.zeroize(); - for o in &mut self.originals { - o.zeroize(); + for mask in &mut self.padded_masks { + mask.zeroize(); } } } -/// cs_mask tree: 1 mask of length ℓ_zk, committed AFTER sumcheck so the `r` -/// part can carry `fold(source.masks, folding_randomness)`. -struct CsMask<'a, F: Field> { - mask_oracle: &'a MaskOracleConfig, - /// The mask polynomial coefficients, structured as `(r_folded ‖ s_fresh)`. - coefficients: Vec, - witness: mask_proximity::Witness, - l_zk: usize, +/// The committed mask tree for the code-switch sub-protocol. +/// Holds the single (r_folded ‖ fresh_padding) polynomial, built after sumcheck. +struct CodeSwitchMask<'a, F: Field> { + cfg: &'a MaskProximityConfig, + poly: Vec, + tree_witness: mask_proximity::Witness, +} + +impl<'a, F: Field + Zeroize> CodeSwitchMask<'a, F> { + /// Assemble poly = (r_folded ‖ fresh_padding) and commit the cs-mask tree. + fn build_and_commit( + r_folded: &[F], + mut fresh_padding: Vec, + cfg: &'a MaskProximityConfig, + ps: &mut ProverState, + ) -> Self + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Standard: Distribution, + Hash: ProverMessage<[H::U]>, + { + let mut poly = Vec::with_capacity(r_folded.len() + fresh_padding.len()); + poly.extend_from_slice(r_folded); + poly.append(&mut fresh_padding); + let tree_witness = cfg.commit(ps, &[&poly[..]]); + Self { + cfg, + poly, + tree_witness, + } + } + + /// Code-switch blinding input: the full polynomial coefficients. + fn blinding(&self) -> &[F] { + &self.poly + } + + /// Length of the mask oracle polynomial. + const fn len(&self) -> usize { + self.poly.len() + } + + /// Open the cs-mask tree at `covector_region`; returns X_cs = ⟨poly, covector_region⟩. + /// Zeroizes poly before returning. + fn prove(mut self, covector_region: &[F], ps: &mut ProverState) -> F + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Standard: Distribution, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let x_cs = dot(&self.poly, covector_region); + self.cfg.prove( + ps, + self.tree_witness, + &[&self.poly[..]], + Some(&[covector_region]), + ); + self.poly.zeroize(); + x_cs + } +} + +/// Manages all ZK mask oracle state for one round. +/// +/// Transitions: Disabled (Standard) or BeforeCodeSwitch → AfterCodeSwitch. +/// The Disabled variant is the Null Object: all methods on it are no-ops or +/// return empty slices, so prove_round has no ZK-specific branches. +enum RoundMaskOracle<'a, F: Field> { + /// Standard mode or basecase-only round: no mask oracle. + Disabled, + /// ZK round — sumcheck-masks tree committed, cs_mask not yet built. + BeforeCodeSwitch { + mask_oracle: &'a MaskOracleConfig, + sc_tree: SumcheckMaskTree<'a, F>, + cs_fresh_padding: Vec, + }, + /// ZK round — cs_mask tree committed, ready to discharge. + AfterCodeSwitch { + sc_tree: SumcheckMaskTree<'a, F>, + cs_mask: CodeSwitchMask<'a, F>, + }, } -impl<'a, F: Field + Zeroize> CsMask<'a, F> { - /// Build `cs_mask = (r_folded ‖ s_fresh)` (s_fresh was pre-sampled in - /// [`SumcheckMasks::sample_and_commit`] and is consumed via `mem::take` - /// here), commit the cs_mask tree, send δ cleartext, and reconcile - /// `*sum := (sum − δ) · γ_sumcheck⁻¹`. - fn commit_and_reconcile( - sumcheck_masks: &mut SumcheckMasks<'a, F>, - round: &RoundConfig>, +impl<'a, F: Field + Default + Zeroize> RoundMaskOracle<'a, F> { + /// Construct the oracle for this round: Disabled if no mask oracle, otherwise + /// sample and commit the sumcheck-masks tree (BeforeCodeSwitch). + fn begin(round: &'a RoundConfig>, ps: &mut ProverState) -> Self + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Standard: Distribution, + Hash: ProverMessage<[H::U]>, + { + let Some(mo) = round.mask_oracle.as_ref() else { + return Self::Disabled; + }; + let num_masks = round.sumcheck.num_rounds; + let mask_poly_len = match round.sumcheck.mode { + SumcheckMode::ZeroKnowledge { mask_length } => mask_length.get(), + SumcheckMode::Standard => 0, + }; + let sc_tree = + SumcheckMaskTree::sample_and_commit(&mo.sumcheck_masks, num_masks, mask_poly_len, ps); + // cs_fresh_padding is the `s` in cs_mask = (r_folded ‖ s). Sampled here because + // it does not depend on folding randomness, preserving the Fiat–Shamir RNG ordering. + let cs_fresh_padding_len = + mo.mask_oracle_len.get() - round.code_switch.source.mask_length(); + let cs_fresh_padding: Vec = random_vector(ps.rng(), cs_fresh_padding_len); + Self::BeforeCodeSwitch { + mask_oracle: mo, + sc_tree, + cs_fresh_padding, + } + } + + /// Flat sumcheck blinding coefficients. Returns &[] for Disabled. + fn sumcheck_blinding(&self) -> &[F] { + match self { + Self::BeforeCodeSwitch { sc_tree, .. } => sc_tree.blinding(), + _ => &[], + } + } + + /// Build and commit the cs_mask tree, send δ cleartext, reconcile *sum. + /// Transitions BeforeCodeSwitch → AfterCodeSwitch in place. + /// No-op for Disabled. + fn bind_code_switch_mask( + &mut self, irs_witness: &IrsWitness, opening: &SumcheckOpening, sum: &mut F, ps: &mut ProverState, - ) -> Self - where + ) where F: Codec<[H::U]>, H: DuplexSpongeInterface, R: RngCore + CryptoRng, Standard: Distribution, Hash: ProverMessage<[H::U]>, { - let mask_oracle = sumcheck_masks.mask_oracle; - let l_zk = mask_oracle.l_zk.get(); - let source_mask_len = round.code_switch.source.mask_length(); - debug_assert!(source_mask_len <= l_zk); - debug_assert_eq!(sumcheck_masks.s_fresh.len(), l_zk - source_mask_len); - - // r_folded = collapse of source IRS randomness by sumcheck challenges. - // Masks live in whir's canonical per-poly contiguous layout. - let raw = lift(round.code_switch.source.embedding(), &irs_witness.masks); - let r_folded = fold_chunks(&raw, source_mask_len, &opening.round_challenges); - debug_assert_eq!(r_folded.len(), source_mask_len); - - // Pre-size to l_zk so the `append` below doesn't reallocate. - let mut cs_mask = Vec::with_capacity(l_zk); - cs_mask.extend_from_slice(&r_folded); - // s_fresh is moved out — only consumed here, so no clone needed. - cs_mask.append(&mut sumcheck_masks.s_fresh); - debug_assert_eq!(cs_mask.len(), l_zk); - - let witness = mask_oracle.cs_mask.commit(ps, &[&cs_mask[..]]); - - // δ = Σ M_i(r_i) from the sumcheck-masks tree. Sent cleartext; bound - // by the verifier's `Σ X_{i (mask_oracle, sc_tree, cs_fresh_padding), + other => { + *self = other; + return; + } + }; + + // Blinding coefficients are no longer needed after sumcheck. + sc_tree.wipe_blinding(); + + // cs_mask = (r_folded ‖ s): r folds the source IRS randomness by the sumcheck challenges. + let source_mask_len = mask_oracle.mask_oracle_len.get() - cs_fresh_padding.len(); + let r_folded = fold_chunks( + &irs_witness.masks, + source_mask_len, + &opening.round_challenges, + ); + + // Build and commit the cs-mask tree (transcript: commitment hash). + let cs_mask = + CodeSwitchMask::build_and_commit(&r_folded, cs_fresh_padding, &mask_oracle.cs_mask, ps); + + // Send δ = Σ Mᵢ(rᵢ) cleartext AFTER committing the cs-mask tree (Constr. 9.7 order). + let delta = sc_tree.eval_sum(&opening.round_challenges); ps.prover_message(&delta); - let gamma_inv = opening + // Reconcile sum to the unmasked dot product. + // mask_rlc is Fiat–Shamir; zero has negligible probability for large fields. + let mask_rlc_inv = opening .mask_rlc .inverse() - .expect("sumcheck mask_rlc must be non-zero"); - *sum = (*sum - delta) * gamma_inv; + .expect("mask_rlc non-zero (negligible probability for large fields)"); + *sum = (*sum - delta) * mask_rlc_inv; - Self { - mask_oracle, - coefficients: cs_mask, - witness, - l_zk, + *self = Self::AfterCodeSwitch { sc_tree, cs_mask }; + } + + /// Number of elements to extend the covector by for the ZK region. + /// Returns 0 for Disabled. + const fn covector_extension(&self) -> usize { + match self { + Self::AfterCodeSwitch { cs_mask, .. } => cs_mask.len(), + _ => 0, } } - fn coefficients(&self) -> &[F] { - &self.coefficients + /// Code-switch mask polynomial coefficients. Returns &[] for Disabled. + fn code_switch_blinding(&self) -> &[F] { + match self { + Self::AfterCodeSwitch { cs_mask, .. } => cs_mask.blinding(), + Self::Disabled => &[], + Self::BeforeCodeSwitch { .. } => { + debug_assert!( + false, + "code_switch_blinding called before bind_code_switch_mask" + ); + &[] + } + } } - /// Prove cs_mask at the post-cs covector mask region. Returns - /// `X_cs = ` for the f-only projection. - fn prove_at(mut self, cs_mask_covector: &[F], ps: &mut ProverState) -> F - where + /// Prove both mask trees and subtract the cs_mask contribution from *sum. + /// No-op for Disabled. + fn finish( + self, + round_challenges: &[F], + cs_mask_covector: &[F], + sum: &mut F, + ps: &mut ProverState, + ) where F: Codec<[H::U]>, H: DuplexSpongeInterface, R: RngCore + CryptoRng, @@ -430,17 +570,10 @@ impl<'a, F: Field + Zeroize> CsMask<'a, F> { U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - let covector_refs: [&[F]; 1] = [cs_mask_covector]; - let x_cs = dot(&self.coefficients, cs_mask_covector); - self.mask_oracle.cs_mask.prove( - ps, - self.witness, - &[&self.coefficients[..]], - Some(&covector_refs), - ); - // cs_mask is uniformly random secret randomness; wipe before drop. - self.coefficients.zeroize(); - x_cs + if let Self::AfterCodeSwitch { sc_tree, cs_mask } = self { + sc_tree.prove(round_challenges, ps); + *sum -= cs_mask.prove(cs_mask_covector, ps); + } } } @@ -550,7 +683,7 @@ mod tests { } #[test] - #[should_panic(expected = "forms.len() != values.len()")] + #[should_panic(expected = "linear_forms.len() != evaluations.len()")] fn prove_rejects_count_mismatch() { let config = ProtocolConfig::::derive( test_spec(Mode::Standard), diff --git a/src/protocols/zook/verifier.rs b/src/protocols/zook/verifier.rs index 4fa37326..1b238b0d 100644 --- a/src/protocols/zook/verifier.rs +++ b/src/protocols/zook/verifier.rs @@ -3,21 +3,24 @@ //! Per ZK round, the verifier: //! 1. Receives the **sumcheck-masks tree** commitment (pre-sumcheck — binds //! each round poly's mask via Fiat-Shamir). -//! 2. Runs `sumcheck.verify` (orchestrator folds its public covector). +//! 2. Runs `sumcheck.verify` (orchestrator tracks sumcheck challenges implicitly). //! Post-sumcheck `state.sum = δ + γ_sumcheck · dot`. //! 3. Receives the **cs_mask tree** commitment (post-sumcheck — cs_mask //! carries `r_folded` from source IRS randomness) and reads `δ` //! cleartext, reconciling `state.sum := (sum − δ) · γ_sumcheck⁻¹`. -//! 4. Extends the covector and runs `code_switch.verify`. +//! 4. Runs `code_switch.verify_for_implicit` — accumulates OOD/in-domain +//! constraints as `ImplicitConstraint` entries instead of updating an +//! explicit covector. //! 5. Verifies sumcheck masks at `[1, r_i, …, r_i^{vec_size_A−1}]` (gives //! X_i = M_i(r_i)); verifies cs_mask at the post-cs covector mask //! region (gives X_cs). Checks `Σ X_i == δ` to bind δ; subtracts X_cs //! to project to f-only. //! //! After all rounds: receives the basecase IRS commitment, runs -//! `basecase.verify`, and ties the result back to the public covector via a -//! final MLE check `multilinear_extend(covector, point) == -//! basecase.linear_form_evaluation`. +//! `basecase.verify`, constructs `full_eval_point = all_round_challenges ++ evaluation_points`, +//! and checks that the implicit covector evaluates to `basecase.linear_form_evaluation` +//! in O((num_constraints + log N)) operations — eliminating the O(N) per-round +//! covector update bottleneck. use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution}; @@ -26,14 +29,18 @@ use tracing::instrument; use crate::{ algebra::{ - embedding::Identity, geometric_sequence, linear_form::LinearForm, multilinear_extend, - sumcheck::fold, + embedding::Identity, + geometric_sequence, + linear_form::{LinearForm, UnivariateEvaluation}, }, hash::Hash, protocols::{ + code_switch::CovectorUpdateParams, irs_commit::Commitment as IrsCommitment, - params::protocol_config::{ProtocolConfig, RoundConfig}, - zook::commit::Commitment, + mask_proximity, + params::protocol_config::{MaskOracleConfig, ProtocolConfig, RoundConfig}, + sumcheck::SumcheckOpening, + zook::{commit::Commitment, FinalClaim}, }, transcript::{ codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, VerificationResult, @@ -43,16 +50,16 @@ use crate::{ }; impl ProtocolConfig> { - /// Verify `f(witness) == values[j]` for every form `f = forms[j]` against + /// Verify `f(witness) == evaluations[j]` for every linear_form `f = linear_forms[j]` against /// the received commitment. - #[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::verify", fields(vector_size = self.tuning.vector_size, num_rounds = self.rounds.len(), num_claims = forms.len())))] + #[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::verify", fields(vector_size = self.tuning.vector_size, num_rounds = self.rounds.len(), num_claims = linear_forms.len())))] pub fn verify( &self, vs: &mut VerifierState, commitment: Commitment, - forms: &[&dyn LinearForm], - values: &[F], - ) -> VerificationResult<()> + linear_forms: &[&dyn LinearForm], + evaluations: &[F], + ) -> VerificationResult> where Standard: Distribution, H: DuplexSpongeInterface, @@ -62,55 +69,338 @@ impl ProtocolConfig> { U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - assert_eq!(forms.len(), values.len(), "forms.len() != values.len()"); - assert!(!forms.is_empty(), "zook requires ≥ 1 (form, value) pair"); + assert_eq!( + linear_forms.len(), + evaluations.len(), + "linear_forms.len() != evaluations.len()" + ); + assert!( + !linear_forms.is_empty(), + "zook requires ≥ 1 (form, value) pair" + ); // RLC challenge binds the form/value set to the commitment. - let gamma: F = vs.verifier_message(); - let gamma_powers = geometric_sequence(gamma, forms.len()); - - // Materialize the combined covector = Σ γ^j · form_j and combined value. - let mut covector = vec![F::ZERO; self.tuning.vector_size]; - for (form, &g) in forms.iter().zip(&gamma_powers) { - form.accumulate(&mut covector, g); + let batching_challenge: F = vs.verifier_message(); + let claim_weights = geometric_sequence(batching_challenge, linear_forms.len()); + let batched_evaluation: F = evaluations + .iter() + .zip(&claim_weights) + .map(|(v, weight)| *v * weight) + .sum(); + + // Basecase-only path: no rounds, evaluate directly. + if self.rounds.is_empty() { + let opening = + self.basecase + .verify(vs, &commitment.irs_commitment, batched_evaluation)?; + // No constraint terms, no round scalings: linear_forms_contribution = opening.linear_form_evaluation, + // initial_claim_scale = F::ONE. The caller checks: + // F::ONE × Σ_j claim_weight_j × form_j.mle_at(evaluation_point) == linear_forms_contribution + return Ok(FinalClaim { + evaluation_point: opening.evaluation_points, + initial_claim_scale: F::ONE, + linear_forms_contribution: opening.linear_form_evaluation, + rlc_coefficients: claim_weights, + }); } - let combined_value: F = values.iter().zip(&gamma_powers).map(|(v, g)| *v * g).sum(); - - // Reduce to basecase inputs `(irs_commitment, covector, sum)`. After - // any rounds, the last round's target commitment IS the basecase - // commitment (`derive` makes `basecase.commit == last.code_switch.target`), - // so we reuse `state.irs_commitment` directly — no separate receive. - let (irs_commitment, covector, sum) = if self.rounds.is_empty() { - (commitment.irs_commitment, covector, combined_value) - } else { - let mut state = VerifierRoundState { - irs_commitment: commitment.irs_commitment, - covector, - sum: combined_value, - }; - for round in &self.rounds { - state = verify_round(round, state, vs)?; - } - (state.irs_commitment, state.covector, state.sum) + + // Multi-round path: accumulate constraints implicitly, no initial_covector. + let mut state = VerifierRoundState { + irs_commitment: commitment.irs_commitment, + constraints: Vec::new(), + all_round_challenges: Vec::new(), + challenges_at: vec![0], + round_scale_factors: Vec::new(), + current_msg_len: self.tuning.vector_size, + sum: batched_evaluation, }; + for round in &self.rounds { + state = verify_round(round, state, vs)?; + } - let opening = self.basecase.verify(vs, &irs_commitment, sum)?; + let opening = self.basecase.verify(vs, &state.irs_commitment, state.sum)?; - // Final consistency: the implicit linear form revealed by basecase - // must equal the MLE of our covector at the sumcheck challenge point. - let expected = multilinear_extend(&covector, &opening.evaluation_points); - verify!(expected == opening.linear_form_evaluation); - Ok(()) + // full_eval_point = all round challenges ++ basecase evaluation points. + let full_eval_point: Vec = state + .all_round_challenges + .iter() + .chain(opening.evaluation_points.iter()) + .copied() + .collect(); + debug_assert_eq!( + full_eval_point.len(), + self.tuning.vector_size.trailing_zeros() as usize, + "full_eval_point length must equal log2(vector_size)" + ); + + // Compute scale_suffixes[round_idx] = Π_{r'=round_idx..num_completed_rounds-1} round_scale_factors[r']. + let num_completed_rounds = state.round_scale_factors.len(); + let mut scale_suffixes = vec![F::ONE; num_completed_rounds + 1]; + for round_idx in (0..num_completed_rounds).rev() { + scale_suffixes[round_idx] = + state.round_scale_factors[round_idx] * scale_suffixes[round_idx + 1]; + } + + // Compute constraint contributions (O(num_constraints × log N)). + // constraint_sum = Σ_c c.batching_weight × scale_suffixes[c.round+1] × mle_of_geom(c.eval_point, z_suffix) + let constraint_sum: F = state + .constraints + .iter() + .map(|c| { + let z_suffix_start = state.challenges_at[c.added_at_round + 1]; + let z_suffix = + &full_eval_point[z_suffix_start..z_suffix_start + c.domain_bits as usize]; + c.batching_weight + * scale_suffixes[c.added_at_round + 1] + * UnivariateEvaluation::new(c.eval_point, 1usize << c.domain_bits) + .mle_evaluate(z_suffix) + }) + .sum(); + + // linear_forms_contribution is what scale_suffixes[0] × initial_forms_mle must equal. + // The caller verifies: scale_suffixes[0] × Σ_j γ^j × form_j.mle_at(full_eval_point) == linear_forms_contribution + let linear_forms_contribution = opening.linear_form_evaluation - constraint_sum; + + Ok(FinalClaim { + evaluation_point: full_eval_point, + initial_claim_scale: scale_suffixes[0], + linear_forms_contribution, + rlc_coefficients: claim_weights, + }) } } struct VerifierRoundState { irs_commitment: IrsCommitment, - covector: Vec, + /// Deferred constraint terms from code_switch OOD and in-domain constraints. + constraints: Vec>, + /// All sumcheck round challenges seen so far, in order. + all_round_challenges: Vec, + /// `challenges_at[r]` = number of cumulative challenges before round r. + /// Length = number of rounds processed + 1 (initial entry is 0). + challenges_at: Vec, + /// `original_sl_coeff` from each round's code_switch, in order. + round_scale_factors: Vec, + /// Current message length (size of post-fold covector this round). + current_msg_len: usize, sum: F, } -#[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::verify_round", fields(msg_len = round.code_switch.source.message_length(), covector_len = state.covector.len())))] +struct ImplicitConstraint { + /// OOD alpha or in-domain omega. + eval_point: F, + /// RLC coefficient at the time this constraint was added. + batching_weight: F, + /// log2 of the effective domain size = log2(msg_len when added). + /// The final contribution is batching_weight * mle_evaluate(eval_point, z_suffix) + /// where z_suffix has exactly this many elements. + domain_bits: u32, + /// Index into `round_scale_factors`: which round added this constraint. + /// `scale_suffixes[added_at_round + 1]` is the product of all round_scale_factors + /// from the round after this one to the end. + added_at_round: usize, +} + +/// Manages ZK mask verification state for one round. +/// Mirrors `RoundMaskOracle` in `prover.rs` on the receive side. +/// `Disabled` is the Null Object for Standard mode — all methods return Ok(()) / &[]. +enum RoundMaskOracleCheck<'a, F: Field> { + /// Standard mode: no mask oracle. + Disabled, + /// ZK — sumcheck-masks commitment received; awaiting cs_mask. + SumcheckCommitmentReceived { + mo: &'a MaskOracleConfig, + sc_commitment: mask_proximity::Commitment, + /// source IRS randomness length — needed for zk_tail in verify_and_discharge. + source_mask_len: usize, + }, + /// ZK — both commitments received, sum reconciled; ready to verify and discharge. + ReadyForDischarge { + mo: &'a MaskOracleConfig, + sc_commitment: mask_proximity::Commitment, + cs_commitment: mask_proximity::Commitment, + /// δ = Σ Mᵢ(rᵢ) received from transcript; bound by the sumcheck-mask opening check. + mask_eval_sum: F, + source_mask_len: usize, + }, +} + +impl<'a, F: Field + Default> RoundMaskOracleCheck<'a, F> { + /// Receive the sumcheck-masks commitment (ZK) or construct Disabled (Standard). + fn begin( + round: &'a RoundConfig>, + vs: &mut VerifierState, + ) -> VerificationResult + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + Hash: ProverMessage<[H::U]>, + { + match round.mask_oracle.as_ref() { + None => Ok(Self::Disabled), + Some(mo) => { + let sc_commitment = mo.sumcheck_masks.receive_commitment(vs)?; + let source_mask_len = round.code_switch.source.mask_length(); + Ok(Self::SumcheckCommitmentReceived { + mo, + sc_commitment, + source_mask_len, + }) + } + } + } + + /// Receive the cs_mask commitment + mask_eval_sum (δ) cleartext, then reconcile *sum. + /// Transitions SumcheckCommitmentReceived → ReadyForDischarge in place. + /// No-op for Disabled. + fn receive_cs_mask_and_reconcile( + &mut self, + opening: &SumcheckOpening, + vs: &mut VerifierState, + sum: &mut F, + ) -> VerificationResult<()> + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + Hash: ProverMessage<[H::U]>, + { + let (mo, sc_commitment, source_mask_len) = match std::mem::replace(self, Self::Disabled) { + Self::SumcheckCommitmentReceived { + mo, + sc_commitment, + source_mask_len, + } => (mo, sc_commitment, source_mask_len), + other => { + *self = other; + return Ok(()); + } + }; + + let cs_commitment = mo.cs_mask.receive_commitment(vs)?; + let mask_eval_sum: F = vs.prover_message()?; + + // Reconcile: sum was (mask_eval_sum + mask_rlc · dot), now (sum − δ)/mask_rlc = dot. + // mask_rlc is Fiat–Shamir; zero has negligible probability for large fields. + let mask_rlc_inv = opening + .mask_rlc + .inverse() + .expect("mask_rlc non-zero (negligible probability for large fields)"); + *sum = (*sum - mask_eval_sum) * mask_rlc_inv; + + *self = Self::ReadyForDischarge { + mo, + sc_commitment, + cs_commitment, + mask_eval_sum, + source_mask_len, + }; + Ok(()) + } + + /// Verify both mask trees and subtract the cs_mask contribution from *sum. + /// No-op for Disabled. + /// + /// Soundness: `code_switch.verify_for_implicit` (step 4) checked OOD/in-domain + /// consistency of the target codeword but did NOT verify the masks are close to + /// C_zk. Both checks are load-bearing per Theorem 9.10 / Construction 7.2. + fn verify_and_discharge( + self, + round_challenges: &[F], + msg_len: usize, + update_params: &CovectorUpdateParams, + vs: &mut VerifierState, + sum: &mut F, + ) -> VerificationResult<()> + where + F: Codec<[H::U]>, + H: DuplexSpongeInterface, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let (mo, sc_commitment, cs_commitment, mask_eval_sum, source_mask_len) = match self { + Self::ReadyForDischarge { + mo, + sc_commitment, + cs_commitment, + mask_eval_sum, + source_mask_len, + } => ( + mo, + sc_commitment, + cs_commitment, + mask_eval_sum, + source_mask_len, + ), + Self::Disabled => return Ok(()), + Self::SumcheckCommitmentReceived { .. } => { + debug_assert!( + false, + "verify_and_discharge called before receive_cs_mask_and_reconcile" + ); + return Ok(()); + } + }; + + // --- Sumcheck-masks tree --- + // Verify each mask at its geometric covector [1, rᵢ, rᵢ², …]; check Σ Xᵢ == mask_eval_sum. + let sumcheck_vec_size = mo.sumcheck_masks.c_zk_commit.vector_size; + let sm_covectors: Vec> = round_challenges + .iter() + .map(|&r| geometric_sequence(r, sumcheck_vec_size)) + .collect(); + let sm_refs: Vec<&[F]> = sm_covectors.iter().map(Vec::as_slice).collect(); + let sm_x_values = mo + .sumcheck_masks + .verify(vs, &sc_commitment, Some(&sm_refs))? + .expect("sumcheck-mask values always returned when covectors passed"); + let sumcheck_x_sum: F = sm_x_values.iter().copied().sum(); + verify!(sumcheck_x_sum == mask_eval_sum); + + // --- cs_mask tree --- + // Reconstruct zk_tail: the covector region [msg_len .. msg_len + mask_oracle_len]. + // OOD points contribute over the full mask_oracle_len; in-domain only over source_mask_len. + let mask_oracle_len = mo.mask_oracle_len.get(); + let mut zk_tail = vec![F::ZERO; mask_oracle_len]; + + for (coeff, alpha) in update_params + .ood_rlc_coeffs + .iter() + .zip(&update_params.ood_eval_points) + { + let alpha_msg_pow = alpha.pow([msg_len as u64]); + let mut alpha_l = alpha_msg_pow; + for entry in &mut zk_tail { + *entry += *coeff * alpha_l; + alpha_l *= *alpha; + } + } + for (coeff, omega) in update_params + .in_domain_rlc_coeffs + .iter() + .zip(&update_params.in_domain_eval_points) + { + let omega_msg_pow = omega.pow([msg_len as u64]); + let mut omega_l = omega_msg_pow; + for entry in &mut zk_tail[..source_mask_len] { + *entry += *coeff * omega_l; + omega_l *= *omega; + } + } + + let cs_cov: [&[F]; 1] = [&zk_tail]; + let cs_x_values = mo + .cs_mask + .verify(vs, &cs_commitment, Some(&cs_cov))? + .expect("cs_mask value always returned when covector passed"); + *sum -= cs_x_values[0]; + + Ok(()) + } +} + +#[cfg_attr(feature = "tracing", instrument(skip_all, name = "zook::verify_round", fields(msg_len = round.code_switch.source.message_length())))] fn verify_round( round: &RoundConfig>, mut state: VerifierRoundState, @@ -127,84 +417,67 @@ where { let msg_len = round.code_switch.source.message_length(); - // (1) Receive the sumcheck-masks tree commitment (pre-sumcheck) so each - // sumcheck round poly is bound to its mask. - let sumcheck_commitment = match round.mask_oracle.as_ref() { - Some(mo) => Some(mo.sumcheck_masks.receive_commitment(vs)?), - None => None, - }; + // Receive sumcheck-masks commitment (ZK) or construct Disabled (Standard). + let mut masker = RoundMaskOracleCheck::begin(round, vs)?; - // (2) ZK sumcheck. Mutates sum to `δ + γ_sumcheck · dot`; orchestrator - // folds its public covector by the round challenges. + // Sumcheck: mutates sum, records round challenges for full_eval_point reconstruction. let opening = round.sumcheck.verify(vs, &mut state.sum)?; - for &r in &opening.round_challenges { - fold(&mut state.covector, r); - } + state + .all_round_challenges + .extend_from_slice(&opening.round_challenges); + state.current_msg_len = msg_len; - // (3) Receive cs_mask tree commitment (post-sumcheck) + δ cleartext. - // Reconcile sum to the unmasked dot. - let cs_mask_commitment = match round.mask_oracle.as_ref() { - Some(mo) => Some(mo.cs_mask.receive_commitment(vs)?), - None => None, - }; - let delta: Option = if round.mask_oracle.is_some() { - let d: F = vs.prover_message()?; - let gamma_inv = opening - .mask_rlc - .inverse() - .expect("sumcheck mask_rlc must be non-zero"); - state.sum = (state.sum - d) * gamma_inv; - Some(d) - } else { - None - }; + // Receive cs_mask commitment + mask_eval_sum (δ), reconcile sum to the unmasked dot. + masker.receive_cs_mask_and_reconcile(&opening, vs, &mut state.sum)?; - // (4) Extend covector and run code-switch. - if let Some(mo) = round.mask_oracle.as_ref() { - state.covector.resize(msg_len + mo.l_zk.get(), F::ZERO); - } - let target_commitment = round.code_switch.verify( + // Code-switch: accumulate implicit constraints; no explicit covector update. + let (target_commitment, update_params) = round.code_switch.verify_for_implicit( vs, &mut state.sum, - &mut state.covector, &opening.round_challenges, &state.irs_commitment, )?; - - // (5) Discharge sumcheck masks (bind δ) and cs_mask (subtract X_cs). - if let (Some(mo), Some(scm), Some(cmc), Some(d)) = ( - round.mask_oracle.as_ref(), - sumcheck_commitment, - cs_mask_commitment, - delta, - ) { - // Sumcheck-masks tree: per-mask covectors [1, r_i, …, r_i^{vec_size−1}]. - let sumcheck_vec_size = mo.sumcheck_masks.c_zk_commit.vector_size; - let sm_covectors: Vec> = opening - .round_challenges - .iter() - .map(|&r| geometric_sequence(r, sumcheck_vec_size)) - .collect(); - let sm_refs: Vec<&[F]> = sm_covectors.iter().map(|v| v.as_slice()).collect(); - let sm_x_values = mo - .sumcheck_masks - .verify(vs, &scm, Some(&sm_refs))? - .expect("sumcheck-mask values always returned when covectors passed"); - let sumcheck_x_sum: F = sm_x_values.iter().copied().sum(); - verify!(sumcheck_x_sum == d); - - // cs_mask tree: single covector at the post-cs covector mask region. - let cs_cov: [&[F]; 1] = [&state.covector[msg_len..]]; - let cs_x_values = mo - .cs_mask - .verify(vs, &cmc, Some(&cs_cov))? - .expect("cs_mask value always returned when covector passed"); - state.sum -= cs_x_values[0]; + let current_round = state.round_scale_factors.len(); + let domain_bits = msg_len.trailing_zeros(); + for (batching_weight, eval_point) in update_params + .ood_rlc_coeffs + .iter() + .zip(&update_params.ood_eval_points) + { + state.constraints.push(ImplicitConstraint { + eval_point: *eval_point, + batching_weight: *batching_weight, + domain_bits, + added_at_round: current_round, + }); } - - state.covector.truncate(msg_len); + for (batching_weight, eval_point) in update_params + .in_domain_rlc_coeffs + .iter() + .zip(&update_params.in_domain_eval_points) + { + state.constraints.push(ImplicitConstraint { + eval_point: *eval_point, + batching_weight: *batching_weight, + domain_bits, + added_at_round: current_round, + }); + } + state + .round_scale_factors + .push(update_params.original_sl_coeff); + state.challenges_at.push(state.all_round_challenges.len()); state.irs_commitment = target_commitment; + // Verify both mask trees and subtract cs_mask contribution from sum. + masker.verify_and_discharge( + &opening.round_challenges, + msg_len, + &update_params, + vs, + &mut state.sum, + )?; + Ok(state) } @@ -283,9 +556,12 @@ mod tests { let mut verifier_state = VerifierState::new_std(&ds, &proof); let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - config + let claim = config .verify(&mut verifier_state, commitment, &[&form], &[value]) .unwrap(); + // Final claim check for MultilinearExtension form: + // initial_claim_scale × γ^0 × form.mle_evaluate(full_eval_point) == linear_forms_contribution + claim.verify(&[&form]).expect("FinalClaim::verify failed"); verifier_state.check_eof().unwrap(); } From 1b2529ce22078039a3e7b3bba37c683d03f401a5 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Mon, 1 Jun 2026 14:51:43 +0530 Subject: [PATCH 5/6] feat : added proper tests --- src/protocols/irs_commit.rs | 28 +-- src/protocols/zook/mod.rs | 437 +++++++++++++++++++++++++++++---- src/protocols/zook/prover.rs | 141 ----------- src/protocols/zook/verifier.rs | 111 --------- 4 files changed, 400 insertions(+), 317 deletions(-) diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index 41d12910..21b03436 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -140,28 +140,22 @@ impl Config { assert!(rate > 0. && rate <= 1.); let message_length = vector_size / interleaving_depth; let masked_message_length = message_length + mode.mask_length(); - // Rate against the unmasked `message_length` when the mask is small - // relative to the message. The mask is `r ≈ in_domain + OOD` per - // Lemma 9.5 (~100 elements at λ=128) — typically much smaller than - // the rate slack of the NTT-smooth codeword, so it fits inside - // without forcing a `next_order` jump (e.g. avoids 524288 → 589824, - // a 12.5% jump on a 2^17-message commit). The resulting effective - // rate is `masked / codeword` ≈ requested + (mask/codeword), which - // for small-mask cases is a negligible degradation (≈1 extra - // in-domain query per round). When the mask is comparable to the - // message (tail rounds at very low rates), this optimization would - // blow the rate up and the parameter-selection fixpoint in - // `params::irs_commit::solve` wouldn't converge — fall back to rating - // the masked length there. Threshold `5·mask ≤ message` keeps the - // degradation under ~20%. + // Use the unmasked codeword when it (a) fits the masked message and + // (b) keeps effective rate within 20% of the requested rate. This + // avoids a `next_order` jump caused by the small IRS randomness mask + // pushing the masked length past an NTT-smooth boundary (e.g. avoids + // 524288 → 589824, a 12.5% jump on a 2^17-message commit). + // When the mask is large enough that the effective rate would exceed + // 1.2 × requested (tail rounds at very low rates), fall back to + // sizing the codeword from the masked length. #[allow(clippy::cast_sign_loss)] let codeword_length = { - let mask_len = mode.mask_length(); - let small_mask = 5 * mask_len <= message_length; let unmasked_target = (message_length as f64 / rate).ceil() as usize; let unmasked = ntt::next_order::(unmasked_target) .expect("codeword length exceeds NTT engine support"); - if small_mask && unmasked >= masked_message_length { + let fits_masked = unmasked >= masked_message_length; + let effective_rate = masked_message_length as f64 / unmasked as f64; + if fits_masked && effective_rate <= 1.2 * rate { unmasked } else { let masked_target = (masked_message_length as f64 / rate).ceil() as usize; diff --git a/src/protocols/zook/mod.rs b/src/protocols/zook/mod.rs index 9c1bb1c2..a0c50891 100644 --- a/src/protocols/zook/mod.rs +++ b/src/protocols/zook/mod.rs @@ -67,7 +67,7 @@ mod tests { use crate::{ algebra::{ embedding::Identity, - fields::Field256, + fields::{Field256, Field64}, linear_form::{Evaluate, LinearForm, MultilinearExtension}, random_vector, }, @@ -78,25 +78,380 @@ mod tests { transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, }; - type F = Field256; - type Embed = Identity; + // ── Small-test field and config (Field64, λ=40, fast) ─────────────────── - /// Keep PoW below the 60-bit cap during `derive` for the test tuning. - const TEST_TARGET_BITS: u32 = 128; + type SmallF = Field64; + type SmallEmbed = Identity; - fn test_spec(mode: Mode) -> SecuritySpec { + fn small_spec(mode: Mode) -> SecuritySpec { SecuritySpec { mode, decoding_regime: DecodingRegime::Johnson, - target_security_bits: TEST_TARGET_BITS, + target_security_bits: 40, + pow_budget: PowBudget::per_slot(10), + hash_id: hash::BLAKE3, + } + } + + /// 2^8 witness, folding_factor 2 → multiple rounds. + fn multi_round_tuning() -> TuningSpec { + TuningSpec { + vector_size: 1 << 8, + starting_log_inv_rate: 1, + folding_factor: FoldingFactor::Constant(2), + rate_schedule: RateSchedule::Stepping, + } + } + + /// 2^3 witness → basecase-only (too small for any round). + fn basecase_tuning() -> TuningSpec { + TuningSpec { + vector_size: 1 << 3, + starting_log_inv_rate: 1, + folding_factor: FoldingFactor::Constant(2), + rate_schedule: RateSchedule::Stepping, + } + } + + // ── Test helpers ───────────────────────────────────────────────────────── + + /// A shared `&'static Empty` so that the returned `DomainSeparator` carries + /// a `'static` lifetime (required by `VerifierState::new_std`). + static EMPTY: Empty = Empty; + + /// Construct a deterministic `DomainSeparator` for the given label. + fn make_ds(label: &str) -> DomainSeparator<'static, Empty> { + DomainSeparator::protocol(&"zook-test") + .session(&label.to_string()) + .instance(&EMPTY) + } + + /// Build a witness, compute true evaluations, prove, and return the proof + /// along with the domain separator, forms, and values for further checks. + fn build_and_prove( + config: &ProtocolConfig, + num_claims: usize, + seed: u64, + label: &str, + ) -> ( + crate::transcript::Proof, + DomainSeparator<'static, Empty>, + Vec>, + Vec, + ) { + let embedding = ::default(); + let mut rng = StdRng::seed_from_u64(seed); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let mu = config.tuning.vector_size.trailing_zeros() as usize; + + let forms: Vec> = (0..num_claims) + .map(|_| MultilinearExtension { + point: random_vector::(&mut rng, mu), + }) + .collect(); + let values: Vec = forms + .iter() + .map(|f| f.evaluate(&embedding, &witness)) + .collect(); + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + + let ds = make_ds(label); + let mut ps = ProverState::new_std(&ds); + let committed = config.commit(&mut ps, &witness); + config.prove(&mut ps, committed, &form_refs, &values); + let proof = ps.proof(); + + (proof, ds, forms, values) + } + + /// Run a full roundtrip: commit → prove → verify → FinalClaim::verify. + /// Panics if anything fails. + fn full_roundtrip( + config: &ProtocolConfig, + num_claims: usize, + seed: u64, + label: &str, + ) { + let (proof, ds, forms, values) = build_and_prove(config, num_claims, seed, label); + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + + let mut vs = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut vs).unwrap(); + let claim = config + .verify(&mut vs, commitment, &form_refs, &values) + .unwrap(); + claim.verify(&form_refs).expect("FinalClaim::verify failed"); + vs.check_eof().unwrap(); + } + + /// Expect verification to fail (handles both `verifier_panics` and normal builds). + fn assert_verify_rejected(verify: impl FnOnce() -> crate::transcript::VerificationResult<()>) { + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(verify)); + match result { + Err(_) | Ok(Err(_)) => {} // panicked or returned Err — failure as expected + Ok(Ok(())) => panic!("expected verification to fail but it succeeded"), + } + } + + // ── Positive tests: happy paths ────────────────────────────────────────── + + #[test] + fn roundtrip_zk_with_rounds() { + let config = ProtocolConfig::::derive( + small_spec(Mode::ZeroKnowledge), + multi_round_tuning(), + ) + .unwrap(); + assert!(!config.rounds.is_empty(), "expected at least one round"); + full_roundtrip(&config, 1, 0, "roundtrip_zk_with_rounds"); + } + + #[test] + fn roundtrip_standard_with_rounds() { + let config = + ProtocolConfig::::derive(small_spec(Mode::Standard), multi_round_tuning()) + .unwrap(); + assert!(!config.rounds.is_empty(), "expected at least one round"); + full_roundtrip(&config, 1, 1, "roundtrip_standard_with_rounds"); + } + + #[test] + fn roundtrip_zk_basecase_only() { + let config = ProtocolConfig::::derive( + small_spec(Mode::ZeroKnowledge), + basecase_tuning(), + ) + .unwrap(); + assert!(config.rounds.is_empty(), "expected basecase-only plan"); + full_roundtrip(&config, 1, 2, "roundtrip_zk_basecase_only"); + } + + #[test] + fn roundtrip_standard_basecase_only() { + let config = + ProtocolConfig::::derive(small_spec(Mode::Standard), basecase_tuning()) + .unwrap(); + assert!(config.rounds.is_empty(), "expected basecase-only plan"); + full_roundtrip(&config, 1, 3, "roundtrip_standard_basecase_only"); + } + + #[test] + fn roundtrip_multiple_claims_zk() { + let config = ProtocolConfig::::derive( + small_spec(Mode::ZeroKnowledge), + multi_round_tuning(), + ) + .unwrap(); + full_roundtrip(&config, 4, 4, "roundtrip_multiple_claims_zk"); + } + + #[test] + fn roundtrip_multiple_claims_standard() { + let config = + ProtocolConfig::::derive(small_spec(Mode::Standard), multi_round_tuning()) + .unwrap(); + full_roundtrip(&config, 4, 5, "roundtrip_multiple_claims_standard"); + } + + // ── Negative tests: input validation ───────────────────────────────────── + + #[test] + #[should_panic(expected = "linear_forms.len() != evaluations.len()")] + fn prove_rejects_form_evaluation_count_mismatch() { + let config = + ProtocolConfig::::derive(small_spec(Mode::Standard), multi_round_tuning()) + .unwrap(); + let mut rng = StdRng::seed_from_u64(0); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let mu = config.tuning.vector_size.trailing_zeros() as usize; + let form: MultilinearExtension = MultilinearExtension { + point: random_vector(&mut rng, mu), + }; + let ds = DomainSeparator::protocol(&"zook-test") + .session(&"count-mismatch".to_string()) + .instance(&Empty); + let mut ps = ProverState::new_std(&ds); + let committed = config.commit(&mut ps, &witness); + // 1 form but 2 values — should panic + config.prove( + &mut ps, + committed, + &[&form as &dyn LinearForm], + &[SmallF::from(1u64), SmallF::from(2u64)], + ); + } + + #[test] + #[should_panic(expected = "zook requires ≥ 1")] + fn prove_rejects_empty_forms() { + let config = + ProtocolConfig::::derive(small_spec(Mode::Standard), multi_round_tuning()) + .unwrap(); + let mut rng = StdRng::seed_from_u64(0); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let ds = DomainSeparator::protocol(&"zook-test") + .session(&"empty-forms".to_string()) + .instance(&Empty); + let mut ps = ProverState::new_std(&ds); + let committed = config.commit(&mut ps, &witness); + // No forms at all — should panic + config.prove(&mut ps, committed, &[], &[]); + } + + // ── Negative tests: wrong claimed values ────────────────────────────────── + + #[test] + fn verify_rejects_wrong_evaluation_zk() { + let config = ProtocolConfig::::derive( + small_spec(Mode::ZeroKnowledge), + multi_round_tuning(), + ) + .unwrap(); + let (proof, ds, forms, true_values) = + build_and_prove(&config, 1, 10, "verify_rejects_wrong_evaluation_zk"); + + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + // Shift the claimed value by 1 — the proof is now inconsistent with this claim + let wrong_values = vec![true_values[0] + SmallF::from(1u64)]; + + assert_verify_rejected(|| { + let mut vs = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut vs)?; + let claim = config.verify(&mut vs, commitment, &form_refs, &wrong_values)?; + claim.verify(&form_refs)?; + vs.check_eof() + }); + } + + #[test] + fn verify_rejects_wrong_evaluation_standard() { + let config = + ProtocolConfig::::derive(small_spec(Mode::Standard), multi_round_tuning()) + .unwrap(); + let (proof, ds, forms, true_values) = + build_and_prove(&config, 1, 11, "verify_rejects_wrong_evaluation_standard"); + + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + let wrong_values = vec![true_values[0] + SmallF::from(1u64)]; + + assert_verify_rejected(|| { + let mut vs = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut vs)?; + let claim = config.verify(&mut vs, commitment, &form_refs, &wrong_values)?; + claim.verify(&form_refs)?; + vs.check_eof() + }); + } + + #[test] + fn verify_rejects_wrong_evaluation_multiple_claims() { + let config = ProtocolConfig::::derive( + small_spec(Mode::ZeroKnowledge), + multi_round_tuning(), + ) + .unwrap(); + let (proof, ds, forms, mut values) = build_and_prove( + &config, + 3, + 12, + "verify_rejects_wrong_evaluation_multiple_claims", + ); + + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + // Corrupt the second claim + values[1] += SmallF::from(42u64); + + assert_verify_rejected(|| { + let mut vs = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut vs)?; + let claim = config.verify(&mut vs, commitment, &form_refs, &values)?; + claim.verify(&form_refs)?; + vs.check_eof() + }); + } + + #[test] + fn final_claim_rejects_wrong_form() { + // Correct proof + correct verify, but FinalClaim::verify called with a + // different form → must be detected. + let config = ProtocolConfig::::derive( + small_spec(Mode::ZeroKnowledge), + multi_round_tuning(), + ) + .unwrap(); + let (proof, ds, forms, values) = + build_and_prove(&config, 1, 20, "final_claim_rejects_wrong_form"); + + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + + let mut vs = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut vs).unwrap(); + let claim = config + .verify(&mut vs, commitment, &form_refs, &values) + .unwrap(); + + // Construct a different form at a different evaluation point + let mut rng = StdRng::seed_from_u64(999); + let mu = config.tuning.vector_size.trailing_zeros() as usize; + let wrong_form: MultilinearExtension = MultilinearExtension { + point: random_vector(&mut rng, mu), + }; + + assert_verify_rejected(|| claim.verify(&[&wrong_form as &dyn LinearForm])); + } + + // ── Negative tests: tampered proof ──────────────────────────────────────── + + #[test] + fn verify_rejects_tampered_proof() { + let config = ProtocolConfig::::derive( + small_spec(Mode::ZeroKnowledge), + multi_round_tuning(), + ) + .unwrap(); + let (mut proof, ds, forms, values) = + build_and_prove(&config, 1, 30, "verify_rejects_tampered_proof"); + + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + + // Flip a byte in the middle of the transcript to corrupt the proof + let mid = proof.narg_string.len() / 2; + proof.narg_string[mid] ^= 0xff; + + assert_verify_rejected(|| { + let mut vs = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut vs)?; + let claim = config.verify(&mut vs, commitment, &form_refs, &values)?; + claim.verify(&form_refs)?; + vs.check_eof() + }); + } + + // ── Large integration tests (Field256, λ=128, realistic) ───────────────── + // These cover the full security target with a 2^19 witness and multiple + // folding rounds. Kept for end-to-end regression coverage. + + type LargeF = Field256; + type LargeEmbed = Identity; + + fn large_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: 128, pow_budget: PowBudget::per_slot(10), hash_id: hash::SHA2, } } - /// 2^20-sized witness; folding_factor 4 ⇒ 5 rounds (keeps round count - /// modest at this size). - fn tuning_2_pow_20() -> TuningSpec { + fn large_tuning() -> TuningSpec { TuningSpec { vector_size: 1 << 19, starting_log_inv_rate: 2, @@ -108,67 +463,53 @@ mod tests { } } - /// Commit → prove → verify roundtrip on a 2^20 witness with 3 multilinear - /// claims. Shared by the ZK and Standard tests below. - fn roundtrip_three_claims(mode: Mode, seed: u64) { + fn large_roundtrip(mode: Mode, seed: u64) { crate::tests::init(); - let config = ProtocolConfig::::derive(test_spec(mode), tuning_2_pow_20()).unwrap(); + let config = + ProtocolConfig::::derive(large_spec(mode), large_tuning()).unwrap(); let mut rng = StdRng::seed_from_u64(seed); - let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); + let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); let mu = config.tuning.vector_size.trailing_zeros() as usize; + let embedding = ::default(); - // 3 random multilinear claims and their true evaluations against the - // witness. - let embedding = ::default(); - let forms: Vec> = (0..3) + let forms: Vec> = (0..3) .map(|_| MultilinearExtension { - point: random_vector::(&mut rng, mu), + point: random_vector::(&mut rng, mu), }) .collect(); - let values: Vec = forms + let values: Vec = forms .iter() .map(|f| f.evaluate(&embedding, &witness)) .collect(); - let form_refs: Vec<&dyn LinearForm> = - forms.iter().map(|f| f as &dyn LinearForm).collect(); - - let ds = DomainSeparator::protocol(&"zook-mod-test") - .session(&format!( - "three-claims 2^20 mode={:?} {}:{}", - mode, - file!(), - line!() - )) + let form_refs: Vec<&dyn LinearForm> = + forms.iter().map(|f| f as &dyn LinearForm).collect(); + + let ds = DomainSeparator::protocol(&"zook-large-test") + .session(&format!("three-claims mode={mode:?} seed={seed}")) .instance(&Empty); - let mut prover_state = ProverState::new_std(&ds); - let committed = config.commit(&mut prover_state, &witness); - config.prove(&mut prover_state, committed, &form_refs, &values); - let proof = prover_state.proof(); + let mut ps = ProverState::new_std(&ds); + let committed = config.commit(&mut ps, &witness); + config.prove(&mut ps, committed, &form_refs, &values); + let proof = ps.proof(); - let mut verifier_state = VerifierState::new_std(&ds, &proof); - let commitment = config - .receive_commitment(&mut verifier_state) - .expect("receive_commitment"); + let mut vs = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut vs).unwrap(); let claim = config - .verify(&mut verifier_state, commitment, &form_refs, &values) - .expect("verify"); - // Final claim check for 3 MultilinearExtension forms: - // initial_claim_scale × Σ_j γ^j × form_j.mle_evaluate(z_full) == linear_forms_contribution + .verify(&mut vs, commitment, &form_refs, &values) + .unwrap(); claim.verify(&form_refs).expect("FinalClaim::verify failed"); - verifier_state - .check_eof() - .expect("transcript fully consumed"); + vs.check_eof().unwrap(); } #[test] fn roundtrip_2_pow_20_three_claims_zk() { - roundtrip_three_claims(Mode::ZeroKnowledge, 0); + large_roundtrip(Mode::ZeroKnowledge, 0); } #[test] fn roundtrip_2_pow_20_three_claims_standard() { - roundtrip_three_claims(Mode::Standard, 1); + large_roundtrip(Mode::Standard, 1); } } diff --git a/src/protocols/zook/prover.rs b/src/protocols/zook/prover.rs index a0eea32f..a4135bd1 100644 --- a/src/protocols/zook/prover.rs +++ b/src/protocols/zook/prover.rs @@ -576,144 +576,3 @@ impl<'a, F: Field + Default + Zeroize> RoundMaskOracle<'a, F> { } } } - -#[cfg(test)] -mod tests { - use ark_std::rand::{rngs::StdRng, SeedableRng}; - - use super::*; - use crate::{ - algebra::{ - fields::Field64, - linear_form::{Evaluate, MultilinearExtension}, - }, - hash, - protocols::params::{ - spec::{ - DecodingRegime, FoldingFactor, Mode, PowBudget, RateSchedule, SecuritySpec, - TuningSpec, - }, - test_utils::TestEmbedding, - }, - transcript::{codecs::Empty, DomainSeparator}, - }; - - type F = Field64; - - const TEST_TARGET_BITS: u32 = 40; - - fn test_spec(mode: Mode) -> SecuritySpec { - SecuritySpec { - mode, - decoding_regime: DecodingRegime::Johnson, - target_security_bits: TEST_TARGET_BITS, - pow_budget: PowBudget::per_slot(10), - hash_id: hash::BLAKE3, - } - } - - fn tuning_with_rounds() -> TuningSpec { - TuningSpec { - vector_size: 1 << 8, - starting_log_inv_rate: 1, - folding_factor: FoldingFactor::Constant(2), - rate_schedule: RateSchedule::Stepping, - } - } - - fn tuning_basecase_only() -> TuningSpec { - TuningSpec { - vector_size: 1 << 3, - starting_log_inv_rate: 1, - folding_factor: FoldingFactor::Constant(2), - rate_schedule: RateSchedule::Stepping, - } - } - - /// Domain separator + ProverState setup shared by every test. - macro_rules! test_session { - ($tag:expr) => {{ - let ds = DomainSeparator::protocol(&"zook-prover-test") - .session(&format!("{} {}:{}", $tag, file!(), line!())) - .instance(&Empty); - ProverState::new_std(&ds) - }}; - } - - fn run_prove(config: &ProtocolConfig, seed: u64) { - let mut rng = StdRng::seed_from_u64(seed); - let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); - let mu = config.tuning.vector_size.trailing_zeros() as usize; - let form = MultilinearExtension { - point: random_vector::(&mut rng, mu), - }; - let embedding = as Default>::default(); - let value = form.evaluate(&embedding, &witness); - - let mut prover_state = test_session!("smoke"); - let committed = config.commit(&mut prover_state, &witness); - config.prove(&mut prover_state, committed, &[&form], &[value]); - let _ = prover_state.proof(); - } - - fn smoke(mode: Mode, tuning: TuningSpec, seed: u64, expect_rounds: bool) { - let config = ProtocolConfig::::derive(test_spec(mode), tuning).unwrap(); - assert_eq!(!config.rounds.is_empty(), expect_rounds); - run_prove(&config, seed); - } - - #[test] - fn prove_completes_with_rounds_zk() { - smoke(Mode::ZeroKnowledge, tuning_with_rounds(), 0, true); - } - - #[test] - fn prove_completes_with_rounds_standard() { - smoke(Mode::Standard, tuning_with_rounds(), 1, true); - } - - #[test] - fn prove_completes_basecase_only_zk() { - smoke(Mode::ZeroKnowledge, tuning_basecase_only(), 2, false); - } - - #[test] - fn prove_completes_basecase_only_standard() { - smoke(Mode::Standard, tuning_basecase_only(), 3, false); - } - - #[test] - #[should_panic(expected = "linear_forms.len() != evaluations.len()")] - fn prove_rejects_count_mismatch() { - let config = ProtocolConfig::::derive( - test_spec(Mode::Standard), - tuning_with_rounds(), - ) - .unwrap(); - let mut rng = StdRng::seed_from_u64(4); - let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); - let mu = config.tuning.vector_size.trailing_zeros() as usize; - let form = MultilinearExtension { - point: random_vector::(&mut rng, mu), - }; - let mut prover_state = test_session!("count-mismatch"); - let committed = config.commit(&mut prover_state, &witness); - // 1 form, 2 values → mismatch. - config.prove(&mut prover_state, committed, &[&form], &[F::ONE, F::ONE]); - } - - #[test] - #[should_panic(expected = "zook requires ≥ 1")] - fn prove_rejects_empty() { - let config = ProtocolConfig::::derive( - test_spec(Mode::Standard), - tuning_with_rounds(), - ) - .unwrap(); - let mut rng = StdRng::seed_from_u64(5); - let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); - let mut prover_state = test_session!("empty"); - let committed = config.commit(&mut prover_state, &witness); - config.prove(&mut prover_state, committed, &[], &[]); - } -} diff --git a/src/protocols/zook/verifier.rs b/src/protocols/zook/verifier.rs index 1b238b0d..cc468087 100644 --- a/src/protocols/zook/verifier.rs +++ b/src/protocols/zook/verifier.rs @@ -480,114 +480,3 @@ where Ok(state) } - -#[cfg(test)] -mod tests { - use ark_std::rand::{rngs::StdRng, SeedableRng}; - - use super::*; - use crate::{ - algebra::{ - fields::Field64, - linear_form::{Evaluate, MultilinearExtension}, - random_vector, - }, - hash, - protocols::params::{ - spec::{ - DecodingRegime, FoldingFactor, Mode, PowBudget, RateSchedule, SecuritySpec, - TuningSpec, - }, - test_utils::TestEmbedding, - }, - transcript::{codecs::Empty, DomainSeparator, ProverState}, - }; - - type F = Field64; - - const TEST_TARGET_BITS: u32 = 40; - - fn test_spec(mode: Mode) -> SecuritySpec { - SecuritySpec { - mode, - decoding_regime: DecodingRegime::Johnson, - target_security_bits: TEST_TARGET_BITS, - pow_budget: PowBudget::per_slot(10), - hash_id: hash::BLAKE3, - } - } - - fn tuning_with_rounds() -> TuningSpec { - TuningSpec { - vector_size: 1 << 8, - starting_log_inv_rate: 1, - folding_factor: FoldingFactor::Constant(2), - rate_schedule: RateSchedule::Stepping, - } - } - - fn tuning_basecase_only() -> TuningSpec { - TuningSpec { - vector_size: 1 << 3, - starting_log_inv_rate: 1, - folding_factor: FoldingFactor::Constant(2), - rate_schedule: RateSchedule::Stepping, - } - } - - fn roundtrip(config: &ProtocolConfig, seed: u64) { - let mut rng = StdRng::seed_from_u64(seed); - let witness: Vec = random_vector(&mut rng, config.tuning.vector_size); - let mu = config.tuning.vector_size.trailing_zeros() as usize; - let form = MultilinearExtension { - point: random_vector::(&mut rng, mu), - }; - let embedding = as Default>::default(); - let value = form.evaluate(&embedding, &witness); - - let ds = DomainSeparator::protocol(&"zook-verifier-test") - .session(&format!("roundtrip {}:{}", file!(), line!())) - .instance(&Empty); - - let mut prover_state = ProverState::new_std(&ds); - let committed = config.commit(&mut prover_state, &witness); - config.prove(&mut prover_state, committed, &[&form], &[value]); - let proof = prover_state.proof(); - - let mut verifier_state = VerifierState::new_std(&ds, &proof); - let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - let claim = config - .verify(&mut verifier_state, commitment, &[&form], &[value]) - .unwrap(); - // Final claim check for MultilinearExtension form: - // initial_claim_scale × γ^0 × form.mle_evaluate(full_eval_point) == linear_forms_contribution - claim.verify(&[&form]).expect("FinalClaim::verify failed"); - verifier_state.check_eof().unwrap(); - } - - fn smoke(mode: Mode, tuning: TuningSpec, seed: u64, expect_rounds: bool) { - let config = ProtocolConfig::::derive(test_spec(mode), tuning).unwrap(); - assert_eq!(!config.rounds.is_empty(), expect_rounds); - roundtrip(&config, seed); - } - - #[test] - fn verify_completes_with_rounds_zk() { - smoke(Mode::ZeroKnowledge, tuning_with_rounds(), 0, true); - } - - #[test] - fn verify_completes_with_rounds_standard() { - smoke(Mode::Standard, tuning_with_rounds(), 1, true); - } - - #[test] - fn verify_completes_basecase_only_zk() { - smoke(Mode::ZeroKnowledge, tuning_basecase_only(), 2, false); - } - - #[test] - fn verify_completes_basecase_only_standard() { - smoke(Mode::Standard, tuning_basecase_only(), 3, false); - } -} From 16afaa316f8ca8459a8f6625c4fd75ffe1d0fb77 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 5 Jun 2026 15:51:06 +0530 Subject: [PATCH 6/6] feat : updated challange indices logic for clear abstraction --- benches/zook_vs_whir.rs | 20 +++---- src/protocols/challenge_indices.rs | 89 ++++++++++++++++++++++-------- 2 files changed, 73 insertions(+), 36 deletions(-) diff --git a/benches/zook_vs_whir.rs b/benches/zook_vs_whir.rs index d750c3fa..c9c017af 100644 --- a/benches/zook_vs_whir.rs +++ b/benches/zook_vs_whir.rs @@ -33,8 +33,8 @@ use whir::{ parameters::ProtocolParameters, protocols::{ params::{ - DecodingRegime, FoldingFactor, KneeWeight, Mode, PowBudget, ProtocolConfig, - RateSchedule, SecuritySpec, TuningSpec, + DecodingRegime, FoldingFactor, Mode, PowBudget, ProtocolConfig, RateSchedule, + SecuritySpec, TuningSpec, }, whir as whir_non_zk, }, @@ -370,21 +370,15 @@ fn main() { Err(e) => eprintln!(" zook_standard_stepping failed: {e}"), } - eprintln!("[2^{log_size}] zook_zk_adaptive ..."); - match bench_zook( - log_size, - Mode::ZeroKnowledge, - RateSchedule::Adaptive { - knee_weight: KneeWeight::DEFAULT, - }, - ) { + eprintln!("[2^{log_size}] zook_zk_stepping ..."); + match bench_zook(log_size, Mode::ZeroKnowledge, RateSchedule::Stepping) { Ok(s) => { let pb = s.proof_bytes; record_cell( &mut csv, log_size, size, - "zook_zk_adaptive", + "zook_zk_stepping", "commit", &s.commit, pb, @@ -393,7 +387,7 @@ fn main() { &mut csv, log_size, size, - "zook_zk_adaptive", + "zook_zk_stepping", "prove", &s.prove, pb, @@ -402,7 +396,7 @@ fn main() { &mut csv, log_size, size, - "zook_zk_adaptive", + "zook_zk_stepping", "verify", &s.verify, pb, diff --git a/src/protocols/challenge_indices.rs b/src/protocols/challenge_indices.rs index 53c7f7af..01c6a028 100644 --- a/src/protocols/challenge_indices.rs +++ b/src/protocols/challenge_indices.rs @@ -21,30 +21,10 @@ where return if deduplicate { vec![0] } else { vec![0; count] }; } - // Size the entropy chunk so `2^(8·size_bytes) ≥ next_pow2(num_leaves)`. - // For pow2 `num_leaves`, the entropy space is an exact multiple of - // `num_leaves` and rejection never triggers — bit-identical to the - // pre-rejection implementation. For non-pow2 `num_leaves`, rejection - // sampling eliminates the modular bias of a plain `% num_leaves`. - let bits_needed = num_leaves.next_power_of_two().ilog2() as usize; - let size_bytes = bits_needed.div_ceil(8); - - // Largest multiple of `num_leaves` below `2^(8·size_bytes)`. u128 - // accommodates `size_bytes ≤ 16` without shift overflow on 64-bit hosts. - let entropy_space: u128 = 1u128 << (8 * size_bytes); - let num_leaves_u = num_leaves as u128; - let threshold: u128 = (entropy_space / num_leaves_u) * num_leaves_u; - + let domain = IndexDomain::new(num_leaves); let mut indices = Vec::with_capacity(count); - while indices.len() < count { - let mut candidate: u128 = 0; - for _ in 0..size_bytes { - candidate = (candidate << 8) | u128::from(transcript.verifier_message::()); - } - if candidate < threshold { - indices.push((candidate % num_leaves_u) as usize); - } - // else: candidate falls in the biased tail — reject and redraw. + for _ in 0..count { + indices.push(domain.sample(transcript)); } if deduplicate { @@ -54,6 +34,69 @@ where indices } +/// Rejection-sampling domain for transcript-derived row indices. +/// +/// For power-of-two domains, `entropy_space` is an exact multiple of +/// `num_leaves`, so rejection never triggers and sampling is bit-identical to +/// the legacy `% num_leaves` implementation. For non-power-of-two domains, +/// candidates in the biased tail are rejected and redrawn. +#[derive(Clone, Copy, Debug)] +struct IndexDomain { + num_leaves: usize, + size_bytes: usize, + threshold: u128, +} + +impl IndexDomain { + fn new(num_leaves: usize) -> Self { + debug_assert!(num_leaves > 1); + let bits_needed = ceil_log2(num_leaves); + let size_bytes = bits_needed.div_ceil(8); + let entropy_bits = 8 * size_bytes; + debug_assert!(entropy_bits < u128::BITS as usize); + + let entropy_space = 1u128 << entropy_bits; + let num_leaves_u = num_leaves as u128; + let threshold = (entropy_space / num_leaves_u) * num_leaves_u; + + Self { + num_leaves, + size_bytes, + threshold, + } + } + + fn sample(&self, transcript: &mut T) -> usize + where + T: VerifierMessage, + u8: Decoding<[T::U]>, + { + loop { + let candidate = self.draw_candidate(transcript); + if candidate < self.threshold { + return (candidate % self.num_leaves as u128) as usize; + } + } + } + + fn draw_candidate(&self, transcript: &mut T) -> u128 + where + T: VerifierMessage, + u8: Decoding<[T::U]>, + { + let mut candidate = 0u128; + for _ in 0..self.size_bytes { + candidate = (candidate << 8) | u128::from(transcript.verifier_message::()); + } + candidate + } +} + +fn ceil_log2(n: usize) -> usize { + debug_assert!(n > 1); + usize::BITS as usize - (n - 1).leading_zeros() as usize +} + #[cfg(test)] mod tests { use super::*;