From a58501a4d0b2a73549f7031c5c144d6e4839f2a8 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 7 May 2026 17:45:54 +0530 Subject: [PATCH 01/31] feat : added param selection for sumcheck and added param helpers, updated protocols for optimal parameter selection --- src/protocols/basecase.rs | 17 +- src/protocols/irs_commit.rs | 3 +- src/protocols/mod.rs | 1 + src/protocols/params/bounds.rs | 81 ++++++++++ src/protocols/params/irs_commit.rs | 1 + src/protocols/params/mod.rs | 6 + src/protocols/params/spec.rs | 45 ++++++ src/protocols/params/sumcheck.rs | 64 ++++++++ src/protocols/sumcheck.rs | 247 ++++++++++++++++++----------- src/protocols/whir/config.rs | 68 ++++---- 10 files changed, 395 insertions(+), 138 deletions(-) create mode 100644 src/protocols/params/bounds.rs create mode 100644 src/protocols/params/irs_commit.rs create mode 100644 src/protocols/params/mod.rs create mode 100644 src/protocols/params/spec.rs create mode 100644 src/protocols/params/sumcheck.rs diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 32763271..95739ee4 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -241,9 +241,7 @@ mod tests { use tracing::instrument; use super::*; - use crate::{ - algebra::fields, protocols::proof_of_work, transcript::DomainSeparator, type_info::Type, - }; + use crate::{algebra::fields, protocols::proof_of_work, transcript::DomainSeparator}; impl Config { pub fn arbitrary(size: usize, mask_length: usize) -> impl Strategy { @@ -254,13 +252,12 @@ mod tests { out_domain_samples: 0, ..commit }, - sumcheck: sumcheck::Config { - field: Type::new(), - initial_size: size, - round_pow: proof_of_work::Config::none(), - num_rounds: size.next_power_of_two().trailing_zeros() as usize, - mask_length: 0, - }, + sumcheck: sumcheck::Config::new( + size, + proof_of_work::Config::none(), + size.next_power_of_two().trailing_zeros() as usize, + sumcheck::SumcheckMode::Standard, + ), masked, }) } diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index b5716f3a..db41c356 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -130,6 +130,7 @@ impl Config { vector_size: usize, interleaving_depth: usize, rate: f64, + mask_length: usize, ) -> Self where M: Default, @@ -180,7 +181,7 @@ impl Config { embedding: Typed::::default(), num_vectors, vector_size, - mask_length: 0, + mask_length, codeword_length, interleaving_depth, matrix_commit: matrix_commit::Config::with_hash( diff --git a/src/protocols/mod.rs b/src/protocols/mod.rs index 17e58040..64f1e0b9 100644 --- a/src/protocols/mod.rs +++ b/src/protocols/mod.rs @@ -16,6 +16,7 @@ pub mod irs_commit; pub mod mask_proximity; pub mod matrix_commit; pub mod merkle_tree; +pub mod params; pub mod proof_of_work; pub mod sumcheck; pub mod whir; diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs new file mode 100644 index 00000000..6732b88a --- /dev/null +++ b/src/protocols/params/bounds.rs @@ -0,0 +1,81 @@ +//! Shared primitives for parameter selection: RS bounds + PoW sizing. + +use std::{f64::consts::LOG2_10, ops::Neg}; + +use crate::{ + algebra::{embedding::Embedding, fields::FieldWithSize}, + bits::Bits, + protocols::irs_commit, +}; + +/// `johnson_slack == 0.0` selects the unique-decoding regime. +#[derive(Debug, Clone, Copy)] +pub struct CodeParams { + pub log_inv_rate: f64, + pub johnson_slack: f64, + pub message_length: usize, + pub field_bits: f64, +} + +impl CodeParams { + pub fn from_irs(irs: &irs_commit::Config) -> Self { + Self { + log_inv_rate: irs.rate().log2().neg(), + johnson_slack: irs.johnson_slack.into_inner(), + message_length: irs.masked_message_length(), + field_bits: M::Target::field_size_bits(), + } + } + + pub fn rate(&self) -> f64 { + 2_f64.powf(-self.log_inv_rate) + } + + pub fn unique_decoding(&self) -> bool { + self.johnson_slack == 0.0 + } +} + +/// log2 |Λ(C, δ)|. +pub fn list_size_log2(p: &CodeParams) -> f64 { + if p.unique_decoding() { + 0.0 + } else { + // Johnson: |Λ| = 1 / (2 η √ρ). + -1.0 - p.johnson_slack.log2() + 0.5 * p.log_inv_rate + } +} + +/// log2 ε_mca(C, δ). +pub fn eps_mca_log2(p: &CodeParams) -> f64 { + let log_k = (p.message_length as f64).log2(); + + let error = if p.unique_decoding() { + log_k + p.log_inv_rate + } else { + debug_assert!(p.johnson_slack.log2() >= -(0.5 * p.log_inv_rate + LOG2_10 + 1.0) - 1e-6); + 7.0 * LOG2_10 + 3.5 * p.log_inv_rate + 2.0 * log_k + }; + + error - p.field_bits +} + +/// log2(1 - δ). +pub fn one_minus_distance_log2(p: &CodeParams) -> f64 { + let one_minus_delta = if p.unique_decoding() { + f64::midpoint(1.0, p.rate()) + } else { + p.rate().sqrt() + p.johnson_slack + }; + one_minus_delta.log2() +} + +/// log2 of the per-OOD-sample Schwartz-Zippel error: (k-1)/|F|. +pub fn ood_per_sample_log2(p: &CodeParams) -> f64 { + ((p.message_length - 1) as f64).log2() - p.field_bits +} + +/// PoW difficulty to close a soundness gap: max(0, target − achieved). +pub fn pow_bits_to_close_gap(target_security_bits: u32, achieved_security_bits: f64) -> Bits { + Bits::new((f64::from(target_security_bits) - achieved_security_bits).max(0.0)) +} diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/protocols/params/irs_commit.rs @@ -0,0 +1 @@ + diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs new file mode 100644 index 00000000..7d47774f --- /dev/null +++ b/src/protocols/params/mod.rs @@ -0,0 +1,6 @@ +// This module contains the parameter selection and security target logic. + +pub mod bounds; +pub mod irs_commit; +pub mod spec; +pub mod sumcheck; diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs new file mode 100644 index 00000000..9d39ef1f --- /dev/null +++ b/src/protocols/params/spec.rs @@ -0,0 +1,45 @@ +use core::marker::PhantomData; + +use crate::{algebra::embedding::Embedding, engines::EngineId}; + +/// Security spec definition for the protocol +pub struct SecuritySpec { + /// Protocol Mode of operation + pub mode: Mode, + /// Target security bits + pub target_security_bits: u32, + /// Size of the input witness / vector + pub vector_size: usize, + /// Starting log inverse rate for RS code + pub starting_log_inv_rate: u32, + /// Initial Folding factor for the first round of sumcheck + pub initial_folding_factor: usize, + /// Folding factor for subsequent round of sumcheck + pub folding_factor: usize, + /// POW bits + pub max_pow_bits: Option, + /// Hash Engine + pub hash_id: EngineId, + pub _embedding: PhantomData, +} + +/// Per round context struct for calculating the bounds +pub struct RoundContext { + /// Round index + pub round_index: usize, + /// Vector size for the particular round + pub vector_size: usize, + /// rate for the RS encoding for the round vector + pub log_inv_rate: u32, + /// Forlding factor for sumcheck + pub folding_factor: u32, + /// Previous round's in domain samples count + pub prev_round_in_domain_samples: usize, + /// To keep track of the errors of all the rounds + pub prev_round_query_error: f64, +} + +pub enum Mode { + Standard, + ZeroKnowledge, +} diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs new file mode 100644 index 00000000..1d7772ba --- /dev/null +++ b/src/protocols/params/sumcheck.rs @@ -0,0 +1,64 @@ +use crate::{ + algebra::embedding::Embedding, + protocols::{ + irs_commit, + params::{ + bounds::{self, CodeParams}, + spec::{Mode, RoundContext, SecuritySpec}, + }, + proof_of_work, sumcheck, + }, +}; + +pub fn solve( + spec: &SecuritySpec, + ctx: &RoundContext, + irs_source: &irs_commit::Config, +) -> sumcheck::Config { + let num_rounds = num_sumcheck_rounds(spec, ctx); + let mode = match spec.mode { + Mode::Standard => sumcheck::SumcheckMode::Standard, + Mode::ZeroKnowledge => sumcheck::SumcheckMode::ZeroKnowledge { + mask_length: zk_mask_length(), + }, + }; + let round_pow = solve_sumcheck_round_pow(spec, irs_source); + sumcheck::Config::new(ctx.vector_size, round_pow, num_rounds, mode) +} + +fn num_sumcheck_rounds(spec: &SecuritySpec, ctx: &RoundContext) -> usize { + if ctx.round_index == 0 { + spec.initial_folding_factor + } else { + spec.folding_factor + } +} + +pub fn masks_required(spec: &SecuritySpec, ctx: &RoundContext) -> usize { + match spec.mode { + Mode::Standard => 0, + Mode::ZeroKnowledge => num_sumcheck_rounds(spec, ctx), + } +} + +const fn zk_mask_length() -> usize { + 3 +} + +/// Sumcheck-specific PoW sizing: closes the per-round Lemma 6.5 soundness gap. +fn solve_sumcheck_round_pow( + spec: &SecuritySpec, + irs_source: &irs_commit::Config, +) -> proof_of_work::Config { + let code = CodeParams::from_irs(irs_source); + + // Lemma 6.5 per-round error has two terms; security in bits is the min. + // TODO: extend with `ℓ_zk · |Λ_C_zk|` factors in ZK mode once mask-code + // params are available (PR 2). + let sec_mca = -bounds::eps_mca_log2(&code); + let sec_combination = code.field_bits - bounds::list_size_log2(&code) - 1.0; + let achieved = sec_mca.min(sec_combination); + + let pow_bits = bounds::pow_bits_to_close_gap(spec.target_security_bits, achieved); + proof_of_work::Config::from_difficulty(pow_bits) +} diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index 5018e22d..2c04e01c 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -30,6 +30,12 @@ pub struct SumcheckOpening { pub mask_rlc: F, } +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum SumcheckMode { + Standard, + ZeroKnowledge { mask_length: usize }, +} + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config @@ -40,10 +46,42 @@ where pub initial_size: usize, pub round_pow: proof_of_work::Config, pub num_rounds: usize, - pub mask_length: usize, + pub mode: SumcheckMode, } impl Config { + pub fn new( + initial_size: usize, + round_pow: proof_of_work::Config, + num_rounds: usize, + mode: SumcheckMode, + ) -> Self { + assert!(num_rounds == 0 || initial_size.next_power_of_two() >= 1 << num_rounds); + if let SumcheckMode::ZeroKnowledge { mask_length } = &mode { + // Mask must cover all 3 sumcheck polynomial coefficients (c0, c1, c2). + assert!(*mask_length >= 3); + // Lemma 6.4 prerequisite. + assert!( + !F::ONE.double().is_zero(), + "ZK sumcheck requires char(F) ≠ 2" + ); + } + Self { + field: Type::new(), + initial_size, + round_pow, + num_rounds, + mode, + } + } + + fn mask_length(&self) -> usize { + match &self.mode { + SumcheckMode::Standard => 0, + SumcheckMode::ZeroKnowledge { mask_length } => *mask_length, + } + } + pub fn final_size(&self) -> usize { assert!( self.num_rounds == 0 || self.initial_size.next_power_of_two() >= 1 << self.num_rounds @@ -84,29 +122,21 @@ impl Config { assert!( self.num_rounds == 0 || self.initial_size.next_power_of_two() >= 1 << self.num_rounds ); - // Mask must cover all 3 sumcheck polynomial coefficients (c0, c1, c2) - assert!(self.mask_length == 0 || self.mask_length >= 3); assert_eq!(a.len(), self.initial_size); assert_eq!(b.len(), self.initial_size); debug_assert_eq!(dot(a, b), *sum); - assert_eq!(masks.len(), self.num_rounds * self.mask_length); + assert_eq!(masks.len(), self.num_rounds * self.mask_length()); + let half = F::from(2).inverse().unwrap(); + let polynomial_len = self.mask_length().max(3); - let mut mask_sum = F::ZERO; - let mut mask_rlc = F::ONE; - if self.mask_length > 0 && self.num_rounds > 0 { - let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(1)); - mask_sum = masks.chunks_exact(self.mask_length).map(eval_01).sum::() * sum_multiple; - prover_state.prover_message(&mask_sum); - mask_rlc = prover_state.verifier_message(); - } + let (mut mask_sum, mask_rlc) = self.maybe_send_initial_mask_sum(prover_state, masks); - // We do a staggered Sumcheck loop so we can merge the inner fold+compute loops. - let mut univariate = Vec::new(); + let mut univariate = Vec::with_capacity(polynomial_len); let mut round_challenges = Vec::with_capacity(self.num_rounds); let mut prev_round_challenge = None; for (round, mask) in - chunks_exact_or_empty(masks, self.mask_length, self.num_rounds).enumerate() + chunks_exact_or_empty(masks, self.mask_length(), self.num_rounds).enumerate() { // Fold and compute sumcheck polynomial in one pass. let (c0, c2) = if let Some(w) = prev_round_challenge { @@ -116,40 +146,33 @@ impl Config { }; let c1 = *sum - c0.double() - c2; - // Optionally mask with univariate - if mask.is_empty() { - prover_state.prover_messages(&[c0, c2]); - } else { - // Initialize to round masking univariate polynomial. - univariate.clear(); - let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(round + 1)); - univariate.extend(mask.iter().map(|m| sum_multiple * *m)); - - // Add constant term from previous and future masks. - univariate[0] += (mask_sum - sum_multiple * eval_01(mask)) * half; - - // Add plain sumcheck polynomial - univariate[0] += mask_rlc * c0; - univariate[1] += mask_rlc * c1; - univariate[2] += mask_rlc * c2; - - prover_state.prover_message(&univariate[0]); - prover_state.prover_messages(&univariate[2..]); + // Build round polynomial. In Standard (`mask = []`, `mask_rlc = 1`, + // `mask_sum = 0`) this collapses to `[c0, c1, c2]`. + univariate.clear(); + univariate.resize(polynomial_len, F::ZERO); + let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(round + 1)); + for (u, m) in univariate.iter_mut().zip(mask.iter()) { + *u = sum_multiple * *m; } + univariate[0] += (mask_sum - sum_multiple * eval_01(mask)) * half; + univariate[0] += mask_rlc * c0; + univariate[1] += mask_rlc * c1; + univariate[2] += mask_rlc * c2; + + prover_state.prover_message(&univariate[0]); + prover_state.prover_messages(&univariate[2..]); - // Receive the random evaluation point and update the sum + // Receive the random evaluation point and update the sum. self.round_pow.prove(prover_state); let r = prover_state.verifier_message::(); round_challenges.push(r); *sum = (c2 * r + c1) * r + c0; - if self.mask_length > 0 && self.num_rounds > 0 { - let masked_sum = univariate_evaluate(&univariate, r); - mask_sum = masked_sum - mask_rlc * *sum; - } + + mask_sum = univariate_evaluate(&univariate, r) - mask_rlc * *sum; prev_round_challenge = Some(r); } if let Some(w) = prev_round_challenge { - // Final fold of the inputs (no polynomial computation) + // Final fold of the inputs (no polynomial computation). fold(a, w); fold(b, w); } @@ -161,6 +184,32 @@ impl Config { } } + fn maybe_send_initial_mask_sum( + &self, + prover_state: &mut ProverState, + masks: &[F], + ) -> (F, F) + where + H: DuplexSpongeInterface, + R: CryptoRng + RngCore, + F: Codec<[H::U]>, + { + match &self.mode { + SumcheckMode::Standard => (F::ZERO, F::ONE), + SumcheckMode::ZeroKnowledge { mask_length } => { + if self.num_rounds == 0 { + return (F::ZERO, F::ONE); + } + let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(1)); + let mask_sum = + masks.chunks_exact(*mask_length).map(eval_01).sum::() * sum_multiple; + prover_state.prover_message(&mask_sum); + let mask_rlc = prover_state.verifier_message(); + (mask_sum, mask_rlc) + } + } + } + #[cfg_attr(feature = "tracing", instrument(skip_all))] pub fn verify( &self, @@ -176,17 +225,10 @@ impl Config { assert!( self.num_rounds == 0 || self.initial_size.next_power_of_two() >= 1 << self.num_rounds ); - // Mask must cover all 3 sumcheck polynomial coefficients (c0, c1, c2) - assert!(self.mask_length == 0 || self.mask_length >= 3); - - let mut mask_rlc = F::ONE; - if self.mask_length > 0 && self.num_rounds > 0 { - let mask_sum: F = verifier_state.prover_message()?; - mask_rlc = verifier_state.verifier_message(); - *sum = mask_sum + mask_rlc * *sum; - } - let mut univariate = vec![F::ZERO; self.mask_length.max(3)]; + let mask_rlc = self.maybe_receive_initial_mask_sum(verifier_state, sum)?; + + let mut univariate = vec![F::ZERO; self.mask_length().max(3)]; let mut round_challenges = Vec::with_capacity(self.num_rounds); for _ in 0..self.num_rounds { // Receive all but linear coefficient. @@ -195,17 +237,17 @@ impl Config { *c = verifier_state.prover_message()?; } - // Derive linear coefficient from relation `univariate(0) + univariate(1) = sum` + // Derive linear coefficient from relation `univariate(0) + univariate(1) = sum`. univariate[1] = *sum - univariate[0].double() - univariate[2..].iter().sum::(); - // Check proof of work (if any) + // Check proof of work (if any). self.round_pow.verify(verifier_state)?; - // Receive the random evaluation point + // Receive the random evaluation point. let round_challenge = verifier_state.verifier_message::(); round_challenges.push(round_challenge); - // Update the sum + // Update the sum. *sum = univariate_evaluate(&univariate, round_challenge); } Ok(SumcheckOpening { @@ -213,17 +255,44 @@ impl Config { mask_rlc, }) } + + fn maybe_receive_initial_mask_sum( + &self, + verifier_state: &mut VerifierState, + sum: &mut F, + ) -> VerificationResult + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + { + match &self.mode { + SumcheckMode::Standard => Ok(F::ONE), + SumcheckMode::ZeroKnowledge { .. } => { + if self.num_rounds == 0 { + return Ok(F::ONE); + } + let mask_sum: F = verifier_state.prover_message()?; + let mask_rlc = verifier_state.verifier_message(); + *sum = mask_sum + mask_rlc * *sum; + Ok(mask_rlc) + } + } + } } impl fmt::Display for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mode_str = match &self.mode { + SumcheckMode::Standard => "standard".to_string(), + SumcheckMode::ZeroKnowledge { mask_length } => format!("zk ℓ_zk={mask_length}"), + }; write!( f, - "size {} rounds {} pow {:.2} ℓ_zk {}", + "size {} rounds {} pow {:.2} {}", self.initial_size, self.num_rounds, self.round_pow.difficulty(), - self.mask_length + mode_str, ) } } @@ -261,21 +330,20 @@ mod tests { Standard: Distribution, { pub fn arbitrary() -> impl Strategy { - let mask_length = prop_oneof![ - 3 => Just(0_usize), - 7 => 3_usize..20, + let mode_strategy = prop_oneof![ + 3 => Just(SumcheckMode::Standard), + 7 => (3_usize..20).prop_map(|mask_length| SumcheckMode::ZeroKnowledge { mask_length }), ]; - (0_usize..(1 << 12), 0_usize..12, mask_length).prop_map( - |(initial_size, num_rounds, mask_length)| { + (0_usize..(1 << 12), 0_usize..12, mode_strategy).prop_map( + |(initial_size, num_rounds, mode)| { let num_rounds = num_rounds.min(initial_size.next_power_of_two().trailing_zeros() as usize); - Self { - field: Type::new(), + Config::new( initial_size, + proof_of_work::Config::none(), num_rounds, - round_pow: proof_of_work::Config::none(), - mask_length, - } + mode, + ) }, ) } @@ -296,7 +364,7 @@ mod tests { let initial_vector = random_vector(&mut rng, config.initial_size); let initial_covector = random_vector(&mut rng, config.initial_size); let initial_sum = dot(&initial_vector, &initial_covector); - let masks = random_vector(&mut rng, config.mask_length * config.num_rounds); + let masks = random_vector(&mut rng, config.mask_length() * config.num_rounds); // Prover let mut vector = initial_vector.clone(); @@ -323,7 +391,7 @@ mod tests { } let expected_mask_sum: F = - chunks_exact_or_empty(&masks, config.mask_length, config.num_rounds) + chunks_exact_or_empty(&masks, config.mask_length(), config.num_rounds) .zip(&point) .map(|(m, x)| univariate_evaluate(m, *x)) .sum(); @@ -345,8 +413,8 @@ mod tests { assert_eq!(verifier_sum, sum); verifier_state.check_eof().unwrap(); - // Non-ZK path: mask_rlc defaults to ONE (no combination randomness sampled) - if config.mask_length == 0 || config.num_rounds == 0 { + // Standard path: mask_rlc defaults to ONE (no combination randomness sampled). + if matches!(config.mode, SumcheckMode::Standard) || config.num_rounds == 0 { assert_eq!(mask_rlc, F::ONE); } } @@ -365,13 +433,12 @@ mod tests { fn test_single_round() { test_config( 0, - &Config:: { - field: Type::new(), - initial_size: 2, - round_pow: proof_of_work::Config::none(), - num_rounds: 1, - mask_length: 3, - }, + &Config::::new( + 2, + proof_of_work::Config::none(), + 1, + SumcheckMode::ZeroKnowledge { mask_length: 3 }, + ), ); } @@ -379,13 +446,12 @@ mod tests { fn test_two_rounds() { test_config( 0, - &Config:: { - field: Type::new(), - initial_size: 3, - round_pow: proof_of_work::Config::none(), - num_rounds: 2, - mask_length: 3, - }, + &Config::::new( + 3, + proof_of_work::Config::none(), + 2, + SumcheckMode::ZeroKnowledge { mask_length: 3 }, + ), ); } @@ -393,13 +459,12 @@ mod tests { fn test_three_rounds() { test_config( 0, - &Config:: { - field: Type::new(), - initial_size: 5, - round_pow: proof_of_work::Config::none(), - num_rounds: 3, - mask_length: 3, - }, + &Config::::new( + 5, + proof_of_work::Config::none(), + 3, + SumcheckMode::ZeroKnowledge { mask_length: 3 }, + ), ); } diff --git a/src/protocols/whir/config.rs b/src/protocols/whir/config.rs index 89892425..aaf4ce94 100644 --- a/src/protocols/whir/config.rs +++ b/src/protocols/whir/config.rs @@ -8,7 +8,6 @@ use crate::{ bits::Bits, parameters::ProtocolParameters, protocols::{irs_commit, proof_of_work, sumcheck}, - type_info::Type, }; impl Config { @@ -45,6 +44,7 @@ impl Config { size, 1 << whir_parameters.initial_folding_factor, 0.5_f64.powi(whir_parameters.starting_log_inv_rate as i32), + 0, ); // Initial sumcheck round pow bits. @@ -85,6 +85,7 @@ impl Config { 1 << num_variables, 1 << whir_parameters.folding_factor, 0.5_f64.powi(next_rate as i32), + 0, ); let combination_error = { let log_list_size = irs_committer.list_size().log2(); @@ -103,13 +104,12 @@ impl Config { let config = RoundConfig { irs_committer, - sumcheck: sumcheck::Config { - field: Type::new(), - initial_size: 1 << num_variables, - round_pow: pow(folding_pow_bits), - num_rounds: whir_parameters.folding_factor, - mask_length: 0, - }, + sumcheck: sumcheck::Config::new( + 1 << num_variables, + pow(folding_pow_bits), + whir_parameters.folding_factor, + sumcheck::SumcheckMode::Standard, + ), pow: pow(pow_bits), }; @@ -131,22 +131,20 @@ impl Config { Self { initial_committer, - initial_sumcheck: sumcheck::Config { - field: Type::new(), - initial_size: size, - round_pow: pow(starting_folding_pow_bits), - num_rounds: whir_parameters.initial_folding_factor, - mask_length: 0, - }, + initial_sumcheck: sumcheck::Config::new( + size, + pow(starting_folding_pow_bits), + whir_parameters.initial_folding_factor, + sumcheck::SumcheckMode::Standard, + ), initial_skip_pow: pow(initial_skip_pow_bits), round_configs, - final_sumcheck: sumcheck::Config { - field: Type::new(), - initial_size: 1 << num_variables, - round_pow: pow(final_folding_pow_bits), - num_rounds: num_variables, - mask_length: 0, - }, + final_sumcheck: sumcheck::Config::new( + 1 << num_variables, + pow(final_folding_pow_bits), + num_variables, + sumcheck::SumcheckMode::Standard, + ), final_pow: pow(final_pow_bits), } } @@ -539,13 +537,12 @@ mod tests { out_domain_samples: 2, deduplicate_in_domain: true, }, - sumcheck: sumcheck::Config { - field: Type::::new(), - initial_size: 1 << 10, - round_pow: proof_of_work::Config::from_difficulty(Bits::new(19.0)), - num_rounds: 2, - mask_length: 0, - }, + sumcheck: sumcheck::Config::::new( + 1 << 10, + proof_of_work::Config::from_difficulty(Bits::new(19.0)), + 2, + sumcheck::SumcheckMode::Standard, + ), pow: proof_of_work::Config::from_difficulty(Bits::new(17.0)), }, RoundConfig { @@ -562,13 +559,12 @@ mod tests { out_domain_samples: 2, deduplicate_in_domain: true, }, - sumcheck: sumcheck::Config { - field: Type::::new(), - initial_size: 1 << 10, - round_pow: proof_of_work::Config::from_difficulty(Bits::new(19.5)), - num_rounds: 2, - mask_length: 0, - }, + sumcheck: sumcheck::Config::::new( + 1 << 10, + proof_of_work::Config::from_difficulty(Bits::new(19.5)), + 2, + sumcheck::SumcheckMode::Standard, + ), pow: proof_of_work::Config::from_difficulty(Bits::new(18.0)), }, ]; From 425538a82575134746f9ba323b1ce3beab01ac00 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 7 May 2026 18:51:39 +0530 Subject: [PATCH 02/31] refactored OOD sample logic from irs commit --- src/protocols/basecase.rs | 7 +- src/protocols/code_switch.rs | 27 ++-- src/protocols/irs_commit.rs | 235 ++++++++++------------------- src/protocols/mask_proximity.rs | 18 +-- src/protocols/whir/config.rs | 76 ++++++++-- src/protocols/whir/mod.rs | 53 ++++++- src/protocols/whir/prover.rs | 30 ++-- src/protocols/whir/verifier.rs | 26 ++-- src/protocols/whir_zk/committer.rs | 4 +- 9 files changed, 243 insertions(+), 233 deletions(-) diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 95739ee4..b5b251d3 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -143,7 +143,7 @@ impl Config { pub fn verify( &self, verifier_state: &mut VerifierState, - commitment: &irs_commit::Commitment, + commitment: &irs_commit::Commitment, mut sum: F, ) -> VerificationResult> where @@ -248,10 +248,7 @@ mod tests { let commit = irs_commit::Config::arbitrary(Identity::::new(), 1, size, mask_length, 1); (commit, bool::weighted(0.8)).prop_map(move |(commit, masked)| Self { - commit: irs_commit::Config { - out_domain_samples: 0, - ..commit - }, + commit: irs_commit::Config { ..commit }, sumcheck: sumcheck::Config::new( size, proof_of_work::Config::none(), diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 62bbb602..73d60bec 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -48,7 +48,7 @@ pub struct Witness { } /// Verifier output from the code-switch. -pub type Commitment = IrsCommitment; +pub type Commitment = IrsCommitment; /// Mask input for the code-switch prover. // TODO : This may be removed after parameter selection PR @@ -101,9 +101,12 @@ impl Config { message_mask_length - source_config.mask_length >= out_domain_samples, "the sampled randomness (s) length must be covering all the out of domain sample requests" ); + // t' = (in-domain queries to g via target IRS) + // + (OOD queries to g via Construction 9.7's OOD step, count = out_domain_samples). + // Lemma 9.5 perfect-ZK: t' ≤ r' = target.mask_length. assert!( target_config.mask_length - >= target_config.in_domain_samples + target_config.out_domain_samples, + >= target_config.in_domain_samples + out_domain_samples, "target encoder violates: t' > r', number of queries should be covered by random mask" ); } @@ -154,7 +157,7 @@ impl Config { &self, prover_state: &mut ProverState, message: Vec, - witness: &IrsWitness, + witness: &IrsWitness, covector: &mut [M::Target], folding_randomness: &[M::Target], mask_input: &MaskInput<'_, M::Target>, @@ -283,8 +286,8 @@ impl Config { verifier_state: &mut VerifierState, sum: &mut M::Target, folding_randomness: &[M::Target], - commitment: &IrsCommitment, - ) -> VerificationResult> + commitment: &IrsCommitment, + ) -> VerificationResult where H: DuplexSpongeInterface, Standard: Distribution, @@ -376,19 +379,18 @@ mod tests { 0_usize..=5, // fresh_s_len (≥ ood for assumption (c)) select(vec![1_usize, 2, 4]), // ι_s (source interleaving) 0_usize..=10, // target.in_domain_samples (t'_in) - 0_usize..=10, // target.out_domain_samples (t'_out) ); scalars.prop_flat_map( - move |(size, src_mask_len, zk, ood, fresh_s_len, iota_s, t_in, t_out)| { + move |(size, src_mask_len, zk, ood, fresh_s_len, iota_s, t_in)| { // Bound 3 assumption (c): ℓ_zk - r ≥ t_ood ⇒ fresh_s_len ≥ ood. let fresh_s_len = if zk { fresh_s_len.max(ood) } else { fresh_s_len }; - // Bound 4 assumption (a): target.mask_length ≥ t' = t_in + t_out. - let target_mask = if zk { t_in + t_out } else { 0 }; + // Bound 4 assumption (a): target.mask_length ≥ t' = t_in + ood. + let target_mask = if zk { t_in + ood } else { 0 }; // ZK with source.mask_length = 0 is valid: the assert // `source.mask_length == 0 || message_mask_length > 0` // is trivially satisfied. Allows testing the corner @@ -416,13 +418,12 @@ mod tests { ); let source = source.clone(); target.prop_map(move |mut target| { - // IrsConfig::arbitrary samples query counts in - // [0,10] independently of mask_length; pin them - // to the values target_mask was sized for so + // IrsConfig::arbitrary samples in_domain_samples + // in [0,10] independently of mask_length; pin it + // to the value target_mask was sized for so // assumption (a) holds. if zk { target.in_domain_samples = t_in; - target.out_domain_samples = t_out; } // r = post-fold randomness length (ι_s parallel // masks fold to a single length-mask_length chunk). diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index db41c356..d1cc9828 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -7,15 +7,9 @@ //! using an NTT friendly Reed-Solomon code to produce a `num_vectors * interleaving_depth` //! by `codeword_size` matrix. This matrix is committed using the [`matrix_commit`] protocol. //! -//! After committing the encoded matrix, the protocol generates a random Reed-Solomon code of -//! length `out_domain_samples` over an extension field `G` of `F` and encodes the original -//! matrix using this code to produce a `num_vectors` by `out_domain_samples` matrix over `G`. -//! Together, these two encoded matrices form a commitment to the original matrix. -//! //! On opening the commitment, the protocol randomly selects `in_domain_samples` rows and opens -//! it using the [`matrix_commit`] protocol. Sampling is done with replacement, so may produce -//! fewer than `in_domain_samples` distinct rows. This produces `in_domain_samples` evaluation -//! points in `F` and `in_domain_samples` by `num_vectors * interleaving_depth`. +//! them using the [`matrix_commit`] protocol. Sampling is done with replacement, so may produce +//! fewer than `in_domain_samples` distinct rows. //! use std::{ f64::{self, consts::LOG2_10}, @@ -44,7 +38,6 @@ use crate::{ }, type_info::Typed, utils::{chunks_exact_or_empty, zip_strict}, - verify, }; /// Commit to vectors over an fft-friendly field F @@ -79,9 +72,6 @@ pub struct Config { /// The number of in-domain samples. pub in_domain_samples: usize, - /// The number of out-of-domain samples. - pub out_domain_samples: usize, - /// Whether to sort and deduplicate the in-domain samples. /// /// Deduplication can slightly reduce proof size and prover/verifier @@ -92,21 +82,16 @@ pub struct Config { #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] #[must_use] -pub struct Witness -where - G: Field, -{ +pub struct Witness { pub masks: Vec, pub matrix: Vec, pub matrix_witness: matrix_commit::Witness, - pub out_of_domain: Evaluations, } #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] #[must_use] -pub struct Commitment { - matrix_commitment: matrix_commit::Commitment, - out_of_domain: Evaluations, +pub struct Commitment { + pub matrix_commitment: matrix_commit::Commitment, } /// Interleaved Reed-Solomon code. @@ -150,16 +135,6 @@ impl Config { } else { rate.sqrt() / 20. }; - let out_domain_samples = { - let list_size = 1. / (2. * johnson_slack * rate.sqrt()); - num_ood_samples( - unique_decoding, - security_target, - M::Target::field_size_bits(), - list_size, - vector_size, - ) - }; #[allow(clippy::cast_sign_loss)] let in_domain_samples = { // Query error is (1 - δ)^q, so we compute 1 - δ @@ -191,7 +166,6 @@ impl Config { ), johnson_slack: OrderedFloat(johnson_slack), in_domain_samples, - out_domain_samples, deduplicate_in_domain: false, } } @@ -235,7 +209,7 @@ impl Config { } pub fn unique_decoding(&self) -> bool { - self.out_domain_samples == 0 && self.johnson_slack == 0.0 + self.johnson_slack == 0.0 } /// Compute a list size bound. @@ -248,16 +222,6 @@ impl Config { } } - /// Round-by-round soundness of the out-of-domain samples in bits. - pub fn rbr_ood_sample(&self) -> f64 { - let list_size = self.list_size(); - let log_field_size = M::Target::field_size_bits(); - // See [STIR] lemma 4.5. - let l_choose_2 = list_size * (list_size - 1.) / 2.; - let log_per_sample = ((self.vector_size - 1) as f64).log2() - log_field_size; - -l_choose_2.log2() - self.out_domain_samples as f64 * log_per_sample - } - /// Round-by-round soundness of the in-domain queries in bits. pub fn rbr_queries(&self) -> f64 { let per_sample = if self.unique_decoding() { @@ -294,7 +258,7 @@ impl Config { &self, prover_state: &mut ProverState, vectors: &[&[M::Source]], - ) -> Witness + ) -> Witness where Standard: Distribution, H: DuplexSpongeInterface, @@ -324,29 +288,10 @@ impl Config { // Commit to the matrix let matrix_witness = self.matrix_commit.commit(prover_state, &matrix); - // Handle out-of-domain points and values - // TODO : Remove this logic after main whir protocol is updated - // as this is not required in the new construction. This will be - // removed in next PR (Parameter Selection) - let oods_points: Vec = - prover_state.verifier_message_vec(self.out_domain_samples); - let mut oods_matrix = Vec::with_capacity(self.out_domain_samples * self.num_vectors); - for &point in &oods_points { - for &vector in vectors { - let value = mixed_univariate_evaluate(&*self.embedding, vector, point); - prover_state.prover_message(&value); - oods_matrix.push(value); - } - } - Witness { masks, matrix, matrix_witness, - out_of_domain: Evaluations { - points: oods_points, - matrix: oods_matrix, - }, } } @@ -355,24 +300,67 @@ impl Config { pub fn receive_commitment( &self, verifier_state: &mut VerifierState, - ) -> VerificationResult> + ) -> VerificationResult where H: DuplexSpongeInterface, Hash: ProverMessage<[H::U]>, M::Target: Codec<[H::U]>, { let matrix_commitment = self.matrix_commit.receive_commitment(verifier_state)?; - let oods_points: Vec = - verifier_state.verifier_message_vec(self.out_domain_samples); - let oods_matrix = - verifier_state.prover_messages_vec(self.out_domain_samples * self.num_vectors)?; - Ok(Commitment { - matrix_commitment, - out_of_domain: Evaluations { - points: oods_points, - matrix: oods_matrix, - }, - }) + Ok(Commitment { matrix_commitment }) + } + + /// Commit to vectors and run the legacy WHIR OOD step in one call. + /// + /// Layered helper bundling `commit` + the OOD message exchange (sample + /// `out_domain_samples` random points, send each vector's evaluation at + /// each point). Used by the legacy WHIR protocol while the OOD step + /// is still part of the per-commit protocol shape; the new construction + /// (Construction 9.7) handles OOD at the code-switch level instead. + #[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))] + pub fn commit_with_ood( + &self, + prover_state: &mut ProverState, + vectors: &[&[M::Source]], + out_domain_samples: usize, + ) -> (Witness, Evaluations) + where + Standard: Distribution, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + M::Target: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let witness = self.commit(prover_state, vectors); + let points: Vec = prover_state.verifier_message_vec(out_domain_samples); + let mut matrix = Vec::with_capacity(out_domain_samples * vectors.len()); + for &point in &points { + for &vector in vectors { + let value = mixed_univariate_evaluate(&*self.embedding, vector, point); + prover_state.prover_message(&value); + matrix.push(value); + } + } + (witness, Evaluations { points, matrix }) + } + + /// Receive a commitment and the legacy WHIR OOD evaluations in one call. + /// Verifier mirror of `commit_with_ood`. + #[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))] + pub fn receive_commitment_with_ood( + &self, + verifier_state: &mut VerifierState, + out_domain_samples: usize, + ) -> VerificationResult<(Commitment, Evaluations)> + where + H: DuplexSpongeInterface, + Hash: ProverMessage<[H::U]>, + M::Target: Codec<[H::U]>, + { + let commitment = self.receive_commitment(verifier_state)?; + let points: Vec = verifier_state.verifier_message_vec(out_domain_samples); + let matrix = verifier_state.prover_messages_vec(out_domain_samples * self.num_vectors)?; + Ok((commitment, Evaluations { points, matrix })) } /// Opens the commitment and returns the evaluations of the vectors. @@ -386,7 +374,7 @@ impl Config { pub fn open( &self, prover_state: &mut ProverState, - witnesses: &[&Witness], + witnesses: &[&Witness], ) -> Evaluations where H: DuplexSpongeInterface, @@ -396,11 +384,6 @@ impl Config { { for witness in witnesses { assert_eq!(witness.matrix.len(), self.size()); - assert_eq!(witness.out_of_domain.points.len(), self.out_domain_samples); - assert_eq!( - witness.out_of_domain.matrix.len(), - self.out_domain_samples * self.num_vectors - ); } // Get in-domain openings @@ -439,20 +422,13 @@ impl Config { pub fn verify( &self, verifier_state: &mut VerifierState, - commitments: &[&Commitment], + commitments: &[&Commitment], ) -> VerificationResult> where H: DuplexSpongeInterface, u8: Decoding<[H::U]>, Hash: ProverMessage<[H::U]>, { - for commitment in commitments { - verify!(commitment.out_of_domain.points.len() == self.out_domain_samples); - verify!( - commitment.out_of_domain.matrix.len() == self.num_vectors * self.out_domain_samples - ); - } - // Get in-domain openings let (indices, points) = self.in_domain_challenges(verifier_state); @@ -501,28 +477,6 @@ impl Config { } } -impl Commitment { - /// Returns the out-of-domain evaluations. - pub const fn out_of_domain(&self) -> &Evaluations { - &self.out_of_domain - } - - pub fn num_vectors(&self) -> usize { - self.out_of_domain().num_columns() - } -} - -impl Witness { - /// Returns the out-of-domain evaluations. - pub const fn out_of_domain(&self) -> &Evaluations { - &self.out_of_domain - } - - pub fn num_vectors(&self) -> usize { - self.out_of_domain().num_columns() - } -} - impl Evaluations { pub const fn num_points(&self) -> usize { self.points.len() @@ -569,11 +523,7 @@ impl fmt::Display for Config { self.num_vectors, self.vector_size, self.interleaving_depth, )?; write!(f, " rate 2⁻{:.2}", -self.rate().log2())?; - write!( - f, - " samples {} in- {} out-domain", - self.in_domain_samples, self.out_domain_samples - ) + write!(f, " samples {} in-domain", self.in_domain_samples) } } @@ -583,7 +533,7 @@ impl fmt::Display for Config { /// where `L` is the list size and `degree` is the polynomial degree bound. /// See [STIR] Lemma 4.5. #[allow(clippy::cast_sign_loss)] -pub(crate) fn num_ood_samples( +pub fn num_ood_samples( unique_decoding: bool, security_target: f64, field_size_bits: f64, @@ -684,24 +634,24 @@ pub(crate) mod tests { ) }); - (codeword_matrix, 0_usize..=10, 0_usize..=10, bool::ANY).prop_map( + (codeword_matrix, 0_usize..=10, bool::ANY).prop_map( move |( (codeword_length, matrix_commit), in_domain_samples, - out_domain_samples, - deduplicate_in_domain, - )| Self { - embedding: Typed::new(embedding.clone()), - num_vectors, - vector_size, - mask_length, - codeword_length, - interleaving_depth, - matrix_commit, - johnson_slack: OrderedFloat::default(), - in_domain_samples, - out_domain_samples, deduplicate_in_domain, + )| { + Self { + embedding: Typed::new(embedding.clone()), + num_vectors, + vector_size, + mask_length, + codeword_length, + interleaving_depth, + matrix_commit, + johnson_slack: OrderedFloat::default(), + in_domain_samples, + deduplicate_in_domain, + } }, ) } @@ -733,30 +683,6 @@ pub(crate) mod tests { &mut prover_state, &vectors.iter().map(|p| p.as_slice()).collect::>(), ); - assert_eq!( - witness.out_of_domain().points.len(), - config.out_domain_samples - ); - assert_eq!( - witness.out_of_domain().matrix.len(), - config.out_domain_samples * config.num_vectors - ); - if config.num_vectors > 0 { - for (point, evals) in zip_strict( - witness.out_of_domain().points.iter(), - witness - .out_of_domain() - .matrix - .chunks_exact(config.num_vectors), - ) { - for (vector, expected) in zip_strict(vectors.iter(), evals.iter()) { - assert_eq!( - mixed_univariate_evaluate(config.embedding(), vector, *point), - *expected - ); - } - } - } let in_domain_evals = config.open(&mut prover_state, &[&witness]); if config.deduplicate_in_domain { // Sorting is over index order, not points @@ -800,7 +726,6 @@ pub(crate) mod tests { // Verifier let mut verifier_state = VerifierState::new_std(&ds, &proof); let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - assert_eq!(commitment.out_of_domain(), witness.out_of_domain()); let verifier_in_domain_evals = config.verify(&mut verifier_state, &[&commitment]).unwrap(); assert_eq!(&verifier_in_domain_evals, &in_domain_evals); verifier_state.check_eof().unwrap(); diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 1d632ca0..7d625865 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -76,7 +76,7 @@ pub struct Witness { } /// Verifier output from the commit phase. -pub type Commitment = IrsCommitment; +pub type Commitment = IrsCommitment; impl Config { pub fn new(c_zk_commit: IrsConfig>, num_masks: usize) -> Self { @@ -89,14 +89,6 @@ impl Config { c_zk_commit.interleaving_depth, 1, "mask proximity requires interleaving_depth = 1" ); - // OOD evaluations are sent in the clear during IRS commit/receive, - // which would leak raw mask values before the γ-combination and - // break the ZK contract. The OOD path in irs_commit is slated for - // removal in the new construction; until then, enforce zero here. - assert_eq!( - c_zk_commit.out_domain_samples, 0, - "mask proximity requires out_domain_samples = 0 (OOD openings would leak raw mask evaluations)" - ); Self { c_zk_commit, num_masks, @@ -148,7 +140,7 @@ impl Config { pub fn receive_commitment( &self, verifier_state: &mut VerifierState, - ) -> VerificationResult> + ) -> VerificationResult where F: Codec<[H::U]>, H: DuplexSpongeInterface, @@ -214,7 +206,7 @@ impl Config { pub fn verify( &self, verifier_state: &mut VerifierState, - commitment: &Commitment, + commitment: &Commitment, ) -> VerificationResult<()> where F: Codec<[H::U]>, @@ -320,10 +312,6 @@ mod tests { ); (Just(num_masks), c_zk) }) - .prop_filter( - "mask proximity requires out_domain_samples = 0", - |(_, c_zk)| c_zk.out_domain_samples == 0, - ) .prop_map(|(num_masks, c_zk)| Self::new(c_zk, num_masks)) } } diff --git a/src/protocols/whir/config.rs b/src/protocols/whir/config.rs index aaf4ce94..a81bbb87 100644 --- a/src/protocols/whir/config.rs +++ b/src/protocols/whir/config.rs @@ -7,9 +7,26 @@ use crate::{ algebra::{embedding::Embedding, fields::FieldWithSize}, bits::Bits, parameters::ProtocolParameters, - protocols::{irs_commit, proof_of_work, sumcheck}, + protocols::{ + irs_commit::{self, num_ood_samples}, + proof_of_work, sumcheck, + }, }; +/// log2 round-by-round soundness of `t_ood` OOD samples against a code with +/// the given list size — formerly `irs_commit::Config::rbr_ood_sample`. +fn rbr_ood_sample( + list_size: f64, + log_field_size: f64, + vector_size: usize, + out_domain_samples: usize, +) -> f64 { + // [STIR] Lemma 4.5. + let l_choose_2 = list_size * (list_size - 1.) / 2.; + let log_per_sample = ((vector_size - 1) as f64).log2() - log_field_size; + -l_choose_2.log2() - out_domain_samples as f64 * log_per_sample +} + impl Config { #[allow(clippy::too_many_lines)] pub fn new(size: usize, whir_parameters: &ProtocolParameters) -> Self @@ -46,6 +63,13 @@ impl Config { 0.5_f64.powi(whir_parameters.starting_log_inv_rate as i32), 0, ); + let initial_out_domain_samples = num_ood_samples( + whir_parameters.unique_decoding, + protocol_security_level, + field_size_bits, + initial_committer.list_size(), + size, + ); // Initial sumcheck round pow bits. let starting_folding_pow_bits = { @@ -87,9 +111,16 @@ impl Config { 0.5_f64.powi(next_rate as i32), 0, ); + let round_out_domain_samples = num_ood_samples( + whir_parameters.unique_decoding, + protocol_security_level, + field_size_bits, + irs_committer.list_size(), + 1 << num_variables, + ); let combination_error = { let log_list_size = irs_committer.list_size().log2(); - let count = irs_committer.out_domain_samples + in_domain_samples; + let count = round_out_domain_samples + in_domain_samples; let log_combination = (count as f64).log2(); field_size_bits - (log_combination + log_list_size + 1.) }; @@ -104,6 +135,7 @@ impl Config { let config = RoundConfig { irs_committer, + out_domain_samples: round_out_domain_samples, sumcheck: sumcheck::Config::new( 1 << num_variables, pow(folding_pow_bits), @@ -131,6 +163,7 @@ impl Config { Self { initial_committer, + initial_out_domain_samples, initial_sumcheck: sumcheck::Config::new( size, pow(starting_folding_pow_bits), @@ -169,11 +202,15 @@ impl Config { security_level = security_level.min(field_size_bits - ((num_linear_forms - 1) as f64).log2()); } - let has_initial_constraints = - num_linear_forms > 0 || self.initial_committer.out_domain_samples > 0; + let has_initial_constraints = num_linear_forms > 0 || self.initial_out_domain_samples > 0; if !self.initial_committer.unique_decoding() { - security_level = security_level.min(self.initial_committer.rbr_ood_sample()); + security_level = security_level.min(rbr_ood_sample( + self.initial_committer.list_size(), + field_size_bits, + self.initial_committer.vector_size, + self.initial_out_domain_samples, + )); } // Initial sumcheck error (or the skipped version for LDT). @@ -198,13 +235,18 @@ impl Config { let new_unique_decoding = round.irs_committer.unique_decoding(); if !new_unique_decoding { - let ood_error = round.irs_committer.rbr_ood_sample(); + let ood_error = rbr_ood_sample( + round.irs_committer.list_size(), + field_size_bits, + round.irs_committer.vector_size, + round.out_domain_samples, + ); security_level = security_level.min(ood_error); } let log_list_size = round.irs_committer.list_size().log2(); let combination_error = { - let count = round.irs_committer.out_domain_samples + old_in_domain_samples; + let count = round.out_domain_samples + old_in_domain_samples; let log_combination = (count as f64).log2(); field_size_bits - (log_combination + log_list_size + 1.) }; @@ -345,7 +387,12 @@ impl Display for Config { writeln!( f, "{:.1} bits -- OOD commitment", - self.initial_committer.rbr_ood_sample() + rbr_ood_sample( + self.initial_committer.list_size(), + field_size_bits, + self.initial_committer.vector_size, + self.initial_out_domain_samples, + ) )?; } let prox_gaps_error = self.initial_committer.rbr_soundness_fold_prox_gaps(); @@ -370,13 +417,18 @@ impl Display for Config { writeln!( f, "{:.1} bits -- OOD sample", - r.irs_committer.rbr_ood_sample() + rbr_ood_sample( + r.irs_committer.list_size(), + field_size_bits, + r.irs_committer.vector_size, + r.out_domain_samples, + ) )?; } let log_list_size = r.irs_committer.list_size().log2(); let combination_error = { - let count = r.irs_committer.out_domain_samples + old_in_domain_samples; + let count = r.out_domain_samples + old_in_domain_samples; let log_combination = (count as f64).log2(); field_size_bits - (log_combination + log_list_size + 1.) }; @@ -534,9 +586,9 @@ mod tests { matrix_commit: matrix_commit::Config::::new(0, 0), johnson_slack: OrderedFloat::default(), in_domain_samples: 5, - out_domain_samples: 2, deduplicate_in_domain: true, }, + out_domain_samples: 2, sumcheck: sumcheck::Config::::new( 1 << 10, proof_of_work::Config::from_difficulty(Bits::new(19.0)), @@ -556,9 +608,9 @@ mod tests { matrix_commit: matrix_commit::Config::::new(0, 0), johnson_slack: OrderedFloat::default(), in_domain_samples: 6, - out_domain_samples: 2, deduplicate_in_domain: true, }, + out_domain_samples: 2, sumcheck: sumcheck::Config::::new( 1 << 10, proof_of_work::Config::from_difficulty(Bits::new(19.5)), diff --git a/src/protocols/whir/mod.rs b/src/protocols/whir/mod.rs index 805a206e..d5828018 100644 --- a/src/protocols/whir/mod.rs +++ b/src/protocols/whir/mod.rs @@ -30,6 +30,9 @@ use crate::{ #[serde(bound = "")] pub struct Config { pub initial_committer: irs_commit::Config, + /// OOD samples on the initial commit (Construction 9.7-style OOD step, + /// formerly inside `irs_commit::commit`). + pub initial_out_domain_samples: usize, pub initial_sumcheck: sumcheck::Config, pub initial_skip_pow: proof_of_work::Config, pub round_configs: Vec>, @@ -41,12 +44,42 @@ pub struct Config { #[serde(bound = "")] pub struct RoundConfig { pub irs_committer: irs_commit::Config>, + /// OOD samples for this round's commit. + pub out_domain_samples: usize, pub sumcheck: sumcheck::Config, pub pow: proof_of_work::Config, } -pub type Witness> = irs_commit::Witness; -pub type Commitment = irs_commit::Commitment; +/// WHIR-level witness: IRS witness + OOD evaluations sampled at commit time. +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[serde(bound( + serialize = "M::Source: Serialize, F: Serialize", + deserialize = "M::Source: Deserialize<'de>, F: Deserialize<'de>" +))] +pub struct Witness> { + pub irs: irs_commit::Witness, + pub out_of_domain: irs_commit::Evaluations, +} + +/// WHIR-level commitment: IRS commitment + OOD evaluations received at commit time. +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))] +pub struct Commitment { + pub irs: irs_commit::Commitment, + pub out_of_domain: irs_commit::Evaluations, +} + +impl> Witness { + pub fn num_vectors(&self) -> usize { + self.out_of_domain.num_columns() + } +} + +impl Commitment { + pub fn num_vectors(&self) -> usize { + self.out_of_domain.num_columns() + } +} #[must_use = "The final claim must be checked if there where any linear forms."] #[derive(Debug, Clone, PartialEq, Eq, Default)] @@ -75,6 +108,10 @@ impl FinalClaim { impl Config { /// Commit to one or more vectors. + /// + /// After the IRS commit, runs the legacy WHIR OOD step: samples + /// `initial_out_domain_samples` random points from the verifier and sends + /// each vector's evaluation at each point. #[cfg_attr(feature = "tracing", instrument(skip_all, fields(size = vectors.first().unwrap().len())))] pub fn commit( &self, @@ -88,7 +125,12 @@ impl Config { M::Target: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - self.initial_committer.commit(prover_state, vectors) + let (irs, out_of_domain) = self.initial_committer.commit_with_ood( + prover_state, + vectors, + self.initial_out_domain_samples, + ); + Witness { irs, out_of_domain } } /// Receive a commitment to vectors. @@ -101,7 +143,10 @@ impl Config { M::Target: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - self.initial_committer.receive_commitment(verifier_state) + let (irs, out_of_domain) = self + .initial_committer + .receive_commitment_with_ood(verifier_state, self.initial_out_domain_samples)?; + Ok(Commitment { irs, out_of_domain }) } /// Disable proof-of-work for test. diff --git a/src/protocols/whir/prover.rs b/src/protocols/whir/prover.rs index 04b4473e..1a300cc0 100644 --- a/src/protocols/whir/prover.rs +++ b/src/protocols/whir/prover.rs @@ -26,8 +26,8 @@ use crate::{ }; enum RoundWitness<'a, F: Field, M: Embedding> { - Initial(Vec>>), - Round(irs_commit::Witness), + Initial(Vec>>), + Round(irs_commit::Witness), } impl Config { @@ -103,8 +103,8 @@ impl Config { let mut vector_offset = 0; for witness in &witnesses { for (oods_eval, oods_row) in zip_strict( - witness.out_of_domain().evaluators(self.initial_size()), - witness.out_of_domain().rows(), + witness.out_of_domain.evaluators(self.initial_size()), + witness.out_of_domain.rows(), ) { for (j, vector) in vectors.iter().enumerate() { if j >= vector_offset && j < oods_row.len() + vector_offset { @@ -219,8 +219,12 @@ impl Config { // Execute standard WHIR rounds on the batched vectors for (round_index, round_config) in self.round_configs.iter().enumerate() { - // Commit to the vector, this generates out-of-domain evaluations. - let new_witness = round_config.irs_committer.commit(prover_state, &[&vector]); + // Commit to the folded vector and run the per-round OOD step. + let (new_witness, out_of_domain) = round_config.irs_committer.commit_with_ood( + prover_state, + &[&vector], + round_config.out_domain_samples, + ); // Proof of work before in-domain challenges round_config.pow.prove(prover_state); @@ -228,9 +232,9 @@ impl Config { // Open the previous round's witness. let in_domain = match prev_witness { RoundWitness::Initial(init_witnesses) => { - let witness_refs: Vec<&_> = init_witnesses.iter().map(|c| &**c).collect(); + let irs_refs: Vec<&_> = init_witnesses.iter().map(|c| &c.irs).collect(); self.initial_committer - .open(prover_state, &witness_refs) + .open(prover_state, &irs_refs) .lift(self.embedding()) } RoundWitness::Round(old_witness) => { @@ -242,13 +246,11 @@ impl Config { }; // Collect constraints for this round and RLC them in - let stir_challenges = new_witness - .out_of_domain() + let stir_challenges = out_of_domain .evaluators(round_config.initial_size()) .chain(in_domain.evaluators(round_config.initial_size())) .collect::>(); - let stir_evaluations = new_witness - .out_of_domain() + let stir_evaluations = out_of_domain .values(&[M::Target::ONE]) .chain(in_domain.values(&tensor_product( &vector_rlc_coeffs, @@ -289,8 +291,8 @@ impl Config { // Open and consume the final previous witness. match prev_witness { RoundWitness::Initial(init_witnesses) => { - let witness_refs: Vec<&_> = init_witnesses.iter().map(|c| &**c).collect(); - let _in_domain = self.initial_committer.open(prover_state, &witness_refs); + let irs_refs: Vec<&_> = init_witnesses.iter().map(|c| &c.irs).collect(); + let _in_domain = self.initial_committer.open(prover_state, &irs_refs); } RoundWitness::Round(old_witness) => { let prev_config = self.round_configs.last().unwrap(); diff --git a/src/protocols/whir/verifier.rs b/src/protocols/whir/verifier.rs index 8421ceab..b7dc80a0 100644 --- a/src/protocols/whir/verifier.rs +++ b/src/protocols/whir/verifier.rs @@ -23,11 +23,11 @@ use crate::{ enum RoundCommitment<'a, F: Field> { Initial { - commitments: &'a [&'a irs_commit::Commitment], + commitments: &'a [&'a Commitment], batching_weights: Vec, }, Round { - commitment: irs_commit::Commitment, + commitment: irs_commit::Commitment, }, } @@ -72,8 +72,8 @@ impl Config { let mut vector_offset = 0; for commitment in commitments { for (weights, oods_row) in zip_strict( - commitment.out_of_domain().evaluators(self.initial_size()), - commitment.out_of_domain().rows(), + commitment.out_of_domain.evaluators(self.initial_size()), + commitment.out_of_domain.rows(), ) { for j in 0..num_vectors { if j >= vector_offset && j < oods_row.len() + vector_offset { @@ -133,10 +133,10 @@ impl Config { round_folding_randomness.push(folding_randomness); for (round_index, round_config) in self.round_configs.iter().enumerate() { - // Receive commitment to the folded vector, plus out-of-domain constraints - let commitment = round_config + // Receive commitment to the folded vector and the per-round OOD evaluations. + let (commitment, out_of_domain) = round_config .irs_committer - .receive_commitment(verifier_state)?; + .receive_commitment_with_ood(verifier_state, round_config.out_domain_samples)?; // Proof of work before in-domain challenges round_config.pow.verify(verifier_state)?; @@ -147,7 +147,8 @@ impl Config { commitments, batching_weights, } => { - let in_domain = self.initial_committer.verify(verifier_state, commitments)?; + let irs_refs: Vec<&_> = commitments.iter().map(|c| &c.irs).collect(); + let in_domain = self.initial_committer.verify(verifier_state, &irs_refs)?; // TODO: Skip lift and keep initial in-domain in subfield for evaluation. // This should be every so slightly more performant. (in_domain.lift(self.embedding()), batching_weights) @@ -162,13 +163,11 @@ impl Config { }; // Random linear combination of out- and in-domain constraints - let constraint_weights = commitment - .out_of_domain() + let constraint_weights = out_of_domain .evaluators(round_config.initial_size()) .chain(in_domain.evaluators(round_config.initial_size())) .collect::>(); - let constraint_values = commitment - .out_of_domain() + let constraint_values = out_of_domain .values(&[M::Target::ONE]) .chain(in_domain.values(&tensor_product( &poly_rlc, @@ -202,7 +201,8 @@ impl Config { commitments, batching_weights, } => { - let in_domain = self.initial_committer.verify(verifier_state, commitments)?; + let irs_refs: Vec<&_> = commitments.iter().map(|c| &c.irs).collect(); + let in_domain = self.initial_committer.verify(verifier_state, &irs_refs)?; (in_domain.lift(self.embedding()), batching_weights) } RoundCommitment::Round { commitment } => { diff --git a/src/protocols/whir_zk/committer.rs b/src/protocols/whir_zk/committer.rs index 942b87e1..1ea1a3ef 100644 --- a/src/protocols/whir_zk/committer.rs +++ b/src/protocols/whir_zk/committer.rs @@ -26,10 +26,10 @@ pub struct Commitment { #[derive(Clone, Debug)] pub struct Witness { pub f_hat_vectors: Vec>, - pub f_hat_witnesses: Vec>, + pub f_hat_witnesses: Vec>, pub blinding_polynomials: Vec>, pub blinding_vectors: Vec>, - pub blinding_witness: irs_commit::Witness, + pub blinding_witness: irs_commit::Witness, } impl Config { From 5d6402afb72091b4cd652f638652b757f4b8e8e0 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 8 May 2026 04:27:35 +0530 Subject: [PATCH 03/31] feat : added irs commit param selection --- src/protocols/basecase.rs | 6 +- src/protocols/code_switch.rs | 16 ++-- src/protocols/irs_commit.rs | 130 +++++++++++++---------------- src/protocols/mask_proximity.rs | 6 +- src/protocols/params/bounds.rs | 32 +++---- src/protocols/params/irs_commit.rs | 107 ++++++++++++++++++++++++ src/protocols/params/spec.rs | 4 + src/protocols/params/sumcheck.rs | 7 +- src/protocols/sumcheck.rs | 4 +- src/protocols/whir/config.rs | 10 +-- 10 files changed, 209 insertions(+), 113 deletions(-) diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index b5b251d3..cfac98e4 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -169,7 +169,7 @@ impl Config { if !self.masked { let vector = verifier_state.prover_messages_vec(self.commit.vector_size)?; let masks = verifier_state - .prover_messages_vec(self.commit.mask_length * self.commit.num_messages())?; + .prover_messages_vec(self.commit.mask_length() * self.commit.num_messages())?; let evals = self.commit.verify(verifier_state, &[commitment])?; let point = self .sumcheck @@ -197,7 +197,7 @@ impl Config { let mask_rlc: F = verifier_state.verifier_message(); verify!(!mask_rlc.is_zero()); let masked_vector: Vec = verifier_state.prover_messages_vec(self.commit.vector_size)?; - let masked_masks: Vec = verifier_state.prover_messages_vec(self.commit.mask_length)?; + let masked_masks: Vec = verifier_state.prover_messages_vec(self.commit.mask_length())?; // Open the commitment and mask simultaneously. let evals = self @@ -248,7 +248,7 @@ mod tests { let commit = irs_commit::Config::arbitrary(Identity::::new(), 1, size, mask_length, 1); (commit, bool::weighted(0.8)).prop_map(move |(commit, masked)| Self { - commit: irs_commit::Config { ..commit }, + commit, sumcheck: sumcheck::Config::new( size, proof_of_work::Config::none(), diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 73d60bec..004861c5 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -93,25 +93,25 @@ impl Config { // Theorem 9.6: ℓ_zk ≥ r (mask oracle must cover source randomness). if message_mask_length > 0 { assert!( - message_mask_length >= source_config.mask_length, + message_mask_length >= source_config.mask_length(), "message_mask_length ({message_mask_length}) must be >= source randomness length ({})", - source_config.mask_length, + source_config.mask_length(), ); assert!( - message_mask_length - source_config.mask_length >= out_domain_samples, + message_mask_length - source_config.mask_length() >= out_domain_samples, "the sampled randomness (s) length must be covering all the out of domain sample requests" ); // t' = (in-domain queries to g via target IRS) // + (OOD queries to g via Construction 9.7's OOD step, count = out_domain_samples). // Lemma 9.5 perfect-ZK: t' ≤ r' = target.mask_length. assert!( - target_config.mask_length + target_config.mask_length() >= target_config.in_domain_samples + out_domain_samples, "target encoder violates: t' > r', number of queries should be covered by random mask" ); } assert!( - source_config.mask_length == 0 || message_mask_length > 0, + source_config.mask_length() == 0 || message_mask_length > 0, "source with mask_length > 0 (IRS randomness) requires ZK mode (message_mask_length > 0)" ); assert!( @@ -427,7 +427,7 @@ mod tests { } // r = post-fold randomness length (ι_s parallel // masks fold to a single length-mask_length chunk). - let r = source.mask_length; + let r = source.mask_length(); let message_mask_length = if zk { r + fresh_s_len } else { 0 }; Self::new(source.clone(), target, ood, message_mask_length) }) @@ -486,7 +486,7 @@ mod tests { // Lift ι parallel masks (total length source.mask_length × ι) and fold // chunks of length source.mask_length down to a single chunk. let raw = lift(config.source.embedding(), &source_witness.masks); - let mut mask = fold_chunks(&raw, config.source.mask_length, folding_randomness); + let mut mask = fold_chunks(&raw, config.source.mask_length(), folding_randomness); // Append fresh padding s of length message_mask_length - source.mask_length. mask.extend(random_vector::( rng, @@ -748,7 +748,7 @@ mod tests { "non-ZK with ood > 0", |config| { config.message_mask_length == 0 - && config.source.mask_length == 0 + && config.source.mask_length() == 0 && config.out_domain_samples > 0 }, ); diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index d1cc9828..2b53b875 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -11,11 +11,7 @@ //! them using the [`matrix_commit`] protocol. Sampling is done with replacement, so may produce //! fewer than `in_domain_samples` distinct rows. //! -use std::{ - f64::{self, consts::LOG2_10}, - fmt, - ops::Neg, -}; +use std::{f64, fmt}; use ark_ff::{AdditiveGroup, Field}; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; @@ -26,12 +22,18 @@ use tracing::instrument; use crate::{ algebra::{ - dot, embedding::Embedding, fields::FieldWithSize, lift, linear_form::UnivariateEvaluation, + dot, embedding::Embedding, lift, linear_form::UnivariateEvaluation, mixed_univariate_evaluate, ntt, random_vector, }, engines::EngineId, hash::Hash, - protocols::{challenge_indices::challenge_indices, matrix_commit}, + protocols::{ + challenge_indices::challenge_indices, + matrix_commit, + params::bounds::{ + eps_mca_log2, list_size_log2, one_minus_distance_log2, ood_per_sample_log2, CodeParams, + }, + }, transcript::{ Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, VerifierMessage, VerifierState, @@ -40,6 +42,12 @@ use crate::{ utils::{chunks_exact_or_empty, zip_strict}, }; +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub enum IrsMode { + Standard, + ZeroKnowledge { mask_length: usize }, +} + /// Commit to vectors over an fft-friendly field F #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] #[serde(bound = "")] @@ -53,9 +61,6 @@ pub struct Config { /// The number of coefficients in each vector. pub vector_size: usize, - /// The number of masking values to add per codeword. - pub mask_length: usize, - /// The number of Reed-Solomon evaluation points. pub codeword_length: usize, @@ -78,6 +83,9 @@ pub struct Config { /// complexity, but it makes transcript pattern and control flow /// non-deterministic. pub deduplicate_in_domain: bool, + + /// Standard / ZeroKnowledge. + pub mode: IrsMode, } #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] @@ -107,6 +115,7 @@ pub struct Evaluations { } impl Config { + #[allow(clippy::too_many_arguments)] pub fn new( security_target: f64, unique_decoding: bool, @@ -115,7 +124,7 @@ impl Config { vector_size: usize, interleaving_depth: usize, rate: f64, - mask_length: usize, + mode: IrsMode, ) -> Self where M: Default, @@ -127,7 +136,6 @@ impl Config { let codeword_length = (message_length as f64 / rate).ceil() as usize; let rate = message_length as f64 / codeword_length as f64; - // Pick in- and out-of-domain samples. // η = slack to Johnson bound. We pick η = √ρ / 20. // TODO: Optimize picking η. let johnson_slack = if unique_decoding { @@ -135,28 +143,12 @@ impl Config { } else { rate.sqrt() / 20. }; - #[allow(clippy::cast_sign_loss)] - let in_domain_samples = { - // Query error is (1 - δ)^q, so we compute 1 - δ - let per_sample = if unique_decoding { - // Unique decoding bound: δ = (1 - ρ) / 2 - f64::midpoint(1., rate) - } else { - // Johnson bound: δ = 1 - √ρ - η - rate.sqrt() + johnson_slack - }; - (security_target / (-per_sample.log2())).ceil() as usize - }; - debug_assert_eq!( - in_domain_samples, - num_in_domain_queries(unique_decoding, security_target, rate) - ); + let in_domain_samples = num_in_domain_queries(unique_decoding, security_target, rate); Self { embedding: Typed::::default(), num_vectors, vector_size, - mask_length, codeword_length, interleaving_depth, matrix_commit: matrix_commit::Config::with_hash( @@ -167,6 +159,7 @@ impl Config { johnson_slack: OrderedFloat(johnson_slack), in_domain_samples, deduplicate_in_domain: false, + mode, } } @@ -191,9 +184,17 @@ impl Config { self.vector_size / self.interleaving_depth } + /// Per-polynomial IRS randomness length. Returns 0 in Standard mode. + pub const fn mask_length(&self) -> usize { + match &self.mode { + IrsMode::Standard => 0, + IrsMode::ZeroKnowledge { mask_length } => *mask_length, + } + } + /// Message length including mask coefficients. pub fn masked_message_length(&self) -> usize { - self.message_length() + self.mask_length + self.message_length() + self.mask_length() } pub fn evaluation_points(&self, indices: &[usize]) -> Vec { @@ -212,44 +213,29 @@ impl Config { self.johnson_slack == 0.0 } + fn log_inv_rate(&self) -> f64 { + -self.rate().log2() + } + /// Compute a list size bound. pub fn list_size(&self) -> f64 { - if self.unique_decoding() { - 1. - } else { - // This is the Johnson bound $1 / (2 η √ρ)$. - 1. / (2. * self.johnson_slack.into_inner() * self.rate().sqrt()) - } + 2_f64.powf(list_size_log2( + self.log_inv_rate(), + self.johnson_slack.into_inner(), + )) } /// Round-by-round soundness of the in-domain queries in bits. pub fn rbr_queries(&self) -> f64 { - let per_sample = if self.unique_decoding() { - // 1 - δ = 1 - (1 + ρ) / 2 - f64::midpoint(1., self.rate()) - } else { - // 1 - δ = sqrt(ρ) + η - self.rate().sqrt() + self.johnson_slack.into_inner() - }; - self.in_domain_samples as f64 * per_sample.log2().neg() + // Query error is (1 - δ)^q in bits = -q · log2(1 - δ). + -(self.in_domain_samples as f64) + * one_minus_distance_log2(self.log_inv_rate(), self.johnson_slack.into_inner()) } - // Compute the proximity gaps term of the fold + /// Round-by-round soundness of the proximity-gaps fold in bits. + /// See WHIR Theorem 4.8. pub fn rbr_soundness_fold_prox_gaps(&self) -> f64 { - let log_field_size = M::Target::field_size_bits(); - let log_inv_rate = self.rate().log2().neg(); - let log_k = (self.masked_message_length() as f64).log2(); - // See WHIR Theorem 4.8 - // Recall, at each round we are only folding by two at a time - let error = if self.unique_decoding() { - log_k + log_inv_rate - } else { - let log_eta = self.johnson_slack.into_inner().log2(); - // Make sure η hits the min bound. - assert!(log_eta >= -(0.5 * log_inv_rate + LOG2_10 + 1.0) - 1e-6); - 7. * LOG2_10 + 3.5 * log_inv_rate + 2. * log_k - }; - log_field_size - error + -eps_mca_log2(&CodeParams::from_irs(self)) } /// Commit to one or more vectors. @@ -276,7 +262,7 @@ impl Config { assert!(vectors.iter().all(|p| p.len() == self.vector_size)); // Generate random mask - let masks = random_vector(prover_state.rng(), self.mask_length * self.num_messages()); + let masks = random_vector(prover_state.rng(), self.mask_length() * self.num_messages()); // Interleaved RS Encode the vectors let messages = vectors @@ -543,9 +529,9 @@ pub fn num_ood_samples( if unique_decoding { return 0; } - let l_choose_2 = list_size * (list_size - 1.) / 2.; - let log_per_sample = field_size_bits - ((degree - 1) as f64).log2(); + let log_per_sample = -ood_per_sample_log2(degree, field_size_bits); assert!(log_per_sample > 0.); + let l_choose_2 = list_size * (list_size - 1.) / 2.; ((security_target + l_choose_2.log2()) / log_per_sample) .ceil() .max(1.) as usize @@ -561,7 +547,6 @@ pub(crate) fn num_in_domain_queries( security_target: f64, rate: f64, ) -> usize { - // Pick in- and out-of-domain samples. // η = slack to Johnson bound. We pick η = √ρ / 20. // TODO: Optimize picking η. let johnson_slack = if unique_decoding { @@ -569,15 +554,9 @@ pub(crate) fn num_in_domain_queries( } else { rate.sqrt() / 20. }; - // Query error is (1 - δ)^q, so we compute 1 - δ - let per_sample = if unique_decoding { - // Unique decoding bound: δ = (1 - ρ) / 2 - f64::midpoint(1., rate) - } else { - // Johnson bound: δ = 1 - √ρ - η - rate.sqrt() + johnson_slack - }; - (security_target / (-per_sample.log2())).ceil() as usize + // Query error is (1 - δ)^q in bits = -q · log2(1 - δ). + let log_one_minus_delta = one_minus_distance_log2(-rate.log2(), johnson_slack); + (security_target / -log_one_minus_delta).ceil() as usize } #[cfg(test)] @@ -640,17 +619,22 @@ pub(crate) mod tests { in_domain_samples, deduplicate_in_domain, )| { + let mode = if mask_length == 0 { + IrsMode::Standard + } else { + IrsMode::ZeroKnowledge { mask_length } + }; Self { embedding: Typed::new(embedding.clone()), num_vectors, vector_size, - mask_length, codeword_length, interleaving_depth, matrix_commit, johnson_slack: OrderedFloat::default(), in_domain_samples, deduplicate_in_domain, + mode, } }, ) diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 7d625865..687b5341 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -171,7 +171,7 @@ impl Config { // Step 2: compute and send combined polynomials + IRS randomness let irs_masks_per_vector = - self.c_zk_commit.mask_length * self.c_zk_commit.interleaving_depth; + self.c_zk_commit.mask_length() * self.c_zk_commit.interleaving_depth; assert_eq!( witness.mask_witness.masks.len(), 2 * self.num_masks * irs_masks_per_vector @@ -220,7 +220,7 @@ impl Config { // Step 2: read combined polynomials + IRS randomness let msg_len = self.c_zk_commit.message_length(); let irs_masks_per_vector = - self.c_zk_commit.mask_length * self.c_zk_commit.interleaving_depth; + self.c_zk_commit.mask_length() * self.c_zk_commit.interleaving_depth; let has_irs_masks = irs_masks_per_vector > 0; let mut combined_msgs = Vec::with_capacity(self.num_masks); let mut combined_rs: Option>> = @@ -442,7 +442,7 @@ mod tests { let gamma: F = prover_state.verifier_message(); let irs_masks_per_vector = - config.c_zk_commit.mask_length * config.c_zk_commit.interleaving_depth; + config.c_zk_commit.mask_length() * config.c_zk_commit.interleaving_depth; for (i, (orig_msg, fresh_msg)) in original_msgs .iter() diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index 6732b88a..6f5380f9 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -26,23 +26,23 @@ impl CodeParams { field_bits: M::Target::field_size_bits(), } } +} - pub fn rate(&self) -> f64 { - 2_f64.powf(-self.log_inv_rate) - } +fn rate(log_inv_rate: f64) -> f64 { + 2_f64.powf(-log_inv_rate) +} - pub fn unique_decoding(&self) -> bool { - self.johnson_slack == 0.0 - } +fn unique_decoding(johnson_slack: f64) -> bool { + johnson_slack == 0.0 } /// log2 |Λ(C, δ)|. -pub fn list_size_log2(p: &CodeParams) -> f64 { - if p.unique_decoding() { +pub fn list_size_log2(log_inv_rate: f64, johnson_slack: f64) -> f64 { + if unique_decoding(johnson_slack) { 0.0 } else { // Johnson: |Λ| = 1 / (2 η √ρ). - -1.0 - p.johnson_slack.log2() + 0.5 * p.log_inv_rate + -1.0 - johnson_slack.log2() + 0.5 * log_inv_rate } } @@ -50,7 +50,7 @@ pub fn list_size_log2(p: &CodeParams) -> f64 { pub fn eps_mca_log2(p: &CodeParams) -> f64 { let log_k = (p.message_length as f64).log2(); - let error = if p.unique_decoding() { + let error = if unique_decoding(p.johnson_slack) { log_k + p.log_inv_rate } else { debug_assert!(p.johnson_slack.log2() >= -(0.5 * p.log_inv_rate + LOG2_10 + 1.0) - 1e-6); @@ -61,18 +61,18 @@ pub fn eps_mca_log2(p: &CodeParams) -> f64 { } /// log2(1 - δ). -pub fn one_minus_distance_log2(p: &CodeParams) -> f64 { - let one_minus_delta = if p.unique_decoding() { - f64::midpoint(1.0, p.rate()) +pub fn one_minus_distance_log2(log_inv_rate: f64, johnson_slack: f64) -> f64 { + let one_minus_delta = if unique_decoding(johnson_slack) { + f64::midpoint(1.0, rate(log_inv_rate)) } else { - p.rate().sqrt() + p.johnson_slack + rate(log_inv_rate).sqrt() + johnson_slack }; one_minus_delta.log2() } /// log2 of the per-OOD-sample Schwartz-Zippel error: (k-1)/|F|. -pub fn ood_per_sample_log2(p: &CodeParams) -> f64 { - ((p.message_length - 1) as f64).log2() - p.field_bits +pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { + ((message_length - 1) as f64).log2() - field_bits } /// PoW difficulty to close a soundness gap: max(0, target − achieved). diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 8b137891..55a11c72 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -1 +1,108 @@ +//! Parameter selection for the IRS commit protocol. +use crate::{ + algebra::embedding::Embedding, + protocols::{ + irs_commit::{self, num_in_domain_queries, IrsMode}, + params::spec::{Mode, RoundContext, SecuritySpec}, + }, +}; + +/// Solve IRS-commit parameters for a single round. +/// +/// `out_domain_samples` is the OOD-query budget owed by Construction 9.7 / +/// Bound 2; in ZK mode it is part of the per-row randomness budget (Lemma 9.5 +/// requires `mask_length ≥ in_domain_samples + out_domain_samples`). +pub fn solve( + spec: &SecuritySpec, + ctx: &RoundContext, + out_domain_samples: usize, +) -> irs_commit::Config { + assert!( + !(matches!(spec.mode, Mode::ZeroKnowledge) && spec.unique_decoding), + "ZK mode requires Johnson regime (code-switch needs OOD samples)" + ); + + let security_target = f64::from( + spec.target_security_bits + .saturating_sub(spec.max_pow_bits.unwrap_or(0)), + ); + let rate = 2_f64.powf(-f64::from(ctx.log_inv_rate)); + let interleaving_depth = 1_usize << ctx.folding_factor; + + let mode = match spec.mode { + Mode::Standard => IrsMode::Standard, + Mode::ZeroKnowledge => IrsMode::ZeroKnowledge { + mask_length: mask_length( + spec.unique_decoding, + security_target, + rate, + out_domain_samples, + ), + }, + }; + + irs_commit::Config::new( + security_target, + spec.unique_decoding, + spec.hash_id, + 1, + ctx.vector_size, + interleaving_depth, + rate, + mode, + ) +} + +/// Solve the shared C_zk IRS-commit config used to commit mask polynomials. +/// +/// C_zk itself carries no IRS randomness (`IrsMode::Standard`); the masks it +/// commits to already are the randomness. +/// +/// - `l_zk` — C_zk message length (Theorem 9.6: ℓ_zk ≥ r). +/// - `log_inv_rate` — C_zk's rate, chosen by the orchestrator. +/// - `num_vectors` — total mask polynomials per commit (e.g. `2 * num_masks` +/// for mask-proximity's original/fresh pairs). +pub fn solve_mask_code( + spec: &SecuritySpec, + l_zk: usize, + log_inv_rate: u32, + num_vectors: usize, +) -> irs_commit::Config { + assert!( + matches!(spec.mode, Mode::ZeroKnowledge), + "C_zk only exists in ZK mode" + ); + assert!( + !spec.unique_decoding, + "code-switch requires Johnson regime (OOD samples needed)" + ); + + let security_target = f64::from( + spec.target_security_bits + .saturating_sub(spec.max_pow_bits.unwrap_or(0)), + ); + let rate = 2_f64.powf(-f64::from(log_inv_rate)); + + irs_commit::Config::new( + security_target, + spec.unique_decoding, + spec.hash_id, + num_vectors, + l_zk, + 1, + rate, + IrsMode::Standard, + ) +} + +/// Lemma 9.5 ZK budget: cover every query that reveals a polynomial value. +fn mask_length( + unique_decoding: bool, + security_target: f64, + rate: f64, + out_domain_samples: usize, +) -> usize { + let in_domain = num_in_domain_queries(unique_decoding, security_target, rate); + in_domain + out_domain_samples +} diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 9d39ef1f..43d604a3 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -8,6 +8,10 @@ pub struct SecuritySpec { pub mode: Mode, /// Target security bits pub target_security_bits: u32, + /// Use the unique-decoding regime (`true`) instead of the Johnson regime. + /// ZK mode requires Johnson — Construction 9.7 / Bound 2 needs OOD queries, + /// and `num_ood_samples` returns 0 in unique-decoding. + pub unique_decoding: bool, /// Size of the input witness / vector pub vector_size: usize, /// Starting log inverse rate for RS code diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 1d7772ba..9036f482 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -26,7 +26,7 @@ pub fn solve( sumcheck::Config::new(ctx.vector_size, round_pow, num_rounds, mode) } -fn num_sumcheck_rounds(spec: &SecuritySpec, ctx: &RoundContext) -> usize { +const fn num_sumcheck_rounds(spec: &SecuritySpec, ctx: &RoundContext) -> usize { if ctx.round_index == 0 { spec.initial_folding_factor } else { @@ -34,7 +34,7 @@ fn num_sumcheck_rounds(spec: &SecuritySpec, ctx: &RoundContext) } } -pub fn masks_required(spec: &SecuritySpec, ctx: &RoundContext) -> usize { +pub const fn masks_required(spec: &SecuritySpec, ctx: &RoundContext) -> usize { match spec.mode { Mode::Standard => 0, Mode::ZeroKnowledge => num_sumcheck_rounds(spec, ctx), @@ -56,7 +56,8 @@ fn solve_sumcheck_round_pow( // TODO: extend with `ℓ_zk · |Λ_C_zk|` factors in ZK mode once mask-code // params are available (PR 2). let sec_mca = -bounds::eps_mca_log2(&code); - let sec_combination = code.field_bits - bounds::list_size_log2(&code) - 1.0; + let sec_combination = + code.field_bits - bounds::list_size_log2(code.log_inv_rate, code.johnson_slack) - 1.0; let achieved = sec_mca.min(sec_combination); let pow_bits = bounds::pow_bits_to_close_gap(spec.target_security_bits, achieved); diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index 2c04e01c..16a2988e 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -75,7 +75,7 @@ impl Config { } } - fn mask_length(&self) -> usize { + const fn mask_length(&self) -> usize { match &self.mode { SumcheckMode::Standard => 0, SumcheckMode::ZeroKnowledge { mask_length } => *mask_length, @@ -338,7 +338,7 @@ mod tests { |(initial_size, num_rounds, mode)| { let num_rounds = num_rounds.min(initial_size.next_power_of_two().trailing_zeros() as usize); - Config::new( + Self::new( initial_size, proof_of_work::Config::none(), num_rounds, diff --git a/src/protocols/whir/config.rs b/src/protocols/whir/config.rs index a81bbb87..45b138ff 100644 --- a/src/protocols/whir/config.rs +++ b/src/protocols/whir/config.rs @@ -8,7 +8,7 @@ use crate::{ bits::Bits, parameters::ProtocolParameters, protocols::{ - irs_commit::{self, num_ood_samples}, + irs_commit::{self, num_ood_samples, IrsMode}, proof_of_work, sumcheck, }, }; @@ -61,7 +61,7 @@ impl Config { size, 1 << whir_parameters.initial_folding_factor, 0.5_f64.powi(whir_parameters.starting_log_inv_rate as i32), - 0, + IrsMode::Standard, ); let initial_out_domain_samples = num_ood_samples( whir_parameters.unique_decoding, @@ -109,7 +109,7 @@ impl Config { 1 << num_variables, 1 << whir_parameters.folding_factor, 0.5_f64.powi(next_rate as i32), - 0, + IrsMode::Standard, ); let round_out_domain_samples = num_ood_samples( whir_parameters.unique_decoding, @@ -580,7 +580,7 @@ mod tests { embedding: Typed::new(embedding::Identity::new()), num_vectors: 1, vector_size: 1 << 10, - mask_length: 0, + mode: IrsMode::Standard, codeword_length: 1 << (10 + 3 - 2), interleaving_depth: 1 << 2, matrix_commit: matrix_commit::Config::::new(0, 0), @@ -602,7 +602,7 @@ mod tests { embedding: Typed::new(embedding::Identity::new()), num_vectors: 1, vector_size: 1 << 10, - mask_length: 0, + mode: IrsMode::Standard, codeword_length: 1 << (10 + 4 - 2), interleaving_depth: 1 << 2, matrix_commit: matrix_commit::Config::::new(0, 0), From 8e75809ef4270aefcdc17ca298dc25062a51d910 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 8 May 2026 06:02:19 +0530 Subject: [PATCH 04/31] feat : added typed params and test for irs_commit --- .../protocols/params/irs_commit.txt | 7 + src/protocols/basecase.rs | 1 + src/protocols/code_switch.rs | 1 + src/protocols/irs_commit.rs | 35 +-- src/protocols/mask_proximity.rs | 1 + src/protocols/params/irs_commit.rs | 217 ++++++++++++++---- src/protocols/params/spec.rs | 70 ++++-- src/protocols/params/sumcheck.rs | 10 +- src/protocols/sumcheck.rs | 1 + src/protocols/whir_zk/mod.rs | 6 +- 10 files changed, 261 insertions(+), 88 deletions(-) create mode 100644 proptest-regressions/protocols/params/irs_commit.txt diff --git a/proptest-regressions/protocols/params/irs_commit.txt b/proptest-regressions/protocols/params/irs_commit.txt new file mode 100644 index 00000000..2ca7e863 --- /dev/null +++ b/proptest-regressions/protocols/params/irs_commit.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 0b6dd03179c9a4e38b29b34b241b88fba69348a2c8938af7253314b7035bea82 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 }, out_domain = 0, seed = 0 diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index cfac98e4..1d40c242 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -31,6 +31,7 @@ pub struct Opening { pub linear_form_evaluation: F, } +#[must_use] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 004861c5..b6f284ae 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -30,6 +30,7 @@ use crate::{ }; /// Code-switching IOR config with optional ZK. +#[must_use] #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index 2b53b875..2fb0233c 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -11,7 +11,7 @@ //! them using the [`matrix_commit`] protocol. Sampling is done with replacement, so may produce //! fewer than `in_domain_samples` distinct rows. //! -use std::{f64, fmt}; +use std::{f64, fmt, num::NonZeroUsize}; use ark_ff::{AdditiveGroup, Field}; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; @@ -45,10 +45,11 @@ use crate::{ #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] pub enum IrsMode { Standard, - ZeroKnowledge { mask_length: usize }, + ZeroKnowledge { mask_length: NonZeroUsize }, } /// Commit to vectors over an fft-friendly field F +#[must_use] #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { @@ -131,10 +132,14 @@ impl Config { { assert!(vector_size.is_multiple_of(interleaving_depth)); assert!(rate > 0. && rate <= 1.); - let message_length = vector_size / interleaving_depth; + let mask_length = match &mode { + IrsMode::Standard => 0, + IrsMode::ZeroKnowledge { mask_length } => mask_length.get(), + }; + let masked_message_length = vector_size / interleaving_depth + mask_length; #[allow(clippy::cast_sign_loss)] - let codeword_length = (message_length as f64 / rate).ceil() as usize; - let rate = message_length as f64 / codeword_length as f64; + let codeword_length = (masked_message_length as f64 / rate).ceil() as usize; + let rate = masked_message_length as f64 / codeword_length as f64; // η = slack to Johnson bound. We pick η = √ρ / 20. // TODO: Optimize picking η. @@ -143,7 +148,7 @@ impl Config { } else { rate.sqrt() / 20. }; - let in_domain_samples = num_in_domain_queries(unique_decoding, security_target, rate); + let in_domain_samples = num_in_domain_queries(unique_decoding, security_target, rate).get(); Self { embedding: Typed::::default(), @@ -188,7 +193,7 @@ impl Config { pub const fn mask_length(&self) -> usize { match &self.mode { IrsMode::Standard => 0, - IrsMode::ZeroKnowledge { mask_length } => *mask_length, + IrsMode::ZeroKnowledge { mask_length } => mask_length.get(), } } @@ -539,6 +544,9 @@ pub fn num_ood_samples( /// Return the number of in-domain queries. /// +/// Always ≥ 1 — the type carries that invariant so callers don't need to +/// re-prove it locally. +/// /// This is used by [`whir_zk`]. // TODO: A method with cleaner abstraction. #[allow(clippy::cast_sign_loss)] @@ -546,7 +554,7 @@ pub(crate) fn num_in_domain_queries( unique_decoding: bool, security_target: f64, rate: f64, -) -> usize { +) -> NonZeroUsize { // η = slack to Johnson bound. We pick η = √ρ / 20. // TODO: Optimize picking η. let johnson_slack = if unique_decoding { @@ -556,7 +564,8 @@ pub(crate) fn num_in_domain_queries( }; // Query error is (1 - δ)^q in bits = -q · log2(1 - δ). let log_one_minus_delta = one_minus_distance_log2(-rate.log2(), johnson_slack); - (security_target / -log_one_minus_delta).ceil() as usize + let q = (security_target / -log_one_minus_delta).ceil() as usize; + NonZeroUsize::new(q).unwrap_or(NonZeroUsize::MIN) } #[cfg(test)] @@ -619,11 +628,9 @@ pub(crate) mod tests { in_domain_samples, deduplicate_in_domain, )| { - let mode = if mask_length == 0 { - IrsMode::Standard - } else { - IrsMode::ZeroKnowledge { mask_length } - }; + let mode = NonZeroUsize::new(mask_length).map_or(IrsMode::Standard, |n| { + IrsMode::ZeroKnowledge { mask_length: n } + }); Self { embedding: Typed::new(embedding.clone()), num_vectors, diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 687b5341..4a733e92 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -61,6 +61,7 @@ use crate::{ /// Mask proximity configuration. /// /// Wraps an IRS config for the shared mask tree and the number of mask pairs. +#[must_use] #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 55a11c72..3b47c651 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -1,51 +1,55 @@ //! Parameter selection for the IRS commit protocol. +use std::iter; + use crate::{ - algebra::embedding::Embedding, + algebra::{embedding::Embedding, ntt}, protocols::{ irs_commit::{self, num_in_domain_queries, IrsMode}, - params::spec::{Mode, RoundContext, SecuritySpec}, + params::spec::{ + LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + }, }, }; -/// Solve IRS-commit parameters for a single round. -/// -/// `out_domain_samples` is the OOD-query budget owed by Construction 9.7 / -/// Bound 2; in ZK mode it is part of the per-row randomness budget (Lemma 9.5 -/// requires `mask_length ≥ in_domain_samples + out_domain_samples`). +/// Solve per-round IRS-commit parameters. ZK mask sized per Lemma 9.5. pub fn solve( spec: &SecuritySpec, ctx: &RoundContext, - out_domain_samples: usize, + out_domain: OodSampleBudget, ) -> irs_commit::Config { - assert!( - !(matches!(spec.mode, Mode::ZeroKnowledge) && spec.unique_decoding), - "ZK mode requires Johnson regime (code-switch needs OOD samples)" - ); - let security_target = f64::from( spec.target_security_bits .saturating_sub(spec.max_pow_bits.unwrap_or(0)), ); - let rate = 2_f64.powf(-f64::from(ctx.log_inv_rate)); + let raw_rate = 2_f64.powf(-f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; + let unique_decoding = spec.mode.unique_decoding(); let mode = match spec.mode { - Mode::Standard => IrsMode::Standard, - Mode::ZeroKnowledge => IrsMode::ZeroKnowledge { - mask_length: mask_length( - spec.unique_decoding, - security_target, - rate, - out_domain_samples, - ), - }, + Mode::Standard { .. } => IrsMode::Standard, + Mode::ZeroKnowledge => { + // Lemma 9.5 ZK budget: every revealed evaluation counts. + let in_domain = num_in_domain_queries(unique_decoding, security_target, raw_rate); + let mask_length = in_domain + .checked_add(out_domain.get()) + .expect("usize overflow in mask_length"); + IrsMode::ZeroKnowledge { mask_length } + } }; + let mask_length_value = match &mode { + IrsMode::Standard => 0, + IrsMode::ZeroKnowledge { mask_length } => mask_length.get(), + }; + let masked_message_length = ctx.vector_size / interleaving_depth + mask_length_value; + let rate = snap_rate::(masked_message_length, raw_rate); + irs_commit::Config::new( security_target, - spec.unique_decoding, + unique_decoding, spec.hash_id, + // Orchestrator commits one vector per round. 1, ctx.vector_size, interleaving_depth, @@ -54,39 +58,46 @@ pub fn solve( ) } -/// Solve the shared C_zk IRS-commit config used to commit mask polynomials. -/// -/// C_zk itself carries no IRS randomness (`IrsMode::Standard`); the masks it -/// commits to already are the randomness. +/// Solve the shared C_zk IRS config for committing mask polynomials. /// -/// - `l_zk` — C_zk message length (Theorem 9.6: ℓ_zk ≥ r). -/// - `log_inv_rate` — C_zk's rate, chosen by the orchestrator. -/// - `num_vectors` — total mask polynomials per commit (e.g. `2 * num_masks` -/// for mask-proximity's original/fresh pairs). +/// - `l_zk` — message length (Theorem 9.6: ℓ_zk ≥ `source_mask_length`). +/// - `source_mask_length` — `r`, the source IRS mask length. +/// - `log_inv_rate` — C_zk rate. +/// - `num_vectors` — total masks per commit; must equal `2 * num_masks` to be +/// consumable by `mask_proximity::Config::new` (original/fresh pairs). pub fn solve_mask_code( spec: &SecuritySpec, - l_zk: usize, - log_inv_rate: u32, + l_zk: MaskCodeMessageLen, + source_mask_length: usize, + log_inv_rate: LogInvRate, num_vectors: usize, ) -> irs_commit::Config { + let l_zk = l_zk.get(); assert!( matches!(spec.mode, Mode::ZeroKnowledge), "C_zk only exists in ZK mode" ); assert!( - !spec.unique_decoding, - "code-switch requires Johnson regime (OOD samples needed)" + l_zk >= source_mask_length, + "Theorem 9.6: ℓ_zk ({l_zk}) must be ≥ source mask length ({source_mask_length})", + ); + assert!( + num_vectors.is_multiple_of(2), + "num_vectors ({num_vectors}) must be even — mask-proximity expects 2 · num_masks (original + fresh)", ); let security_target = f64::from( spec.target_security_bits .saturating_sub(spec.max_pow_bits.unwrap_or(0)), ); - let rate = 2_f64.powf(-f64::from(log_inv_rate)); + let raw_rate = 2_f64.powf(-f64::from(log_inv_rate.get())); + // C_zk has interleaving_depth = 1 and IrsMode::Standard, so masked_message_length = l_zk. + let rate = snap_rate::(l_zk, raw_rate); irs_commit::Config::new( security_target, - spec.unique_decoding, + // ZK ⇒ Johnson regime. + false, spec.hash_id, num_vectors, l_zk, @@ -96,13 +107,125 @@ pub fn solve_mask_code( ) } -/// Lemma 9.5 ZK budget: cover every query that reveals a polynomial value. -fn mask_length( - unique_decoding: bool, - security_target: f64, - rate: f64, - out_domain_samples: usize, -) -> usize { - let in_domain = num_in_domain_queries(unique_decoding, security_target, rate); - in_domain + out_domain_samples +/// Snap `rate` so `Config::new`'s codeword sizing lands on a valid power-of-two +/// NTT order. Returns a rate `≤ raw_rate`. +fn snap_rate(masked_message_length: usize, raw_rate: f64) -> f64 { + #[allow(clippy::cast_sign_loss)] + let desired = (masked_message_length as f64 / raw_rate).ceil() as usize; + let codeword_length = iter::successors(ntt::next_order::(desired), |&n| { + ntt::next_order::(n + 1) + }) + .find(|n| n.is_power_of_two()) + .expect("no valid power-of-two NTT order ≥ desired codeword length"); + masked_message_length as f64 / codeword_length as f64 +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use ark_std::rand::{rngs::StdRng, SeedableRng}; + use proptest::prelude::*; + + use super::*; + use crate::{ + algebra::{embedding::Identity, fields::Field64, random_vector}, + hash, + transcript::{DomainSeparator, ProverState, VerifierState}, + }; + + type F = Field64; + type M = Identity; + + fn arb_spec_with(mode: impl Strategy) -> impl Strategy> { + (mode, 80u32..=128, 1u32..=4, prop::option::of(0u32..=20)).prop_map( + |(mode, target_security_bits, starting_log_inv_rate, max_pow_bits)| SecuritySpec { + mode, + target_security_bits, + vector_size: 1 << 8, + starting_log_inv_rate, + initial_folding_factor: 4, + folding_factor: 4, + max_pow_bits, + hash_id: hash::BLAKE3, + _embedding: PhantomData, + }, + ) + } + + fn arb_zk_spec() -> impl Strategy> { + arb_spec_with(Just(Mode::ZeroKnowledge)) + } + + fn arb_standard_spec() -> impl Strategy> { + arb_spec_with(any::().prop_map(|unique_decoding| Mode::Standard { unique_decoding })) + } + + fn arb_any_spec() -> impl Strategy> { + prop_oneof![arb_zk_spec(), arb_standard_spec()] + } + + fn arb_ctx() -> impl Strategy { + (4u32..=8, 1u32..=4, 1u32..=3).prop_map(|(log_size, log_inv_rate, folding_factor)| { + RoundContext { + round_index: 0, + vector_size: 1_usize << log_size, + log_inv_rate, + folding_factor, + prev_round_in_domain_samples: 0, + prev_round_query_error: 0.0, + } + }) + } + + proptest! { + /// Lemma 9.5: ZK mask covers all revealed evaluations. + #[test] + fn zk_mask_covers_lemma_9_5( + spec in arb_zk_spec(), + ctx in arb_ctx(), + out_domain in 0usize..16, + ) { + let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); + prop_assert!( + config.mask_length() >= config.in_domain_samples + out_domain, + "mask {} < in_domain {} + out_domain {}", + config.mask_length(), config.in_domain_samples, out_domain, + ); + } + + /// Standard mode produces no IRS randomness. + #[test] + fn standard_has_no_mask(spec in arb_standard_spec(), ctx in arb_ctx()) { + let config = solve(&spec, &ctx, OodSampleBudget::new(0)); + prop_assert_eq!(config.mask_length(), 0); + } + + /// Round-trip: solve → commit → verify with the produced config. + #[test] + fn solve_round_trips_through_irs_commit( + spec in arb_any_spec(), + ctx in arb_ctx(), + out_domain in 0usize..8, + seed: u64, + ) { + let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); + + let ds = DomainSeparator::protocol(&config) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&seed); + let mut rng = StdRng::seed_from_u64(seed); + let vector = random_vector::(&mut rng, config.vector_size); + + let mut prover_state = ProverState::new_std(&ds); + let witness = config.commit(&mut prover_state, &[&vector]); + let _ = config.open(&mut prover_state, &[&witness]); + let proof = prover_state.proof(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut verifier_state).unwrap(); + let _ = config.verify(&mut verifier_state, &[&commitment]).unwrap(); + verifier_state.check_eof().unwrap(); + } + } } diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 43d604a3..d7205dc4 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -2,48 +2,76 @@ use core::marker::PhantomData; use crate::{algebra::embedding::Embedding, engines::EngineId}; -/// Security spec definition for the protocol +/// Phantom-typed primitive — `Tagged` and `Tagged` are distinct types. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Tagged(T, PhantomData); + +impl Tagged { + pub const fn new(v: T) -> Self { + Self(v, PhantomData) + } + + pub const fn get(self) -> T { + self.0 + } +} + +/// Protocol-wide security spec. +#[derive(Debug, Clone)] pub struct SecuritySpec { - /// Protocol Mode of operation pub mode: Mode, - /// Target security bits pub target_security_bits: u32, - /// Use the unique-decoding regime (`true`) instead of the Johnson regime. - /// ZK mode requires Johnson — Construction 9.7 / Bound 2 needs OOD queries, - /// and `num_ood_samples` returns 0 in unique-decoding. - pub unique_decoding: bool, - /// Size of the input witness / vector pub vector_size: usize, - /// Starting log inverse rate for RS code pub starting_log_inv_rate: u32, - /// Initial Folding factor for the first round of sumcheck pub initial_folding_factor: usize, - /// Folding factor for subsequent round of sumcheck pub folding_factor: usize, - /// POW bits pub max_pow_bits: Option, - /// Hash Engine pub hash_id: EngineId, pub _embedding: PhantomData, } -/// Per round context struct for calculating the bounds +/// Per-round context for bound calculations. +#[derive(Debug, Clone)] pub struct RoundContext { - /// Round index pub round_index: usize, - /// Vector size for the particular round pub vector_size: usize, - /// rate for the RS encoding for the round vector pub log_inv_rate: u32, - /// Forlding factor for sumcheck pub folding_factor: u32, - /// Previous round's in domain samples count pub prev_round_in_domain_samples: usize, - /// To keep track of the errors of all the rounds pub prev_round_query_error: f64, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Mode { - Standard, + /// Regime is selectable. + Standard { unique_decoding: bool }, + /// Always Johnson regime — Construction 9.7 needs OOD queries. ZeroKnowledge, } + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OodSampleBudgetTag {} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum MaskCodeMessageLenTag {} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum LogInvRateTag {} + +/// `t_ood` — Bound 2's OOD-sample budget (produced by code-switch). +pub type OodSampleBudget = Tagged; + +/// `ℓ_zk` — C_zk message length (Theorem 9.6: ℓ_zk ≥ source mask length). +pub type MaskCodeMessageLen = Tagged; + +/// `rate = 2^-log_inv_rate`. +pub type LogInvRate = Tagged; + +impl Mode { + pub const fn unique_decoding(&self) -> bool { + matches!( + self, + Self::Standard { + unique_decoding: true + } + ) + } +} diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 9036f482..17f407ad 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -17,9 +17,9 @@ pub fn solve( ) -> sumcheck::Config { let num_rounds = num_sumcheck_rounds(spec, ctx); let mode = match spec.mode { - Mode::Standard => sumcheck::SumcheckMode::Standard, + Mode::Standard { .. } => sumcheck::SumcheckMode::Standard, Mode::ZeroKnowledge => sumcheck::SumcheckMode::ZeroKnowledge { - mask_length: zk_mask_length(), + mask_length: mask_length(), }, }; let round_pow = solve_sumcheck_round_pow(spec, irs_source); @@ -36,12 +36,14 @@ const fn num_sumcheck_rounds(spec: &SecuritySpec, ctx: &RoundCo pub const fn masks_required(spec: &SecuritySpec, ctx: &RoundContext) -> usize { match spec.mode { - Mode::Standard => 0, + Mode::Standard { .. } => 0, Mode::ZeroKnowledge => num_sumcheck_rounds(spec, ctx), } } -const fn zk_mask_length() -> usize { +/// 3 coefficients = constant + linear + quadratic, sufficient to mask each +/// degree-2 sumcheck round polynomial. +const fn mask_length() -> usize { 3 } diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index 16a2988e..bccd316f 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -36,6 +36,7 @@ pub enum SumcheckMode { ZeroKnowledge { mask_length: usize }, } +#[must_use] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs index b5e41c6e..ce49b13f 100644 --- a/src/protocols/whir_zk/mod.rs +++ b/src/protocols/whir_zk/mod.rs @@ -46,13 +46,15 @@ impl BlindingSizePolicy { main_whir_params.unique_decoding, protocol_security_level_main as f64, 0.5_f64.powi(main_whir_params.starting_log_inv_rate as i32), - ); + ) + .get(); #[allow(clippy::cast_possible_wrap)] let q_delta_2 = irs_commit::num_in_domain_queries( main_whir_params.unique_decoding, main_whir_params.security_level as f64, 0.5_f64.powi(main_whir_params.starting_log_inv_rate as i32), - ); + ) + .get(); // Default send-in-clear thresholds match query complexities. Self { From 7477a8b43640908e946149da01ec746376a054e5 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 8 May 2026 15:09:31 +0530 Subject: [PATCH 05/31] fix : rs-in-order build --- src/protocols/params/irs_commit.rs | 10 ++-------- src/protocols/params/spec.rs | 9 +++++++++ src/protocols/params/sumcheck.rs | 3 ++- src/protocols/whir/verifier.rs | 7 +++++++ src/protocols/whir_zk/committer.rs | 7 ++++--- src/protocols/whir_zk/prover.rs | 2 +- src/protocols/whir_zk/verifier.rs | 3 ++- 7 files changed, 27 insertions(+), 14 deletions(-) diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 3b47c651..3d860486 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -18,10 +18,7 @@ pub fn solve( ctx: &RoundContext, out_domain: OodSampleBudget, ) -> irs_commit::Config { - let security_target = f64::from( - spec.target_security_bits - .saturating_sub(spec.max_pow_bits.unwrap_or(0)), - ); + let security_target = f64::from(spec.protocol_security_target_bits()); let raw_rate = 2_f64.powf(-f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; let unique_decoding = spec.mode.unique_decoding(); @@ -86,10 +83,7 @@ pub fn solve_mask_code( "num_vectors ({num_vectors}) must be even — mask-proximity expects 2 · num_masks (original + fresh)", ); - let security_target = f64::from( - spec.target_security_bits - .saturating_sub(spec.max_pow_bits.unwrap_or(0)), - ); + let security_target = f64::from(spec.protocol_security_target_bits()); let raw_rate = 2_f64.powf(-f64::from(log_inv_rate.get())); // C_zk has interleaving_depth = 1 and IrsMode::Standard, so masked_message_length = l_zk. let rate = snap_rate::(l_zk, raw_rate); diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index d7205dc4..fa1a9f14 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -75,3 +75,12 @@ impl Mode { ) } } + +impl SecuritySpec { + /// Security bits the non-PoW parameters must deliver alone; the remaining + /// `max_pow_bits` are closed by PoW grinding. + pub fn protocol_security_target_bits(&self) -> u32 { + self.target_security_bits + .saturating_sub(self.max_pow_bits.unwrap_or(0)) + } +} diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 17f407ad..51b82c1d 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -62,6 +62,7 @@ fn solve_sumcheck_round_pow( code.field_bits - bounds::list_size_log2(code.log_inv_rate, code.johnson_slack) - 1.0; let achieved = sec_mca.min(sec_combination); - let pow_bits = bounds::pow_bits_to_close_gap(spec.target_security_bits, achieved); + // protocol-level PoW closes that portion separately. + let pow_bits = bounds::pow_bits_to_close_gap(spec.protocol_security_target_bits(), achieved); proof_of_work::Config::from_difficulty(pow_bits) } diff --git a/src/protocols/whir/verifier.rs b/src/protocols/whir/verifier.rs index b7dc80a0..5f4e7b84 100644 --- a/src/protocols/whir/verifier.rs +++ b/src/protocols/whir/verifier.rs @@ -63,6 +63,13 @@ impl Config { return Ok(FinalClaim::default()); } + let expected_matrix_len = + self.initial_out_domain_samples * self.initial_committer.num_vectors; + for commitment in commitments { + verify!(commitment.out_of_domain.points.len() == self.initial_out_domain_samples); + verify!(commitment.out_of_domain.matrix.len() == expected_matrix_len); + } + // Complete the constraint and evaluation matrix with OODs and their cross-terms. let (oods_evals, oods_matrix) = { let mut oods_evals = Vec::new(); diff --git a/src/protocols/whir_zk/committer.rs b/src/protocols/whir_zk/committer.rs index 1ea1a3ef..44768014 100644 --- a/src/protocols/whir_zk/committer.rs +++ b/src/protocols/whir_zk/committer.rs @@ -5,8 +5,9 @@ use tracing::instrument; use super::{utils::BlindingPolynomials, Config}; use crate::{ + algebra::embedding::Identity, hash::Hash, - protocols::{irs_commit, whir}, + protocols::whir, transcript::{ Codec, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, VerifierState, }, @@ -26,10 +27,10 @@ pub struct Commitment { #[derive(Clone, Debug)] pub struct Witness { pub f_hat_vectors: Vec>, - pub f_hat_witnesses: Vec>, + pub f_hat_witnesses: Vec>>, pub blinding_polynomials: Vec>, pub blinding_vectors: Vec>, - pub blinding_witness: irs_commit::Witness, + pub blinding_witness: whir::Witness>, } impl Config { diff --git a/src/protocols/whir_zk/prover.rs b/src/protocols/whir_zk/prover.rs index d1527c13..c2da218e 100644 --- a/src/protocols/whir_zk/prover.rs +++ b/src/protocols/whir_zk/prover.rs @@ -320,7 +320,7 @@ impl Config { let initial_in_domain = { #[cfg(feature = "tracing")] let _span = tracing::info_span!("open_f_hat").entered(); - let witness_refs: Vec<_> = f_hat_witnesses.iter().collect(); + let witness_refs: Vec<_> = f_hat_witnesses.iter().map(|w| &w.irs).collect(); self.blinded_commitment .initial_committer .open(prover_state, &witness_refs) diff --git a/src/protocols/whir_zk/verifier.rs b/src/protocols/whir_zk/verifier.rs index 7df12023..1ad7e0d6 100644 --- a/src/protocols/whir_zk/verifier.rs +++ b/src/protocols/whir_zk/verifier.rs @@ -73,10 +73,11 @@ impl Config { let masking_challenge: F = verifier_state.verifier_message(); verify!(masking_challenge != F::ZERO); let commitments = commitment.f_hat.iter().collect::>(); + let irs_commitments = commitments.iter().map(|c| &c.irs).collect::>(); let initial_in_domain = self .blinded_commitment .initial_committer - .verify(verifier_state, &commitments)?; + .verify(verifier_state, &irs_commitments)?; // Expand base queries into coset points for the first folding round. let h_gammas = self.all_gammas(&initial_in_domain.points); From 806e0ea581b7bb2c7a69f45fd7692993b6302e9f Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 13 May 2026 03:34:34 +0530 Subject: [PATCH 06/31] feat : added bounds and tests --- .../protocols/params/code_switch.txt | 10 + .../protocols/params/sumcheck.txt | 11 + src/protocols/code_switch.rs | 255 +++++++------- src/protocols/irs_commit.rs | 21 +- src/protocols/params/bounds.rs | 106 +++++- src/protocols/params/code_switch.rs | 313 ++++++++++++++++++ src/protocols/params/irs_commit.rs | 254 ++++++++------ src/protocols/params/mask_proximity.rs | 62 ++++ src/protocols/params/mod.rs | 8 +- src/protocols/params/plan.rs | 48 +++ src/protocols/params/spec.rs | 37 ++- src/protocols/params/sumcheck.rs | 149 ++++++--- src/protocols/params/test_utils.rs | 96 ++++++ 13 files changed, 1098 insertions(+), 272 deletions(-) create mode 100644 proptest-regressions/protocols/params/code_switch.txt create mode 100644 proptest-regressions/protocols/params/sumcheck.txt create mode 100644 src/protocols/params/code_switch.rs create mode 100644 src/protocols/params/mask_proximity.rs create mode 100644 src/protocols/params/plan.rs create mode 100644 src/protocols/params/test_utils.rs diff --git a/proptest-regressions/protocols/params/code_switch.txt b/proptest-regressions/protocols/params/code_switch.txt new file mode 100644 index 00000000..20850e33 --- /dev/null +++ b/proptest-regressions/protocols/params/code_switch.txt @@ -0,0 +1,10 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 7a7df094ea650db7a295d162b75dd9da9b52d1fc36947d2b07df8150cd9d906f # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 +cc b42c982074a04c7110df07cf00f45156607be547e176b1ddd5f9d994ad491ddb # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 +cc eaf09a2b6bdffa86026264679f008326498ca800260dd2f17d4370df9fb3f801 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 +cc 3887a5fa698c99109e8262e843dbd24ea94b9c9d420791e4520b5c9211a3eca0 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 100, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 7) diff --git a/proptest-regressions/protocols/params/sumcheck.txt b/proptest-regressions/protocols/params/sumcheck.txt new file mode 100644 index 00000000..8fca101b --- /dev/null +++ b/proptest-regressions/protocols/params/sumcheck.txt @@ -0,0 +1,11 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 0ffdc71948ed0315f4cf55fb8f2dd25bf71f7e41f53cd4fe35ee9da6fb125a20 # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 86, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 32, log_inv_rate: 2, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc e8ab6549772cf6bf4c3af116ebcba3dbf295ffbe2aee4a94be7df4b9f45d61ec # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 91, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: Some(10), hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc 8c4300cc375640956f81e9da5aef9ea11ef476ddc4dd253dc560afa07609262d # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 98, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc 8ea40f13c63b4c0021386369ce698a5d9289381a39dc85db43d2d69b9b4877bb # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 88, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 3, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc f1dca600886474c74d857c547baea0c2b4faf45b2946036f21a008106396eb1c # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 256, log_inv_rate: 3, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index b6f284ae..ab5b03ab 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -3,7 +3,7 @@ //! Reduces a proximity claim about oracle f (source code C) to a proximity //! claim about oracle g (target code C'). Supports optional ZK via mask oracle. -use std::fmt; +use std::{fmt, num::NonZeroUsize}; use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; @@ -29,6 +29,13 @@ use crate::{ verify, }; +/// Standard / ZeroKnowledge selector for code-switch. +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub enum Mode { + Standard, + ZeroKnowledge { message_mask_length: NonZeroUsize }, +} + /// Code-switching IOR config with optional ZK. #[must_use] #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] @@ -36,7 +43,7 @@ use crate::{ pub struct Config { pub source: IrsConfig, pub target: IrsConfig>, - pub message_mask_length: usize, // l_zk + pub mode: Mode, pub out_domain_samples: usize, } @@ -51,25 +58,13 @@ pub struct Witness { /// Verifier output from the code-switch. pub type Commitment = IrsCommitment; -/// Mask input for the code-switch prover. -// TODO : This may be removed after parameter selection PR -pub enum MaskInput<'a, F> { - Disabled, - Enabled(&'a [F]), -} - impl Config { /// Create a code-switch config. - /// - /// The orchestrator is responsible for: - /// - Setting `target_config.mask_length` for ZK mode before passing it in. - /// - Computing `out_domain_samples` from the security budget. - /// - Setting `message_mask_length` = mask oracle message length (0 for non-ZK). pub fn new( source_config: IrsConfig, target_config: IrsConfig>, out_domain_samples: usize, - message_mask_length: usize, + mode: Mode, ) -> Self { assert_eq!( source_config.num_vectors, 1, @@ -91,46 +86,65 @@ impl Config { target_config.interleaving_depth.is_power_of_two(), "target.interleaving_depth must be a power of 2" ); - // Theorem 9.6: ℓ_zk ≥ r (mask oracle must cover source randomness). - if message_mask_length > 0 { + assert!( + source_config.interleaving_depth.is_power_of_two(), + "source.interleaving_depth must be a power of 2" + ); + if let Mode::ZeroKnowledge { + message_mask_length, + } = &mode + { + let l_zk = message_mask_length.get(); + // Theorem 9.6: ℓ_zk ≥ r (mask oracle must cover source randomness). assert!( - message_mask_length >= source_config.mask_length(), - "message_mask_length ({message_mask_length}) must be >= source randomness length ({})", + l_zk >= source_config.mask_length(), + "message_mask_length ({l_zk}) must be >= source randomness length ({})", source_config.mask_length(), ); assert!( - message_mask_length - source_config.mask_length() >= out_domain_samples, - "the sampled randomness (s) length must be covering all the out of domain sample requests" + l_zk - source_config.mask_length() >= out_domain_samples, + "sampled randomness (s) length must cover all out-of-domain sample requests" ); - // t' = (in-domain queries to g via target IRS) - // + (OOD queries to g via Construction 9.7's OOD step, count = out_domain_samples). + // t' = target in-domain queries + OOD queries (Construction 9.7 step 4). // Lemma 9.5 perfect-ZK: t' ≤ r' = target.mask_length. assert!( - target_config.mask_length() - >= target_config.in_domain_samples + out_domain_samples, - "target encoder violates: t' > r', number of queries should be covered by random mask" + target_config.mask_length() >= target_config.in_domain_samples + out_domain_samples, + "target encoder violates t' ≤ r': queries must be covered by target mask" + ); + } else { + assert_eq!( + source_config.mask_length(), + 0, + "source with IRS randomness requires ZK mode", ); } - assert!( - source_config.mask_length() == 0 || message_mask_length > 0, - "source with mask_length > 0 (IRS randomness) requires ZK mode (message_mask_length > 0)" - ); - assert!( - source_config.interleaving_depth.is_power_of_two(), - "source.interleaving_depth must be a power of 2" - ); Self { source: source_config, target: target_config, - message_mask_length, + mode, out_domain_samples, } } + /// Mask oracle length `ℓ_zk`. Returns 0 in Standard mode. + pub const fn message_mask_length(&self) -> usize { + match &self.mode { + Mode::Standard => 0, + Mode::ZeroKnowledge { + message_mask_length, + } => message_mask_length.get(), + } + } + + /// `true` iff the protocol is configured for ZK. + pub const fn is_zk(&self) -> bool { + matches!(&self.mode, Mode::ZeroKnowledge { .. }) + } + /// Length of the covector for this code-switch. pub fn covector_length(&self) -> usize { - self.source.message_length() + self.message_mask_length + self.source.message_length() + self.message_mask_length() } /// Prove the code-switch. @@ -150,9 +164,9 @@ impl Config { /// `message` is `Fold(f, γ)`, the post-sumcheck polynomial of length /// `source.message_length()`. /// - /// `mask_input` is `(r || s)` from the orchestrator's shared mask tree - /// (see Construction 9.7 Step 1, p.55). Must be `None` when - /// `message_mask_length == 0`. + /// `mask` is `(r || s)` from the orchestrator's shared mask tree + /// (see Construction 9.7 Step 1, p.55). Length must equal + /// `self.message_mask_length()` — pass an empty slice in Standard mode. #[cfg_attr(feature = "tracing", instrument(skip_all))] pub fn prove( &self, @@ -161,7 +175,7 @@ impl Config { witness: &IrsWitness, covector: &mut [M::Target], folding_randomness: &[M::Target], - mask_input: &MaskInput<'_, M::Target>, + mask: &[M::Target], ) -> Witness where H: DuplexSpongeInterface, @@ -173,6 +187,7 @@ impl Config { { assert_eq!(message.len(), self.source.message_length()); assert_eq!(covector.len(), self.covector_length()); + assert_eq!(mask.len(), self.message_mask_length()); assert_eq!( 1 << folding_randomness.len(), self.source.interleaving_depth, @@ -180,41 +195,13 @@ impl Config { folding_randomness.len(), self.source.interleaving_depth, ); - let mask_msg: Option<&[M::Target]> = match &mask_input { - MaskInput::Disabled => { - assert_eq!( - self.message_mask_length, 0, - "MaskInput::Disabled requires message_mask_length == 0" - ); - None - } - MaskInput::Enabled(mask) => { - assert_eq!( - mask.len(), - self.message_mask_length, - "mask_msg length must equal message_mask_length" - ); - Some(mask) - } - }; // Step 1: g := Enc_{C'}(f, r') — Construction 9.7 Step 1, p.55 let target_witness = self.target.commit(prover_state, &[&message]); // Step 2-3: OOD challenge + answers — Construction 9.7 Steps 2-3, p.55 - // y := ze_ood(ρ) · [f; r; s] = f(α) + α^ℓ · (r,s)(α) let ood_points: Vec = prover_state.verifier_message_vec(self.out_domain_samples); - let msg_len = message.len(); - for &point in &ood_points { - let f_eval = univariate_evaluate(&message, point); - if let Some(mask) = mask_msg { - let mask_eval = univariate_evaluate(mask, point); - let shift = point.pow([msg_len as u64]); - prover_state.prover_message(&(f_eval + shift * mask_eval)); - } else { - prover_state.prover_message(&f_eval); - } - } + self.maybe_send_ood_answers(prover_state, &message, mask, &ood_points); // Step 4: in-domain queries — Construction 9.7 Step 4, p.55 let source_evaluations = self.source.open(prover_state, &[witness]); @@ -230,24 +217,13 @@ impl Config { // Covector update — sl' from Completeness proof (p.55-56) let eval_points = lift(self.source.embedding(), &source_evaluations.points); scalar_mul(covector, original_sl_coeff); - if self.message_mask_length == 0 { - // Non-ZK: single accumulate over all points - let all_points: Vec<_> = ood_points.iter().chain(&eval_points).copied().collect(); - let pows: Vec<_> = ood_rlc_coeffs - .iter() - .chain(in_domain_rlc_coeffs) - .copied() - .collect(); - geometric_accumulate(covector, pows, &all_points); - } else { - // ZK: OOD contributes to full [f; r; s], in-domain only to [f; r] - geometric_accumulate(covector, ood_rlc_coeffs.to_vec(), &ood_points); - geometric_accumulate( - &mut covector[..self.source.masked_message_length()], - in_domain_rlc_coeffs.to_vec(), - &eval_points, - ); - } + self.update_covector( + covector, + ood_rlc_coeffs, + &ood_points, + in_domain_rlc_coeffs, + &eval_points, + ); Witness { message, @@ -255,6 +231,67 @@ impl Config { } } + /// Send OOD answers `y_i = f(α_i) [+ α_i^ℓ · (r ‖ s)(α_i)]`. + /// In Standard mode the bracketed term is omitted. + fn maybe_send_ood_answers( + &self, + prover_state: &mut ProverState, + message: &[M::Target], + mask: &[M::Target], + ood_points: &[M::Target], + ) where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + M::Target: Codec<[H::U]>, + { + let msg_len = message.len(); + for &point in ood_points { + let f_eval = univariate_evaluate(message, point); + let answer = match &self.mode { + Mode::Standard => f_eval, + Mode::ZeroKnowledge { .. } => { + let mask_eval = univariate_evaluate(mask, point); + let shift = point.pow([msg_len as u64]); + f_eval + shift * mask_eval + } + }; + prover_state.prover_message(&answer); + } + } + + /// Accumulate OOD and in-domain weights into the covector. + /// Standard mode treats all points uniformly; ZK mode applies OOD over + /// the full `[f; r; s]` and in-domain over the `[f; r]` prefix only. + fn update_covector( + &self, + covector: &mut [M::Target], + ood_rlc_coeffs: &[M::Target], + ood_points: &[M::Target], + in_domain_rlc_coeffs: &[M::Target], + in_domain_points: &[M::Target], + ) { + match &self.mode { + Mode::Standard => { + let all_points: Vec<_> = + ood_points.iter().chain(in_domain_points).copied().collect(); + let pows: Vec<_> = ood_rlc_coeffs + .iter() + .chain(in_domain_rlc_coeffs) + .copied() + .collect(); + geometric_accumulate(covector, pows, &all_points); + } + Mode::ZeroKnowledge { .. } => { + geometric_accumulate(covector, ood_rlc_coeffs.to_vec(), ood_points); + geometric_accumulate( + &mut covector[..self.source.masked_message_length()], + in_domain_rlc_coeffs.to_vec(), + in_domain_points, + ); + } + } + } + /// Verify the code-switch. /// /// `folding_randomness` is the **sumcheck folding randomness `γ`** the @@ -343,7 +380,7 @@ impl fmt::Display for Config { self.source, self.target, self.out_domain_samples, - self.message_mask_length != 0, + self.is_zk(), ) } } @@ -385,17 +422,16 @@ mod tests { scalars.prop_flat_map( move |(size, src_mask_len, zk, ood, fresh_s_len, iota_s, t_in)| { // Bound 3 assumption (c): ℓ_zk - r ≥ t_ood ⇒ fresh_s_len ≥ ood. + // Also enforce `ℓ_zk = r + fresh_s_len > 0` so NonZeroUsize + // construction below is total in ZK mode. let fresh_s_len = if zk { - fresh_s_len.max(ood) + let min_fresh = usize::from(src_mask_len == 0); + fresh_s_len.max(ood).max(min_fresh) } else { fresh_s_len }; // Bound 4 assumption (a): target.mask_length ≥ t' = t_in + ood. let target_mask = if zk { t_in + ood } else { 0 }; - // ZK with source.mask_length = 0 is valid: the assert - // `source.mask_length == 0 || message_mask_length > 0` - // is trivially satisfied. Allows testing the corner - // where the mask oracle has only fresh randomness. let source_mask = if zk { src_mask_len } else { 0 }; IrsConfig::arbitrary(embedding.clone(), 1, size, source_mask, iota_s) @@ -429,8 +465,15 @@ mod tests { // r = post-fold randomness length (ι_s parallel // masks fold to a single length-mask_length chunk). let r = source.mask_length(); - let message_mask_length = if zk { r + fresh_s_len } else { 0 }; - Self::new(source.clone(), target, ood, message_mask_length) + let mode = if zk { + Mode::ZeroKnowledge { + message_mask_length: NonZeroUsize::new(r + fresh_s_len) + .expect("ZK ⇒ r + fresh_s_len > 0"), + } + } else { + Mode::Standard + }; + Self::new(source.clone(), target, ood, mode) }) }) }) @@ -481,7 +524,7 @@ mod tests { where Standard: Distribution, { - if config.message_mask_length == 0 { + if !config.is_zk() { return Vec::new(); } // Lift ι parallel masks (total length source.mask_length × ι) and fold @@ -491,19 +534,11 @@ mod tests { // Append fresh padding s of length message_mask_length - source.mask_length. mask.extend(random_vector::( rng, - config.message_mask_length - mask.len(), + config.message_mask_length() - mask.len(), )); mask } - fn mask_input(mask_msg: &[F]) -> MaskInput<'_, F> { - if mask_msg.is_empty() { - MaskInput::Disabled - } else { - MaskInput::Enabled(mask_msg) - } - } - fn test_config>(seed: u64, config: &Config>) where Standard: Distribution, @@ -538,7 +573,7 @@ mod tests { &source_witness, &mut covector, &folding_randomness, - &mask_input(&mask_msg), + &mask_msg, ); let proof = prover_state.proof(); @@ -604,7 +639,7 @@ mod tests { &source_witness, &mut covector, &folding_randomness, - &mask_input(&mask_msg), + &mask_msg, ); let proof = prover_state.proof(); @@ -661,7 +696,7 @@ mod tests { &source_witness, &mut covector, &folding_randomness, - &MaskInput::Disabled, + &[], ); let proof = prover_state.proof(); @@ -748,9 +783,7 @@ mod tests { let configs = Config::arbitrary(Identity::::new()).prop_filter( "non-ZK with ood > 0", |config| { - config.message_mask_length == 0 - && config.source.mask_length() == 0 - && config.out_domain_samples > 0 + !config.is_zk() && config.source.mask_length() == 0 && config.out_domain_samples > 0 }, ); proptest!(|(seed: u64, config in configs)| { diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index 2fb0233c..83aedb03 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -48,6 +48,16 @@ pub enum IrsMode { ZeroKnowledge { mask_length: NonZeroUsize }, } +impl IrsMode { + /// Per-polynomial IRS randomness length. Returns 0 in Standard mode. + pub const fn mask_length(&self) -> usize { + match self { + Self::Standard => 0, + Self::ZeroKnowledge { mask_length } => mask_length.get(), + } + } +} + /// Commit to vectors over an fft-friendly field F #[must_use] #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] @@ -132,11 +142,7 @@ impl Config { { assert!(vector_size.is_multiple_of(interleaving_depth)); assert!(rate > 0. && rate <= 1.); - let mask_length = match &mode { - IrsMode::Standard => 0, - IrsMode::ZeroKnowledge { mask_length } => mask_length.get(), - }; - let masked_message_length = vector_size / interleaving_depth + mask_length; + let masked_message_length = vector_size / interleaving_depth + mode.mask_length(); #[allow(clippy::cast_sign_loss)] let codeword_length = (masked_message_length as f64 / rate).ceil() as usize; let rate = masked_message_length as f64 / codeword_length as f64; @@ -191,10 +197,7 @@ impl Config { /// Per-polynomial IRS randomness length. Returns 0 in Standard mode. pub const fn mask_length(&self) -> usize { - match &self.mode { - IrsMode::Standard => 0, - IrsMode::ZeroKnowledge { mask_length } => mask_length.get(), - } + self.mode.mask_length() } /// Message length including mask coefficients. diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index 6f5380f9..1a6a5d34 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -76,6 +76,108 @@ pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { } /// PoW difficulty to close a soundness gap: max(0, target − achieved). -pub fn pow_bits_to_close_gap(target_security_bits: u32, achieved_security_bits: f64) -> Bits { - Bits::new((f64::from(target_security_bits) - achieved_security_bits).max(0.0)) +/// +/// Currently unused — solvers emit `Config::none()` PoW. Will be re-wired by +/// the cross-protocol PoW pass. +#[allow(dead_code)] +pub fn pow_bits_to_close_gap(target_security_bits: f64, achieved_security_bits: f64) -> Bits { + Bits::new((target_security_bits - achieved_security_bits).max(0.0)) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Within 1e-9 of expected — formulas use `log2`, so floats are inexact. + fn approx_eq(a: f64, b: f64) { + assert!( + (a - b).abs() < 1e-9, + "expected ≈ {b}, got {a} (diff {})", + (a - b).abs() + ); + } + + #[test] + fn list_size_unique_decoding_is_one() { + // Unique decoding → |Λ| = 1 → log2 = 0. + approx_eq(list_size_log2(1.0, 0.0), 0.0); + approx_eq(list_size_log2(5.0, 0.0), 0.0); + } + + #[test] + fn list_size_johnson_grows_as_slack_shrinks() { + // Same rate, smaller slack → larger list. + let big_slack = list_size_log2(2.0, 0.5); + let small_slack = list_size_log2(2.0, 0.05); + assert!(small_slack > big_slack, "{small_slack} > {big_slack}"); + } + + #[test] + fn list_size_johnson_grows_as_rate_drops() { + // Lower rate (larger log_inv_rate) → larger list. + let high_rate = list_size_log2(1.0, 0.1); + let low_rate = list_size_log2(4.0, 0.1); + assert!(low_rate > high_rate, "{low_rate} > {high_rate}"); + } + + fn code(log_inv_rate: f64, johnson_slack: f64, message_length: usize) -> CodeParams { + CodeParams { + log_inv_rate, + johnson_slack, + message_length, + field_bits: 64.0, + } + } + + #[test] + fn eps_mca_grows_with_message_length() { + // Longer message → larger ε (less negative log) → less security. + let short = eps_mca_log2(&code(2.0, 0.1, 16)); + let long = eps_mca_log2(&code(2.0, 0.1, 1024)); + assert!(long > short, "{long} > {short}"); + } + + #[test] + fn eps_mca_grows_with_log_inv_rate() { + // Lower rate (larger log_inv_rate) → larger ε. + let high_rate = eps_mca_log2(&code(1.0, 0.1, 128)); + let low_rate = eps_mca_log2(&code(4.0, 0.1, 128)); + assert!(low_rate > high_rate, "{low_rate} > {high_rate}"); + } + + #[test] + fn one_minus_distance_unique_is_midpoint() { + // Unique decoding: 1 - δ = (1 + ρ) / 2. + // log_inv_rate = 1 → ρ = 0.5 → (1 + 0.5)/2 = 0.75. + approx_eq(one_minus_distance_log2(1.0, 0.0), 0.75_f64.log2()); + } + + #[test] + fn one_minus_distance_johnson_more_negative_than_unique() { + // Johnson allows larger δ than unique decoding → smaller (1-δ) → more + // negative log. + let unique = one_minus_distance_log2(2.0, 0.0); + let johnson = one_minus_distance_log2(2.0, 0.1); + assert!(johnson < unique, "{johnson} < {unique}"); + } + + #[test] + fn ood_per_sample_exact() { + // (k-1)/|F| with k=2, |F|=2^64 → log2 = -64. + approx_eq(ood_per_sample_log2(2, 64.0), -64.0); + // k=9, |F|=2^64 → (8)/2^64 → log2 = 3 - 64 = -61. + approx_eq(ood_per_sample_log2(9, 64.0), -61.0); + } + + #[test] + fn pow_bits_zero_when_achieved_meets_target() { + assert!(pow_bits_to_close_gap(80.0, 100.0).is_zero()); + assert!(pow_bits_to_close_gap(80.0, 80.0).is_zero()); + } + + #[test] + fn pow_bits_fills_gap_to_target() { + let bits = pow_bits_to_close_gap(80.0, 50.0); + approx_eq(f64::from(bits), 30.0); + } } diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs new file mode 100644 index 00000000..0aafc182 --- /dev/null +++ b/src/protocols/params/code_switch.rs @@ -0,0 +1,313 @@ +//! Parameter selection for the code-switching IOR (Construction 9.7, p.55). +//! +//! Computes `t_ood` (Bound 2 / Lemma 9.9 first error term) and sizes the ZK +//! mask oracle `ℓ_zk` per Theorem 9.6 + Lemma 9.5. + +use std::num::NonZeroUsize; + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, + protocols::{ + code_switch, + irs_commit::{self, Config as IrsConfig}, + params::{ + plan::RoundModeParams, + spec::{MaskCodeMessageLen, Mode, SecuritySpec}, + }, + }, +}; + +/// Assemble the [`code_switch::Config`] from precomputed `t_ood` and a +/// mode-typed `zk` context. +/// +/// In ZK mode, `zk` carries the `l_zk` produced by [`compute_l_zk`]; the +/// orchestrator must have used the same value for `irs_commit::solve_mask_code` +/// so both consumers see the same mask-oracle length. +pub fn solve( + source: IrsConfig, + target: IrsConfig>, + t_ood: usize, + zk: &RoundModeParams, +) -> code_switch::Config { + let mode = match zk { + RoundModeParams::Standard => code_switch::Mode::Standard, + RoundModeParams::ZeroKnowledge { l_zk, .. } => { + let l_zk = l_zk.get(); + assert!( + l_zk >= source.mask_length() + t_ood, + "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Bound 3", + source.mask_length(), + t_ood, + ); + code_switch::Mode::ZeroKnowledge { + message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), + } + } + }; + code_switch::Config::new(source, target, t_ood, mode) +} + +/// `ℓ_zk = next_power_of_two(r + t_ood)` — shared by code-switch and C_zk. +/// +/// Bound 3 / Lemma 9.3 requires `ℓ_zk ≥ r + t_ood`; pow2 padding lets the +/// same value drive `irs_commit::solve_mask_code`'s NTT-order assertion. +pub const fn compute_l_zk( + source: &irs_commit::Config, + t_ood: usize, +) -> MaskCodeMessageLen { + MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()) +} + +/// `t_ood` from Bound 2 / Lemma 9.9 first error term. +/// +/// Solves `(|Λ(C')| · |Λ(C_zk)|)² / 2 · ((ℓ + ℓ_zk - 1) / |F|)^{t_ood} ≤ +/// 2^{-security}`. In ZK mode `ℓ_zk = r + t_ood` is mutually dependent with +/// `t_ood`; iterate to the fixed point. +pub fn compute_t_ood( + spec: &SecuritySpec, + source: &IrsConfig, + target_list_size: f64, + c_zk_list_size: Option, +) -> usize { + const MAX_ITER: usize = 32; + + let security_target = spec.protocol_security_target_bits(); + let field_bits = M::Target::field_size_bits(); + let unique_decoding = spec.mode.unique_decoding(); + let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); + let message_length = source.message_length(); + let source_mask_length = source.mask_length(); + + let solve_for_degree = |degree: usize| { + irs_commit::num_ood_samples( + unique_decoding, + security_target, + field_bits, + combined_list_size, + degree, + ) + }; + + if !matches!(spec.mode, Mode::ZeroKnowledge) { + return solve_for_degree(message_length); + } + + // ZK: t_ood = f(ℓ + r + t_ood); iterate. + let mut t_ood = 0; + for _ in 0..MAX_ITER { + let new_t_ood = solve_for_degree(message_length + source_mask_length + t_ood); + if new_t_ood == t_ood { + return t_ood; + } + t_ood = new_t_ood; + } + panic!("compute_t_ood did not converge in {MAX_ITER} iterations"); +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::protocols::params::{ + irs_commit as params_irs, + plan::RoundModeParams, + spec::{LogInvRate, OodSampleBudget, RoundContext}, + test_utils::{ + arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, + deterministic_standard_spec, TestEmbedding, + }, + }; + + type M = TestEmbedding; + + fn arb_zk_spec() -> impl Strategy> { + utils_zk_spec(80..=128) + } + + fn arb_standard_johnson_spec() -> impl Strategy> { + utils_standard_spec(80..=128) + } + + /// Orchestrator-style: build source + target IRS and the matching `t_ood`. + /// Iterates target until its `codeword_length` stabilizes (target's realized + /// rate depends on its mask budget, which depends on `t_ood`). + fn build_inputs( + spec: &SecuritySpec, + log_inv_rate: u32, + folding_factor: u32, + num_vars: u32, + c_zk_list_size: Option, + ) -> (IrsConfig, IrsConfig, usize) { + let source_ctx = RoundContext { + round_index: 0, + vector_size: 1usize << num_vars, + log_inv_rate, + folding_factor, + prev_round_in_domain_samples: 0, + prev_round_query_error: 0.0, + }; + let source = params_irs::solve(spec, &source_ctx, OodSampleBudget::new(0)); + + let target_ctx = RoundContext { + round_index: 1, + vector_size: source.message_length(), + log_inv_rate: log_inv_rate + folding_factor - 1, + folding_factor, + prev_round_in_domain_samples: source.in_domain_samples, + prev_round_query_error: 0.0, + }; + + let mut target = params_irs::solve(spec, &target_ctx, OodSampleBudget::new(0)); + for _ in 0..8 { + let t_ood = compute_t_ood(spec, &source, target.list_size(), c_zk_list_size); + let new_target = params_irs::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); + if new_target.codeword_length == target.codeword_length { + return (source, new_target, t_ood); + } + target = new_target; + } + panic!("target IRS did not stabilize"); + } + + /// `num_vars ≥ 2 * folding_factor` so target's `vector_size` stays divisible + /// by target's `interleaving_depth = 1 << folding_factor`. + fn arb_dims() -> impl Strategy { + (1u32..=3, 1u32..=2).prop_flat_map(|(log_inv_rate, folding_factor)| { + let min_num_vars = 2 * folding_factor; + ( + Just(log_inv_rate), + Just(folding_factor), + min_num_vars..=(min_num_vars + 4), + ) + }) + } + + proptest! { + /// Standard mode: `Config::new` assertions pass, `t_ood ≥ 1` in Johnson. + #[test] + fn solve_standard_assembles( + spec in arb_standard_johnson_spec(), + (log_inv_rate, folding_factor, num_vars) in arb_dims(), + ) { + let (source, target, t_ood) = + build_inputs(&spec, log_inv_rate, folding_factor, num_vars, None); + let config = solve(source, target, t_ood, &RoundModeParams::Standard); + prop_assert!(matches!(config.mode, code_switch::Mode::Standard)); + prop_assert!(config.out_domain_samples >= 1); + } + + /// ZK mode: `ℓ_zk = next_power_of_two(r + t_ood)` shared with C_zk. + #[test] + fn solve_zk_mask_equals_padded_r_plus_t_ood( + spec in arb_zk_spec(), + (log_inv_rate, folding_factor, num_vars) in arb_dims(), + ) { + // Bootstrap C_zk with a placeholder t_ood to break the + // t_ood ↔ c_zk.list_size circular dependency. + let placeholder_source_ctx = RoundContext { + round_index: 0, + vector_size: 1usize << num_vars, + log_inv_rate, + folding_factor, + prev_round_in_domain_samples: 0, + prev_round_query_error: 0.0, + }; + let placeholder_source = params_irs::solve( + &spec, + &placeholder_source_ctx, + OodSampleBudget::new(0), + ); + let c_zk_placeholder = params_irs::solve_mask_code( + &spec, + compute_l_zk(&placeholder_source, 1), + placeholder_source.mask_length(), + LogInvRate::new(log_inv_rate), + 2, + ); + let (source, target, t_ood) = build_inputs( + &spec, log_inv_rate, folding_factor, num_vars, Some(c_zk_placeholder.list_size()), + ); + let r = source.mask_length(); + let l_zk = compute_l_zk(&source, t_ood); + let c_zk = params_irs::solve_mask_code( + &spec, + l_zk, + r, + LogInvRate::new(log_inv_rate), + 2, + ); + // Fixed-point check: t_ood was computed with the placeholder + // C_zk's list_size; the final C_zk's list_size must agree. + let recomputed_t_ood = + compute_t_ood(&spec, &source, target.list_size(), Some(c_zk.list_size())); + prop_assert_eq!( + t_ood, recomputed_t_ood, + "t_ood computed with placeholder C_zk must equal t_ood with final C_zk", + ); + let zk = RoundModeParams::ZeroKnowledge { c_zk, l_zk }; + let config = solve(source, target, t_ood, &zk); + prop_assert_eq!(config.message_mask_length(), (r + t_ood).next_power_of_two()); + } + + /// `compute_t_ood` converges and returns `t_ood ≥ 1` in Johnson regime. + #[test] + fn compute_t_ood_converges( + spec in arb_zk_spec(), + (log_inv_rate, folding_factor, num_vars) in arb_dims(), + ) { + let (_source, _target, t_ood) = + build_inputs(&spec, log_inv_rate, folding_factor, num_vars, None); + prop_assert!(t_ood >= 1); + } + } + + /// Smoke test: the generics compile and `solve` works end-to-end with a + /// non-identity embedding (`M::Source ≠ M::Target`). + #[test] + fn solve_works_with_basefield_embedding() { + use crate::algebra::{embedding::Basefield, fields::Field64_2}; + type NonIdM = Basefield; + + let spec_source: SecuritySpec = deterministic_standard_spec(); + let spec_target: SecuritySpec> = deterministic_standard_spec(); + + let source_ctx = RoundContext { + round_index: 0, + vector_size: 16, + log_inv_rate: 1, + folding_factor: 2, + prev_round_in_domain_samples: 0, + prev_round_query_error: 0.0, + }; + let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); + + let target_ctx = RoundContext { + round_index: 1, + vector_size: source.message_length(), + log_inv_rate: source_ctx.log_inv_rate + source_ctx.folding_factor - 1, + folding_factor: source_ctx.folding_factor, + prev_round_in_domain_samples: source.in_domain_samples, + prev_round_query_error: 0.0, + }; + + let mut target = params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(0)); + let mut t_ood = compute_t_ood(&spec_source, &source, target.list_size(), None); + for _ in 0..8 { + let new_target = + params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(t_ood)); + if new_target.codeword_length == target.codeword_length { + target = new_target; + break; + } + target = new_target; + t_ood = compute_t_ood(&spec_source, &source, target.list_size(), None); + } + + let config = solve(source, target, t_ood, &RoundModeParams::Standard); + assert!(matches!(config.mode, code_switch::Mode::Standard)); + } +} diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 3d860486..edbf26fc 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -1,9 +1,9 @@ //! Parameter selection for the IRS commit protocol. -use std::iter; +use std::num::NonZeroUsize; use crate::{ - algebra::{embedding::Embedding, ntt}, + algebra::embedding::Embedding, protocols::{ irs_commit::{self, num_in_domain_queries, IrsMode}, params::spec::{ @@ -12,41 +12,42 @@ use crate::{ }, }; -/// Solve per-round IRS-commit parameters. ZK mask sized per Lemma 9.5. +/// Solve per-round IRS-commit parameters. ZK mask sized per Lemma 9.5, +/// padded so `message + mask` is a power of 2 (NTT-valid codeword length). pub fn solve( spec: &SecuritySpec, ctx: &RoundContext, - out_domain: OodSampleBudget, + out_domain_samples: OodSampleBudget, ) -> irs_commit::Config { - let security_target = f64::from(spec.protocol_security_target_bits()); - let raw_rate = 2_f64.powf(-f64::from(ctx.log_inv_rate)); + let security_target = spec.protocol_security_target_bits(); + let rate = 2_f64.powf(-f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; let unique_decoding = spec.mode.unique_decoding(); + let message_length = ctx.vector_size / interleaving_depth; let mode = match spec.mode { Mode::Standard { .. } => IrsMode::Standard, Mode::ZeroKnowledge => { - // Lemma 9.5 ZK budget: every revealed evaluation counts. - let in_domain = num_in_domain_queries(unique_decoding, security_target, raw_rate); - let mask_length = in_domain - .checked_add(out_domain.get()) - .expect("usize overflow in mask_length"); + let min_mask = num_in_domain_queries(unique_decoding, security_target, rate) + .checked_add(out_domain_samples.get()) + .expect("usize overflow"); + // Pad to pow2: Lemma 9.5 is `≥` so over-allocating is safe. + let mask_length = message_length + .checked_add(min_mask.get()) + .expect("usize overflow") + .next_power_of_two() + .checked_sub(message_length) + .and_then(NonZeroUsize::new) + .expect("mask_length non-zero in ZK"); IrsMode::ZeroKnowledge { mask_length } } }; - let mask_length_value = match &mode { - IrsMode::Standard => 0, - IrsMode::ZeroKnowledge { mask_length } => mask_length.get(), - }; - let masked_message_length = ctx.vector_size / interleaving_depth + mask_length_value; - let rate = snap_rate::(masked_message_length, raw_rate); - irs_commit::Config::new( security_target, unique_decoding, spec.hash_id, - // Orchestrator commits one vector per round. + // num_vectors: orchestrator commits one vector per round. 1, ctx.vector_size, interleaving_depth, @@ -57,11 +58,10 @@ pub fn solve( /// Solve the shared C_zk IRS config for committing mask polynomials. /// -/// - `l_zk` — message length (Theorem 9.6: ℓ_zk ≥ `source_mask_length`). -/// - `source_mask_length` — `r`, the source IRS mask length. +/// - `l_zk` — message length. Must be a power of 2 (caller pads it; see assert). +/// - `source_mask_length` — `r`, the source IRS mask length (Theorem 9.6). /// - `log_inv_rate` — C_zk rate. -/// - `num_vectors` — total masks per commit; must equal `2 * num_masks` to be -/// consumable by `mask_proximity::Config::new` (original/fresh pairs). +/// - `num_vectors` — total masks per commit; `2 * num_masks` for mask-proximity. pub fn solve_mask_code( spec: &SecuritySpec, l_zk: MaskCodeMessageLen, @@ -76,17 +76,16 @@ pub fn solve_mask_code( ); assert!( l_zk >= source_mask_length, - "Theorem 9.6: ℓ_zk ({l_zk}) must be ≥ source mask length ({source_mask_length})", + "Theorem 9.6: ℓ_zk ({l_zk}) ≥ source mask length ({source_mask_length})", ); + assert!(l_zk.is_power_of_two(), "ℓ_zk ({l_zk}) must be a power of 2"); assert!( num_vectors.is_multiple_of(2), - "num_vectors ({num_vectors}) must be even — mask-proximity expects 2 · num_masks (original + fresh)", + "num_vectors ({num_vectors}) must be even (mask-proximity original/fresh pairs)", ); - let security_target = f64::from(spec.protocol_security_target_bits()); - let raw_rate = 2_f64.powf(-f64::from(log_inv_rate.get())); - // C_zk has interleaving_depth = 1 and IrsMode::Standard, so masked_message_length = l_zk. - let rate = snap_rate::(l_zk, raw_rate); + let security_target = spec.protocol_security_target_bits(); + let rate = 2_f64.powf(-f64::from(log_inv_rate.get())); irs_commit::Config::new( security_target, @@ -101,19 +100,6 @@ pub fn solve_mask_code( ) } -/// Snap `rate` so `Config::new`'s codeword sizing lands on a valid power-of-two -/// NTT order. Returns a rate `≤ raw_rate`. -fn snap_rate(masked_message_length: usize, raw_rate: f64) -> f64 { - #[allow(clippy::cast_sign_loss)] - let desired = (masked_message_length as f64 / raw_rate).ceil() as usize; - let codeword_length = iter::successors(ntt::next_order::(desired), |&n| { - ntt::next_order::(n + 1) - }) - .find(|n| n.is_power_of_two()) - .expect("no valid power-of-two NTT order ≥ desired codeword length"); - masked_message_length as f64 / codeword_length as f64 -} - #[cfg(test)] mod tests { use std::marker::PhantomData; @@ -123,61 +109,121 @@ mod tests { use super::*; use crate::{ - algebra::{embedding::Identity, fields::Field64, random_vector}, + algebra::random_vector, hash, + protocols::params::test_utils::{arb_round_ctx, arb_spec, arb_zk_spec, TestEmbedding}, transcript::{DomainSeparator, ProverState, VerifierState}, }; - type F = Field64; - type M = Identity; - - fn arb_spec_with(mode: impl Strategy) -> impl Strategy> { - (mode, 80u32..=128, 1u32..=4, prop::option::of(0u32..=20)).prop_map( - |(mode, target_security_bits, starting_log_inv_rate, max_pow_bits)| SecuritySpec { - mode, - target_security_bits, - vector_size: 1 << 8, - starting_log_inv_rate, - initial_folding_factor: 4, - folding_factor: 4, - max_pow_bits, - hash_id: hash::BLAKE3, - _embedding: PhantomData, + type M = TestEmbedding; + type F = ::Source; + + fn minimal_zk_spec() -> SecuritySpec { + SecuritySpec { + mode: Mode::ZeroKnowledge, + target_security_bits: 80, + max_pow_bits: None, + hash_id: hash::BLAKE3, + _embedding: PhantomData, + } + } + + fn minimal_standard_spec() -> SecuritySpec { + SecuritySpec { + mode: Mode::Standard { + unique_decoding: false, }, - ) + target_security_bits: 80, + max_pow_bits: None, + hash_id: hash::BLAKE3, + _embedding: PhantomData, + } } - fn arb_zk_spec() -> impl Strategy> { - arb_spec_with(Just(Mode::ZeroKnowledge)) + #[test] + #[should_panic(expected = "C_zk only exists in ZK mode")] + fn solve_mask_code_rejects_standard_spec() { + let _ = solve_mask_code( + &minimal_standard_spec(), + MaskCodeMessageLen::new(2), + 0, + LogInvRate::new(1), + 2, + ); } - fn arb_standard_spec() -> impl Strategy> { - arb_spec_with(any::().prop_map(|unique_decoding| Mode::Standard { unique_decoding })) + #[test] + #[should_panic(expected = "must be a power of 2")] + fn solve_mask_code_rejects_non_pow2_l_zk() { + let _ = solve_mask_code( + &minimal_zk_spec(), + MaskCodeMessageLen::new(3), + 0, + LogInvRate::new(1), + 2, + ); + } + + #[test] + #[should_panic(expected = "Theorem 9.6")] + fn solve_mask_code_rejects_l_zk_below_source_mask_length() { + let _ = solve_mask_code( + &minimal_zk_spec(), + MaskCodeMessageLen::new(2), + 4, + LogInvRate::new(1), + 2, + ); + } + + #[test] + #[should_panic(expected = "must be even")] + fn solve_mask_code_rejects_odd_num_vectors() { + let _ = solve_mask_code( + &minimal_zk_spec(), + MaskCodeMessageLen::new(2), + 0, + LogInvRate::new(1), + 3, + ); + } + + fn arb_zk_spec_default() -> impl Strategy> { + arb_zk_spec(80..=128) } - fn arb_any_spec() -> impl Strategy> { - prop_oneof![arb_zk_spec(), arb_standard_spec()] + /// IRS-specific: vary `unique_decoding` to exercise both regimes inside + /// `irs_commit::Config::new`. + fn arb_standard_spec() -> impl Strategy> { + any::() + .prop_flat_map(|unique_decoding| arb_spec(Mode::Standard { unique_decoding }, 80..=128)) } - fn arb_ctx() -> impl Strategy { - (4u32..=8, 1u32..=4, 1u32..=3).prop_map(|(log_size, log_inv_rate, folding_factor)| { - RoundContext { - round_index: 0, - vector_size: 1_usize << log_size, - log_inv_rate, - folding_factor, - prev_round_in_domain_samples: 0, - prev_round_query_error: 0.0, - } - }) + fn commit_open_verify(config: &irs_commit::Config, seed: u64) -> irs_commit::Witness { + let ds = DomainSeparator::protocol(config) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&seed); + let mut rng = StdRng::seed_from_u64(seed); + let vector = random_vector::(&mut rng, config.vector_size); + + let mut prover_state = ProverState::new_std(&ds); + let witness = config.commit(&mut prover_state, &[&vector]); + let _ = config.open(&mut prover_state, &[&witness]); + let proof = prover_state.proof(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let commitment = config.receive_commitment(&mut verifier_state).unwrap(); + let _ = config.verify(&mut verifier_state, &[&commitment]).unwrap(); + verifier_state.check_eof().unwrap(); + witness } proptest! { /// Lemma 9.5: ZK mask covers all revealed evaluations. #[test] fn zk_mask_covers_lemma_9_5( - spec in arb_zk_spec(), - ctx in arb_ctx(), + spec in arb_zk_spec_default(), + ctx in arb_round_ctx(), out_domain in 0usize..16, ) { let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); @@ -188,38 +234,46 @@ mod tests { ); } - /// Standard mode produces no IRS randomness. + /// Standard mode produces no IRS randomness regardless of input. #[test] - fn standard_has_no_mask(spec in arb_standard_spec(), ctx in arb_ctx()) { - let config = solve(&spec, &ctx, OodSampleBudget::new(0)); + fn standard_has_no_mask( + spec in arb_standard_spec(), + ctx in arb_round_ctx(), + out_domain in 0usize..8, + ) { + let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); prop_assert_eq!(config.mask_length(), 0); } - /// Round-trip: solve → commit → verify with the produced config. + /// ZK round-trip + witness shape check. #[test] - fn solve_round_trips_through_irs_commit( - spec in arb_any_spec(), - ctx in arb_ctx(), + fn zk_round_trips( + spec in arb_zk_spec_default(), + ctx in arb_round_ctx(), out_domain in 0usize..8, seed: u64, ) { let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); + prop_assert!(config.mask_length() > 0, "ZK mode must produce non-zero mask"); + let witness = commit_open_verify(&config, seed); + prop_assert_eq!( + witness.masks.len(), + config.mask_length() * config.num_messages(), + "witness mask vector size", + ); + } - let ds = DomainSeparator::protocol(&config) - .session(&format!("Test at {}:{}", file!(), line!())) - .instance(&seed); - let mut rng = StdRng::seed_from_u64(seed); - let vector = random_vector::(&mut rng, config.vector_size); - - let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit(&mut prover_state, &[&vector]); - let _ = config.open(&mut prover_state, &[&witness]); - let proof = prover_state.proof(); - - let mut verifier_state = VerifierState::new_std(&ds, &proof); - let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - let _ = config.verify(&mut verifier_state, &[&commitment]).unwrap(); - verifier_state.check_eof().unwrap(); + /// Standard round-trip + empty-mask check. + #[test] + fn standard_round_trips( + spec in arb_standard_spec(), + ctx in arb_round_ctx(), + seed: u64, + ) { + let config = solve(&spec, &ctx, OodSampleBudget::new(0)); + prop_assert_eq!(config.mask_length(), 0); + let witness = commit_open_verify(&config, seed); + prop_assert!(witness.masks.is_empty(), "Standard mode must produce no masks"); } } } diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs new file mode 100644 index 00000000..198dd682 --- /dev/null +++ b/src/protocols/params/mask_proximity.rs @@ -0,0 +1,62 @@ +//! Parameter selection for the mask-proximity protocol (Construction 7.2). +//! +//! Mask-proximity spot-checks each committed mask oracle against C_zk via +//! γ-combination (Lemma 7.4). ZK-only — Standard mode never invokes it. + +use ark_ff::Field; + +use crate::{ + algebra::embedding::Identity, + protocols::{irs_commit, mask_proximity}, +}; + +/// Assemble a [`mask_proximity::Config`] from the shared C_zk IRS config and +/// the number of mask polynomials the protocol commits to. +/// +/// `c_zk` must be sized with `num_vectors == 2 * num_masks` (Construction 7.2 +/// commits originals and their fresh mask-of-masks side by side in the shared +/// tree). The orchestrator obtains `c_zk` via `irs_commit::solve_mask_code` +/// with that same `num_vectors`. +pub fn solve( + c_zk: irs_commit::Config>, + num_masks: usize, +) -> mask_proximity::Config { + mask_proximity::Config::new(c_zk, num_masks) +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::protocols::params::{ + irs_commit as params_irs, + spec::{LogInvRate, MaskCodeMessageLen}, + test_utils::arb_zk_spec, + }; + + proptest! { + /// `solve` produces a Config satisfying `mask_proximity::Config::new`'s + /// invariants (`num_vectors == 2 * num_masks`, `interleaving_depth == 1`). + #[test] + fn solve_assembles( + spec in arb_zk_spec(80..=128), + log_inv_rate in 1u32..=3, + num_masks in 1usize..=8, + l_zk_log in 1u32..=5, + ) { + let l_zk = MaskCodeMessageLen::new(1usize << l_zk_log); + let c_zk = params_irs::solve_mask_code( + &spec, + l_zk, + 0, + LogInvRate::new(log_inv_rate), + 2 * num_masks, + ); + let config = solve(c_zk, num_masks); + prop_assert_eq!(config.num_masks, num_masks); + prop_assert_eq!(config.c_zk_commit.num_vectors, 2 * num_masks); + prop_assert_eq!(config.c_zk_commit.interleaving_depth, 1); + } + } +} diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index 7d47774f..59391882 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -1,6 +1,12 @@ // This module contains the parameter selection and security target logic. -pub mod bounds; +pub(crate) mod bounds; +pub mod code_switch; pub mod irs_commit; +pub mod mask_proximity; +pub mod plan; pub mod spec; pub mod sumcheck; + +#[cfg(test)] +pub(crate) mod test_utils; diff --git a/src/protocols/params/plan.rs b/src/protocols/params/plan.rs new file mode 100644 index 00000000..ff7ed1df --- /dev/null +++ b/src/protocols/params/plan.rs @@ -0,0 +1,48 @@ +//! Derived parameter plan for the Construction 9.7 ZK protocol. +//! +//! Built by the orchestrator from a [`SecuritySpec`] + [`TuningSpec`]; owns +//! the cross-protocol resolved values (source/target IRS, C_zk, t_ood, ℓ_zk, +//! per-round sub-protocol configs) so downstream code doesn't coordinate them. + +use crate::{ + algebra::embedding::{Embedding, Identity}, + protocols::{ + code_switch, irs_commit, + params::spec::{MaskCodeMessageLen, SecuritySpec, TuningSpec}, + sumcheck, + }, +}; + +/// Full derived parameter plan for one protocol run. +#[derive(Clone, Debug)] +pub struct ParameterPlan { + pub security: SecuritySpec, + pub tuning: TuningSpec, + pub rounds: Vec>, +} + +/// Parameters for a single round (sumcheck + code-switch). +#[derive(Clone, Debug)] +pub struct RoundParams { + pub round_index: usize, + pub source_irs: irs_commit::Config, + pub target_irs: irs_commit::Config>, + pub sumcheck: sumcheck::Config, + pub code_switch: code_switch::Config, + pub zk: RoundModeParams, +} + +#[derive(Clone, Debug)] +pub enum RoundModeParams { + Standard, + ZeroKnowledge { + c_zk: irs_commit::Config>, + l_zk: MaskCodeMessageLen, + }, +} + +impl RoundModeParams { + pub const fn is_zk(&self) -> bool { + matches!(self, Self::ZeroKnowledge { .. }) + } +} diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index fa1a9f14..40c283a0 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -16,18 +16,30 @@ impl Tagged { } } -/// Protocol-wide security spec. +/// Security-target spec — *what* security the user wants. Tuning knobs live +/// in [`TuningSpec`]. #[derive(Debug, Clone)] pub struct SecuritySpec { pub mode: Mode, pub target_security_bits: u32, + // TODO: cross-protocol PoW pass; until then, set this to `None` or `Some(0)` + // to avoid silently surrendering `max_pow_bits` of security. + pub max_pow_bits: Option, + pub hash_id: EngineId, + pub _embedding: PhantomData, +} + +/// Tuning knobs — proof-size / prover-time / soundness-margin tradeoffs. +#[derive(Debug, Clone)] +pub struct TuningSpec { + /// Witness vector size (input polynomial coefficient count). pub vector_size: usize, + /// Starting log inverse rate for the initial RS code. pub starting_log_inv_rate: u32, + /// Folding factor for the first (initial) sumcheck round. pub initial_folding_factor: usize, + /// Folding factor for subsequent sumcheck rounds. pub folding_factor: usize, - pub max_pow_bits: Option, - pub hash_id: EngineId, - pub _embedding: PhantomData, } /// Per-round context for bound calculations. @@ -37,6 +49,8 @@ pub struct RoundContext { pub vector_size: usize, pub log_inv_rate: u32, pub folding_factor: u32, + // Reserved for the orchestrator's combination-error sizing; unused by + // current solvers. pub prev_round_in_domain_samples: usize, pub prev_round_query_error: f64, } @@ -79,8 +93,17 @@ impl Mode { impl SecuritySpec { /// Security bits the non-PoW parameters must deliver alone; the remaining /// `max_pow_bits` are closed by PoW grinding. - pub fn protocol_security_target_bits(&self) -> u32 { - self.target_security_bits - .saturating_sub(self.max_pow_bits.unwrap_or(0)) + /// + /// **Until the cross-protocol PoW pass lands**, solvers emit no PoW — + /// so subtracting `max_pow_bits` would silently under-target security. + /// This function therefore asserts `max_pow_bits` is zero. Re-enable the + /// subtraction when PoW grinding is wired in. + pub fn protocol_security_target_bits(&self) -> f64 { + assert!( + self.max_pow_bits.unwrap_or(0) == 0, + "max_pow_bits must be None or Some(0) until cross-protocol PoW grinding lands; \ + setting it nonzero now would silently surrender that many bits of security", + ); + f64::from(self.target_security_bits) } } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 51b82c1d..a0acad9e 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -1,68 +1,133 @@ +//! Parameter selection for the per-round sumcheck protocol. +//! +//! Produces a [`sumcheck::Config`] from a `RoundContext` and the ZK context. +//! ZK mode adds a degree-2 masking polynomial per round (Lemma 6.4, p.38). + use crate::{ algebra::embedding::Embedding, protocols::{ - irs_commit, - params::{ - bounds::{self, CodeParams}, - spec::{Mode, RoundContext, SecuritySpec}, - }, + params::{plan::RoundModeParams, spec::RoundContext}, proof_of_work, sumcheck, }, }; +/// Solve sumcheck parameters for one round. pub fn solve( - spec: &SecuritySpec, ctx: &RoundContext, - irs_source: &irs_commit::Config, + zk: &RoundModeParams, ) -> sumcheck::Config { - let num_rounds = num_sumcheck_rounds(spec, ctx); - let mode = match spec.mode { - Mode::Standard { .. } => sumcheck::SumcheckMode::Standard, - Mode::ZeroKnowledge => sumcheck::SumcheckMode::ZeroKnowledge { - mask_length: mask_length(), + let num_rounds = num_sumcheck_rounds(ctx); + let mode = match zk { + RoundModeParams::Standard => sumcheck::SumcheckMode::Standard, + RoundModeParams::ZeroKnowledge { .. } => sumcheck::SumcheckMode::ZeroKnowledge { + mask_length: zk_mask_length(), }, }; - let round_pow = solve_sumcheck_round_pow(spec, irs_source); - sumcheck::Config::new(ctx.vector_size, round_pow, num_rounds, mode) + sumcheck::Config::new( + ctx.vector_size, + proof_of_work::Config::none(), + num_rounds, + mode, + ) } -const fn num_sumcheck_rounds(spec: &SecuritySpec, ctx: &RoundContext) -> usize { - if ctx.round_index == 0 { - spec.initial_folding_factor +/// Number of mask polynomials required for one round of sumcheck. +pub const fn masks_required(zk: &RoundModeParams, ctx: &RoundContext) -> usize { + if zk.is_zk() { + num_sumcheck_rounds(ctx) } else { - spec.folding_factor + 0 } } -pub const fn masks_required(spec: &SecuritySpec, ctx: &RoundContext) -> usize { - match spec.mode { - Mode::Standard { .. } => 0, - Mode::ZeroKnowledge => num_sumcheck_rounds(spec, ctx), - } +const fn num_sumcheck_rounds(ctx: &RoundContext) -> usize { + ctx.folding_factor as usize } -/// 3 coefficients = constant + linear + quadratic, sufficient to mask each -/// degree-2 sumcheck round polynomial. -const fn mask_length() -> usize { +/// 3 coefficients suffice to mask the degree-2 sumcheck round polynomial — +/// Lemma 6.4, p.38. +const fn zk_mask_length() -> usize { 3 } -/// Sumcheck-specific PoW sizing: closes the per-round Lemma 6.5 soundness gap. -fn solve_sumcheck_round_pow( - spec: &SecuritySpec, - irs_source: &irs_commit::Config, -) -> proof_of_work::Config { - let code = CodeParams::from_irs(irs_source); +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::protocols::params::test_utils::{ + arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, build_minimal_round_mode, + }; - // Lemma 6.5 per-round error has two terms; security in bits is the min. - // TODO: extend with `ℓ_zk · |Λ_C_zk|` factors in ZK mode once mask-code - // params are available (PR 2). - let sec_mca = -bounds::eps_mca_log2(&code); - let sec_combination = - code.field_bits - bounds::list_size_log2(code.log_inv_rate, code.johnson_slack) - 1.0; - let achieved = sec_mca.min(sec_combination); + proptest! { + /// Standard spec produces `SumcheckMode::Standard`. + #[test] + fn standard_mode_propagates( + spec in arb_standard_johnson_spec(80..=128), + ctx in arb_round_ctx(), + ) { + let zk = build_minimal_round_mode(&spec); + let config = solve(&ctx, &zk); + prop_assert!(matches!(config.mode, sumcheck::SumcheckMode::Standard)); + } - // protocol-level PoW closes that portion separately. - let pow_bits = bounds::pow_bits_to_close_gap(spec.protocol_security_target_bits(), achieved); - proof_of_work::Config::from_difficulty(pow_bits) + /// ZK spec produces `SumcheckMode::ZeroKnowledge { mask_length: 3 }` — Lemma 6.4. + #[test] + fn zk_mode_has_three_mask_coefficients( + spec in arb_zk_spec(80..=128), + ctx in arb_round_ctx(), + ) { + let zk = build_minimal_round_mode(&spec); + let config = solve(&ctx, &zk); + match config.mode { + sumcheck::SumcheckMode::ZeroKnowledge { mask_length } => { + prop_assert_eq!(mask_length, 3); + } + sumcheck::SumcheckMode::Standard => prop_assert!(false, "expected ZK"), + } + } + + /// `num_rounds = ctx.folding_factor`. + #[test] + fn num_rounds_matches_folding_factor( + spec in prop_oneof![ + arb_standard_johnson_spec(80..=128), + arb_zk_spec(80..=128), + ], + ctx in arb_round_ctx(), + ) { + let zk = build_minimal_round_mode(&spec); + let config = solve(&ctx, &zk); + prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); + } + + /// `masks_required` = 0 in Standard, = `ctx.folding_factor` in ZK. + #[test] + fn masks_required_matches_mode( + spec in prop_oneof![ + arb_standard_johnson_spec(80..=128), + arb_zk_spec(80..=128), + ], + ctx in arb_round_ctx(), + ) { + let zk = build_minimal_round_mode(&spec); + let required = masks_required(&zk, &ctx); + let expected = if zk.is_zk() { ctx.folding_factor as usize } else { 0 }; + prop_assert_eq!(required, expected); + } + + /// Smoke test: `solve` doesn't panic on assembly. + #[test] + fn solve_assembles_without_panic( + spec in prop_oneof![ + arb_standard_johnson_spec(80..=128), + arb_zk_spec(80..=128), + ], + ctx in arb_round_ctx(), + ) { + let zk = build_minimal_round_mode(&spec); + let config = solve(&ctx, &zk); + prop_assert_eq!(config.initial_size, ctx.vector_size); + } + } } diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs new file mode 100644 index 00000000..3700778e --- /dev/null +++ b/src/protocols/params/test_utils.rs @@ -0,0 +1,96 @@ +//! Shared test fixtures for `params/` solvers. + +use std::{marker::PhantomData, ops::RangeInclusive}; + +use proptest::prelude::*; + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::Field64, + }, + hash, + protocols::params::{ + irs_commit as params_irs, + plan::RoundModeParams, + spec::{LogInvRate, MaskCodeMessageLen, Mode, RoundContext, SecuritySpec}, + }, +}; + +pub type TestField = Field64; +pub type TestEmbedding = Identity; + +/// Build a deterministic Standard-Johnson `SecuritySpec` for the given +/// embedding. Useful for one-shot smoke tests over non-identity embeddings. +pub fn deterministic_standard_spec() -> SecuritySpec { + SecuritySpec { + mode: Mode::Standard { + unique_decoding: false, + }, + target_security_bits: 80, + max_pow_bits: None, + hash_id: hash::BLAKE3, + _embedding: PhantomData, + } +} + +/// `SecuritySpec` strategy with `max_pow_bits ∈ {None, Some(0)}` (PoW deferred). +pub fn arb_spec( + mode: Mode, + target_range: RangeInclusive, +) -> impl Strategy> { + (target_range, prop_oneof![Just(None), Just(Some(0u32))]).prop_map(move |(target, max_pow)| { + SecuritySpec { + mode, + target_security_bits: target, + max_pow_bits: max_pow, + hash_id: hash::BLAKE3, + _embedding: PhantomData, + } + }) +} + +pub fn arb_zk_spec( + target_range: RangeInclusive, +) -> impl Strategy> { + arb_spec(Mode::ZeroKnowledge, target_range) +} + +pub fn arb_standard_johnson_spec( + target_range: RangeInclusive, +) -> impl Strategy> { + arb_spec( + Mode::Standard { + unique_decoding: false, + }, + target_range, + ) +} + +pub fn arb_round_ctx() -> impl Strategy { + (0usize..=3, 4u32..=8, 1u32..=4, 1u32..=3).prop_map( + |(round_index, log_size, log_inv_rate, folding_factor)| RoundContext { + round_index, + vector_size: 1usize << log_size, + log_inv_rate, + folding_factor, + prev_round_in_domain_samples: 0, + prev_round_query_error: 0.0, + }, + ) +} + +/// Minimal `RoundModeParams` matching `spec.mode`: +/// - `Mode::Standard` → `RoundModeParams::Standard`. +/// - `Mode::ZeroKnowledge` → `ZeroKnowledge { c_zk, l_zk }` with ℓ_zk = 2 and +/// C_zk at rate 1/2. +pub fn build_minimal_round_mode( + spec: &SecuritySpec, +) -> RoundModeParams { + if !matches!(spec.mode, Mode::ZeroKnowledge) { + return RoundModeParams::Standard; + } + let l_zk = MaskCodeMessageLen::new(2); + let c_zk = params_irs::solve_mask_code(spec, l_zk, 0, LogInvRate::new(1), 2); + RoundModeParams::ZeroKnowledge { c_zk, l_zk } +} From 359bfbc9f5455560a658080a65ab808f00a41c03 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 13 May 2026 04:05:47 +0530 Subject: [PATCH 07/31] clean up --- src/protocols/params/bounds.rs | 98 --------------------- src/protocols/params/code_switch.rs | 114 +++++++++++++++++++------ src/protocols/params/irs_commit.rs | 67 +++------------ src/protocols/params/mask_proximity.rs | 47 +++++++++- src/protocols/params/test_utils.rs | 20 +++-- 5 files changed, 154 insertions(+), 192 deletions(-) diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index 1a6a5d34..2ce14647 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -83,101 +83,3 @@ pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { pub fn pow_bits_to_close_gap(target_security_bits: f64, achieved_security_bits: f64) -> Bits { Bits::new((target_security_bits - achieved_security_bits).max(0.0)) } - -#[cfg(test)] -mod tests { - use super::*; - - /// Within 1e-9 of expected — formulas use `log2`, so floats are inexact. - fn approx_eq(a: f64, b: f64) { - assert!( - (a - b).abs() < 1e-9, - "expected ≈ {b}, got {a} (diff {})", - (a - b).abs() - ); - } - - #[test] - fn list_size_unique_decoding_is_one() { - // Unique decoding → |Λ| = 1 → log2 = 0. - approx_eq(list_size_log2(1.0, 0.0), 0.0); - approx_eq(list_size_log2(5.0, 0.0), 0.0); - } - - #[test] - fn list_size_johnson_grows_as_slack_shrinks() { - // Same rate, smaller slack → larger list. - let big_slack = list_size_log2(2.0, 0.5); - let small_slack = list_size_log2(2.0, 0.05); - assert!(small_slack > big_slack, "{small_slack} > {big_slack}"); - } - - #[test] - fn list_size_johnson_grows_as_rate_drops() { - // Lower rate (larger log_inv_rate) → larger list. - let high_rate = list_size_log2(1.0, 0.1); - let low_rate = list_size_log2(4.0, 0.1); - assert!(low_rate > high_rate, "{low_rate} > {high_rate}"); - } - - fn code(log_inv_rate: f64, johnson_slack: f64, message_length: usize) -> CodeParams { - CodeParams { - log_inv_rate, - johnson_slack, - message_length, - field_bits: 64.0, - } - } - - #[test] - fn eps_mca_grows_with_message_length() { - // Longer message → larger ε (less negative log) → less security. - let short = eps_mca_log2(&code(2.0, 0.1, 16)); - let long = eps_mca_log2(&code(2.0, 0.1, 1024)); - assert!(long > short, "{long} > {short}"); - } - - #[test] - fn eps_mca_grows_with_log_inv_rate() { - // Lower rate (larger log_inv_rate) → larger ε. - let high_rate = eps_mca_log2(&code(1.0, 0.1, 128)); - let low_rate = eps_mca_log2(&code(4.0, 0.1, 128)); - assert!(low_rate > high_rate, "{low_rate} > {high_rate}"); - } - - #[test] - fn one_minus_distance_unique_is_midpoint() { - // Unique decoding: 1 - δ = (1 + ρ) / 2. - // log_inv_rate = 1 → ρ = 0.5 → (1 + 0.5)/2 = 0.75. - approx_eq(one_minus_distance_log2(1.0, 0.0), 0.75_f64.log2()); - } - - #[test] - fn one_minus_distance_johnson_more_negative_than_unique() { - // Johnson allows larger δ than unique decoding → smaller (1-δ) → more - // negative log. - let unique = one_minus_distance_log2(2.0, 0.0); - let johnson = one_minus_distance_log2(2.0, 0.1); - assert!(johnson < unique, "{johnson} < {unique}"); - } - - #[test] - fn ood_per_sample_exact() { - // (k-1)/|F| with k=2, |F|=2^64 → log2 = -64. - approx_eq(ood_per_sample_log2(2, 64.0), -64.0); - // k=9, |F|=2^64 → (8)/2^64 → log2 = 3 - 64 = -61. - approx_eq(ood_per_sample_log2(9, 64.0), -61.0); - } - - #[test] - fn pow_bits_zero_when_achieved_meets_target() { - assert!(pow_bits_to_close_gap(80.0, 100.0).is_zero()); - assert!(pow_bits_to_close_gap(80.0, 80.0).is_zero()); - } - - #[test] - fn pow_bits_fills_gap_to_target() { - let bits = pow_bits_to_close_gap(80.0, 50.0); - approx_eq(f64::from(bits), 30.0); - } -} diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 0aafc182..d696f55f 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -118,7 +118,7 @@ mod tests { spec::{LogInvRate, OodSampleBudget, RoundContext}, test_utils::{ arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, - deterministic_standard_spec, TestEmbedding, + deterministic_spec, TestEmbedding, TestExtensionField, TestNonIdentityEmbedding, }, }; @@ -265,49 +265,107 @@ mod tests { } } - /// Smoke test: the generics compile and `solve` works end-to-end with a - /// non-identity embedding (`M::Source ≠ M::Target`). - #[test] - fn solve_works_with_basefield_embedding() { - use crate::algebra::{embedding::Basefield, fields::Field64_2}; - type NonIdM = Basefield; - - let spec_source: SecuritySpec = deterministic_standard_spec(); - let spec_target: SecuritySpec> = deterministic_standard_spec(); - + /// Build the canonical `(source_ctx, target_ctx)` pair used by both + /// non-identity smoke tests. Single source of truth so Standard and ZK + /// exercise the same problem shape, only the mode differs. + fn non_identity_smoke_ctxs() -> (RoundContext, RoundContext) { let source_ctx = RoundContext { round_index: 0, - vector_size: 16, + vector_size: 64, log_inv_rate: 1, folding_factor: 2, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0, }; - let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); - let target_ctx = RoundContext { round_index: 1, - vector_size: source.message_length(), + vector_size: source_ctx.vector_size / (1 << source_ctx.folding_factor), log_inv_rate: source_ctx.log_inv_rate + source_ctx.folding_factor - 1, folding_factor: source_ctx.folding_factor, - prev_round_in_domain_samples: source.in_domain_samples, + prev_round_in_domain_samples: 0, prev_round_query_error: 0.0, }; + (source_ctx, target_ctx) + } - let mut target = params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(0)); - let mut t_ood = compute_t_ood(&spec_source, &source, target.list_size(), None); - for _ in 0..8 { - let new_target = - params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(t_ood)); - if new_target.codeword_length == target.codeword_length { - target = new_target; - break; - } - target = new_target; - t_ood = compute_t_ood(&spec_source, &source, target.list_size(), None); - } + /// Standard-mode smoke test with `M::Source ≠ M::Target`. + #[test] + fn solve_works_with_basefield_embedding_standard() { + let spec_source: SecuritySpec = + deterministic_spec(Mode::Standard { + unique_decoding: false, + }); + let spec_target: SecuritySpec> = + deterministic_spec(Mode::Standard { + unique_decoding: false, + }); + let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); + + let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); + // Standard target: codeword_length is independent of t_ood (mask = 0), + // so one solve is sufficient. + let target = params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(0)); + let t_ood = compute_t_ood(&spec_source, &source, target.list_size(), None); let config = solve(source, target, t_ood, &RoundModeParams::Standard); assert!(matches!(config.mode, code_switch::Mode::Standard)); } + + /// ZK-mode smoke test with `M::Source ≠ M::Target`. Exercises the + /// `RoundModeParams::ZeroKnowledge { c_zk, l_zk }` type path with + /// `c_zk: irs_commit::Config>`. + #[test] + fn solve_works_with_basefield_embedding_zk() { + let spec_source: SecuritySpec = + deterministic_spec(Mode::ZeroKnowledge); + let spec_target: SecuritySpec> = + deterministic_spec(Mode::ZeroKnowledge); + let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); + + let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); + // Bootstrap C_zk's list_size with a placeholder ℓ_zk. + let c_zk_placeholder = params_irs::solve_mask_code( + &spec_target, + compute_l_zk(&source, 1), + source.mask_length(), + LogInvRate::new(1), + 2, + ); + let c_zk_list_size = c_zk_placeholder.list_size(); + // Two-solve target rebuild: placeholder t_ood, then final. + let target_placeholder = + params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(0)); + let t_ood = compute_t_ood( + &spec_source, + &source, + target_placeholder.list_size(), + Some(c_zk_list_size), + ); + let target = params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(t_ood)); + let t_ood_check = compute_t_ood( + &spec_source, + &source, + target.list_size(), + Some(c_zk_list_size), + ); + assert_eq!( + t_ood, t_ood_check, + "smoke-test params should converge in one iteration", + ); + + let l_zk = compute_l_zk(&source, t_ood); + let c_zk = params_irs::solve_mask_code( + &spec_target, + l_zk, + source.mask_length(), + LogInvRate::new(1), + 2, + ); + let zk = RoundModeParams::ZeroKnowledge { c_zk, l_zk }; + let config = solve(source, target, t_ood, &zk); + assert!(matches!( + config.mode, + code_switch::Mode::ZeroKnowledge { .. } + )); + } } diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index edbf26fc..fd7be808 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -102,90 +102,49 @@ pub fn solve_mask_code( #[cfg(test)] mod tests { - use std::marker::PhantomData; - use ark_std::rand::{rngs::StdRng, SeedableRng}; use proptest::prelude::*; use super::*; use crate::{ algebra::random_vector, - hash, - protocols::params::test_utils::{arb_round_ctx, arb_spec, arb_zk_spec, TestEmbedding}, + protocols::params::test_utils::{ + arb_round_ctx, arb_spec, arb_zk_spec, deterministic_spec, TestEmbedding, + }, transcript::{DomainSeparator, ProverState, VerifierState}, }; type M = TestEmbedding; type F = ::Source; - fn minimal_zk_spec() -> SecuritySpec { - SecuritySpec { - mode: Mode::ZeroKnowledge, - target_security_bits: 80, - max_pow_bits: None, - hash_id: hash::BLAKE3, - _embedding: PhantomData, - } - } - - fn minimal_standard_spec() -> SecuritySpec { - SecuritySpec { - mode: Mode::Standard { - unique_decoding: false, - }, - target_security_bits: 80, - max_pow_bits: None, - hash_id: hash::BLAKE3, - _embedding: PhantomData, - } - } - #[test] #[should_panic(expected = "C_zk only exists in ZK mode")] fn solve_mask_code_rejects_standard_spec() { - let _ = solve_mask_code( - &minimal_standard_spec(), - MaskCodeMessageLen::new(2), - 0, - LogInvRate::new(1), - 2, - ); + let spec: SecuritySpec = deterministic_spec(Mode::Standard { + unique_decoding: false, + }); + let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 2); } #[test] #[should_panic(expected = "must be a power of 2")] fn solve_mask_code_rejects_non_pow2_l_zk() { - let _ = solve_mask_code( - &minimal_zk_spec(), - MaskCodeMessageLen::new(3), - 0, - LogInvRate::new(1), - 2, - ); + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(3), 0, LogInvRate::new(1), 2); } #[test] #[should_panic(expected = "Theorem 9.6")] fn solve_mask_code_rejects_l_zk_below_source_mask_length() { - let _ = solve_mask_code( - &minimal_zk_spec(), - MaskCodeMessageLen::new(2), - 4, - LogInvRate::new(1), - 2, - ); + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(2), 4, LogInvRate::new(1), 2); } #[test] #[should_panic(expected = "must be even")] fn solve_mask_code_rejects_odd_num_vectors() { - let _ = solve_mask_code( - &minimal_zk_spec(), - MaskCodeMessageLen::new(2), - 0, - LogInvRate::new(1), - 3, - ); + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 3); } fn arb_zk_spec_default() -> impl Strategy> { diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index 198dd682..a166b9e1 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -29,10 +29,17 @@ mod tests { use proptest::prelude::*; use super::*; - use crate::protocols::params::{ - irs_commit as params_irs, - spec::{LogInvRate, MaskCodeMessageLen}, - test_utils::arb_zk_spec, + use crate::{ + algebra::fields::Field64, + hash, + protocols::{ + irs_commit::IrsMode, + params::{ + irs_commit as params_irs, + spec::{LogInvRate, MaskCodeMessageLen, Mode}, + test_utils::{arb_zk_spec, deterministic_spec, TestEmbedding}, + }, + }, }; proptest! { @@ -59,4 +66,36 @@ mod tests { prop_assert_eq!(config.c_zk_commit.interleaving_depth, 1); } } + + /// `mask_proximity::Config::new` rejects `c_zk.num_vectors != 2 * num_masks`. + #[test] + #[should_panic(expected = "c_zk.num_vectors must be 2 * num_masks")] + fn solve_rejects_mismatched_num_vectors() { + let spec = deterministic_spec::(Mode::ZeroKnowledge); + // c_zk built for 2 masks (num_vectors = 4); caller passes num_masks = 3. + let c_zk = params_irs::solve_mask_code( + &spec, + MaskCodeMessageLen::new(2), + 0, + LogInvRate::new(1), + 4, + ); + let _ = solve(c_zk, 3); + } + + #[test] + #[should_panic(expected = "interleaving_depth = 1")] + fn solve_rejects_non_unit_interleaving() { + let c_zk = crate::protocols::irs_commit::Config::>::new( + 80.0, + false, + hash::BLAKE3, + 2, // num_vectors = 2 (= 2 * num_masks with num_masks=1) + 8, // vector_size + 2, // interleaving_depth ≠ 1 + 0.5, + IrsMode::Standard, + ); + let _ = solve(c_zk, 1); + } } diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 3700778e..918b01d0 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -6,8 +6,8 @@ use proptest::prelude::*; use crate::{ algebra::{ - embedding::{Embedding, Identity}, - fields::Field64, + embedding::{Basefield, Embedding, Identity}, + fields::{Field64, Field64_2}, }, hash, protocols::params::{ @@ -20,13 +20,17 @@ use crate::{ pub type TestField = Field64; pub type TestEmbedding = Identity; -/// Build a deterministic Standard-Johnson `SecuritySpec` for the given -/// embedding. Useful for one-shot smoke tests over non-identity embeddings. -pub fn deterministic_standard_spec() -> SecuritySpec { +/// Extension field used by non-identity smoke tests. +pub type TestExtensionField = Field64_2; + +/// Non-identity embedding: `Source = Field64`, `Target = Field64_2`. +pub type TestNonIdentityEmbedding = Basefield; + +/// Build a deterministic `SecuritySpec` for the given embedding and mode. +/// Useful for one-shot smoke / negative tests. +pub fn deterministic_spec(mode: Mode) -> SecuritySpec { SecuritySpec { - mode: Mode::Standard { - unique_decoding: false, - }, + mode, target_security_bits: 80, max_pow_bits: None, hash_id: hash::BLAKE3, From 39339ccc0b91a3d000b573e06e28befdd849bf6e Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 14 May 2026 12:57:06 +0530 Subject: [PATCH 08/31] feat : added PoW for all the sub protocols --- .../protocols/params/sumcheck.txt | 1 + src/protocols/basecase.rs | 249 ++++---- src/protocols/code_switch.rs | 44 +- src/protocols/mask_proximity.rs | 31 +- src/protocols/params/basecase.rs | 212 ++++++ src/protocols/params/bounds.rs | 20 +- src/protocols/params/code_switch.rs | 241 ++++--- src/protocols/params/irs_commit.rs | 40 +- src/protocols/params/mask_proximity.rs | 91 ++- src/protocols/params/mod.rs | 4 +- src/protocols/params/plan.rs | 146 ++++- src/protocols/params/planner.rs | 601 ++++++++++++++++++ src/protocols/params/spec.rs | 106 +-- src/protocols/params/sumcheck.rs | 158 +++-- src/protocols/params/test_utils.rs | 48 +- src/protocols/proof_of_work.rs | 15 + 16 files changed, 1552 insertions(+), 455 deletions(-) create mode 100644 src/protocols/params/basecase.rs create mode 100644 src/protocols/params/planner.rs diff --git a/proptest-regressions/protocols/params/sumcheck.txt b/proptest-regressions/protocols/params/sumcheck.txt index 8fca101b..d6f6e6ed 100644 --- a/proptest-regressions/protocols/params/sumcheck.txt +++ b/proptest-regressions/protocols/params/sumcheck.txt @@ -9,3 +9,4 @@ cc e8ab6549772cf6bf4c3af116ebcba3dbf295ffbe2aee4a94be7df4b9f45d61ec # shrinks to cc 8c4300cc375640956f81e9da5aef9ea11ef476ddc4dd253dc560afa07609262d # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 98, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } cc 8ea40f13c63b4c0021386369ce698a5d9289381a39dc85db43d2d69b9b4877bb # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 88, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 3, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } cc f1dca600886474c74d857c547baea0c2b4faf45b2946036f21a008106396eb1c # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 256, log_inv_rate: 3, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc 36d0f5929e8099fa8644b0511229cf11634e5a7a66d99c06099c304f5f7a8c6e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 47, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 128, log_inv_rate: 4, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 1d40c242..589caba3 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -1,8 +1,4 @@ -//! Base Case Linear Opening Protocol -//! -//! It support honest verifier zero-knowledge (HVZK), but is not succinct. -//! -//! § 7. +//! Non-succinct linear opening (Construction 7.2, p.43). HVZK in ZK mode. use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; @@ -15,7 +11,7 @@ use crate::{ univariate_evaluate, }, hash::Hash, - protocols::{irs_commit, sumcheck}, + protocols::{irs_commit, proof_of_work, sumcheck}, transcript::{ codecs::U64, Codec, DuplexSpongeInterface, ProverMessage, ProverState, VerifierMessage, VerifierState, @@ -24,22 +20,27 @@ use crate::{ verify, }; -/// Output from the base case protocol (shared by prover and verifier). #[must_use] pub struct Opening { pub evaluation_points: Vec, pub linear_form_evaluation: F, } +/// Standard / ZeroKnowledge selector for basecase. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum Mode { + Standard, + ZeroKnowledge, +} + #[must_use] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { pub commit: irs_commit::Config>, pub sumcheck: sumcheck::Config, - - /// Whether to mask the vectors, which adds HVZK. - pub masked: bool, + pub mode: Mode, + pub pow: proof_of_work::Config, } impl Config { @@ -47,6 +48,10 @@ impl Config { self.sumcheck.initial_size } + pub const fn is_zk(&self) -> bool { + matches!(self.mode, Mode::ZeroKnowledge) + } + pub fn prove( &self, prover_state: &mut ProverState, @@ -77,63 +82,22 @@ impl Config { }; } - // Even more trivial non-zk protocol: send f and r directly. - if !self.masked { - prover_state.prover_messages(&vector); - prover_state.prover_messages(&witness.masks); - let _ = self.commit.open(prover_state, &[witness]); - let point = self - .sumcheck - .prove(prover_state, &mut vector, &mut covector, &mut sum, &[]) - .round_challenges; - assert!(!vector[0].is_zero(), "Proof failed"); - return Opening { - evaluation_points: point, - linear_form_evaluation: covector[0], - }; - } - - // Create masking vector. - let mask = random_vector(prover_state.rng(), vector.len()); - - // Commit to the masking vector. - let mask_witness = self.commit.commit(prover_state, &[&mask]); - - // Compute and send linear form of mask (μ' in paper). - let mask_sum = dot(&mask, &covector); - prover_state.prover_message(&mask_sum); - - // RLC the mask with the vector - let mask_rlc = prover_state.verifier_message::(); - assert!(!mask_rlc.is_zero(), "Proof failed"); - let mut masked_vector = scalar_mul_add_new(&mask, mask_rlc, &vector); - prover_state.prover_messages(&masked_vector); - - // Send combined IRS randomness. (r^* in paper) - let masked_masks = scalar_mul_add_new(&mask_witness.masks, mask_rlc, &witness.masks); - prover_state.prover_messages(&masked_masks); + let blinding_witness = + self.maybe_blind_prove(prover_state, &mut vector, witness, &covector, &mut sum); - // Open the commitment and mask simultaneously. - let _ = self.commit.open(prover_state, &[&mask_witness, witness]); + let witnesses: Vec<&irs_commit::Witness> = blinding_witness + .as_ref() + .map_or_else(|| vec![witness], |b| vec![b, witness]); + let _ = self.commit.open(prover_state, &witnesses); - // Run sumcheck to reduce linear form claim - let mut masked_sum = mask_sum + mask_rlc * sum; let point = self .sumcheck - .prove( - prover_state, - &mut masked_vector, - &mut covector, - &mut masked_sum, - &[], - ) + .prove(prover_state, &mut vector, &mut covector, &mut sum, &[]) .round_challenges; - // If the MLE of `masked_vector` evaluates to zero, the verifier can not proceed. - // Basically the sumcheck equation has degenerated to 0 * l(r) = 0, which provides - // no constraints on l(r) that the verifier can return. - // This event is cryptographically unlikely as `F` is challenge sized. - assert!(!masked_vector[0].is_zero(), "Proof failed"); + // Negligible event over a challenge-sized field; without it the verifier + // cannot derive `l(r) = sum / vector_mle(r)`. + assert!(!vector[0].is_zero(), "Proof failed"); Opening { evaluation_points: point, @@ -141,6 +105,60 @@ impl Config { } } + /// ZK: commits a blinding codeword, runs the RLC, mutates `vector`/`sum` to + /// the combined values, sends them cleartext. Standard: sends `vector` and + /// `witness.masks` cleartext (no ZK). + fn maybe_blind_prove( + &self, + prover_state: &mut ProverState, + vector: &mut Vec, + witness: &irs_commit::Witness, + covector: &[F], + sum: &mut F, + ) -> Option> + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + Standard: Distribution, + { + match self.mode { + Mode::Standard => { + prover_state.prover_messages(vector); + prover_state.prover_messages(&witness.masks); + None + } + Mode::ZeroKnowledge => { + let blinding_vector = random_vector(prover_state.rng(), vector.len()); + let blinding_witness = self.commit.commit(prover_state, &[&blinding_vector]); + let blinding_inner_product = dot(&blinding_vector, covector); + prover_state.prover_message(&blinding_inner_product); + + // Grind the Theorem 7.1 γ-combination gap before γ is sampled. + self.pow.prove(prover_state); + + let combination_randomness = prover_state.verifier_message::(); + assert!(!combination_randomness.is_zero(), "Proof failed"); + + *vector = scalar_mul_add_new(&blinding_vector, combination_randomness, vector); + prover_state.prover_messages(vector); + + let combined_irs_randomness = scalar_mul_add_new( + &blinding_witness.masks, + combination_randomness, + &witness.masks, + ); + prover_state.prover_messages(&combined_irs_randomness); + + *sum = blinding_inner_product + combination_randomness * *sum; + Some(blinding_witness) + } + } + } + pub fn verify( &self, verifier_state: &mut VerifierState, @@ -166,72 +184,71 @@ impl Config { }); } - // Unmasked protocol - if !self.masked { - let vector = verifier_state.prover_messages_vec(self.commit.vector_size)?; - let masks = verifier_state - .prover_messages_vec(self.commit.mask_length() * self.commit.num_messages())?; - let evals = self.commit.verify(verifier_state, &[commitment])?; - let point = self - .sumcheck - .verify(verifier_state, &mut sum)? - .round_challenges; - - for (&point, value) in zip_strict(&evals.points, evals.values(&[F::ONE])) { - // We expected `f(x) + x^l · g(x)` where l = deg(f) + 1, f is the message and g the mask. - let expected = univariate_evaluate(&vector, point) - + point.pow([self.commit.message_length() as u64]) - * univariate_evaluate(&masks, point); - verify!(value == expected); - } - let mle = multilinear_extend(&vector, &point); - verify!(!mle.is_zero()); - let linear_mle = sum / mle; - return Ok(Opening { - evaluation_points: point, - linear_form_evaluation: linear_mle, - }); - } + let blind = self.maybe_receive_blind(verifier_state, &mut sum)?; - let mask_commitment = self.commit.receive_commitment(verifier_state)?; - let mask_sum: F = verifier_state.prover_message()?; - let mask_rlc: F = verifier_state.verifier_message(); - verify!(!mask_rlc.is_zero()); - let masked_vector: Vec = verifier_state.prover_messages_vec(self.commit.vector_size)?; - let masked_masks: Vec = verifier_state.prover_messages_vec(self.commit.mask_length())?; + let vector = verifier_state.prover_messages_vec(self.commit.vector_size)?; + let irs_randomness = verifier_state + .prover_messages_vec(self.commit.mask_length() * self.commit.num_messages())?; - // Open the commitment and mask simultaneously. - let evals = self - .commit - .verify(verifier_state, &[&mask_commitment, commitment])?; + let (commitments, weights): (Vec<&irs_commit::Commitment>, Vec) = match &blind { + Some((b, gamma)) => (vec![b, commitment], vec![F::ONE, *gamma]), + None => (vec![commitment], vec![F::ONE]), + }; + let evals = self.commit.verify(verifier_state, &commitments)?; - // Spot check evaluations. - for (&point, value) in zip_strict(&evals.points, evals.values(&[F::ONE, mask_rlc])) { - // We expected `f(x) + x^l · g(x)` where l = deg(f) + 1, f is the message and g the mask. - let expected = univariate_evaluate(&masked_vector, point) + // Spot-check: Enc_C(vector, irs_randomness)(x) = Σ weights · opened_row(x). + for (&point, value) in zip_strict(&evals.points, evals.values(&weights)) { + let expected = univariate_evaluate(&vector, point) + point.pow([self.commit.message_length() as u64]) - * univariate_evaluate(&masked_masks, point); + * univariate_evaluate(&irs_randomness, point); verify!(value == expected); } - // Sumcheck on masked inner product - let mut masked_sum = mask_sum + mask_rlc * sum; let point = self .sumcheck - .verify(verifier_state, &mut masked_sum)? + .verify(verifier_state, &mut sum)? .round_challenges; - // Compute implied MLE of the linear form - // f*(r) · l(r) = sum => l(r) = sum / f*(r) - let masked_mle = multilinear_extend(&masked_vector, &point); - verify!(!masked_mle.is_zero()); - let linear_mle = masked_sum / masked_mle; + // l(r) = sum / vector_mle(r), where l is the implicit linear form. + let mle = multilinear_extend(&vector, &point); + verify!(!mle.is_zero()); + let linear_mle = sum / mle; Ok(Opening { evaluation_points: point, linear_form_evaluation: linear_mle, }) } + + /// ZK: reads the blinding commitment + μ' + γ, mutates `sum` to the + /// combined value, returns `(commitment, γ)`. Standard: no-op. + fn maybe_receive_blind( + &self, + verifier_state: &mut VerifierState, + sum: &mut F, + ) -> VerificationResult> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + match self.mode { + Mode::Standard => Ok(None), + Mode::ZeroKnowledge => { + let blinding_commitment = self.commit.receive_commitment(verifier_state)?; + let blinding_inner_product: F = verifier_state.prover_message()?; + // Grind the Theorem 7.1 γ-combination gap before γ is sampled. + self.pow.verify(verifier_state)?; + let combination_randomness: F = verifier_state.verifier_message(); + verify!(!combination_randomness.is_zero()); + *sum = blinding_inner_product + combination_randomness * *sum; + Ok(Some((blinding_commitment, combination_randomness))) + } + } + } } #[cfg(test)] @@ -248,7 +265,7 @@ mod tests { pub fn arbitrary(size: usize, mask_length: usize) -> impl Strategy { let commit = irs_commit::Config::arbitrary(Identity::::new(), 1, size, mask_length, 1); - (commit, bool::weighted(0.8)).prop_map(move |(commit, masked)| Self { + (commit, bool::weighted(0.8)).prop_map(move |(commit, is_zk)| Self { commit, sumcheck: sumcheck::Config::new( size, @@ -256,7 +273,12 @@ mod tests { size.next_power_of_two().trailing_zeros() as usize, sumcheck::SumcheckMode::Standard, ), - masked, + mode: if is_zk { + Mode::ZeroKnowledge + } else { + Mode::Standard + }, + pow: proof_of_work::Config::none(), }) } } @@ -267,7 +289,6 @@ mod tests { F: Field + Codec, Standard: Distribution, { - // Pseudo-random Instance let instance = U64(seed); let ds = DomainSeparator::protocol(config) .session(&format!("Test at {}:{}", file!(), line!())) @@ -277,7 +298,6 @@ mod tests { let covector = random_vector(&mut rng, config.size()); let sum = dot(&vector, &covector); - // Prover let mut prover_state = ProverState::new_std(&ds); let witness = config.commit.commit(&mut prover_state, &[&vector]); let prover_result = config.prove( @@ -293,7 +313,6 @@ mod tests { ); let proof = prover_state.proof(); - // Verifier let mut verifier_state = VerifierState::new_std(&ds, &proof); let commitment = config .commit diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index ab5b03ab..3843b6be 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -21,10 +21,11 @@ use crate::{ protocols::{ geometric_challenge::geometric_challenge, irs_commit::{Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness}, + proof_of_work, }, transcript::{ - Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, - VerifierMessage, VerifierState, + codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, + VerificationResult, VerifierMessage, VerifierState, }, verify, }; @@ -45,6 +46,7 @@ pub struct Config { pub target: IrsConfig>, pub mode: Mode, pub out_domain_samples: usize, + pub pow: proof_of_work::Config, } /// Prover output from the code-switch. @@ -65,6 +67,7 @@ impl Config { target_config: IrsConfig>, out_domain_samples: usize, mode: Mode, + pow: proof_of_work::Config, ) -> Self { assert_eq!( source_config.num_vectors, 1, @@ -74,6 +77,12 @@ impl Config { target_config.num_vectors, 1, "code-switch requires a single target vector" ); + // Construction 9.7 needs at least one OOD challenge; unique-decoding + // Standard mode (`t_ood = 0`) is incompatible with code-switch. + assert!( + out_domain_samples > 0, + "code-switch requires t_ood ≥ 1 (Construction 9.7)", + ); // Target encodes one polynomial of length ℓ = source.message_length() // under C' = D^{ι_t}. The IRS splits the input of length ℓ into ι_t // parallel slices of length ℓ/ι_t, each encoded under D. @@ -124,6 +133,7 @@ impl Config { target: target_config, mode, out_domain_samples, + pow, } } @@ -183,6 +193,8 @@ impl Config { Standard: Distribution, M::Target: Codec<[H::U]>, u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { assert_eq!(message.len(), self.source.message_length()); @@ -199,6 +211,9 @@ impl Config { // Step 1: g := Enc_{C'}(f, r') — Construction 9.7 Step 1, p.55 let target_witness = self.target.commit(prover_state, &[&message]); + // Grind Lemma 9.9 OOD gap before α is sampled. + self.pow.prove(prover_state); + // Step 2-3: OOD challenge + answers — Construction 9.7 Steps 2-3, p.55 let ood_points: Vec = prover_state.verifier_message_vec(self.out_domain_samples); self.maybe_send_ood_answers(prover_state, &message, mask, &ood_points); @@ -331,6 +346,8 @@ impl Config { Standard: Distribution, M::Target: Codec<[H::U]>, u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { verify!(1 << folding_randomness.len() == self.source.interleaving_depth); @@ -341,6 +358,9 @@ impl Config { // Mask oracle is committed in the shared mask tree by the orchestrator. let target_commitment = self.target.receive_commitment(verifier_state)?; + // Grind Lemma 9.9 OOD gap before α is sampled. + self.pow.verify(verifier_state)?; + // Step 2-3: OOD — Construction 9.7 Steps 2-3, p.55 // In ZK mode, ood_answers = f(α) + α^ℓ · (r,s)(α) where (r,s) is // the mask oracle message committed in the shared tree. @@ -413,7 +433,7 @@ mod tests { select(valid_sizes), 0_usize..=3, // src_mask_len (source IRS randomness, post-fold) bool::ANY, // zk - 0_usize..=5, // ood (= code-switch t_ood) + 1_usize..=5, // ood (= code-switch t_ood; ≥ 1 per Construction 9.7) 0_usize..=5, // fresh_s_len (≥ ood for assumption (c)) select(vec![1_usize, 2, 4]), // ι_s (source interleaving) 0_usize..=10, // target.in_domain_samples (t'_in) @@ -473,7 +493,13 @@ mod tests { } else { Mode::Standard }; - Self::new(source.clone(), target, ood, mode) + Self::new( + source.clone(), + target, + ood, + mode, + proof_of_work::Config::none(), + ) }) }) }) @@ -780,12 +806,10 @@ mod tests { #[test] fn test_tampered_ood() { crate::tests::init(); - let configs = Config::arbitrary(Identity::::new()).prop_filter( - "non-ZK with ood > 0", - |config| { - !config.is_zk() && config.source.mask_length() == 0 && config.out_domain_samples > 0 - }, - ); + let configs = Config::arbitrary(Identity::::new()) + .prop_filter("non-ZK", |config| { + !config.is_zk() && config.source.mask_length() == 0 + }); proptest!(|(seed: u64, config in configs)| { test_tampered_ood_config(seed, &config); }); diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 4a733e92..c5d90b67 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -47,12 +47,13 @@ use serde::{Deserialize, Serialize}; use crate::{ algebra::{embedding::Identity, random_vector, scalar_mul_add_new, univariate_evaluate}, hash::Hash, - protocols::irs_commit::{ - Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness, + protocols::{ + irs_commit::{Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness}, + proof_of_work, }, transcript::{ - Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, - VerifierMessage, VerifierState, + codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, + VerificationResult, VerifierMessage, VerifierState, }, utils::zip_strict, verify, @@ -67,6 +68,7 @@ use crate::{ pub struct Config { pub c_zk_commit: IrsConfig>, pub num_masks: usize, + pub pow: proof_of_work::Config, } /// Prover output from the commit phase. @@ -80,7 +82,11 @@ pub struct Witness { pub type Commitment = IrsCommitment; impl Config { - pub fn new(c_zk_commit: IrsConfig>, num_masks: usize) -> Self { + pub fn new( + c_zk_commit: IrsConfig>, + num_masks: usize, + pow: proof_of_work::Config, + ) -> Self { assert_eq!( c_zk_commit.num_vectors, 2 * num_masks, @@ -93,6 +99,7 @@ impl Config { Self { c_zk_commit, num_masks, + pow, } } @@ -162,11 +169,16 @@ impl Config { R: RngCore + CryptoRng, Standard: Distribution, u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { assert_eq!(original_msgs.len(), self.num_masks); assert_eq!(witness.fresh_msgs.len(), self.num_masks); + // Grind the Lemma 7.4 γ-combination gap before γ is sampled. + self.pow.prove(prover_state); + // Step 1: receive combination randomness γ let gamma: F = prover_state.verifier_message(); @@ -213,8 +225,13 @@ impl Config { F: Codec<[H::U]>, H: DuplexSpongeInterface, u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { + // Grind the Lemma 7.4 γ-combination gap before γ is sampled. + self.pow.verify(verifier_state)?; + // Step 1: send combination randomness γ let gamma: F = verifier_state.verifier_message(); @@ -313,7 +330,9 @@ mod tests { ); (Just(num_masks), c_zk) }) - .prop_map(|(num_masks, c_zk)| Self::new(c_zk, num_masks)) + .prop_map(|(num_masks, c_zk)| { + Self::new(c_zk, num_masks, proof_of_work::Config::none()) + }) } } diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs new file mode 100644 index 00000000..91fc794e --- /dev/null +++ b/src/protocols/params/basecase.rs @@ -0,0 +1,212 @@ +//! Basecase (Construction 7.2, p.43) parameter selection + γ-combination bound. + +use ark_ff::Field; + +use crate::{ + algebra::{embedding::Identity, fields::FieldWithSize}, + bits::Bits, + protocols::{ + basecase, + irs_commit::Config as IrsConfig, + params::{ + irs_commit as irs_solver, + spec::{Mode as SpecMode, OodSampleBudget, RoundContext, SecuritySpec}, + sumcheck as sumcheck_solver, + }, + proof_of_work, sumcheck, + }, +}; + +/// PoW closes the Theorem 7.1 γ-slot gap to `spec.target_security_bits`; no +/// γ challenge in Standard mode ⇒ `Config::none()`. +pub fn solve( + spec: &SecuritySpec>, + vector_size: usize, + log_inv_rate: u32, +) -> basecase::Config { + assert!(vector_size > 0, "basecase requires vector_size ≥ 1"); + + let ctx = RoundContext { + round_index: 0, + vector_size, + log_inv_rate, + folding_factor: 0, + }; + let commit = irs_solver::solve(spec, &ctx, OodSampleBudget::new(0)); + + let target_bits = Bits::new(f64::from(spec.target_security_bits)); + let sumcheck_pow = proof_of_work::Config::grind_to( + target_bits, + sumcheck_solver::analytic_error_bits(&commit, None), + spec.hash_id, + ); + let sumcheck = sumcheck::Config::new( + vector_size, + sumcheck_pow, + vector_size.next_power_of_two().trailing_zeros() as usize, + sumcheck::SumcheckMode::Standard, + ); + + let mode = match spec.mode { + SpecMode::Standard { .. } => basecase::Mode::Standard, + SpecMode::ZeroKnowledge => basecase::Mode::ZeroKnowledge, + }; + + let pow = match mode { + basecase::Mode::Standard => proof_of_work::Config::none(), + basecase::Mode::ZeroKnowledge => { + proof_of_work::Config::grind_to(target_bits, analytic_error_bits(&commit), spec.hash_id) + } + }; + + basecase::Config { + commit, + sumcheck, + mode, + pow, + } +} + +/// γ-combination soundness (Theorem 7.1, n=0): `log|F| − log|Λ(C^≡2, δ)|`. +pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { + let field_bits = F::field_size_bits(); + let log_list = commit.list_size().log2(); + Bits::new((field_bits - log_list).max(0.0)) +} + +#[cfg(test)] +mod tests { + use ark_std::rand::{rngs::StdRng, SeedableRng}; + use proptest::prelude::*; + + use super::*; + use crate::{ + algebra::{dot, multilinear_extend, random_vector}, + protocols::params::test_utils::{ + arb_standard_johnson_spec, arb_zk_spec, deterministic_spec, TestEmbedding, + }, + transcript::{codecs::U64, DomainSeparator, ProverState, VerifierState}, + }; + + // Keeps `target − error ≤ 60`, the cap `proof_of_work::threshold` enforces. + const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; + + fn arb_dims() -> impl Strategy { + (1u32..=4, 1u32..=3) + } + + proptest! { + #[test] + fn solve_standard_assembles( + spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), + (log_size, log_inv_rate) in arb_dims(), + ) { + let config = solve(&spec, 1usize << log_size, log_inv_rate); + prop_assert!(matches!(config.mode, basecase::Mode::Standard)); + prop_assert_eq!(config.commit.interleaving_depth, 1); + prop_assert_eq!(config.commit.num_vectors, 1); + prop_assert_eq!(config.commit.vector_size, config.sumcheck.initial_size); + } + + #[test] + fn solve_zk_assembles( + spec in arb_zk_spec(TEST_TARGET_RANGE), + (log_size, log_inv_rate) in arb_dims(), + ) { + let config = solve(&spec, 1usize << log_size, log_inv_rate); + prop_assert!(matches!(config.mode, basecase::Mode::ZeroKnowledge)); + prop_assert!(config.commit.mask_length() > 0); + } + + #[test] + fn pow_closes_gap_to_target_zk( + spec in arb_zk_spec(TEST_TARGET_RANGE), + (log_size, log_inv_rate) in arb_dims(), + ) { + let config = solve(&spec, 1usize << log_size, log_inv_rate); + let error = f64::from(analytic_error_bits(&config.commit)); + let pow_bits = f64::from(config.pow.difficulty()); + prop_assert!( + error + pow_bits >= f64::from(spec.target_security_bits) - 1e-3, + "error {} + pow {} < target {}", + error, pow_bits, spec.target_security_bits, + ); + } + + #[test] + fn standard_mode_has_no_pow( + spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), + (log_size, log_inv_rate) in arb_dims(), + ) { + let config = solve(&spec, 1usize << log_size, log_inv_rate); + prop_assert_eq!(config.pow, proof_of_work::Config::none()); + } + } + + fn round_trip(seed: u64, vector_size: usize, zk: bool) { + let spec: SecuritySpec = deterministic_spec(if zk { + SpecMode::ZeroKnowledge + } else { + SpecMode::Standard { + unique_decoding: false, + } + }); + let spec = SecuritySpec { + target_security_bits: 40, + ..spec + }; + let config = solve(&spec, vector_size, 1); + + let mut rng = StdRng::seed_from_u64(seed); + let vector = random_vector::<::Source>( + &mut rng, + vector_size, + ); + let covector = random_vector(&mut rng, vector_size); + let sum = dot(&vector, &covector); + + let instance = U64(seed); + let ds = DomainSeparator::protocol(&config) + .session(&format!("Test at {}:{}", file!(), line!())) + .instance(&instance); + + let mut prover_state = ProverState::new_std(&ds); + let witness = config.commit.commit(&mut prover_state, &[&vector]); + let prover_result = config.prove( + &mut prover_state, + vector.clone(), + &witness, + covector.clone(), + sum, + ); + assert_eq!( + multilinear_extend(&covector, &prover_result.evaluation_points), + prover_result.linear_form_evaluation, + ); + let proof = prover_state.proof(); + + let mut verifier_state = VerifierState::new_std(&ds, &proof); + let commitment = config + .commit + .receive_commitment(&mut verifier_state) + .unwrap(); + let verifier_result = config + .verify(&mut verifier_state, &commitment, sum) + .unwrap(); + verifier_state.check_eof().unwrap(); + assert_eq!( + verifier_result.linear_form_evaluation, + prover_result.linear_form_evaluation, + ); + } + + #[test] + fn round_trip_standard() { + round_trip(0x5EED_5EED, 8, false); + } + + #[test] + fn round_trip_zk() { + round_trip(0x5EED_5EED, 8, true); + } +} diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index 2ce14647..30054df1 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -1,4 +1,4 @@ -//! Shared primitives for parameter selection: RS bounds + PoW sizing. +//! Shared RS-code primitives + the [`SoundnessBounded`] abstraction. use std::{f64::consts::LOG2_10, ops::Neg}; @@ -8,6 +8,18 @@ use crate::{ protocols::irs_commit, }; +/// Analytic soundness bits (excluding PoW) delivered by a protocol-level unit. +/// +/// Implemented on [`RoundPlan`](super::plan::RoundPlan), +/// [`MaskOraclePlan`](super::plan::MaskOraclePlan), and +/// [`ParameterPlan`](super::plan::ParameterPlan). Sub-protocol `Config` types +/// lack the cross-protocol context to self-report. +// TODO(phase-6): wire `analytic + pow >= target` so this is called outside tests. +#[allow(dead_code)] +pub trait SoundnessBounded { + fn analytic_bits(&self) -> Bits; +} + /// `johnson_slack == 0.0` selects the unique-decoding regime. #[derive(Debug, Clone, Copy)] pub struct CodeParams { @@ -75,10 +87,8 @@ pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { ((message_length - 1) as f64).log2() - field_bits } -/// PoW difficulty to close a soundness gap: max(0, target − achieved). -/// -/// Currently unused — solvers emit `Config::none()` PoW. Will be re-wired by -/// the cross-protocol PoW pass. +/// PoW difficulty to close a soundness gap: `max(0, target − achieved)`. +// TODO(phase-6): re-wire from the cross-protocol PoW pass. #[allow(dead_code)] pub fn pow_bits_to_close_gap(target_security_bits: f64, achieved_security_bits: f64) -> Bits { Bits::new((target_security_bits - achieved_security_bits).max(0.0)) diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index d696f55f..c350a2d9 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -1,7 +1,5 @@ -//! Parameter selection for the code-switching IOR (Construction 9.7, p.55). -//! -//! Computes `t_ood` (Bound 2 / Lemma 9.9 first error term) and sizes the ZK -//! mask oracle `ℓ_zk` per Theorem 9.6 + Lemma 9.5. +//! Code-switching IOR (Construction 9.7, p.55) builder + Lemma 9.9 OOD bound. +//! The `t_ood` / `ℓ_zk` fixed-points live in the planner. use std::num::NonZeroUsize; @@ -10,101 +8,86 @@ use crate::{ embedding::{Embedding, Identity}, fields::FieldWithSize, }, + bits::Bits, protocols::{ code_switch, - irs_commit::{self, Config as IrsConfig}, - params::{ - plan::RoundModeParams, - spec::{MaskCodeMessageLen, Mode, SecuritySpec}, - }, + irs_commit::Config as IrsConfig, + params::{plan::MaskOracleInfo, spec::SecuritySpec}, + proof_of_work, }, }; -/// Assemble the [`code_switch::Config`] from precomputed `t_ood` and a -/// mode-typed `zk` context. +/// `mask_oracle.l_zk` must have been used to size C_zk (planner's job). /// -/// In ZK mode, `zk` carries the `l_zk` produced by [`compute_l_zk`]; the -/// orchestrator must have used the same value for `irs_commit::solve_mask_code` -/// so both consumers see the same mask-oracle length. +/// PoW closes the Lemma 9.9 OOD gap to `spec.target_security_bits`. When +/// `t_ood == 0` no OOD challenge is drawn, so no grinding is required. pub fn solve( + spec: &SecuritySpec, source: IrsConfig, target: IrsConfig>, t_ood: usize, - zk: &RoundModeParams, + mask_oracle: Option, ) -> code_switch::Config { - let mode = match zk { - RoundModeParams::Standard => code_switch::Mode::Standard, - RoundModeParams::ZeroKnowledge { l_zk, .. } => { - let l_zk = l_zk.get(); - assert!( - l_zk >= source.mask_length() + t_ood, - "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Bound 3", - source.mask_length(), - t_ood, - ); - code_switch::Mode::ZeroKnowledge { - message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), - } + let mode = mask_oracle.map_or(code_switch::Mode::Standard, |info| { + let l_zk = info.l_zk.get(); + assert!( + l_zk >= source.mask_length() + t_ood, + "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Bound 3", + source.mask_length(), + t_ood, + ); + code_switch::Mode::ZeroKnowledge { + message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), } - }; - code_switch::Config::new(source, target, t_ood, mode) -} + }); -/// `ℓ_zk = next_power_of_two(r + t_ood)` — shared by code-switch and C_zk. -/// -/// Bound 3 / Lemma 9.3 requires `ℓ_zk ≥ r + t_ood`; pow2 padding lets the -/// same value drive `irs_commit::solve_mask_code`'s NTT-order assertion. -pub const fn compute_l_zk( - source: &irs_commit::Config, - t_ood: usize, -) -> MaskCodeMessageLen { - MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()) + let target_bits = Bits::new(f64::from(spec.target_security_bits)); + let analytic = analytic_error_bits(&source, &target, t_ood, mask_oracle); + let pow = proof_of_work::Config::grind_to(target_bits, analytic, spec.hash_id); + + code_switch::Config::new(source, target, t_ood, mode, pow) } -/// `t_ood` from Bound 2 / Lemma 9.9 first error term. +/// Dominant soundness gap that PoW must close: `min(OOD term, combination term)`. /// -/// Solves `(|Λ(C')| · |Λ(C_zk)|)² / 2 · ((ℓ + ℓ_zk - 1) / |F|)^{t_ood} ≤ -/// 2^{-security}`. In ZK mode `ℓ_zk = r + t_ood` is mutually dependent with -/// `t_ood`; iterate to the fixed point. -pub fn compute_t_ood( - spec: &SecuritySpec, +/// - OOD (Lemma 9.9, term 1): `t_ood · (log|F| − log(degree − 1)) − log(L choose 2)`, +/// with `L = target × c_zk` (ZK) or `target` (Standard), and +/// `degree = ℓ + r + t_ood` (ZK) or `ℓ` (Standard). +/// - Combination (Bound 1, γ-RLC): `log|F| − log(t_ood + t·ι) − log|Λ(target)| − [log|Λ(C_zk)|]`. +/// +/// `t_ood ≥ 1` per [`code_switch::Config::new`]. +pub fn analytic_error_bits( source: &IrsConfig, - target_list_size: f64, - c_zk_list_size: Option, -) -> usize { - const MAX_ITER: usize = 32; - - let security_target = spec.protocol_security_target_bits(); + target: &IrsConfig>, + t_ood: usize, + mask_oracle: Option, +) -> Bits { + assert!(t_ood > 0, "code-switch requires t_ood ≥ 1"); let field_bits = M::Target::field_size_bits(); - let unique_decoding = spec.mode.unique_decoding(); - let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); - let message_length = source.message_length(); - let source_mask_length = source.mask_length(); + let target_list = target.list_size(); + let combined_list = mask_oracle.map_or(target_list, |info| target_list * info.c_zk_list_size); + let degree = mask_oracle.map_or_else( + || source.message_length(), + |_| source.masked_message_length() + t_ood, + ); - let solve_for_degree = |degree: usize| { - irs_commit::num_ood_samples( - unique_decoding, - security_target, - field_bits, - combined_list_size, - degree, - ) - }; + #[allow(clippy::cast_precision_loss)] + let log_degree_minus_1 = ((degree - 1) as f64).log2(); + let l_choose_2 = combined_list * (combined_list - 1.0) / 2.0; + #[allow(clippy::cast_precision_loss)] + let ood_term = (t_ood as f64) * (field_bits - log_degree_minus_1) - l_choose_2.log2(); - if !matches!(spec.mode, Mode::ZeroKnowledge) { - return solve_for_degree(message_length); - } + // Combination term: counts OOD samples plus the in-domain batch + // (t source queries, each contributing one column of the ι-interleaved + // source codeword to the geometric_challenge RLC). + let count = t_ood + source.in_domain_samples * source.interleaving_depth; + #[allow(clippy::cast_precision_loss)] + let log_count = (count as f64).log2(); + let log_target_list = target_list.log2(); + let log_c_zk_list = mask_oracle.map_or(0.0, |info| info.c_zk_list_size.log2()); + let combination_term = field_bits - log_count - log_target_list - log_c_zk_list; - // ZK: t_ood = f(ℓ + r + t_ood); iterate. - let mut t_ood = 0; - for _ in 0..MAX_ITER { - let new_t_ood = solve_for_degree(message_length + source_mask_length + t_ood); - if new_t_ood == t_ood { - return t_ood; - } - t_ood = new_t_ood; - } - panic!("compute_t_ood did not converge in {MAX_ITER} iterations"); + Bits::new(ood_term.min(combination_term).max(0.0)) } #[cfg(test)] @@ -114,8 +97,8 @@ mod tests { use super::*; use crate::protocols::params::{ irs_commit as params_irs, - plan::RoundModeParams, - spec::{LogInvRate, OodSampleBudget, RoundContext}, + planner::{compute_l_zk, compute_t_ood}, + spec::{LogInvRate, Mode, OodSampleBudget, RoundContext, SecuritySpec}, test_utils::{ arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, deterministic_spec, TestEmbedding, TestExtensionField, TestNonIdentityEmbedding, @@ -124,17 +107,21 @@ mod tests { type M = TestEmbedding; + // Keeps `target − error ≤ 60`, the cap `proof_of_work::threshold` enforces. + // On Field64 the γ-RLC combination term sits at ~0 bits in ZK and ~30 bits + // in Standard, so the gap to target must stay under 60. + const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; + fn arb_zk_spec() -> impl Strategy> { - utils_zk_spec(80..=128) + utils_zk_spec(TEST_TARGET_RANGE) } fn arb_standard_johnson_spec() -> impl Strategy> { - utils_standard_spec(80..=128) + utils_standard_spec(TEST_TARGET_RANGE) } - /// Orchestrator-style: build source + target IRS and the matching `t_ood`. - /// Iterates target until its `codeword_length` stabilizes (target's realized - /// rate depends on its mask budget, which depends on `t_ood`). + /// Iterates target until `codeword_length` stabilizes — its realized rate + /// depends on `mask_length`, which depends on `t_ood`. fn build_inputs( spec: &SecuritySpec, log_inv_rate: u32, @@ -147,8 +134,6 @@ mod tests { vector_size: 1usize << num_vars, log_inv_rate, folding_factor, - prev_round_in_domain_samples: 0, - prev_round_query_error: 0.0, }; let source = params_irs::solve(spec, &source_ctx, OodSampleBudget::new(0)); @@ -157,8 +142,6 @@ mod tests { vector_size: source.message_length(), log_inv_rate: log_inv_rate + folding_factor - 1, folding_factor, - prev_round_in_domain_samples: source.in_domain_samples, - prev_round_query_error: 0.0, }; let mut target = params_irs::solve(spec, &target_ctx, OodSampleBudget::new(0)); @@ -173,8 +156,7 @@ mod tests { panic!("target IRS did not stabilize"); } - /// `num_vars ≥ 2 * folding_factor` so target's `vector_size` stays divisible - /// by target's `interleaving_depth = 1 << folding_factor`. + /// `num_vars ≥ 2 * folding_factor` keeps target IRS valid. fn arb_dims() -> impl Strategy { (1u32..=3, 1u32..=2).prop_flat_map(|(log_inv_rate, folding_factor)| { let min_num_vars = 2 * folding_factor; @@ -187,7 +169,6 @@ mod tests { } proptest! { - /// Standard mode: `Config::new` assertions pass, `t_ood ≥ 1` in Johnson. #[test] fn solve_standard_assembles( spec in arb_standard_johnson_spec(), @@ -195,26 +176,23 @@ mod tests { ) { let (source, target, t_ood) = build_inputs(&spec, log_inv_rate, folding_factor, num_vars, None); - let config = solve(source, target, t_ood, &RoundModeParams::Standard); + let config = solve(&spec, source, target, t_ood, None); prop_assert!(matches!(config.mode, code_switch::Mode::Standard)); prop_assert!(config.out_domain_samples >= 1); } - /// ZK mode: `ℓ_zk = next_power_of_two(r + t_ood)` shared with C_zk. + /// ZK: `ℓ_zk = next_power_of_two(r + t_ood)`. #[test] fn solve_zk_mask_equals_padded_r_plus_t_ood( spec in arb_zk_spec(), (log_inv_rate, folding_factor, num_vars) in arb_dims(), ) { - // Bootstrap C_zk with a placeholder t_ood to break the - // t_ood ↔ c_zk.list_size circular dependency. + // Break the t_ood ↔ c_zk.list_size cycle with a placeholder C_zk. let placeholder_source_ctx = RoundContext { round_index: 0, vector_size: 1usize << num_vars, log_inv_rate, folding_factor, - prev_round_in_domain_samples: 0, - prev_round_query_error: 0.0, }; let placeholder_source = params_irs::solve( &spec, @@ -240,20 +218,17 @@ mod tests { LogInvRate::new(log_inv_rate), 2, ); - // Fixed-point check: t_ood was computed with the placeholder - // C_zk's list_size; the final C_zk's list_size must agree. let recomputed_t_ood = compute_t_ood(&spec, &source, target.list_size(), Some(c_zk.list_size())); - prop_assert_eq!( - t_ood, recomputed_t_ood, - "t_ood computed with placeholder C_zk must equal t_ood with final C_zk", - ); - let zk = RoundModeParams::ZeroKnowledge { c_zk, l_zk }; - let config = solve(source, target, t_ood, &zk); + prop_assert_eq!(t_ood, recomputed_t_ood, "placeholder ⇒ final C_zk fixed-point"); + let mask_oracle = MaskOracleInfo { + c_zk_list_size: c_zk.list_size(), + l_zk, + }; + let config = solve(&spec, source, target, t_ood, Some(mask_oracle)); prop_assert_eq!(config.message_mask_length(), (r + t_ood).next_power_of_two()); } - /// `compute_t_ood` converges and returns `t_ood ≥ 1` in Johnson regime. #[test] fn compute_t_ood_converges( spec in arb_zk_spec(), @@ -263,32 +238,44 @@ mod tests { build_inputs(&spec, log_inv_rate, folding_factor, num_vars, None); prop_assert!(t_ood >= 1); } + + /// `analytic_error + pow ≥ target` (Lemma 9.9 OOD term). + #[test] + fn pow_closes_gap_to_target_standard( + spec in arb_standard_johnson_spec(), + (log_inv_rate, folding_factor, num_vars) in arb_dims(), + ) { + let (source, target, t_ood) = + build_inputs(&spec, log_inv_rate, folding_factor, num_vars, None); + let config = solve(&spec, source.clone(), target.clone(), t_ood, None); + let error = f64::from(analytic_error_bits(&source, &target, t_ood, None)); + let pow_bits = f64::from(config.pow.difficulty()); + prop_assert!( + error + pow_bits >= f64::from(spec.target_security_bits) - 1e-3, + "error {} + pow {} < target {}", + error, pow_bits, spec.target_security_bits, + ); + } } - /// Build the canonical `(source_ctx, target_ctx)` pair used by both - /// non-identity smoke tests. Single source of truth so Standard and ZK - /// exercise the same problem shape, only the mode differs. + /// Shared shape so Standard and ZK smoke tests differ only in mode. fn non_identity_smoke_ctxs() -> (RoundContext, RoundContext) { let source_ctx = RoundContext { round_index: 0, vector_size: 64, log_inv_rate: 1, folding_factor: 2, - prev_round_in_domain_samples: 0, - prev_round_query_error: 0.0, }; let target_ctx = RoundContext { round_index: 1, vector_size: source_ctx.vector_size / (1 << source_ctx.folding_factor), log_inv_rate: source_ctx.log_inv_rate + source_ctx.folding_factor - 1, folding_factor: source_ctx.folding_factor, - prev_round_in_domain_samples: 0, - prev_round_query_error: 0.0, }; (source_ctx, target_ctx) } - /// Standard-mode smoke test with `M::Source ≠ M::Target`. + /// Smoke test: `M::Source ≠ M::Target`, Standard mode. #[test] fn solve_works_with_basefield_embedding_standard() { let spec_source: SecuritySpec = @@ -302,18 +289,15 @@ mod tests { let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); - // Standard target: codeword_length is independent of t_ood (mask = 0), - // so one solve is sufficient. + // Standard target: codeword_length is t_ood-independent (mask = 0). let target = params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(0)); let t_ood = compute_t_ood(&spec_source, &source, target.list_size(), None); - let config = solve(source, target, t_ood, &RoundModeParams::Standard); + let config = solve(&spec_source, source, target, t_ood, None); assert!(matches!(config.mode, code_switch::Mode::Standard)); } - /// ZK-mode smoke test with `M::Source ≠ M::Target`. Exercises the - /// `RoundModeParams::ZeroKnowledge { c_zk, l_zk }` type path with - /// `c_zk: irs_commit::Config>`. + /// Smoke test: `M::Source ≠ M::Target`, ZK mode with shared C_zk. #[test] fn solve_works_with_basefield_embedding_zk() { let spec_source: SecuritySpec = @@ -323,7 +307,7 @@ mod tests { let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); - // Bootstrap C_zk's list_size with a placeholder ℓ_zk. + // Placeholder ℓ_zk to bootstrap c_zk.list_size. let c_zk_placeholder = params_irs::solve_mask_code( &spec_target, compute_l_zk(&source, 1), @@ -332,7 +316,6 @@ mod tests { 2, ); let c_zk_list_size = c_zk_placeholder.list_size(); - // Two-solve target rebuild: placeholder t_ood, then final. let target_placeholder = params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(0)); let t_ood = compute_t_ood( @@ -348,10 +331,7 @@ mod tests { target.list_size(), Some(c_zk_list_size), ); - assert_eq!( - t_ood, t_ood_check, - "smoke-test params should converge in one iteration", - ); + assert_eq!(t_ood, t_ood_check, "fixed-point in one iteration"); let l_zk = compute_l_zk(&source, t_ood); let c_zk = params_irs::solve_mask_code( @@ -361,8 +341,11 @@ mod tests { LogInvRate::new(1), 2, ); - let zk = RoundModeParams::ZeroKnowledge { c_zk, l_zk }; - let config = solve(source, target, t_ood, &zk); + let mask_oracle = MaskOracleInfo { + c_zk_list_size: c_zk.list_size(), + l_zk, + }; + let config = solve(&spec_source, source, target, t_ood, Some(mask_oracle)); assert!(matches!( config.mode, code_switch::Mode::ZeroKnowledge { .. } diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index fd7be808..e3e7d84b 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -1,4 +1,5 @@ -//! Parameter selection for the IRS commit protocol. +//! IRS-commit parameter selection. ZK mask sized per Lemma 9.5, padded so +//! `message + mask` is a pow2 (NTT-valid codeword length). use std::num::NonZeroUsize; @@ -12,8 +13,6 @@ use crate::{ }, }; -/// Solve per-round IRS-commit parameters. ZK mask sized per Lemma 9.5, -/// padded so `message + mask` is a power of 2 (NTT-valid codeword length). pub fn solve( spec: &SecuritySpec, ctx: &RoundContext, @@ -31,7 +30,7 @@ pub fn solve( let min_mask = num_in_domain_queries(unique_decoding, security_target, rate) .checked_add(out_domain_samples.get()) .expect("usize overflow"); - // Pad to pow2: Lemma 9.5 is `≥` so over-allocating is safe. + // Lemma 9.5 is `≥`, so pow2 padding is safe. let mask_length = message_length .checked_add(min_mask.get()) .expect("usize overflow") @@ -47,8 +46,7 @@ pub fn solve( security_target, unique_decoding, spec.hash_id, - // num_vectors: orchestrator commits one vector per round. - 1, + 1, // one vector committed per round ctx.vector_size, interleaving_depth, rate, @@ -56,12 +54,11 @@ pub fn solve( ) } -/// Solve the shared C_zk IRS config for committing mask polynomials. +/// Shared C_zk IRS config for mask polynomials. /// -/// - `l_zk` — message length. Must be a power of 2 (caller pads it; see assert). -/// - `source_mask_length` — `r`, the source IRS mask length (Theorem 9.6). -/// - `log_inv_rate` — C_zk rate. -/// - `num_vectors` — total masks per commit; `2 * num_masks` for mask-proximity. +/// - `l_zk`: message length, must be a power of 2. +/// - `source_mask_length`: `r` from Theorem 9.6. +/// - `num_vectors`: `2 * num_masks` (Construction 7.2: originals + fresh). pub fn solve_mask_code( spec: &SecuritySpec, l_zk: MaskCodeMessageLen, @@ -89,8 +86,7 @@ pub fn solve_mask_code( irs_commit::Config::new( security_target, - // ZK ⇒ Johnson regime. - false, + false, // ZK ⇒ Johnson regime spec.hash_id, num_vectors, l_zk, @@ -151,8 +147,7 @@ mod tests { arb_zk_spec(80..=128) } - /// IRS-specific: vary `unique_decoding` to exercise both regimes inside - /// `irs_commit::Config::new`. + /// Varies `unique_decoding` to exercise both regimes. fn arb_standard_spec() -> impl Strategy> { any::() .prop_flat_map(|unique_decoding| arb_spec(Mode::Standard { unique_decoding }, 80..=128)) @@ -178,7 +173,7 @@ mod tests { } proptest! { - /// Lemma 9.5: ZK mask covers all revealed evaluations. + /// Lemma 9.5: mask covers all revealed evaluations. #[test] fn zk_mask_covers_lemma_9_5( spec in arb_zk_spec_default(), @@ -193,7 +188,6 @@ mod tests { ); } - /// Standard mode produces no IRS randomness regardless of input. #[test] fn standard_has_no_mask( spec in arb_standard_spec(), @@ -204,7 +198,6 @@ mod tests { prop_assert_eq!(config.mask_length(), 0); } - /// ZK round-trip + witness shape check. #[test] fn zk_round_trips( spec in arb_zk_spec_default(), @@ -213,16 +206,11 @@ mod tests { seed: u64, ) { let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); - prop_assert!(config.mask_length() > 0, "ZK mode must produce non-zero mask"); + prop_assert!(config.mask_length() > 0); let witness = commit_open_verify(&config, seed); - prop_assert_eq!( - witness.masks.len(), - config.mask_length() * config.num_messages(), - "witness mask vector size", - ); + prop_assert_eq!(witness.masks.len(), config.mask_length() * config.num_messages()); } - /// Standard round-trip + empty-mask check. #[test] fn standard_round_trips( spec in arb_standard_spec(), @@ -232,7 +220,7 @@ mod tests { let config = solve(&spec, &ctx, OodSampleBudget::new(0)); prop_assert_eq!(config.mask_length(), 0); let witness = commit_open_verify(&config, seed); - prop_assert!(witness.masks.is_empty(), "Standard mode must produce no masks"); + prop_assert!(witness.masks.is_empty()); } } } diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index a166b9e1..48f671f9 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -1,27 +1,40 @@ -//! Parameter selection for the mask-proximity protocol (Construction 7.2). -//! -//! Mask-proximity spot-checks each committed mask oracle against C_zk via -//! γ-combination (Lemma 7.4). ZK-only — Standard mode never invokes it. +//! Mask-proximity (Construction 7.2) builder + Lemma 7.4 γ-combination bound. +//! ZK-only. use ark_ff::Field; use crate::{ - algebra::embedding::Identity, - protocols::{irs_commit, mask_proximity}, + algebra::{embedding::Identity, fields::FieldWithSize}, + bits::Bits, + protocols::{ + irs_commit::Config as IrsConfig, mask_proximity, params::spec::SecuritySpec, proof_of_work, + }, }; -/// Assemble a [`mask_proximity::Config`] from the shared C_zk IRS config and -/// the number of mask polynomials the protocol commits to. -/// -/// `c_zk` must be sized with `num_vectors == 2 * num_masks` (Construction 7.2 -/// commits originals and their fresh mask-of-masks side by side in the shared -/// tree). The orchestrator obtains `c_zk` via `irs_commit::solve_mask_code` -/// with that same `num_vectors`. +/// `c_zk.num_vectors` must equal `2 * num_masks` (originals + fresh). +/// PoW closes the Lemma 7.4 γ-combination gap to `spec.target_security_bits`. pub fn solve( - c_zk: irs_commit::Config>, + spec: &SecuritySpec>, + c_zk: IrsConfig>, num_masks: usize, ) -> mask_proximity::Config { - mask_proximity::Config::new(c_zk, num_masks) + let target_bits = Bits::new(f64::from(spec.target_security_bits)); + let analytic = analytic_error_bits(&c_zk, num_masks); + let pow = proof_of_work::Config::grind_to(target_bits, analytic, spec.hash_id); + mask_proximity::Config::new(c_zk, num_masks, pow) +} + +/// γ-combination soundness (Lemma 7.4): +/// `log|F| − log(num_masks · (deg − 1))`, with `deg = c_zk.masked_message_length()`. +pub fn analytic_error_bits(c_zk: &IrsConfig>, num_masks: usize) -> Bits { + let field_bits = F::field_size_bits(); + let deg = c_zk.masked_message_length(); + if deg <= 1 || num_masks == 0 { + return Bits::new(field_bits.max(0.0)); + } + #[allow(clippy::cast_precision_loss)] + let log_combined = ((num_masks * (deg - 1)) as f64).log2(); + Bits::new((field_bits - log_combined).max(0.0)) } #[cfg(test)] @@ -42,12 +55,13 @@ mod tests { }, }; + // Keeps `target − error ≤ 60`, the cap `proof_of_work::threshold` enforces. + const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; + proptest! { - /// `solve` produces a Config satisfying `mask_proximity::Config::new`'s - /// invariants (`num_vectors == 2 * num_masks`, `interleaving_depth == 1`). #[test] fn solve_assembles( - spec in arb_zk_spec(80..=128), + spec in arb_zk_spec(TEST_TARGET_RANGE), log_inv_rate in 1u32..=3, num_masks in 1usize..=8, l_zk_log in 1u32..=5, @@ -60,19 +74,43 @@ mod tests { LogInvRate::new(log_inv_rate), 2 * num_masks, ); - let config = solve(c_zk, num_masks); + let config = solve(&spec, c_zk, num_masks); prop_assert_eq!(config.num_masks, num_masks); prop_assert_eq!(config.c_zk_commit.num_vectors, 2 * num_masks); prop_assert_eq!(config.c_zk_commit.interleaving_depth, 1); } + + /// `analytic_error + pow ≥ target` (Lemma 7.4 γ-combination). + #[test] + fn pow_closes_gap_to_target( + spec in arb_zk_spec(TEST_TARGET_RANGE), + log_inv_rate in 1u32..=3, + num_masks in 1usize..=8, + l_zk_log in 1u32..=5, + ) { + let l_zk = MaskCodeMessageLen::new(1usize << l_zk_log); + let c_zk = params_irs::solve_mask_code( + &spec, + l_zk, + 0, + LogInvRate::new(log_inv_rate), + 2 * num_masks, + ); + let analytic = f64::from(analytic_error_bits(&c_zk, num_masks)); + let config = solve(&spec, c_zk, num_masks); + let pow_bits = f64::from(config.pow.difficulty()); + prop_assert!( + analytic + pow_bits >= f64::from(spec.target_security_bits) - 1e-3, + "analytic {} + pow {} < target {}", + analytic, pow_bits, spec.target_security_bits, + ); + } } - /// `mask_proximity::Config::new` rejects `c_zk.num_vectors != 2 * num_masks`. #[test] #[should_panic(expected = "c_zk.num_vectors must be 2 * num_masks")] fn solve_rejects_mismatched_num_vectors() { let spec = deterministic_spec::(Mode::ZeroKnowledge); - // c_zk built for 2 masks (num_vectors = 4); caller passes num_masks = 3. let c_zk = params_irs::solve_mask_code( &spec, MaskCodeMessageLen::new(2), @@ -80,22 +118,23 @@ mod tests { LogInvRate::new(1), 4, ); - let _ = solve(c_zk, 3); + let _ = solve(&spec, c_zk, 3); } #[test] #[should_panic(expected = "interleaving_depth = 1")] fn solve_rejects_non_unit_interleaving() { + let spec = deterministic_spec::(Mode::ZeroKnowledge); let c_zk = crate::protocols::irs_commit::Config::>::new( 80.0, false, hash::BLAKE3, - 2, // num_vectors = 2 (= 2 * num_masks with num_masks=1) - 8, // vector_size - 2, // interleaving_depth ≠ 1 + 2, + 8, + 2, // interleaving_depth ≠ 1 — triggers the panic 0.5, IrsMode::Standard, ); - let _ = solve(c_zk, 1); + let _ = solve(&spec, c_zk, 1); } } diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index 59391882..4a56c6e6 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -1,10 +1,10 @@ -// This module contains the parameter selection and security target logic. - +pub mod basecase; pub(crate) mod bounds; pub mod code_switch; pub mod irs_commit; pub mod mask_proximity; pub mod plan; +pub mod planner; pub mod spec; pub mod sumcheck; diff --git a/src/protocols/params/plan.rs b/src/protocols/params/plan.rs index ff7ed1df..21b8ed48 100644 --- a/src/protocols/params/plan.rs +++ b/src/protocols/params/plan.rs @@ -1,48 +1,158 @@ -//! Derived parameter plan for the Construction 9.7 ZK protocol. +//! Output shape of the planner. //! -//! Built by the orchestrator from a [`SecuritySpec`] + [`TuningSpec`]; owns -//! the cross-protocol resolved values (source/target IRS, C_zk, t_ood, ℓ_zk, -//! per-round sub-protocol configs) so downstream code doesn't coordinate them. +//! C_zk and ℓ_zk are protocol-global (one shared Merkle tree across all +//! rounds) and live in [`SharedPlan`]; per-round sumcheck + code-switch live +//! in [`RoundPlan`]. Source/target IRS configs are accessed via +//! `round.code_switch` — not duplicated at the round level. + +use ark_ff::Field; use crate::{ algebra::embedding::{Embedding, Identity}, + bits::Bits, protocols::{ - code_switch, irs_commit, - params::spec::{MaskCodeMessageLen, SecuritySpec, TuningSpec}, + basecase, code_switch, irs_commit, mask_proximity, + params::{ + basecase as basecase_solver, + bounds::SoundnessBounded, + code_switch as code_switch_solver, mask_proximity as mask_proximity_solver, + spec::{MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, + sumcheck as sumcheck_solver, + }, sumcheck, }, }; -/// Full derived parameter plan for one protocol run. #[derive(Clone, Debug)] pub struct ParameterPlan { pub security: SecuritySpec, pub tuning: TuningSpec, - pub rounds: Vec>, + pub shared: SharedPlan, + pub rounds: Vec>, + pub basecase: basecase::Config, +} + +impl SoundnessBounded for ParameterPlan { + fn analytic_bits(&self) -> Bits { + let mut min_bits = f64::INFINITY; + for round in &self.rounds { + min_bits = min_bits.min(f64::from(round.analytic_bits())); + } + if let Some(mo) = &self.shared.mask_oracle { + min_bits = min_bits.min(f64::from(mo.analytic_bits())); + } + // Basecase sumcheck per-round bound applies in both modes; the γ-slot + // only contributes in ZK. + min_bits = min_bits.min(f64::from(sumcheck_solver::analytic_error_bits( + &self.basecase.commit, + None, + ))); + if matches!(self.basecase.mode, basecase::Mode::ZeroKnowledge) { + min_bits = min_bits.min(f64::from(basecase_solver::analytic_error_bits( + &self.basecase.commit, + ))); + } + if min_bits.is_infinite() { + return Bits::new(f64::from(self.security.target_security_bits)); + } + Bits::new(min_bits.max(0.0)) + } } -/// Parameters for a single round (sumcheck + code-switch). #[derive(Clone, Debug)] -pub struct RoundParams { +pub struct RoundPlan { pub round_index: usize, - pub source_irs: irs_commit::Config, - pub target_irs: irs_commit::Config>, pub sumcheck: sumcheck::Config, pub code_switch: code_switch::Config, - pub zk: RoundModeParams, + pub mode: RoundMode, } -#[derive(Clone, Debug)] -pub enum RoundModeParams { +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum RoundMode { Standard, ZeroKnowledge { - c_zk: irs_commit::Config>, - l_zk: MaskCodeMessageLen, + /// Bound 2 / Lemma 9.9. + t_ood: OodSampleBudget, + /// Cached view of the shared mask oracle (denormalized from + /// [`MaskOraclePlan`]) so each round is self-contained for soundness. + mask_oracle: MaskOracleInfo, }, } -impl RoundModeParams { +impl RoundMode { pub const fn is_zk(&self) -> bool { matches!(self, Self::ZeroKnowledge { .. }) } + + pub const fn mask_oracle(&self) -> Option { + match self { + Self::Standard => None, + Self::ZeroKnowledge { mask_oracle, .. } => Some(*mask_oracle), + } + } +} + +impl SoundnessBounded for RoundPlan { + fn analytic_bits(&self) -> Bits { + let source = &self.code_switch.source; + let target = &self.code_switch.target; + let mask_oracle = self.mode.mask_oracle(); + + let sumcheck_term = sumcheck_solver::analytic_error_bits(source, mask_oracle); + let code_switch_term = code_switch_solver::analytic_error_bits( + source, + target, + self.code_switch.out_domain_samples, + mask_oracle, + ); + + if f64::from(code_switch_term) < f64::from(sumcheck_term) { + code_switch_term + } else { + sumcheck_term + } + } +} + +#[derive(Clone, Debug)] +pub struct SharedPlan { + /// `Some` iff `Mode::ZeroKnowledge`. + pub mask_oracle: Option>, +} + +/// One C_zk codeword + one shared Merkle tree + one mask-proximity check, +/// covering every mask committed across all rounds. +#[derive(Clone, Debug)] +pub struct MaskOraclePlan { + /// `num_vectors = 2 * total_masks` (Construction 7.2: originals + fresh). + pub c_zk: irs_commit::Config>, + /// Dominates every round's `r + t_ood` (Lemma 9.3). + pub l_zk: MaskCodeMessageLen, + pub mask_proximity: mask_proximity::Config, +} + +/// Slim mask-oracle view (C_zk's list size + ℓ_zk) for builders that don't +/// need the full config. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct MaskOracleInfo { + pub c_zk_list_size: f64, + pub l_zk: MaskCodeMessageLen, +} + +impl MaskOraclePlan { + pub fn info(&self) -> MaskOracleInfo { + MaskOracleInfo { + c_zk_list_size: self.c_zk.list_size(), + l_zk: self.l_zk, + } + } +} + +impl SoundnessBounded for MaskOraclePlan { + fn analytic_bits(&self) -> Bits { + mask_proximity_solver::analytic_error_bits( + &self.mask_proximity.c_zk_commit, + self.mask_proximity.num_masks, + ) + } } diff --git a/src/protocols/params/planner.rs b/src/protocols/params/planner.rs new file mode 100644 index 00000000..ed55a916 --- /dev/null +++ b/src/protocols/params/planner.rs @@ -0,0 +1,601 @@ +//! Derives a [`ParameterPlan`] from a spec + tuning. All cross-protocol +//! coordination — per-round loop, `t_ood ↔ r` and `ℓ_zk ↔ c_zk` fixed-points, +//! shared C_zk + mask-proximity — lives here. + +use std::marker::PhantomData; + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, + protocols::{ + irs_commit::{self, Config as IrsConfig}, + params::{ + basecase as bc_solver, code_switch as cs_solver, irs_commit as irs_solver, + mask_proximity as mp_solver, + plan::{ + MaskOracleInfo, MaskOraclePlan, ParameterPlan, RoundMode, RoundPlan, SharedPlan, + }, + spec::{ + LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + TuningSpec, + }, + sumcheck as sc_solver, + }, + }, +}; + +const L_ZK_MAX_ITER: usize = 16; +/// Smallest pow2 ≥ 1 — satisfies `solve_mask_code`'s pow2 assertion. +const L_ZK_BOOTSTRAP: usize = 2; + +impl ParameterPlan { + /// ZK mode runs a global ℓ_zk fixed-point so one C_zk covers every round. + pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Self { + let RoundLayout { + shapes, + basecase_vector_size, + basecase_log_inv_rate, + } = round_layout(&tuning); + match spec.mode { + Mode::Standard { .. } => derive_standard( + spec, + tuning, + &shapes, + basecase_vector_size, + basecase_log_inv_rate, + ), + Mode::ZeroKnowledge => derive_zk( + spec, + tuning, + &shapes, + basecase_vector_size, + basecase_log_inv_rate, + ), + } + } +} + +// Round layout +// --------------------------------------------------------------------------- + +/// `target_folding_factor` is the next round's source folding — uniform +/// `tuning.folding_factor` — so `target_r → source_{r+1}` has matching +/// interleaving. +#[derive(Debug, Clone, Copy)] +struct RoundShape { + round_index: usize, + source_vector_size: usize, + source_log_inv_rate: u32, + source_folding_factor: u32, + target_folding_factor: u32, +} + +/// Shapes plus the basecase tail (size + rate of the message after every fold). +struct RoundLayout { + shapes: Vec, + basecase_vector_size: usize, + basecase_log_inv_rate: u32, +} + +struct RoundData { + source: IrsConfig, + target: IrsConfig>, + t_ood: usize, +} + +/// Stops when there's no room for both a valid source and a valid target IRS. +fn round_layout(tuning: &TuningSpec) -> RoundLayout { + assert!(tuning.vector_size.is_power_of_two()); + assert!(tuning.folding_factor >= 1); + assert!(tuning.initial_folding_factor >= 1); + + let mut num_vars = tuning.vector_size.trailing_zeros() as usize; + let mut log_inv_rate = tuning.starting_log_inv_rate; + let mut source_folding = tuning.initial_folding_factor; + let target_folding = tuning.folding_factor; + let mut shapes = Vec::new(); + + while num_vars >= source_folding + target_folding { + #[allow(clippy::cast_possible_truncation)] + shapes.push(RoundShape { + round_index: shapes.len(), + source_vector_size: 1usize << num_vars, + source_log_inv_rate: log_inv_rate, + source_folding_factor: source_folding as u32, + target_folding_factor: target_folding as u32, + }); + num_vars -= source_folding; + #[allow(clippy::cast_possible_truncation)] + { + log_inv_rate += (source_folding as u32).saturating_sub(1); + } + source_folding = target_folding; + } + + RoundLayout { + shapes, + basecase_vector_size: 1usize << num_vars, + basecase_log_inv_rate: log_inv_rate, + } +} + +const fn round_context(shape: &RoundShape) -> RoundContext { + RoundContext { + round_index: shape.round_index, + vector_size: shape.source_vector_size, + log_inv_rate: shape.source_log_inv_rate, + folding_factor: shape.source_folding_factor, + } +} + +fn target_context(shape: &RoundShape, source: &IrsConfig) -> RoundContext { + RoundContext { + round_index: shape.round_index, + vector_size: source.message_length(), + log_inv_rate: shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1), + folding_factor: shape.target_folding_factor, + } +} + +// Standard mode +// --------------------------------------------------------------------------- + +fn derive_standard( + spec: SecuritySpec, + tuning: TuningSpec, + shapes: &[RoundShape], + basecase_vector_size: usize, + basecase_log_inv_rate: u32, +) -> ParameterPlan { + let target_spec = transfer_spec_to_target(&spec); + let rounds = shapes + .iter() + .map(|shape| build_round(&spec, shape, None)) + .collect(); + let basecase = bc_solver::solve(&target_spec, basecase_vector_size, basecase_log_inv_rate); + ParameterPlan { + security: spec, + tuning, + shared: SharedPlan { mask_oracle: None }, + rounds, + basecase, + } +} + +// Zero-knowledge mode — global ℓ_zk fixed-point + shared C_zk +// --------------------------------------------------------------------------- + +fn derive_zk( + spec: SecuritySpec, + tuning: TuningSpec, + shapes: &[RoundShape], + basecase_vector_size: usize, + basecase_log_inv_rate: u32, +) -> ParameterPlan { + let target_spec: SecuritySpec> = transfer_spec_to_target(&spec); + let c_zk_log_inv_rate = LogInvRate::new(tuning.starting_log_inv_rate); + + // Lemma 6.4: one mask polynomial per sumcheck round. C_zk holds 2×. + let total_masks: usize = shapes + .iter() + .map(|s| s.source_folding_factor as usize) + .sum(); + assert!(total_masks > 0, "ZK requires ≥ 1 mask polynomial"); + let c_zk_num_vectors = 2 * total_masks; + + let mut l_zk = MaskCodeMessageLen::new(L_ZK_BOOTSTRAP); + let mut c_zk = + irs_solver::solve_mask_code(&target_spec, l_zk, 0, c_zk_log_inv_rate, c_zk_num_vectors); + + let mut last_round_data: Vec> = Vec::new(); + + for _ in 0..L_ZK_MAX_ITER { + let round_data: Vec> = shapes + .iter() + .map(|shape| build_zk_round_data(&spec, shape, c_zk.list_size())) + .collect(); + + let max_r_plus_t_ood = round_data + .iter() + .map(|r| r.source.mask_length() + r.t_ood) + .max() + .expect("non-empty rounds"); + let new_l_zk = MaskCodeMessageLen::new(max_r_plus_t_ood.next_power_of_two()); + + if new_l_zk.get() == l_zk.get() { + last_round_data = round_data; + break; + } + + l_zk = new_l_zk; + // Solve_mask_code asserts `ℓ_zk ≥ r`; pass the max so it always holds. + let max_source_mask = round_data + .iter() + .map(|r| r.source.mask_length()) + .max() + .unwrap_or(0); + c_zk = irs_solver::solve_mask_code( + &target_spec, + l_zk, + max_source_mask, + c_zk_log_inv_rate, + c_zk_num_vectors, + ); + last_round_data = round_data; + } + + let mask_oracle_info = MaskOracleInfo { + c_zk_list_size: c_zk.list_size(), + l_zk, + }; + + let rounds = shapes + .iter() + .zip(last_round_data) + .map(|(shape, data)| finalize_zk_round(&spec, shape, data, mask_oracle_info)) + .collect(); + + let mask_proximity = mp_solver::solve(&target_spec, c_zk.clone(), total_masks); + let mask_oracle = MaskOraclePlan { + c_zk, + l_zk, + mask_proximity, + }; + + let basecase = bc_solver::solve(&target_spec, basecase_vector_size, basecase_log_inv_rate); + + ParameterPlan { + security: spec, + tuning, + shared: SharedPlan { + mask_oracle: Some(mask_oracle), + }, + rounds, + basecase, + } +} + +/// Local fixed point: `source.mask_length` covers `t_ood` queries; `t_ood` is +/// sized against `source.message + source.mask`. +fn build_zk_round_data( + spec: &SecuritySpec, + shape: &RoundShape, + c_zk_list_size: f64, +) -> RoundData { + const LOCAL_MAX_ITER: usize = 16; + + let src_ctx = round_context(shape); + let mut source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(0)); + let mut t_ood = 0; + let mut target = irs_solver::solve( + &transfer_spec_to_target(spec), + &target_context(shape, &source), + OodSampleBudget::new(0), + ); + + for _ in 0..LOCAL_MAX_ITER { + let new_t_ood = compute_t_ood(spec, &source, target.list_size(), Some(c_zk_list_size)); + let new_source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(new_t_ood)); + let new_target = irs_solver::solve( + &transfer_spec_to_target(spec), + &target_context(shape, &new_source), + OodSampleBudget::new(new_t_ood), + ); + + if new_t_ood == t_ood + && new_source.codeword_length == source.codeword_length + && new_target.codeword_length == target.codeword_length + { + return RoundData { + source: new_source, + target: new_target, + t_ood: new_t_ood, + }; + } + + source = new_source; + target = new_target; + t_ood = new_t_ood; + } + + panic!("per-round ZK fixed-point did not converge"); +} + +fn finalize_zk_round( + spec: &SecuritySpec, + shape: &RoundShape, + data: RoundData, + mask_oracle: MaskOracleInfo, +) -> RoundPlan { + let RoundData { + source, + target, + t_ood, + } = data; + let src_ctx = round_context(shape); + let sumcheck = sc_solver::solve(spec, &src_ctx, &source, Some(mask_oracle)); + let code_switch = cs_solver::solve(spec, source, target, t_ood, Some(mask_oracle)); + RoundPlan { + round_index: shape.round_index, + sumcheck, + code_switch, + mode: RoundMode::ZeroKnowledge { + t_ood: OodSampleBudget::new(t_ood), + mask_oracle, + }, + } +} + +// Standard mode per-round builder +// --------------------------------------------------------------------------- + +fn build_round( + spec: &SecuritySpec, + shape: &RoundShape, + mask_oracle: Option, +) -> RoundPlan { + debug_assert!(mask_oracle.is_none(), "ZK path uses finalize_zk_round"); + + let target_spec = transfer_spec_to_target(spec); + let src_ctx = round_context(shape); + let source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(0)); + + let mut target = irs_solver::solve( + &target_spec, + &target_context(shape, &source), + OodSampleBudget::new(0), + ); + let mut t_ood = compute_t_ood(spec, &source, target.list_size(), None); + for _ in 0..8 { + let new_target = irs_solver::solve( + &target_spec, + &target_context(shape, &source), + OodSampleBudget::new(t_ood), + ); + let new_t_ood = compute_t_ood(spec, &source, new_target.list_size(), None); + if new_target.codeword_length == target.codeword_length && new_t_ood == t_ood { + target = new_target; + t_ood = new_t_ood; + break; + } + target = new_target; + t_ood = new_t_ood; + } + + let sumcheck = sc_solver::solve(spec, &src_ctx, &source, None); + let code_switch = cs_solver::solve(spec, source, target, t_ood, None); + RoundPlan { + round_index: shape.round_index, + sumcheck, + code_switch, + mode: RoundMode::Standard, + } +} + +// Cross-protocol bound helpers +// --------------------------------------------------------------------------- + +/// Per-round `ℓ_zk = next_power_of_two(r + t_ood)` (Lemma 9.3). The global +/// ℓ_zk in [`derive_zk`] is the max-then-pad over all rounds, computed inline. +#[allow(dead_code)] +pub(super) const fn compute_l_zk( + source: &IrsConfig, + t_ood: usize, +) -> MaskCodeMessageLen { + MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()) +} + +/// Solves Lemma 9.9 term 1 for `t_ood`. In ZK, `degree = ℓ + r + t_ood` +/// couples back to `t_ood`, so iterate. +pub(super) fn compute_t_ood( + spec: &SecuritySpec, + source: &IrsConfig, + target_list_size: f64, + c_zk_list_size: Option, +) -> usize { + const MAX_ITER: usize = 32; + + let security_target = spec.protocol_security_target_bits(); + let field_bits = M::Target::field_size_bits(); + let unique_decoding = spec.mode.unique_decoding(); + let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); + let message_length = source.message_length(); + let source_mask_length = source.mask_length(); + + let solve_for_degree = |degree: usize| { + irs_commit::num_ood_samples( + unique_decoding, + security_target, + field_bits, + combined_list_size, + degree, + ) + }; + + if !matches!(spec.mode, Mode::ZeroKnowledge) { + return solve_for_degree(message_length); + } + + let mut t_ood = 0; + for _ in 0..MAX_ITER { + let new_t_ood = solve_for_degree(message_length + source_mask_length + t_ood); + if new_t_ood == t_ood { + return t_ood; + } + t_ood = new_t_ood; + } + panic!("compute_t_ood did not converge in {MAX_ITER} iterations"); +} + +// SecuritySpec helpers +// --------------------------------------------------------------------------- + +/// C_zk lives in `Identity`; copy the rest of the spec across. +const fn transfer_spec_to_target( + spec: &SecuritySpec, +) -> SecuritySpec> { + SecuritySpec { + mode: spec.mode, + target_security_bits: spec.target_security_bits, + max_pow_bits: spec.max_pow_bits, + hash_id: spec.hash_id, + _embedding: PhantomData, + } +} + +#[cfg(test)] +#[allow(clippy::float_cmp)] +mod tests { + use super::*; + use crate::{ + hash, + protocols::params::{bounds::SoundnessBounded, test_utils::TestEmbedding}, + }; + + fn tuning_with(vector_size: usize) -> TuningSpec { + TuningSpec { + vector_size, + starting_log_inv_rate: 1, + initial_folding_factor: 2, + folding_factor: 2, + } + } + + /// Keeps PoW below the 60-bit cap for small test tunings. + fn test_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + target_security_bits: 40, + max_pow_bits: None, + hash_id: hash::BLAKE3, + _embedding: PhantomData, + } + } + + #[test] + fn round_shapes_match_old_whir_loop() { + let tuning = tuning_with(1 << 10); + let layout = round_layout(&tuning); + assert!(!layout.shapes.is_empty()); + assert_eq!(layout.shapes[0].source_vector_size, 1 << 10); + assert_eq!(layout.shapes[0].source_folding_factor, 2); + } + + #[test] + fn derive_standard_assembles() { + let spec: SecuritySpec = test_spec(Mode::Standard { + unique_decoding: false, + }); + let tuning = tuning_with(1 << 8); + let plan = ParameterPlan::derive(spec, tuning); + assert!( + plan.shared.mask_oracle.is_none(), + "Standard ⇒ no mask oracle" + ); + assert!(!plan.rounds.is_empty()); + for r in &plan.rounds { + assert!(matches!(r.mode, RoundMode::Standard)); + } + } + + #[test] + fn derive_zk_produces_shared_mask_oracle() { + let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); + let tuning = tuning_with(1 << 8); + let plan = ParameterPlan::derive(spec, tuning); + let mask_oracle = plan + .shared + .mask_oracle + .as_ref() + .expect("ZK plan must produce a mask oracle"); + + // Bound 3: ℓ_zk dominates every round's r + t_ood. + for r in &plan.rounds { + let RoundMode::ZeroKnowledge { + t_ood, + mask_oracle: round_oracle, + } = r.mode + else { + panic!("expected ZK round"); + }; + let source_mask = r.code_switch.source.mask_length(); + assert!(mask_oracle.l_zk.get() >= source_mask + t_ood.get()); + assert_eq!(round_oracle.l_zk.get(), mask_oracle.l_zk.get()); + assert_eq!(round_oracle.c_zk_list_size, mask_oracle.c_zk.list_size()); + } + + let total_masks: usize = plan.rounds.iter().map(|r| r.sumcheck.num_rounds).sum(); + assert_eq!(mask_oracle.c_zk.num_vectors, 2 * total_masks); + assert_eq!(mask_oracle.mask_proximity.num_masks, total_masks); + } + + fn basecase_min_bits(plan: &ParameterPlan) -> f64 { + let sumcheck = f64::from(sc_solver::analytic_error_bits(&plan.basecase.commit, None)); + if matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::ZeroKnowledge + ) { + sumcheck.min(f64::from(bc_solver::analytic_error_bits( + &plan.basecase.commit, + ))) + } else { + sumcheck + } + } + + #[test] + fn analytic_bits_finite_and_positive_standard() { + let spec: SecuritySpec = test_spec(Mode::Standard { + unique_decoding: false, + }); + let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); + let bits: f64 = plan.analytic_bits().into(); + assert!(bits.is_finite() && bits > 0.0, "bits = {bits}"); + let min_round = plan + .rounds + .iter() + .map(|r| f64::from(r.analytic_bits())) + .fold(f64::INFINITY, f64::min); + let expected = min_round.min(basecase_min_bits(&plan)); + assert!((bits - expected).abs() < 1e-9, "{bits} vs {expected}"); + } + + #[test] + fn analytic_bits_includes_mask_oracle_in_zk() { + let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); + let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); + let plan_bits: f64 = plan.analytic_bits().into(); + let mo_bits: f64 = plan + .shared + .mask_oracle + .as_ref() + .expect("ZK has mask oracle") + .analytic_bits() + .into(); + let min_round = plan + .rounds + .iter() + .map(|r| f64::from(r.analytic_bits())) + .fold(f64::INFINITY, f64::min); + let expected = mo_bits.min(min_round).min(basecase_min_bits(&plan)); + assert!( + (plan_bits - expected).abs() < 1e-9, + "{plan_bits} vs {expected}" + ); + } + + #[test] + fn derive_plans_basecase() { + let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); + let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); + assert!(matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::ZeroKnowledge + )); + assert_eq!(plan.basecase.commit.interleaving_depth, 1); + // Sumcheck folds basecase to size 1. + assert_eq!(plan.basecase.sumcheck.final_size(), 1); + } +} diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 40c283a0..59bcc70a 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -2,7 +2,7 @@ use core::marker::PhantomData; use crate::{algebra::embedding::Embedding, engines::EngineId}; -/// Phantom-typed primitive — `Tagged` and `Tagged` are distinct types. +/// Phantom-typed newtype — `Tagged` and `Tagged` are distinct types. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Tagged(T, PhantomData); @@ -16,53 +16,60 @@ impl Tagged { } } -/// Security-target spec — *what* security the user wants. Tuning knobs live -/// in [`TuningSpec`]. #[derive(Debug, Clone)] pub struct SecuritySpec { pub mode: Mode, pub target_security_bits: u32, - // TODO: cross-protocol PoW pass; until then, set this to `None` or `Some(0)` - // to avoid silently surrendering `max_pow_bits` of security. pub max_pow_bits: Option, pub hash_id: EngineId, pub _embedding: PhantomData, } -/// Tuning knobs — proof-size / prover-time / soundness-margin tradeoffs. +impl SecuritySpec { + pub fn protocol_security_target_bits(&self) -> f64 { + let pow = self.max_pow_bits.unwrap_or(0); + f64::from(self.target_security_bits.saturating_sub(pow)) + } +} + +/// Proof-size / prover-time / soundness-margin tradeoffs. #[derive(Debug, Clone)] pub struct TuningSpec { - /// Witness vector size (input polynomial coefficient count). pub vector_size: usize, - /// Starting log inverse rate for the initial RS code. pub starting_log_inv_rate: u32, - /// Folding factor for the first (initial) sumcheck round. pub initial_folding_factor: usize, - /// Folding factor for subsequent sumcheck rounds. pub folding_factor: usize, } -/// Per-round context for bound calculations. +/// Per-round context handed to a sub-protocol builder. #[derive(Debug, Clone)] pub struct RoundContext { pub round_index: usize, pub vector_size: usize, pub log_inv_rate: u32, pub folding_factor: u32, - // Reserved for the orchestrator's combination-error sizing; unused by - // current solvers. - pub prev_round_in_domain_samples: usize, - pub prev_round_query_error: f64, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Mode { - /// Regime is selectable. - Standard { unique_decoding: bool }, + Standard { + unique_decoding: bool, + }, /// Always Johnson regime — Construction 9.7 needs OOD queries. ZeroKnowledge, } +impl Mode { + pub const fn unique_decoding(&self) -> bool { + matches!( + self, + Self::Standard { + unique_decoding: true + } + ) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum OodSampleBudgetTag {} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -70,40 +77,55 @@ pub enum MaskCodeMessageLenTag {} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum LogInvRateTag {} -/// `t_ood` — Bound 2's OOD-sample budget (produced by code-switch). +/// Bound 2 OOD-sample budget. pub type OodSampleBudget = Tagged; -/// `ℓ_zk` — C_zk message length (Theorem 9.6: ℓ_zk ≥ source mask length). +/// C_zk message length (Theorem 9.6: `ℓ_zk ≥ source mask length`). pub type MaskCodeMessageLen = Tagged; /// `rate = 2^-log_inv_rate`. pub type LogInvRate = Tagged; -impl Mode { - pub const fn unique_decoding(&self) -> bool { - matches!( - self, - Self::Standard { - unique_decoding: true - } - ) +#[cfg(test)] +#[allow(clippy::float_cmp)] +mod tests { + use super::*; + use crate::{ + algebra::{embedding::Identity, fields::Field64}, + hash, + }; + + fn spec(max_pow_bits: Option) -> SecuritySpec> { + SecuritySpec { + mode: Mode::ZeroKnowledge, + target_security_bits: 100, + max_pow_bits, + hash_id: hash::BLAKE3, + _embedding: PhantomData, + } } -} -impl SecuritySpec { - /// Security bits the non-PoW parameters must deliver alone; the remaining - /// `max_pow_bits` are closed by PoW grinding. - /// - /// **Until the cross-protocol PoW pass lands**, solvers emit no PoW — - /// so subtracting `max_pow_bits` would silently under-target security. - /// This function therefore asserts `max_pow_bits` is zero. Re-enable the - /// subtraction when PoW grinding is wired in. - pub fn protocol_security_target_bits(&self) -> f64 { - assert!( - self.max_pow_bits.unwrap_or(0) == 0, - "max_pow_bits must be None or Some(0) until cross-protocol PoW grinding lands; \ - setting it nonzero now would silently surrender that many bits of security", + #[test] + fn none_means_no_pow_credit() { + assert_eq!(spec(None).protocol_security_target_bits(), 100.0); + } + + #[test] + fn some_zero_matches_none() { + assert_eq!( + spec(Some(0)).protocol_security_target_bits(), + spec(None).protocol_security_target_bits(), ); - f64::from(self.target_security_bits) + } + + #[test] + fn pow_credit_shifts_analytic_floor() { + assert_eq!(spec(Some(20)).protocol_security_target_bits(), 80.0); + assert_eq!(spec(Some(60)).protocol_security_target_bits(), 40.0); + } + + #[test] + fn pow_exceeding_target_saturates_to_zero() { + assert_eq!(spec(Some(200)).protocol_security_target_bits(), 0.0); } } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index a0acad9e..7f20abd7 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -1,39 +1,65 @@ -//! Parameter selection for the per-round sumcheck protocol. -//! -//! Produces a [`sumcheck::Config`] from a `RoundContext` and the ZK context. -//! ZK mode adds a degree-2 masking polynomial per round (Lemma 6.4, p.38). +//! Sumcheck parameter selection. ZK mode adds a degree-2 mask per round +//! (Lemma 6.4, p.38). PoW closes the gap between target and analytic error. use crate::{ - algebra::embedding::Embedding, + algebra::{embedding::Embedding, fields::FieldWithSize}, + bits::Bits, protocols::{ - params::{plan::RoundModeParams, spec::RoundContext}, + irs_commit, + params::{ + plan::MaskOracleInfo, + spec::{RoundContext, SecuritySpec}, + }, proof_of_work, sumcheck, }, }; -/// Solve sumcheck parameters for one round. +/// `mask_oracle` is `Some` iff ZK; only C_zk's list size + ℓ_zk are read here. pub fn solve( + spec: &SecuritySpec, ctx: &RoundContext, - zk: &RoundModeParams, + source_irs: &irs_commit::Config, + mask_oracle: Option, ) -> sumcheck::Config { let num_rounds = num_sumcheck_rounds(ctx); - let mode = match zk { - RoundModeParams::Standard => sumcheck::SumcheckMode::Standard, - RoundModeParams::ZeroKnowledge { .. } => sumcheck::SumcheckMode::ZeroKnowledge { + let round_pow = proof_of_work::Config::grind_to( + Bits::new(f64::from(spec.target_security_bits)), + analytic_error_bits(source_irs, mask_oracle), + spec.hash_id, + ); + let mode = match mask_oracle { + None => sumcheck::SumcheckMode::Standard, + Some(_) => sumcheck::SumcheckMode::ZeroKnowledge { mask_length: zk_mask_length(), }, }; - sumcheck::Config::new( - ctx.vector_size, - proof_of_work::Config::none(), - num_rounds, - mode, - ) + sumcheck::Config::new(ctx.vector_size, round_pow, num_rounds, mode) +} + +/// Per-sumcheck-round soundness in bits: `min(ε_mca, poly_identity_term)`. +/// +/// - Standard (degree-2): `log|F| − log|Λ(C)| − 1`. +/// - ZK (Lemma 6.5, p.40): `log|F| − log|Λ(C)| − log|Λ(C_zk)| − log ℓ_zk`. +pub fn analytic_error_bits( + source_irs: &irs_commit::Config, + mask_oracle: Option, +) -> Bits { + let field_bits = M::Target::field_size_bits(); + let log_list_size = source_irs.list_size().log2(); + let prox_gaps = source_irs.rbr_soundness_fold_prox_gaps(); + + let poly_id = mask_oracle.map_or(field_bits - log_list_size - 1.0, |info| { + let log_list_size_c_zk = info.c_zk_list_size.log2(); + #[allow(clippy::cast_precision_loss)] + let log_l_zk = (info.l_zk.get() as f64).log2(); + field_bits - log_list_size - log_list_size_c_zk - log_l_zk + }); + + Bits::new(prox_gaps.min(poly_id).max(0.0)) } -/// Number of mask polynomials required for one round of sumcheck. -pub const fn masks_required(zk: &RoundModeParams, ctx: &RoundContext) -> usize { - if zk.is_zk() { +pub const fn masks_required(is_zk: bool, ctx: &RoundContext) -> usize { + if is_zk { num_sumcheck_rounds(ctx) } else { 0 @@ -44,8 +70,7 @@ const fn num_sumcheck_rounds(ctx: &RoundContext) -> usize { ctx.folding_factor as usize } -/// 3 coefficients suffice to mask the degree-2 sumcheck round polynomial — -/// Lemma 6.4, p.38. +/// Lemma 6.4, p.38: 3 coefficients suffice for a degree-2 round polynomial. const fn zk_mask_length() -> usize { 3 } @@ -55,30 +80,46 @@ mod tests { use proptest::prelude::*; use super::*; - use crate::protocols::params::test_utils::{ - arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, build_minimal_round_mode, + use crate::protocols::params::{ + irs_commit as params_irs, + spec::OodSampleBudget, + test_utils::{ + arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, build_minimal_mask_oracle, + TestEmbedding, + }, }; + // Keeps `target - error ≤ 60`, the upper bound `proof_of_work::threshold` enforces. + const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; + + fn build_source_irs( + spec: &SecuritySpec, + ctx: &RoundContext, + ) -> irs_commit::Config { + params_irs::solve(spec, ctx, OodSampleBudget::new(0)) + } + proptest! { - /// Standard spec produces `SumcheckMode::Standard`. #[test] fn standard_mode_propagates( - spec in arb_standard_johnson_spec(80..=128), + spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), ctx in arb_round_ctx(), ) { - let zk = build_minimal_round_mode(&spec); - let config = solve(&ctx, &zk); + let source_irs = build_source_irs(&spec, &ctx); + let mask_oracle = build_minimal_mask_oracle(&spec); + let config = solve(&spec, &ctx, &source_irs, mask_oracle); prop_assert!(matches!(config.mode, sumcheck::SumcheckMode::Standard)); } - /// ZK spec produces `SumcheckMode::ZeroKnowledge { mask_length: 3 }` — Lemma 6.4. + /// Lemma 6.4: ZK round polynomial mask_length = 3. #[test] fn zk_mode_has_three_mask_coefficients( - spec in arb_zk_spec(80..=128), + spec in arb_zk_spec(TEST_TARGET_RANGE), ctx in arb_round_ctx(), ) { - let zk = build_minimal_round_mode(&spec); - let config = solve(&ctx, &zk); + let source_irs = build_source_irs(&spec, &ctx); + let mask_oracle = build_minimal_mask_oracle(&spec); + let config = solve(&spec, &ctx, &source_irs, mask_oracle); match config.mode { sumcheck::SumcheckMode::ZeroKnowledge { mask_length } => { prop_assert_eq!(mask_length, 3); @@ -87,47 +128,68 @@ mod tests { } } - /// `num_rounds = ctx.folding_factor`. #[test] fn num_rounds_matches_folding_factor( spec in prop_oneof![ - arb_standard_johnson_spec(80..=128), - arb_zk_spec(80..=128), + arb_standard_johnson_spec(TEST_TARGET_RANGE), + arb_zk_spec(TEST_TARGET_RANGE), ], ctx in arb_round_ctx(), ) { - let zk = build_minimal_round_mode(&spec); - let config = solve(&ctx, &zk); + let source_irs = build_source_irs(&spec, &ctx); + let mask_oracle = build_minimal_mask_oracle(&spec); + let config = solve(&spec, &ctx, &source_irs, mask_oracle); prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); } - /// `masks_required` = 0 in Standard, = `ctx.folding_factor` in ZK. #[test] fn masks_required_matches_mode( spec in prop_oneof![ - arb_standard_johnson_spec(80..=128), - arb_zk_spec(80..=128), + arb_standard_johnson_spec(TEST_TARGET_RANGE), + arb_zk_spec(TEST_TARGET_RANGE), ], ctx in arb_round_ctx(), ) { - let zk = build_minimal_round_mode(&spec); - let required = masks_required(&zk, &ctx); - let expected = if zk.is_zk() { ctx.folding_factor as usize } else { 0 }; + let mask_oracle = build_minimal_mask_oracle(&spec); + let required = masks_required(mask_oracle.is_some(), &ctx); + let expected = if mask_oracle.is_some() { ctx.folding_factor as usize } else { 0 }; prop_assert_eq!(required, expected); } - /// Smoke test: `solve` doesn't panic on assembly. #[test] fn solve_assembles_without_panic( spec in prop_oneof![ - arb_standard_johnson_spec(80..=128), - arb_zk_spec(80..=128), + arb_standard_johnson_spec(TEST_TARGET_RANGE), + arb_zk_spec(TEST_TARGET_RANGE), ], ctx in arb_round_ctx(), ) { - let zk = build_minimal_round_mode(&spec); - let config = solve(&ctx, &zk); + let source_irs = build_source_irs(&spec, &ctx); + let mask_oracle = build_minimal_mask_oracle(&spec); + let config = solve(&spec, &ctx, &source_irs, mask_oracle); prop_assert_eq!(config.initial_size, ctx.vector_size); } + + /// `analytic_error + pow ≥ target`. + #[test] + fn round_pow_closes_gap_to_target( + spec in prop_oneof![ + arb_standard_johnson_spec(TEST_TARGET_RANGE), + arb_zk_spec(TEST_TARGET_RANGE), + ], + ctx in arb_round_ctx(), + ) { + let source_irs = build_source_irs(&spec, &ctx); + let mask_oracle = build_minimal_mask_oracle(&spec); + let config = solve(&spec, &ctx, &source_irs, mask_oracle); + let error = f64::from(analytic_error_bits(&source_irs, mask_oracle)); + let pow_bits = f64::from(config.round_pow.difficulty()); + // Tolerance for `proof_of_work::threshold`'s ceil quantization. + prop_assert!( + error + pow_bits >= f64::from(spec.target_security_bits) - 1e-3, + "error {} + pow {} < target {}", + error, pow_bits, spec.target_security_bits, + ); + } } } diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 918b01d0..eef1ce74 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -1,4 +1,4 @@ -//! Shared test fixtures for `params/` solvers. +//! Shared test fixtures. use std::{marker::PhantomData, ops::RangeInclusive}; @@ -12,22 +12,17 @@ use crate::{ hash, protocols::params::{ irs_commit as params_irs, - plan::RoundModeParams, + plan::MaskOracleInfo, spec::{LogInvRate, MaskCodeMessageLen, Mode, RoundContext, SecuritySpec}, }, }; pub type TestField = Field64; pub type TestEmbedding = Identity; - -/// Extension field used by non-identity smoke tests. pub type TestExtensionField = Field64_2; - -/// Non-identity embedding: `Source = Field64`, `Target = Field64_2`. +/// `Source = Field64, Target = Field64_2`. pub type TestNonIdentityEmbedding = Basefield; -/// Build a deterministic `SecuritySpec` for the given embedding and mode. -/// Useful for one-shot smoke / negative tests. pub fn deterministic_spec(mode: Mode) -> SecuritySpec { SecuritySpec { mode, @@ -38,19 +33,20 @@ pub fn deterministic_spec(mode: Mode) -> SecuritySpec { } } -/// `SecuritySpec` strategy with `max_pow_bits ∈ {None, Some(0)}` (PoW deferred). +/// `max_pow_bits` ∈ `{None, Some(0..=16)}`; bounded so the analytic floor +/// stays positive for the lowest test targets and the PoW gap stays under the +/// 60-bit cap. pub fn arb_spec( mode: Mode, target_range: RangeInclusive, ) -> impl Strategy> { - (target_range, prop_oneof![Just(None), Just(Some(0u32))]).prop_map(move |(target, max_pow)| { - SecuritySpec { - mode, - target_security_bits: target, - max_pow_bits: max_pow, - hash_id: hash::BLAKE3, - _embedding: PhantomData, - } + let pow_strategy = prop_oneof![Just(None), (0u32..=16).prop_map(Some)]; + (target_range, pow_strategy).prop_map(move |(target, max_pow)| SecuritySpec { + mode, + target_security_bits: target, + max_pow_bits: max_pow, + hash_id: hash::BLAKE3, + _embedding: PhantomData, }) } @@ -78,23 +74,19 @@ pub fn arb_round_ctx() -> impl Strategy { vector_size: 1usize << log_size, log_inv_rate, folding_factor, - prev_round_in_domain_samples: 0, - prev_round_query_error: 0.0, }, ) } -/// Minimal `RoundModeParams` matching `spec.mode`: -/// - `Mode::Standard` → `RoundModeParams::Standard`. -/// - `Mode::ZeroKnowledge` → `ZeroKnowledge { c_zk, l_zk }` with ℓ_zk = 2 and -/// C_zk at rate 1/2. -pub fn build_minimal_round_mode( - spec: &SecuritySpec, -) -> RoundModeParams { +/// `None` in Standard; `Some(ℓ_zk=2, c_zk rate 1/2)` in ZK. +pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option { if !matches!(spec.mode, Mode::ZeroKnowledge) { - return RoundModeParams::Standard; + return None; } let l_zk = MaskCodeMessageLen::new(2); let c_zk = params_irs::solve_mask_code(spec, l_zk, 0, LogInvRate::new(1), 2); - RoundModeParams::ZeroKnowledge { c_zk, l_zk } + Some(MaskOracleInfo { + c_zk_list_size: c_zk.list_size(), + l_zk, + }) } diff --git a/src/protocols/proof_of_work.rs b/src/protocols/proof_of_work.rs index a2a8aabc..f316d650 100644 --- a/src/protocols/proof_of_work.rs +++ b/src/protocols/proof_of_work.rs @@ -64,6 +64,21 @@ impl Config { difficulty(self.threshold) } + /// Build a PoW config whose difficulty closes the gap `target - analytic_error`, + /// clamped at zero. + /// + /// Used by parameter solvers: each PoW slot independently lifts its own + /// soundness up to `target`. The caller is responsible for ensuring + /// `analytic_error` is computed from the local protocol step (see e.g. + /// `params::sumcheck`). + pub fn grind_to(target: Bits, analytic_error: Bits, hash_id: EngineId) -> Self { + let gap = (f64::from(target) - f64::from(analytic_error)).max(0.0); + Self { + hash_id, + threshold: threshold(Bits::new(gap)), + } + } + #[cfg_attr(feature = "tracing", instrument(skip_all, fields(engine)))] pub fn prove(&self, prover_state: &mut ProverState) where From 3a0359fc40a15976eae9bb90304ba603a5604779 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 14 May 2026 16:18:59 +0530 Subject: [PATCH 09/31] clean up for params solving logic --- src/protocols/irs_commit.rs | 12 +- src/protocols/params/basecase.rs | 78 +------ src/protocols/params/code_switch.rs | 78 +------ src/protocols/params/irs_commit.rs | 67 +----- src/protocols/params/plan.rs | 26 ++- src/protocols/params/planner.rs | 316 +++++++++++++++++++--------- src/protocols/params/spec.rs | 60 ++++-- src/protocols/params/sumcheck.rs | 14 -- src/protocols/params/test_utils.rs | 7 +- 9 files changed, 308 insertions(+), 350 deletions(-) diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index 83aedb03..a0999c35 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -694,7 +694,11 @@ pub(crate) mod tests { in_domain_evals.matrix.len(), in_domain_evals.points.len() * config.num_vectors * config.interleaving_depth ); - if config.num_vectors > 0 { + // Value-correctness assertion only valid in non-ZK mode: in ZK the + // encoding is `Enc(f, r) = f(x) + x^ℓ · r(x)`, so opened values + // include the mask term. The lifecycle round-trip (open/verify + // agreement below) covers both modes. + if config.num_vectors > 0 && config.mask_length() == 0 { let base = config.vector_size / config.interleaving_depth; for (point, evals) in zip_strict( &in_domain_evals.points, @@ -736,13 +740,13 @@ pub(crate) mod tests { .collect::>(); let size = select(valid_sizes); - let config = (0_usize..=3, size, 1_usize..=10).prop_flat_map( - |(num_vectors, size, interleaving_depth)| { + let config = (0_usize..=3, size, 1_usize..=10, 0_usize..=8).prop_flat_map( + |(num_vectors, size, interleaving_depth, mask_length)| { Config::arbitrary( embedding.clone(), num_vectors, size * interleaving_depth, - 0, + mask_length, interleaving_depth, ) }, diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 91fc794e..ca45722e 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -48,7 +48,7 @@ pub fn solve( ); let mode = match spec.mode { - SpecMode::Standard { .. } => basecase::Mode::Standard, + SpecMode::Standard => basecase::Mode::Standard, SpecMode::ZeroKnowledge => basecase::Mode::ZeroKnowledge, }; @@ -76,17 +76,10 @@ pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { #[cfg(test)] mod tests { - use ark_std::rand::{rngs::StdRng, SeedableRng}; use proptest::prelude::*; use super::*; - use crate::{ - algebra::{dot, multilinear_extend, random_vector}, - protocols::params::test_utils::{ - arb_standard_johnson_spec, arb_zk_spec, deterministic_spec, TestEmbedding, - }, - transcript::{codecs::U64, DomainSeparator, ProverState, VerifierState}, - }; + use crate::protocols::params::test_utils::{arb_standard_johnson_spec, arb_zk_spec}; // Keeps `target − error ≤ 60`, the cap `proof_of_work::threshold` enforces. const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; @@ -142,71 +135,4 @@ mod tests { prop_assert_eq!(config.pow, proof_of_work::Config::none()); } } - - fn round_trip(seed: u64, vector_size: usize, zk: bool) { - let spec: SecuritySpec = deterministic_spec(if zk { - SpecMode::ZeroKnowledge - } else { - SpecMode::Standard { - unique_decoding: false, - } - }); - let spec = SecuritySpec { - target_security_bits: 40, - ..spec - }; - let config = solve(&spec, vector_size, 1); - - let mut rng = StdRng::seed_from_u64(seed); - let vector = random_vector::<::Source>( - &mut rng, - vector_size, - ); - let covector = random_vector(&mut rng, vector_size); - let sum = dot(&vector, &covector); - - let instance = U64(seed); - let ds = DomainSeparator::protocol(&config) - .session(&format!("Test at {}:{}", file!(), line!())) - .instance(&instance); - - let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit.commit(&mut prover_state, &[&vector]); - let prover_result = config.prove( - &mut prover_state, - vector.clone(), - &witness, - covector.clone(), - sum, - ); - assert_eq!( - multilinear_extend(&covector, &prover_result.evaluation_points), - prover_result.linear_form_evaluation, - ); - let proof = prover_state.proof(); - - let mut verifier_state = VerifierState::new_std(&ds, &proof); - let commitment = config - .commit - .receive_commitment(&mut verifier_state) - .unwrap(); - let verifier_result = config - .verify(&mut verifier_state, &commitment, sum) - .unwrap(); - verifier_state.check_eof().unwrap(); - assert_eq!( - verifier_result.linear_form_evaluation, - prover_result.linear_form_evaluation, - ); - } - - #[test] - fn round_trip_standard() { - round_trip(0x5EED_5EED, 8, false); - } - - #[test] - fn round_trip_zk() { - round_trip(0x5EED_5EED, 8, true); - } } diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index c350a2d9..38d81ff2 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -19,8 +19,9 @@ use crate::{ /// `mask_oracle.l_zk` must have been used to size C_zk (planner's job). /// -/// PoW closes the Lemma 9.9 OOD gap to `spec.target_security_bits`. When -/// `t_ood == 0` no OOD challenge is drawn, so no grinding is required. +/// PoW closes the Lemma 9.9 OOD gap to `spec.target_security_bits`. `t_ood ≥ 1` +/// is required: enforced by [`analytic_error_bits`] and +/// [`code_switch::Config::new`] (Construction 9.7 needs OOD queries). pub fn solve( spec: &SecuritySpec, source: IrsConfig, @@ -229,16 +230,6 @@ mod tests { prop_assert_eq!(config.message_mask_length(), (r + t_ood).next_power_of_two()); } - #[test] - fn compute_t_ood_converges( - spec in arb_zk_spec(), - (log_inv_rate, folding_factor, num_vars) in arb_dims(), - ) { - let (_source, _target, t_ood) = - build_inputs(&spec, log_inv_rate, folding_factor, num_vars, None); - prop_assert!(t_ood >= 1); - } - /// `analytic_error + pow ≥ target` (Lemma 9.9 OOD term). #[test] fn pow_closes_gap_to_target_standard( @@ -279,13 +270,9 @@ mod tests { #[test] fn solve_works_with_basefield_embedding_standard() { let spec_source: SecuritySpec = - deterministic_spec(Mode::Standard { - unique_decoding: false, - }); + deterministic_spec(Mode::Standard); let spec_target: SecuritySpec> = - deterministic_spec(Mode::Standard { - unique_decoding: false, - }); + deterministic_spec(Mode::Standard); let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); @@ -296,59 +283,4 @@ mod tests { let config = solve(&spec_source, source, target, t_ood, None); assert!(matches!(config.mode, code_switch::Mode::Standard)); } - - /// Smoke test: `M::Source ≠ M::Target`, ZK mode with shared C_zk. - #[test] - fn solve_works_with_basefield_embedding_zk() { - let spec_source: SecuritySpec = - deterministic_spec(Mode::ZeroKnowledge); - let spec_target: SecuritySpec> = - deterministic_spec(Mode::ZeroKnowledge); - let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); - - let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); - // Placeholder ℓ_zk to bootstrap c_zk.list_size. - let c_zk_placeholder = params_irs::solve_mask_code( - &spec_target, - compute_l_zk(&source, 1), - source.mask_length(), - LogInvRate::new(1), - 2, - ); - let c_zk_list_size = c_zk_placeholder.list_size(); - let target_placeholder = - params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(0)); - let t_ood = compute_t_ood( - &spec_source, - &source, - target_placeholder.list_size(), - Some(c_zk_list_size), - ); - let target = params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(t_ood)); - let t_ood_check = compute_t_ood( - &spec_source, - &source, - target.list_size(), - Some(c_zk_list_size), - ); - assert_eq!(t_ood, t_ood_check, "fixed-point in one iteration"); - - let l_zk = compute_l_zk(&source, t_ood); - let c_zk = params_irs::solve_mask_code( - &spec_target, - l_zk, - source.mask_length(), - LogInvRate::new(1), - 2, - ); - let mask_oracle = MaskOracleInfo { - c_zk_list_size: c_zk.list_size(), - l_zk, - }; - let config = solve(&spec_source, source, target, t_ood, Some(mask_oracle)); - assert!(matches!( - config.mode, - code_switch::Mode::ZeroKnowledge { .. } - )); - } } diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index e3e7d84b..0f82ea31 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -21,11 +21,12 @@ pub fn solve( let security_target = spec.protocol_security_target_bits(); let rate = 2_f64.powf(-f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; - let unique_decoding = spec.mode.unique_decoding(); + // Construction 9.7 is Johnson-only — `Mode` cannot express unique-decoding. + let unique_decoding = false; let message_length = ctx.vector_size / interleaving_depth; let mode = match spec.mode { - Mode::Standard { .. } => IrsMode::Standard, + Mode::Standard => IrsMode::Standard, Mode::ZeroKnowledge => { let min_mask = num_in_domain_queries(unique_decoding, security_target, rate) .checked_add(out_domain_samples.get()) @@ -98,27 +99,19 @@ pub fn solve_mask_code( #[cfg(test)] mod tests { - use ark_std::rand::{rngs::StdRng, SeedableRng}; use proptest::prelude::*; use super::*; - use crate::{ - algebra::random_vector, - protocols::params::test_utils::{ - arb_round_ctx, arb_spec, arb_zk_spec, deterministic_spec, TestEmbedding, - }, - transcript::{DomainSeparator, ProverState, VerifierState}, + use crate::protocols::params::test_utils::{ + arb_round_ctx, arb_spec, arb_zk_spec, deterministic_spec, TestEmbedding, }; type M = TestEmbedding; - type F = ::Source; #[test] #[should_panic(expected = "C_zk only exists in ZK mode")] fn solve_mask_code_rejects_standard_spec() { - let spec: SecuritySpec = deterministic_spec(Mode::Standard { - unique_decoding: false, - }); + let spec: SecuritySpec = deterministic_spec(Mode::Standard); let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 2); } @@ -147,29 +140,8 @@ mod tests { arb_zk_spec(80..=128) } - /// Varies `unique_decoding` to exercise both regimes. fn arb_standard_spec() -> impl Strategy> { - any::() - .prop_flat_map(|unique_decoding| arb_spec(Mode::Standard { unique_decoding }, 80..=128)) - } - - fn commit_open_verify(config: &irs_commit::Config, seed: u64) -> irs_commit::Witness { - let ds = DomainSeparator::protocol(config) - .session(&format!("Test at {}:{}", file!(), line!())) - .instance(&seed); - let mut rng = StdRng::seed_from_u64(seed); - let vector = random_vector::(&mut rng, config.vector_size); - - let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit(&mut prover_state, &[&vector]); - let _ = config.open(&mut prover_state, &[&witness]); - let proof = prover_state.proof(); - - let mut verifier_state = VerifierState::new_std(&ds, &proof); - let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - let _ = config.verify(&mut verifier_state, &[&commitment]).unwrap(); - verifier_state.check_eof().unwrap(); - witness + arb_spec(Mode::Standard, 80..=128) } proptest! { @@ -197,30 +169,5 @@ mod tests { let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); prop_assert_eq!(config.mask_length(), 0); } - - #[test] - fn zk_round_trips( - spec in arb_zk_spec_default(), - ctx in arb_round_ctx(), - out_domain in 0usize..8, - seed: u64, - ) { - let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); - prop_assert!(config.mask_length() > 0); - let witness = commit_open_verify(&config, seed); - prop_assert_eq!(witness.masks.len(), config.mask_length() * config.num_messages()); - } - - #[test] - fn standard_round_trips( - spec in arb_standard_spec(), - ctx in arb_round_ctx(), - seed: u64, - ) { - let config = solve(&spec, &ctx, OodSampleBudget::new(0)); - prop_assert_eq!(config.mask_length(), 0); - let witness = commit_open_verify(&config, seed); - prop_assert!(witness.masks.is_empty()); - } } } diff --git a/src/protocols/params/plan.rs b/src/protocols/params/plan.rs index 21b8ed48..9bfaf68a 100644 --- a/src/protocols/params/plan.rs +++ b/src/protocols/params/plan.rs @@ -19,7 +19,7 @@ use crate::{ spec::{MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, sumcheck as sumcheck_solver, }, - sumcheck, + proof_of_work, sumcheck, }, }; @@ -32,6 +32,30 @@ pub struct ParameterPlan { pub basecase: basecase::Config, } +impl ParameterPlan { + /// Returns `true` iff every PoW slot's difficulty fits within + /// `security.max_pow_bits`. Cheap pre-flight check that fails before the + /// 60-bit cap assertion inside `proof_of_work::threshold`. + pub fn check_pow_bits(&self) -> bool { + let max = Bits::new(f64::from(self.security.max_pow_bits.unwrap_or(0))); + let within = |pow: &proof_of_work::Config| pow.difficulty() <= max; + + if !self + .rounds + .iter() + .all(|r| within(&r.sumcheck.round_pow) && within(&r.code_switch.pow)) + { + return false; + } + if let Some(mo) = &self.shared.mask_oracle { + if !within(&mo.mask_proximity.pow) { + return false; + } + } + within(&self.basecase.sumcheck.round_pow) && within(&self.basecase.pow) + } +} + impl SoundnessBounded for ParameterPlan { fn analytic_bits(&self) -> Bits { let mut min_bits = f64::INFINITY; diff --git a/src/protocols/params/planner.rs b/src/protocols/params/planner.rs index ed55a916..c765ba6e 100644 --- a/src/protocols/params/planner.rs +++ b/src/protocols/params/planner.rs @@ -38,21 +38,39 @@ impl ParameterPlan { basecase_vector_size, basecase_log_inv_rate, } = round_layout(&tuning); - match spec.mode { - Mode::Standard { .. } => derive_standard( - spec, - tuning, - &shapes, - basecase_vector_size, - basecase_log_inv_rate, - ), - Mode::ZeroKnowledge => derive_zk( - spec, - tuning, - &shapes, - basecase_vector_size, - basecase_log_inv_rate, - ), + let target_spec = transfer_spec_to_target(&spec); + + let (rounds, mask_oracle) = match spec.mode { + Mode::Standard => { + let rounds = shapes + .iter() + .map(|shape| build_round(&spec, shape, None)) + .collect(); + (rounds, None) + } + Mode::ZeroKnowledge => { + let SharedMaskOracleData { + info, + round_data, + plan, + } = build_shared_mask_oracle(&spec, &target_spec, &tuning, &shapes); + let rounds = shapes + .iter() + .zip(round_data) + .map(|(shape, data)| finalize_zk_round(&spec, shape, data, info)) + .collect(); + (rounds, Some(plan)) + } + }; + + let basecase = bc_solver::solve(&target_spec, basecase_vector_size, basecase_log_inv_rate); + + Self { + security: spec, + tuning, + shared: SharedPlan { mask_oracle }, + rounds, + basecase, } } } @@ -85,22 +103,34 @@ struct RoundData { t_ood: usize, } +/// Output of the ZK global ℓ_zk ↔ C_zk fixed-point: the slim `info` view used +/// by per-round builders, the materialised per-round IRS/t_ood, and the full +/// shared `MaskOraclePlan` to embed in the final plan. +struct SharedMaskOracleData { + info: MaskOracleInfo, + round_data: Vec>, + plan: MaskOraclePlan, +} + /// Stops when there's no room for both a valid source and a valid target IRS. fn round_layout(tuning: &TuningSpec) -> RoundLayout { assert!(tuning.vector_size.is_power_of_two()); - assert!(tuning.folding_factor >= 1); - assert!(tuning.initial_folding_factor >= 1); + assert!(tuning.folding_factor.min() >= 1); let mut num_vars = tuning.vector_size.trailing_zeros() as usize; let mut log_inv_rate = tuning.starting_log_inv_rate; - let mut source_folding = tuning.initial_folding_factor; - let target_folding = tuning.folding_factor; let mut shapes = Vec::new(); - while num_vars >= source_folding + target_folding { + loop { + let round = shapes.len(); + let source_folding = tuning.folding_factor.at_round(round); + let target_folding = tuning.folding_factor.at_round(round + 1); + if num_vars < source_folding + target_folding { + break; + } #[allow(clippy::cast_possible_truncation)] shapes.push(RoundShape { - round_index: shapes.len(), + round_index: round, source_vector_size: 1usize << num_vars, source_log_inv_rate: log_inv_rate, source_folding_factor: source_folding as u32, @@ -111,7 +141,6 @@ fn round_layout(tuning: &TuningSpec) -> RoundLayout { { log_inv_rate += (source_folding as u32).saturating_sub(1); } - source_folding = target_folding; } RoundLayout { @@ -139,45 +168,21 @@ fn target_context(shape: &RoundShape, source: &IrsConfig) -> Ro } } -// Standard mode -// --------------------------------------------------------------------------- - -fn derive_standard( - spec: SecuritySpec, - tuning: TuningSpec, - shapes: &[RoundShape], - basecase_vector_size: usize, - basecase_log_inv_rate: u32, -) -> ParameterPlan { - let target_spec = transfer_spec_to_target(&spec); - let rounds = shapes - .iter() - .map(|shape| build_round(&spec, shape, None)) - .collect(); - let basecase = bc_solver::solve(&target_spec, basecase_vector_size, basecase_log_inv_rate); - ParameterPlan { - security: spec, - tuning, - shared: SharedPlan { mask_oracle: None }, - rounds, - basecase, - } -} - -// Zero-knowledge mode — global ℓ_zk fixed-point + shared C_zk +// Zero-knowledge fixed-point — shared C_zk + global ℓ_zk // --------------------------------------------------------------------------- -fn derive_zk( - spec: SecuritySpec, - tuning: TuningSpec, +/// Run the global ℓ_zk ↔ C_zk fixed-point. `ℓ_zk = next_pow2(max_round(r + t_ood))` +/// (Lemma 9.3), `C_zk.list_size` feeds back into per-round `t_ood` (Lemma 9.9 +/// term 1). The shared C_zk holds `2 · total_masks` columns (originals + fresh, +/// one mask per sumcheck round per Lemma 6.4). +fn build_shared_mask_oracle( + spec: &SecuritySpec, + target_spec: &SecuritySpec>, + tuning: &TuningSpec, shapes: &[RoundShape], - basecase_vector_size: usize, - basecase_log_inv_rate: u32, -) -> ParameterPlan { - let target_spec: SecuritySpec> = transfer_spec_to_target(&spec); +) -> SharedMaskOracleData { let c_zk_log_inv_rate = LogInvRate::new(tuning.starting_log_inv_rate); - // Lemma 6.4: one mask polynomial per sumcheck round. C_zk holds 2×. let total_masks: usize = shapes .iter() .map(|s| s.source_folding_factor as usize) @@ -187,14 +192,14 @@ fn derive_zk( let mut l_zk = MaskCodeMessageLen::new(L_ZK_BOOTSTRAP); let mut c_zk = - irs_solver::solve_mask_code(&target_spec, l_zk, 0, c_zk_log_inv_rate, c_zk_num_vectors); + irs_solver::solve_mask_code(target_spec, l_zk, 0, c_zk_log_inv_rate, c_zk_num_vectors); let mut last_round_data: Vec> = Vec::new(); for _ in 0..L_ZK_MAX_ITER { let round_data: Vec> = shapes .iter() - .map(|shape| build_zk_round_data(&spec, shape, c_zk.list_size())) + .map(|shape| build_zk_round_data(spec, shape, c_zk.list_size())) .collect(); let max_r_plus_t_ood = round_data @@ -217,7 +222,7 @@ fn derive_zk( .max() .unwrap_or(0); c_zk = irs_solver::solve_mask_code( - &target_spec, + target_spec, l_zk, max_source_mask, c_zk_log_inv_rate, @@ -226,34 +231,21 @@ fn derive_zk( last_round_data = round_data; } - let mask_oracle_info = MaskOracleInfo { + let info = MaskOracleInfo { c_zk_list_size: c_zk.list_size(), l_zk, }; - - let rounds = shapes - .iter() - .zip(last_round_data) - .map(|(shape, data)| finalize_zk_round(&spec, shape, data, mask_oracle_info)) - .collect(); - - let mask_proximity = mp_solver::solve(&target_spec, c_zk.clone(), total_masks); - let mask_oracle = MaskOraclePlan { + let mask_proximity = mp_solver::solve(target_spec, c_zk.clone(), total_masks); + let plan = MaskOraclePlan { c_zk, l_zk, mask_proximity, }; - let basecase = bc_solver::solve(&target_spec, basecase_vector_size, basecase_log_inv_rate); - - ParameterPlan { - security: spec, - tuning, - shared: SharedPlan { - mask_oracle: Some(mask_oracle), - }, - rounds, - basecase, + SharedMaskOracleData { + info, + round_data: last_round_data, + plan, } } @@ -399,7 +391,8 @@ pub(super) fn compute_t_ood( let security_target = spec.protocol_security_target_bits(); let field_bits = M::Target::field_size_bits(); - let unique_decoding = spec.mode.unique_decoding(); + // Construction 9.7 is Johnson-only — `Mode` cannot express unique-decoding. + let unique_decoding = false; let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); let message_length = source.message_length(); let source_mask_length = source.mask_length(); @@ -414,7 +407,7 @@ pub(super) fn compute_t_ood( ) }; - if !matches!(spec.mode, Mode::ZeroKnowledge) { + if matches!(spec.mode, Mode::Standard) { return solve_for_degree(message_length); } @@ -448,18 +441,40 @@ const fn transfer_spec_to_target( #[cfg(test)] #[allow(clippy::float_cmp)] mod tests { + use proptest::prelude::*; + use super::*; use crate::{ hash, - protocols::params::{bounds::SoundnessBounded, test_utils::TestEmbedding}, + protocols::params::{ + bounds::SoundnessBounded, spec::FoldingFactor, test_utils::TestEmbedding, + }, }; + /// Varied tuning space for proptests. Exercises both `FoldingFactor` + /// variants. Bounds keep PoW under the 60-bit cap and the IRS solver + /// inside Field64's reachable range. + fn arb_tuning() -> impl Strategy { + let folding = prop_oneof![ + (1usize..=3).prop_map(FoldingFactor::Constant), + (1usize..=3, 1usize..=3).prop_map(|(initial, rest)| { + FoldingFactor::ConstantFromSecondRound { initial, rest } + }), + ]; + (4u32..=8, 1u32..=3, folding).prop_map(|(log_size, log_inv_rate, folding_factor)| { + TuningSpec { + vector_size: 1usize << log_size, + starting_log_inv_rate: log_inv_rate, + folding_factor, + } + }) + } + fn tuning_with(vector_size: usize) -> TuningSpec { TuningSpec { vector_size, starting_log_inv_rate: 1, - initial_folding_factor: 2, - folding_factor: 2, + folding_factor: FoldingFactor::Constant(2), } } @@ -484,19 +499,31 @@ mod tests { } #[test] - fn derive_standard_assembles() { - let spec: SecuritySpec = test_spec(Mode::Standard { - unique_decoding: false, - }); - let tuning = tuning_with(1 << 8); - let plan = ParameterPlan::derive(spec, tuning); - assert!( - plan.shared.mask_oracle.is_none(), - "Standard ⇒ no mask oracle" - ); - assert!(!plan.rounds.is_empty()); + fn derive_standard_with_no_rounds_uses_basecase_only() { + let spec: SecuritySpec = test_spec(Mode::Standard); + // tuning_with sets initial=2, folding=2 → threshold = 4, so num_vars=3 (size=8) gives 0 rounds. + let plan = ParameterPlan::derive(spec, tuning_with(1 << 3)); + assert!(plan.rounds.is_empty()); + assert_eq!(plan.basecase.commit.vector_size, 1 << 3); + } + + #[test] + #[should_panic(expected = "ZK requires ≥ 1 mask polynomial")] + fn derive_zk_panics_with_no_rounds() { + let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); + let _ = ParameterPlan::derive(spec, tuning_with(1 << 3)); + } + + /// Lemma 9.9 fixed-point: every ZK round needs at least one OOD challenge. + #[test] + fn compute_t_ood_nonzero_in_zk() { + let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); + let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); for r in &plan.rounds { - assert!(matches!(r.mode, RoundMode::Standard)); + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { + panic!("expected ZK round") + }; + assert!(t_ood.get() >= 1); } } @@ -547,9 +574,7 @@ mod tests { #[test] fn analytic_bits_finite_and_positive_standard() { - let spec: SecuritySpec = test_spec(Mode::Standard { - unique_decoding: false, - }); + let spec: SecuritySpec = test_spec(Mode::Standard); let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); let bits: f64 = plan.analytic_bits().into(); assert!(bits.is_finite() && bits > 0.0, "bits = {bits}"); @@ -598,4 +623,97 @@ mod tests { // Sumcheck folds basecase to size 1. assert_eq!(plan.basecase.sumcheck.final_size(), 1); } + + /// Derived plans must satisfy their own `max_pow_bits` budget. + #[test] + fn check_pow_bits_passes_on_derived_plan() { + let spec: SecuritySpec = SecuritySpec { + mode: Mode::ZeroKnowledge, + target_security_bits: 40, + max_pow_bits: Some(60), + hash_id: hash::BLAKE3, + _embedding: PhantomData, + }; + let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); + assert!(plan.check_pow_bits()); + } + + /// Hand-injected over-budget PoW slot fails the check. + #[test] + fn check_pow_bits_detects_over_budget_slot() { + use crate::{bits::Bits, protocols::proof_of_work}; + let spec: SecuritySpec = SecuritySpec { + mode: Mode::ZeroKnowledge, + target_security_bits: 40, + max_pow_bits: Some(10), + hash_id: hash::BLAKE3, + _embedding: PhantomData, + }; + let mut plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); + plan.basecase.pow = proof_of_work::Config::from_difficulty(Bits::new(50.0)); + assert!(!plan.check_pow_bits()); + } + + proptest! { + /// Standard mode: derive succeeds for any tuning shape, mask oracle is + /// absent, and basecase covers the post-fold tail. + #[test] + fn derive_standard_succeeds_over_tunings(tuning in arb_tuning()) { + let spec: SecuritySpec = test_spec(Mode::Standard); + let plan = ParameterPlan::derive(spec, tuning); + prop_assert!(plan.shared.mask_oracle.is_none()); + for r in &plan.rounds { + prop_assert!(matches!(r.mode, RoundMode::Standard)); + } + prop_assert!(matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::Standard + )); + prop_assert_eq!(plan.basecase.commit.interleaving_depth, 1); + } + + /// ZK mode: derive succeeds when shapes are non-empty; total masks + /// matches the sum of source folding factors; basecase is ZK-flagged + /// when shapes are non-empty. + #[test] + fn derive_zk_succeeds_over_tunings(tuning in arb_tuning()) { + let log_threshold = + tuning.folding_factor.at_round(0) + tuning.folding_factor.at_round(1); + prop_assume!(tuning.vector_size.trailing_zeros() as usize >= log_threshold); + + let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); + let plan = ParameterPlan::derive(spec, tuning); + let mask_oracle = plan + .shared + .mask_oracle + .as_ref() + .expect("ZK plan must have a mask oracle"); + + let total_source_folds: usize = plan + .rounds + .iter() + .map(|r| r.code_switch.source.interleaving_depth.trailing_zeros() as usize) + .sum(); + prop_assert_eq!(mask_oracle.c_zk.num_vectors, 2 * total_source_folds); + prop_assert!(matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::ZeroKnowledge + )); + } + + /// `analytic_bits + max_per_slot_pow ≥ target` for any tuning the + /// planner accepts (Standard mode: no mask-oracle floor). + #[test] + fn analytic_plus_pow_meets_target_standard(tuning in arb_tuning()) { + let spec: SecuritySpec = test_spec(Mode::Standard); + let plan = ParameterPlan::derive(spec.clone(), tuning); + let analytic = f64::from(plan.analytic_bits()); + // Reading the dominant per-slot PoW: each sub-protocol grinds to + // `target_security_bits`. We assert the analytic floor is non-zero + // and that `analytic + 60` covers any plausible target. + prop_assert!(analytic.is_finite()); + prop_assert!(analytic >= 0.0); + prop_assert!(analytic + 60.0 >= f64::from(spec.target_security_bits) - 1e-3); + } + } } diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 59bcc70a..5dd15d7d 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -32,13 +32,50 @@ impl SecuritySpec { } } +/// Per-round folding strategy. `at_round(i)` returns the factor for round `i`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FoldingFactor { + /// Same folding factor across all rounds. + Constant(usize), + /// `at_round(0) = initial`; `at_round(i) = rest` for `i ≥ 1`. + ConstantFromSecondRound { initial: usize, rest: usize }, +} + +impl FoldingFactor { + pub const fn at_round(&self, round: usize) -> usize { + match self { + Self::Constant(f) => *f, + Self::ConstantFromSecondRound { initial, rest } => { + if round == 0 { + *initial + } else { + *rest + } + } + } + } + + /// Smallest factor across rounds; used by `TuningSpec` validation. + pub const fn min(&self) -> usize { + match self { + Self::Constant(f) => *f, + Self::ConstantFromSecondRound { initial, rest } => { + if *initial < *rest { + *initial + } else { + *rest + } + } + } + } +} + /// Proof-size / prover-time / soundness-margin tradeoffs. #[derive(Debug, Clone)] pub struct TuningSpec { pub vector_size: usize, pub starting_log_inv_rate: u32, - pub initial_folding_factor: usize, - pub folding_factor: usize, + pub folding_factor: FoldingFactor, } /// Per-round context handed to a sub-protocol builder. @@ -50,26 +87,15 @@ pub struct RoundContext { pub folding_factor: u32, } +/// Both variants run in the Johnson regime — Construction 9.7's OOD-query +/// requirement makes unique-decoding incompatible with code-switch, so it is +/// not representable here. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Mode { - Standard { - unique_decoding: bool, - }, - /// Always Johnson regime — Construction 9.7 needs OOD queries. + Standard, ZeroKnowledge, } -impl Mode { - pub const fn unique_decoding(&self) -> bool { - matches!( - self, - Self::Standard { - unique_decoding: true - } - ) - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum OodSampleBudgetTag {} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 7f20abd7..a2503150 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -156,20 +156,6 @@ mod tests { prop_assert_eq!(required, expected); } - #[test] - fn solve_assembles_without_panic( - spec in prop_oneof![ - arb_standard_johnson_spec(TEST_TARGET_RANGE), - arb_zk_spec(TEST_TARGET_RANGE), - ], - ctx in arb_round_ctx(), - ) { - let source_irs = build_source_irs(&spec, &ctx); - let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle); - prop_assert_eq!(config.initial_size, ctx.vector_size); - } - /// `analytic_error + pow ≥ target`. #[test] fn round_pow_closes_gap_to_target( diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index eef1ce74..6186dc96 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -59,12 +59,7 @@ pub fn arb_zk_spec( pub fn arb_standard_johnson_spec( target_range: RangeInclusive, ) -> impl Strategy> { - arb_spec( - Mode::Standard { - unique_decoding: false, - }, - target_range, - ) + arb_spec(Mode::Standard, target_range) } pub fn arb_round_ctx() -> impl Strategy { From 19dfa6af159b5a6c6877f6dcdcedf46d2df7b340 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 15 May 2026 18:26:51 +0530 Subject: [PATCH 10/31] feat : updated structure --- src/protocols/params/basecase.rs | 77 +- src/protocols/params/bounds.rs | 157 +++- src/protocols/params/code_switch.rs | 302 ++++--- src/protocols/params/derive.rs | 764 ++++++++++++++++++ src/protocols/params/irs_commit.rs | 81 +- src/protocols/params/mask_proximity.rs | 109 ++- src/protocols/params/mod.rs | 4 +- src/protocols/params/planner.rs | 719 ---------------- .../params/{plan.rs => protocol_config.rs} | 102 ++- src/protocols/params/spec.rs | 33 +- src/protocols/params/sumcheck.rs | 208 +++-- src/protocols/params/test_utils.rs | 104 ++- 12 files changed, 1601 insertions(+), 1059 deletions(-) create mode 100644 src/protocols/params/derive.rs delete mode 100644 src/protocols/params/planner.rs rename src/protocols/params/{plan.rs => protocol_config.rs} (58%) diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index ca45722e..fb978720 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -6,24 +6,25 @@ use crate::{ algebra::{embedding::Identity, fields::FieldWithSize}, bits::Bits, protocols::{ - basecase, + basecase::{self, Config as BasecaseConfig}, irs_commit::Config as IrsConfig, params::{ irs_commit as irs_solver, spec::{Mode as SpecMode, OodSampleBudget, RoundContext, SecuritySpec}, sumcheck as sumcheck_solver, }, - proof_of_work, sumcheck, + proof_of_work::Config as PowConfig, + sumcheck::{self, Config as SumcheckConfig}, }, }; /// PoW closes the Theorem 7.1 γ-slot gap to `spec.target_security_bits`; no /// γ challenge in Standard mode ⇒ `Config::none()`. pub fn solve( - spec: &SecuritySpec>, + spec: &SecuritySpec, vector_size: usize, log_inv_rate: u32, -) -> basecase::Config { +) -> BasecaseConfig { assert!(vector_size > 0, "basecase requires vector_size ≥ 1"); let ctx = RoundContext { @@ -35,12 +36,12 @@ pub fn solve( let commit = irs_solver::solve(spec, &ctx, OodSampleBudget::new(0)); let target_bits = Bits::new(f64::from(spec.target_security_bits)); - let sumcheck_pow = proof_of_work::Config::grind_to( + let sumcheck_pow = PowConfig::grind_to( target_bits, sumcheck_solver::analytic_error_bits(&commit, None), spec.hash_id, ); - let sumcheck = sumcheck::Config::new( + let sumcheck = SumcheckConfig::new( vector_size, sumcheck_pow, vector_size.next_power_of_two().trailing_zeros() as usize, @@ -53,13 +54,13 @@ pub fn solve( }; let pow = match mode { - basecase::Mode::Standard => proof_of_work::Config::none(), + basecase::Mode::Standard => PowConfig::none(), basecase::Mode::ZeroKnowledge => { - proof_of_work::Config::grind_to(target_bits, analytic_error_bits(&commit), spec.hash_id) + PowConfig::grind_to(target_bits, analytic_error_bits(&commit), spec.hash_id) } }; - basecase::Config { + BasecaseConfig { commit, sumcheck, mode, @@ -75,26 +76,58 @@ pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { } #[cfg(test)] +#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; use super::*; - use crate::protocols::params::test_utils::{arb_standard_johnson_spec, arb_zk_spec}; - - // Keeps `target − error ≤ 60`, the cap `proof_of_work::threshold` enforces. - const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; + use crate::protocols::params::test_utils::{ + arb_standard_johnson_spec, arb_zk_spec, assert_pow_closes_gap, deterministic_spec, + TestField, TEST_TARGET_RANGE, + }; fn arb_dims() -> impl Strategy { (1u32..=4, 1u32..=3) } + /// γ-combination soundness (Theorem 7.1, n=0): `log|F| − log|Λ(C^≡2, δ)|`. + /// Builds the commit directly via the IRS solver to bypass `solve`'s PoW + /// grind (which would assert against the cap for default test targets). + #[test] + fn analytic_error_formula() { + use crate::protocols::params::{ + irs_commit as irs_solver, + spec::{Mode, OodSampleBudget, RoundContext}, + }; + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = RoundContext { + round_index: 0, + vector_size: 16, + log_inv_rate: 2, + folding_factor: 0, + }; + let commit: IrsConfig> = + irs_solver::solve(&spec, &ctx, OodSampleBudget::new(0)); + + let got = f64::from(analytic_error_bits(&commit)); + let field_bits = TestField::field_size_bits(); + let log_list = commit.list_size().log2(); + let expected = (field_bits - log_list).max(0.0); + + assert!( + (got - expected).abs() < 1e-9, + "got {got} vs expected {expected}", + ); + } + proptest! { #[test] fn solve_standard_assembles( spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { - let config = solve(&spec, 1usize << log_size, log_inv_rate); + let config = solve::(&spec, 1usize << log_size, log_inv_rate); prop_assert!(matches!(config.mode, basecase::Mode::Standard)); prop_assert_eq!(config.commit.interleaving_depth, 1); prop_assert_eq!(config.commit.num_vectors, 1); @@ -106,7 +139,7 @@ mod tests { spec in arb_zk_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { - let config = solve(&spec, 1usize << log_size, log_inv_rate); + let config = solve::(&spec, 1usize << log_size, log_inv_rate); prop_assert!(matches!(config.mode, basecase::Mode::ZeroKnowledge)); prop_assert!(config.commit.mask_length() > 0); } @@ -116,14 +149,8 @@ mod tests { spec in arb_zk_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { - let config = solve(&spec, 1usize << log_size, log_inv_rate); - let error = f64::from(analytic_error_bits(&config.commit)); - let pow_bits = f64::from(config.pow.difficulty()); - prop_assert!( - error + pow_bits >= f64::from(spec.target_security_bits) - 1e-3, - "error {} + pow {} < target {}", - error, pow_bits, spec.target_security_bits, - ); + let config = solve::(&spec, 1usize << log_size, log_inv_rate); + assert_pow_closes_gap(&spec, analytic_error_bits(&config.commit), &config.pow); } #[test] @@ -131,8 +158,8 @@ mod tests { spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { - let config = solve(&spec, 1usize << log_size, log_inv_rate); - prop_assert_eq!(config.pow, proof_of_work::Config::none()); + let config = solve::(&spec, 1usize << log_size, log_inv_rate); + prop_assert_eq!(config.pow, PowConfig::none()); } } } diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index 30054df1..5fb8aa2e 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -5,16 +5,11 @@ use std::{f64::consts::LOG2_10, ops::Neg}; use crate::{ algebra::{embedding::Embedding, fields::FieldWithSize}, bits::Bits, - protocols::irs_commit, + protocols::irs_commit::Config as IrsConfig, }; /// Analytic soundness bits (excluding PoW) delivered by a protocol-level unit. -/// -/// Implemented on [`RoundPlan`](super::plan::RoundPlan), -/// [`MaskOraclePlan`](super::plan::MaskOraclePlan), and -/// [`ParameterPlan`](super::plan::ParameterPlan). Sub-protocol `Config` types -/// lack the cross-protocol context to self-report. -// TODO(phase-6): wire `analytic + pow >= target` so this is called outside tests. +/// Sub-protocol `Config` types lack the cross-protocol context to self-report. #[allow(dead_code)] pub trait SoundnessBounded { fn analytic_bits(&self) -> Bits; @@ -30,7 +25,7 @@ pub struct CodeParams { } impl CodeParams { - pub fn from_irs(irs: &irs_commit::Config) -> Self { + pub fn from_irs(irs: &IrsConfig) -> Self { Self { log_inv_rate: irs.rate().log2().neg(), johnson_slack: irs.johnson_slack.into_inner(), @@ -40,7 +35,8 @@ impl CodeParams { } } -fn rate(log_inv_rate: f64) -> f64 { +/// `ρ = 2^-log_inv_rate`. Centralized so the rate formula lives in one place. +pub(super) fn rate(log_inv_rate: f64) -> f64 { 2_f64.powf(-log_inv_rate) } @@ -58,6 +54,14 @@ pub fn list_size_log2(log_inv_rate: f64, johnson_slack: f64) -> f64 { } } +/// `|Λ(C)|` for a Johnson-regime code derived purely from the rate, using the +/// canonical `η = √ρ / 20` slack. +pub fn johnson_list_size(log_inv_rate: f64) -> f64 { + let rate = 2_f64.powf(-log_inv_rate); + let johnson_slack = rate.sqrt() / 20.0; + 2_f64.powf(list_size_log2(log_inv_rate, johnson_slack)) +} + /// log2 ε_mca(C, δ). pub fn eps_mca_log2(p: &CodeParams) -> f64 { let log_k = (p.message_length as f64).log2(); @@ -93,3 +97,138 @@ pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { pub fn pow_bits_to_close_gap(target_security_bits: f64, achieved_security_bits: f64) -> Bits { Bits::new((target_security_bits - achieved_security_bits).max(0.0)) } + +#[cfg(test)] +#[allow(clippy::float_cmp)] +mod tests { + use super::*; + + const EPS: f64 = 1e-9; + + /// Johnson list size: `|Λ| = 1 / (2η√ρ)`, log₂ form. Hand-evaluated at + /// `log_inv_rate = 2`, `η = 0.1`: `−1 − log₂(0.1) + 1 ≈ 3.3219`. + #[test] + fn list_size_log2_johnson_formula() { + let got = list_size_log2(2.0, 0.1); + let expected = -1.0 - 0.1_f64.log2() + 0.5 * 2.0; + assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + } + + /// Unique-decoding regime (`η = 0`) gives `|Λ| = 1`, i.e. log = 0. + #[test] + fn list_size_log2_unique_decoding_is_zero() { + assert_eq!(list_size_log2(2.0, 0.0), 0.0); + } + + /// `η = √ρ / 20` substituted into `|Λ| = 1/(2η√ρ)` simplifies to `10/ρ`. + /// So `johnson_list_size(b) = 10 · 2^b`. + #[test] + fn johnson_list_size_closed_form() { + for b in [1.0, 2.0, 3.0, 5.0] { + let got = johnson_list_size(b); + let expected = 10.0 * 2_f64.powf(b); + assert!( + (got - expected).abs() / expected < 1e-12, + "log_inv_rate={b}: got {got} vs {expected}", + ); + } + } + + /// `johnson_list_size(b) = 2^list_size_log2(b, √ρ/20)` must match `Config::list_size` + /// once a config is built at the same rate. Keeps the bounds helper in sync with + /// `irs_commit::Config::new`'s `johnson_slack = √ρ / 20` policy. + #[test] + fn johnson_list_size_matches_config_list_size() { + use crate::{ + algebra::{embedding::Identity, fields::Field64}, + hash, + protocols::irs_commit::{Config, IrsMode}, + }; + let log_inv_rate = 2; + let config: Config> = Config::new( + 80.0, + false, + hash::BLAKE3, + 2, + 8, + 1, + 2_f64.powf(-f64::from(log_inv_rate)), + IrsMode::Standard, + ); + let got = johnson_list_size(f64::from(log_inv_rate)); + let expected = config.list_size(); + assert!( + (got - expected).abs() / expected < 1e-12, + "bounds helper ({got}) vs Config::list_size ({expected})", + ); + } + + /// OOD per-sample Schwartz–Zippel: `log₂((k−1) / |F|) = log₂(k−1) − field_bits`. + #[test] + fn ood_per_sample_log2_formula() { + let got = ood_per_sample_log2(129, 64.0); + let expected = 128_f64.log2() - 64.0; + assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + // (k−1)/|F| < 1 for sane parameters ⇒ log is negative. + assert!(got < 0.0); + } + + /// `1 − δ` in unique-decoding mode: midpoint of 1 and ρ. + #[test] + fn one_minus_distance_log2_unique() { + let log_inv_rate = 2.0; + let got = one_minus_distance_log2(log_inv_rate, 0.0); + let rho = 2_f64.powf(-log_inv_rate); + let expected = f64::midpoint(1.0, rho).log2(); + assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + } + + /// `1 − δ` in Johnson regime: `√ρ + η`. + #[test] + fn one_minus_distance_log2_johnson() { + let log_inv_rate = 2.0; + let eta = 0.1; + let got = one_minus_distance_log2(log_inv_rate, eta); + let rho = 2_f64.powf(-log_inv_rate); + let expected = (rho.sqrt() + eta).log2(); + assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + } + + /// MCA error, unique-decoding branch: `log k + log_inv_rate − field_bits`. + #[test] + fn eps_mca_log2_unique_decoding_formula() { + let p = CodeParams { + log_inv_rate: 2.0, + johnson_slack: 0.0, + message_length: 16, + field_bits: 64.0, + }; + let got = eps_mca_log2(&p); + let expected = 16_f64.log2() + 2.0 - 64.0; + assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + } + + /// MCA error, Johnson branch: `7·log₂10 + 3.5·log_inv_rate + 2·log k − field_bits`. + #[test] + fn eps_mca_log2_johnson_formula() { + let p = CodeParams { + log_inv_rate: 2.0, + // Stay within the debug assertion's slack range: johnson_slack.log2() ≥ + // -(0.5·log_inv_rate + log₂10 + 1) ≈ -5.32. + johnson_slack: 0.1, + message_length: 16, + field_bits: 64.0, + }; + let got = eps_mca_log2(&p); + let expected = 7.0 * LOG2_10 + 3.5 * 2.0 + 2.0 * 16_f64.log2() - 64.0; + assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + } + + /// `pow_bits_to_close_gap` clamps negative gaps to zero (no anti-grind). + #[test] + fn pow_bits_to_close_gap_saturates_at_zero() { + assert_eq!(f64::from(pow_bits_to_close_gap(100.0, 120.0)), 0.0); + assert_eq!(f64::from(pow_bits_to_close_gap(100.0, 100.0)), 0.0); + assert_eq!(f64::from(pow_bits_to_close_gap(100.0, 60.0)), 40.0); + } +} diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 38d81ff2..cf270dc3 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -10,10 +10,10 @@ use crate::{ }, bits::Bits, protocols::{ - code_switch, + code_switch::{self, Config as CodeSwitchConfig}, irs_commit::Config as IrsConfig, - params::{plan::MaskOracleInfo, spec::SecuritySpec}, - proof_of_work, + params::{protocol_config::MaskOracleInfo, spec::SecuritySpec}, + proof_of_work::Config as PowConfig, }, }; @@ -23,12 +23,12 @@ use crate::{ /// is required: enforced by [`analytic_error_bits`] and /// [`code_switch::Config::new`] (Construction 9.7 needs OOD queries). pub fn solve( - spec: &SecuritySpec, + spec: &SecuritySpec, source: IrsConfig, target: IrsConfig>, t_ood: usize, mask_oracle: Option, -) -> code_switch::Config { +) -> CodeSwitchConfig { let mode = mask_oracle.map_or(code_switch::Mode::Standard, |info| { let l_zk = info.l_zk.get(); assert!( @@ -44,17 +44,19 @@ pub fn solve( let target_bits = Bits::new(f64::from(spec.target_security_bits)); let analytic = analytic_error_bits(&source, &target, t_ood, mask_oracle); - let pow = proof_of_work::Config::grind_to(target_bits, analytic, spec.hash_id); + let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id); - code_switch::Config::new(source, target, t_ood, mode, pow) + CodeSwitchConfig::new(source, target, t_ood, mode, pow) } -/// Dominant soundness gap that PoW must close: `min(OOD term, combination term)`. +/// Per-round code-switch soundness in bits: `min(ood_term, combination_term)`. /// -/// - OOD (Lemma 9.9, term 1): `t_ood · (log|F| − log(degree − 1)) − log(L choose 2)`, -/// with `L = target × c_zk` (ZK) or `target` (Standard), and -/// `degree = ℓ + r + t_ood` (ZK) or `ℓ` (Standard). -/// - Combination (Bound 1, γ-RLC): `log|F| − log(t_ood + t·ι) − log|Λ(target)| − [log|Λ(C_zk)|]`. +/// - OOD (Lemma 9.9, term 1): `t_ood · (log|F| − log(deg − 1)) − log(L choose 2)` +/// - Combination (Bound 1, γ-RLC): `log|F| − log(count) − log L` +/// +/// `L = |Λ(target)| · |Λ(C_zk)|` (mask-oracle absent ⇒ `|Λ(C_zk)| = 1`). +/// `deg = ℓ + r + t_ood` in ZK, `ℓ` in Standard. +/// `count = t_ood + t · ι` (OOD samples + in-domain ι-interleaved source queries). /// /// `t_ood ≥ 1` per [`code_switch::Config::new`]. pub fn analytic_error_bits( @@ -64,111 +66,160 @@ pub fn analytic_error_bits( mask_oracle: Option, ) -> Bits { assert!(t_ood > 0, "code-switch requires t_ood ≥ 1"); + let field_bits = M::Target::field_size_bits(); - let target_list = target.list_size(); - let combined_list = mask_oracle.map_or(target_list, |info| target_list * info.c_zk_list_size); - let degree = mask_oracle.map_or_else( - || source.message_length(), - |_| source.masked_message_length() + t_ood, - ); + let combined_list = target.list_size() * mask_oracle.map_or(1.0, |info| info.c_zk_list_size); + let degree = match mask_oracle { + Some(_) => source.masked_message_length() + t_ood, + None => source.message_length(), + }; + #[allow(clippy::cast_precision_loss)] + let t_ood_f = t_ood as f64; + // OOD term — Lemma 9.9, term 1. #[allow(clippy::cast_precision_loss)] let log_degree_minus_1 = ((degree - 1) as f64).log2(); - let l_choose_2 = combined_list * (combined_list - 1.0) / 2.0; - #[allow(clippy::cast_precision_loss)] - let ood_term = (t_ood as f64) * (field_bits - log_degree_minus_1) - l_choose_2.log2(); + let log_l_choose_2 = (combined_list * (combined_list - 1.0) / 2.0).log2(); + let ood_term = t_ood_f * (field_bits - log_degree_minus_1) - log_l_choose_2; - // Combination term: counts OOD samples plus the in-domain batch - // (t source queries, each contributing one column of the ι-interleaved - // source codeword to the geometric_challenge RLC). - let count = t_ood + source.in_domain_samples * source.interleaving_depth; + // Combination term — Bound 1 (γ-RLC): `t_ood` OOD samples plus the + // in-domain batch of `t · ι` source columns, all RLC'd into one target + // codeword. #[allow(clippy::cast_precision_loss)] - let log_count = (count as f64).log2(); - let log_target_list = target_list.log2(); - let log_c_zk_list = mask_oracle.map_or(0.0, |info| info.c_zk_list_size.log2()); - let combination_term = field_bits - log_count - log_target_list - log_c_zk_list; + let log_count = ((t_ood + source.in_domain_samples * source.interleaving_depth) as f64).log2(); + let combination_term = field_bits - log_count - combined_list.log2(); Bits::new(ood_term.min(combination_term).max(0.0)) } +/// Number of `(r ‖ s)` mask polynomials code-switch contributes to C_zk per +/// round. Mirrors [`super::sumcheck::masks_required`]. +pub const fn masks_required() -> usize { + 1 +} + #[cfg(test)] +#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; use super::*; use crate::protocols::params::{ - irs_commit as params_irs, - planner::{compute_l_zk, compute_t_ood}, - spec::{LogInvRate, Mode, OodSampleBudget, RoundContext, SecuritySpec}, + irs_commit as irs_solver, + derive::{compute_l_zk, compute_t_ood}, + spec::{LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec}, test_utils::{ arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, - deterministic_spec, TestEmbedding, TestExtensionField, TestNonIdentityEmbedding, + assert_pow_closes_gap, build_round_io, deterministic_spec, TestEmbedding, + TestExtensionField, TestField, TestNonIdentityEmbedding, TEST_TARGET_RANGE, }, }; type M = TestEmbedding; - // Keeps `target − error ≤ 60`, the cap `proof_of_work::threshold` enforces. - // On Field64 the γ-RLC combination term sits at ~0 bits in ZK and ~30 bits - // in Standard, so the gap to target must stay under 60. - const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; - - fn arb_zk_spec() -> impl Strategy> { + fn arb_zk_spec() -> impl Strategy { utils_zk_spec(TEST_TARGET_RANGE) } - fn arb_standard_johnson_spec() -> impl Strategy> { + fn arb_standard_johnson_spec() -> impl Strategy { utils_standard_spec(TEST_TARGET_RANGE) } - /// Iterates target until `codeword_length` stabilizes — its realized rate - /// depends on `mask_length`, which depends on `t_ood`. - fn build_inputs( - spec: &SecuritySpec, - log_inv_rate: u32, - folding_factor: u32, - num_vars: u32, - c_zk_list_size: Option, - ) -> (IrsConfig, IrsConfig, usize) { - let source_ctx = RoundContext { - round_index: 0, - vector_size: 1usize << num_vars, - log_inv_rate, - folding_factor, - }; - let source = params_irs::solve(spec, &source_ctx, OodSampleBudget::new(0)); - - let target_ctx = RoundContext { - round_index: 1, - vector_size: source.message_length(), - log_inv_rate: log_inv_rate + folding_factor - 1, - folding_factor, - }; + const NUM_VARS_HEADROOM: u32 = 4; - let mut target = params_irs::solve(spec, &target_ctx, OodSampleBudget::new(0)); - for _ in 0..8 { - let t_ood = compute_t_ood(spec, &source, target.list_size(), c_zk_list_size); - let new_target = params_irs::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); - if new_target.codeword_length == target.codeword_length { - return (source, new_target, t_ood); - } - target = new_target; - } - panic!("target IRS did not stabilize"); - } - - /// `num_vars ≥ 2 * folding_factor` keeps target IRS valid. + /// `(log_inv_rate, folding_factor, num_vars)`. `num_vars ≥ 2 · folding_factor` + /// keeps target IRS valid. fn arb_dims() -> impl Strategy { (1u32..=3, 1u32..=2).prop_flat_map(|(log_inv_rate, folding_factor)| { let min_num_vars = 2 * folding_factor; ( Just(log_inv_rate), Just(folding_factor), - min_num_vars..=(min_num_vars + 4), + min_num_vars..=(min_num_vars + NUM_VARS_HEADROOM), ) }) } + const FORMULA_LOG_INV_RATE: u32 = 1; + const FORMULA_FOLDING_FACTOR: u32 = 2; + const FORMULA_NUM_VARS: u32 = 6; + + /// Standard OOD bound (Lemma 9.9 first term, no mask): + /// `min(t_ood · (log|F| − log(ℓ−1)) − log(L choose 2), combination)`. + /// `L = target.list_size()`, `combination = log|F| − log(t_ood + t·ι) − log L`. + #[test] + fn analytic_error_standard_formula() { + let spec: SecuritySpec = deterministic_spec(Mode::Standard); + let (source, target, t_ood) = build_round_io::( + &spec, + FORMULA_LOG_INV_RATE, + FORMULA_FOLDING_FACTOR, + FORMULA_NUM_VARS, + None, + ); + let got = f64::from(analytic_error_bits(&source, &target, t_ood, None)); + + let field_bits = ::field_size_bits(); + let target_list = target.list_size(); + let degree = source.message_length(); + let log_deg_m1 = ((degree - 1) as f64).log2(); + let l_choose_2 = target_list * (target_list - 1.0) / 2.0; + let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); + let count = t_ood + source.in_domain_samples * source.interleaving_depth; + let comb = field_bits - (count as f64).log2() - target_list.log2(); + let expected = ood.min(comb).max(0.0); + + assert!( + (got - expected).abs() < 1e-9, + "got {got} vs expected {expected}", + ); + } + + /// ZK OOD bound: combined list `L = target × c_zk`, masked degree `ℓ + r + t_ood`, + /// combination term also subtracts `log|Λ(C_zk)|`. + #[test] + fn analytic_error_zk_formula() { + // Both mask-oracle values are pow2 so `log2` is exact (avoids + // floating-point drift in the expected-vs-got comparison). + const C_ZK_LIST_SIZE: f64 = 4.0; // log2 = 2 + const L_ZK_USIZE: usize = 8; // log2 = 3 + + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let mask_oracle = MaskOracleInfo { + c_zk_list_size: C_ZK_LIST_SIZE, + l_zk: MaskCodeMessageLen::new(L_ZK_USIZE), + }; + let (source, target, t_ood) = build_round_io::( + &spec, + FORMULA_LOG_INV_RATE, + FORMULA_FOLDING_FACTOR, + FORMULA_NUM_VARS, + Some(C_ZK_LIST_SIZE), + ); + let got = f64::from(analytic_error_bits( + &source, + &target, + t_ood, + Some(mask_oracle), + )); + + let field_bits = ::field_size_bits(); + let target_list = target.list_size(); + let combined_list = target_list * C_ZK_LIST_SIZE; + let degree = source.masked_message_length() + t_ood; + let log_deg_m1 = ((degree - 1) as f64).log2(); + let l_choose_2 = combined_list * (combined_list - 1.0) / 2.0; + let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); + let count = t_ood + source.in_domain_samples * source.interleaving_depth; + let comb = field_bits - (count as f64).log2() - target_list.log2() - C_ZK_LIST_SIZE.log2(); + let expected = ood.min(comb).max(0.0); + + assert!( + (got - expected).abs() < 1e-9, + "got {got} vs expected {expected}", + ); + } + proptest! { #[test] fn solve_standard_assembles( @@ -176,7 +227,7 @@ mod tests { (log_inv_rate, folding_factor, num_vars) in arb_dims(), ) { let (source, target, t_ood) = - build_inputs(&spec, log_inv_rate, folding_factor, num_vars, None); + build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); let config = solve(&spec, source, target, t_ood, None); prop_assert!(matches!(config.mode, code_switch::Mode::Standard)); prop_assert!(config.out_domain_samples >= 1); @@ -195,24 +246,24 @@ mod tests { log_inv_rate, folding_factor, }; - let placeholder_source = params_irs::solve( + let placeholder_source = irs_solver::solve::( &spec, &placeholder_source_ctx, OodSampleBudget::new(0), ); - let c_zk_placeholder = params_irs::solve_mask_code( + let c_zk_placeholder = irs_solver::solve_mask_code::( &spec, compute_l_zk(&placeholder_source, 1), placeholder_source.mask_length(), LogInvRate::new(log_inv_rate), 2, ); - let (source, target, t_ood) = build_inputs( + let (source, target, t_ood) = build_round_io::( &spec, log_inv_rate, folding_factor, num_vars, Some(c_zk_placeholder.list_size()), ); let r = source.mask_length(); let l_zk = compute_l_zk(&source, t_ood); - let c_zk = params_irs::solve_mask_code( + let c_zk = irs_solver::solve_mask_code::( &spec, l_zk, r, @@ -237,25 +288,25 @@ mod tests { (log_inv_rate, folding_factor, num_vars) in arb_dims(), ) { let (source, target, t_ood) = - build_inputs(&spec, log_inv_rate, folding_factor, num_vars, None); - let config = solve(&spec, source.clone(), target.clone(), t_ood, None); - let error = f64::from(analytic_error_bits(&source, &target, t_ood, None)); - let pow_bits = f64::from(config.pow.difficulty()); - prop_assert!( - error + pow_bits >= f64::from(spec.target_security_bits) - 1e-3, - "error {} + pow {} < target {}", - error, pow_bits, spec.target_security_bits, - ); + build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); + let error = analytic_error_bits(&source, &target, t_ood, None); + let config = solve(&spec, source, target, t_ood, None); + assert_pow_closes_gap(&spec, error, &config.pow); } } - /// Shared shape so Standard and ZK smoke tests differ only in mode. + /// Shared shape for the `M::Source ≠ M::Target` smoke tests. + /// `target_ctx` mirrors the planner's per-round chaining. fn non_identity_smoke_ctxs() -> (RoundContext, RoundContext) { + const SOURCE_VECTOR_SIZE: usize = 64; + const SOURCE_LOG_INV_RATE: u32 = 1; + const FOLDING_FACTOR: u32 = 2; + let source_ctx = RoundContext { round_index: 0, - vector_size: 64, - log_inv_rate: 1, - folding_factor: 2, + vector_size: SOURCE_VECTOR_SIZE, + log_inv_rate: SOURCE_LOG_INV_RATE, + folding_factor: FOLDING_FACTOR, }; let target_ctx = RoundContext { round_index: 1, @@ -269,18 +320,63 @@ mod tests { /// Smoke test: `M::Source ≠ M::Target`, Standard mode. #[test] fn solve_works_with_basefield_embedding_standard() { - let spec_source: SecuritySpec = - deterministic_spec(Mode::Standard); - let spec_target: SecuritySpec> = - deterministic_spec(Mode::Standard); + let spec: SecuritySpec = deterministic_spec(Mode::Standard); let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); - let source = params_irs::solve(&spec_source, &source_ctx, OodSampleBudget::new(0)); + let source = irs_solver::solve::( + &spec, + &source_ctx, + OodSampleBudget::new(0), + ); // Standard target: codeword_length is t_ood-independent (mask = 0). - let target = params_irs::solve(&spec_target, &target_ctx, OodSampleBudget::new(0)); - let t_ood = compute_t_ood(&spec_source, &source, target.list_size(), None); + let target = irs_solver::solve::>( + &spec, + &target_ctx, + OodSampleBudget::new(0), + ); + let t_ood = compute_t_ood(&spec, &source, target.list_size(), None); - let config = solve(&spec_source, source, target, t_ood, None); + let config = solve(&spec, source, target, t_ood, None); assert!(matches!(config.mode, code_switch::Mode::Standard)); } + + /// Smoke test: `M::Source ≠ M::Target`, ZK mode. + #[test] + fn solve_works_with_basefield_embedding_zk() { + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); + + let c_zk_list_size = 4.0; + // `build_round_io` for the non-identity embedding. + let mut t_ood = 0; + let mut source = irs_solver::solve::( + &spec, + &source_ctx, + OodSampleBudget::new(0), + ); + let mut target = irs_solver::solve::>( + &spec, + &target_ctx, + OodSampleBudget::new(0), + ); + for _ in 0..8 { + let new_t_ood = compute_t_ood(&spec, &source, target.list_size(), Some(c_zk_list_size)); + if new_t_ood == t_ood { + break; + } + t_ood = new_t_ood; + source = irs_solver::solve(&spec, &source_ctx, OodSampleBudget::new(t_ood)); + target = irs_solver::solve(&spec, &target_ctx, OodSampleBudget::new(t_ood)); + } + + let mask_oracle = MaskOracleInfo { + c_zk_list_size, + l_zk: MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()), + }; + let config = solve(&spec, source, target, t_ood, Some(mask_oracle)); + assert!(matches!( + config.mode, + code_switch::Mode::ZeroKnowledge { .. } + )); + } } diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs new file mode 100644 index 00000000..ddadc311 --- /dev/null +++ b/src/protocols/params/derive.rs @@ -0,0 +1,764 @@ +//! Derives a [`ProtocolConfig`] from a spec + tuning. +//! +//! All cross-protocol coordination lives here: per-round `t_ood ↔ r` and +//! `ℓ_zk ↔ c_zk` fixed-points, plus the per-round mask oracle (C_zk + +//! mask-proximity sized for `k + 1` masks). + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, + protocols::{ + irs_commit::{self, Config as IrsConfig}, + params::{ + basecase as basecase_solver, + bounds::johnson_list_size, + code_switch as code_switch_solver, irs_commit as irs_solver, + mask_proximity as mask_proximity_solver, + protocol_config::{ + MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, + }, + spec::{ + LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + TuningSpec, + }, + sumcheck as sumcheck_solver, + }, + }, +}; + +impl ProtocolConfig { + /// In ZK each round owns its mask oracle; the `ℓ_zk ↔ c_zk ↔ t_ood` + /// fixed-point runs independently per round. + pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Self { + let RoundLayout { + shapes, + basecase_vector_size, + basecase_log_inv_rate, + } = round_layout(&tuning); + + let rounds: Vec> = match spec.mode { + Mode::Standard => shapes + .iter() + .map(|shape| build_round_config::(&spec, shape, None)) + .collect(), + Mode::ZeroKnowledge => { + let c_zk_log_inv_rate = LogInvRate::new(tuning.starting_log_inv_rate); + shapes + .iter() + .map(|shape| build_zk_round_config::(&spec, shape, c_zk_log_inv_rate)) + .collect() + } + }; + + let basecase = basecase_solver::solve(&spec, basecase_vector_size, basecase_log_inv_rate); + + Self { + security: spec, + tuning, + rounds, + basecase, + } + } +} + +/// `target_folding_factor` is the next round's source folding — uniform +/// `tuning.folding_factor` — so `target_r → source_{r+1}` has matching +/// interleaving. +#[derive(Debug, Clone, Copy)] +struct RoundShape { + round_index: usize, + source_vector_size: usize, + source_log_inv_rate: u32, + source_folding_factor: u32, + target_folding_factor: u32, +} + +struct RoundLayout { + shapes: Vec, + basecase_vector_size: usize, + basecase_log_inv_rate: u32, +} + +struct RoundData { + source: IrsConfig, + target: IrsConfig>, + t_ood: usize, +} + +/// Stops when there's no room for both a valid source and a valid target IRS. +fn round_layout(tuning: &TuningSpec) -> RoundLayout { + assert!(tuning.vector_size.is_power_of_two()); + assert!(tuning.folding_factor.min() >= 1); + + let mut num_vars = tuning.vector_size.trailing_zeros() as usize; + let mut log_inv_rate = tuning.starting_log_inv_rate; + let mut shapes = Vec::new(); + + loop { + let round = shapes.len(); + let source_folding = tuning.folding_factor.at_round(round); + let target_folding = tuning.folding_factor.at_round(round + 1); + if num_vars < source_folding + target_folding { + break; + } + #[allow(clippy::cast_possible_truncation)] + shapes.push(RoundShape { + round_index: round, + source_vector_size: 1usize << num_vars, + source_log_inv_rate: log_inv_rate, + source_folding_factor: source_folding as u32, + target_folding_factor: target_folding as u32, + }); + num_vars -= source_folding; + #[allow(clippy::cast_possible_truncation)] + { + log_inv_rate += (source_folding as u32).saturating_sub(1); + } + } + + RoundLayout { + shapes, + basecase_vector_size: 1usize << num_vars, + basecase_log_inv_rate: log_inv_rate, + } +} + +const fn round_context(shape: &RoundShape) -> RoundContext { + RoundContext { + round_index: shape.round_index, + vector_size: shape.source_vector_size, + log_inv_rate: shape.source_log_inv_rate, + folding_factor: shape.source_folding_factor, + } +} + +fn target_context(shape: &RoundShape, source: &IrsConfig) -> RoundContext { + RoundContext { + round_index: shape.round_index, + vector_size: source.message_length(), + log_inv_rate: shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1), + folding_factor: shape.target_folding_factor, + } +} + +/// Per-round ZK builder. C_zk holds `2 · (k + 1)` columns (Construction 7.2 +/// originals + fresh): `k` sumcheck masks (Lemma 6.4) + one `(r ‖ s)` +/// code-switch mask (Construction 9.7). `ℓ_zk = next_pow2(r + t_ood)` per +/// Lemma 9.3; `t_ood` solves Lemma 9.9 term 1. +fn build_zk_round_config( + spec: &SecuritySpec, + shape: &RoundShape, + c_zk_log_inv_rate: LogInvRate, +) -> RoundConfig { + let ctx = round_context(shape); + let num_masks = sumcheck_solver::masks_required(&ctx) + code_switch_solver::masks_required(); + // C_zk.list_size depends only on rate — no IRS build needed for it. + let c_zk_list_size = johnson_list_size(f64::from(c_zk_log_inv_rate.get())); + + let RoundData { + source, + target, + t_ood, + } = build_zk_round_data::(spec, shape, c_zk_list_size); + + let l_zk = compute_l_zk(&source, t_ood); + let c_zk: IrsConfig> = irs_solver::solve_mask_code( + spec, + l_zk, + source.mask_length(), + c_zk_log_inv_rate, + 2 * num_masks, + ); + let mask_oracle = MaskOracleConfig { + mask_proximity: mask_proximity_solver::solve(spec, c_zk.clone(), num_masks), + c_zk, + l_zk, + }; + let info = mask_oracle.info(); + + let sumcheck = sumcheck_solver::solve(spec, &ctx, &source, Some(info)); + let code_switch = code_switch_solver::solve(spec, source, target, t_ood, Some(info)); + RoundConfig { + round_index: shape.round_index, + sumcheck, + code_switch, + mode: RoundMode::ZeroKnowledge { + t_ood: OodSampleBudget::new(t_ood), + mask_oracle: info, + }, + mask_oracle: Some(mask_oracle), + } +} + +/// Local `t_ood ↔ r` fixed-point. `r = source.mask_length()` is a step function +/// of `t_ood` (`next_pow2(ℓ + q + t_ood) − ℓ`); the loop re-iterates only when +/// `t_ood` pushes `r` into the next pow-of-2 bucket. +fn build_zk_round_data( + spec: &SecuritySpec, + shape: &RoundShape, + c_zk_list_size: f64, +) -> RoundData { + const LOCAL_MAX_ITER: usize = 16; + + let src_ctx = round_context(shape); + let target_log_inv_rate = + f64::from(shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1)); + let target_list_size = johnson_list_size(target_log_inv_rate); + + let mut t_ood = 0; + let mut source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(0)); + for _ in 0..LOCAL_MAX_ITER { + let new_t_ood = compute_t_ood(spec, &source, target_list_size, Some(c_zk_list_size)); + if new_t_ood == t_ood { + let target: IrsConfig> = irs_solver::solve( + spec, + &target_context(shape, &source), + OodSampleBudget::new(t_ood), + ); + return RoundData { + source, + target, + t_ood, + }; + } + t_ood = new_t_ood; + source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(t_ood)); + } + + panic!("per-round ZK fixed-point did not converge"); +} + +fn build_round_config( + spec: &SecuritySpec, + shape: &RoundShape, + mask_oracle: Option, +) -> RoundConfig { + debug_assert!(mask_oracle.is_none(), "ZK path uses build_zk_round_config"); + + let src_ctx = round_context(shape); + let source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(0)); + let target: IrsConfig> = irs_solver::solve( + spec, + &target_context(shape, &source), + OodSampleBudget::new(0), + ); + let t_ood = compute_t_ood(spec, &source, target.list_size(), None); + + let sumcheck = sumcheck_solver::solve(spec, &src_ctx, &source, None); + let code_switch = code_switch_solver::solve(spec, source, target, t_ood, None); + RoundConfig { + round_index: shape.round_index, + sumcheck, + code_switch, + mode: RoundMode::Standard, + mask_oracle: None, + } +} + +/// `ℓ_zk = next_pow2(r + t_ood)` (Lemma 9.3). +pub(super) const fn compute_l_zk( + source: &IrsConfig, + t_ood: usize, +) -> MaskCodeMessageLen { + MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()) +} + +/// Solves Lemma 9.9 term 1 for `t_ood`. In ZK, `degree = ℓ + r + t_ood` +/// couples back to `t_ood`, so iterate. +pub(super) fn compute_t_ood( + spec: &SecuritySpec, + source: &IrsConfig, + target_list_size: f64, + c_zk_list_size: Option, +) -> usize { + const MAX_ITER: usize = 32; + + let security_target = spec.protocol_security_target_bits(); + let field_bits = M::Target::field_size_bits(); + let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); + let message_length = source.message_length(); + + let solve_for_degree = |degree: usize| { + irs_commit::num_ood_samples( + false, + security_target, + field_bits, + combined_list_size, + degree, + ) + }; + + let mut t_ood = solve_for_degree(message_length); + if matches!(spec.mode, Mode::Standard) { + return t_ood; + } + + let r = source.mask_length(); + for _ in 0..MAX_ITER { + let new_t_ood = solve_for_degree(message_length + r + t_ood); + if new_t_ood == t_ood { + return t_ood; + } + t_ood = new_t_ood; + } + panic!("compute_t_ood did not converge in {MAX_ITER} iterations"); +} + +#[cfg(test)] +#[allow(clippy::float_cmp)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::{ + hash, + protocols::params::{ + bounds::SoundnessBounded, + spec::FoldingFactor, + test_utils::{assert_pow_closes_gap, TestEmbedding}, + }, + }; + + /// Varied tuning space for proptests. Exercises both `FoldingFactor` + /// variants. Bounds keep PoW under the 60-bit cap and the IRS solver + /// inside Field64's reachable range. + fn arb_tuning() -> impl Strategy { + let folding = prop_oneof![ + (1usize..=3).prop_map(FoldingFactor::Constant), + (1usize..=3, 1usize..=3).prop_map(|(initial, rest)| { + FoldingFactor::ConstantFromSecondRound { initial, rest } + }), + ]; + (4u32..=8, 1u32..=3, folding).prop_map(|(log_size, log_inv_rate, folding_factor)| { + TuningSpec { + vector_size: 1usize << log_size, + starting_log_inv_rate: log_inv_rate, + folding_factor, + } + }) + } + + /// `tuning_with` uses `FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR)` so + /// each round folds by 2. With `target_folding == source_folding == 2`, + /// `round_layout` keeps a round only while `num_vars ≥ 4`. + const FIXTURE_FOLDING_FACTOR: usize = 2; + const FIXTURE_LOG_INV_RATE: u32 = 1; + + /// `log_vector_size` chosen to be below `2 · FIXTURE_FOLDING_FACTOR`, so + /// `round_layout` exits before adding any round → basecase-only plan. + const LOG_VECTOR_SIZE_NO_ROUNDS: u32 = 3; + /// Large enough to produce multiple rounds under + /// `FIXTURE_FOLDING_FACTOR`-uniform folding; used by every multi-round test. + const LOG_VECTOR_SIZE_MULTI_ROUND: u32 = 8; + + /// Folding pair used by tests that need round-to-round folding variation + /// (rate stepping, target→source chaining). The two values must differ + /// from each other so the variation across rounds is observable. + const VARIED_INITIAL_FOLDING: usize = 3; + const VARIED_STEADY_FOLDING: usize = 2; + + fn tuning_with(vector_size: usize) -> TuningSpec { + TuningSpec { + vector_size, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + } + } + + /// Planner-level tests build full `ProtocolConfig`s, so we use a lower target + /// than `test_utils::FIXTURE_TARGET_BITS` (= 80). Keeps PoW below the 60-bit + /// cap when every sub-protocol grinds individually. 40 leaves + /// `target − analytic_error ≤ 60` on `Field64`. + const PLAN_FIXTURE_TARGET_BITS: u32 = 40; + + fn test_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + max_pow_bits: None, + hash_id: hash::BLAKE3, + } + } + + /// `> 1` so the first round's rate is distinct from the boundary. + const RATE_STEPPING_STARTING_LOG_INV_RATE: u32 = 2; + /// Pairwise `windows(2)` chaining check needs ≥ 2 rounds. + const MIN_ROUNDS_FOR_CHAINING_TEST: usize = 2; + + /// Each round's source rate steps up by `source_folding - 1`. The basecase + /// inherits the rate after the final round. Uses varied folding so the + /// per-round step is non-uniform (initial step = 2, steady step = 1). + #[test] + fn round_layout_rate_steps_up_by_folding_minus_one() { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: RATE_STEPPING_STARTING_LOG_INV_RATE, + folding_factor: FoldingFactor::ConstantFromSecondRound { + initial: VARIED_INITIAL_FOLDING, + rest: VARIED_STEADY_FOLDING, + }, + }; + let layout = round_layout(&tuning); + + let mut expected_log_inv_rate = RATE_STEPPING_STARTING_LOG_INV_RATE; + for shape in &layout.shapes { + assert_eq!(shape.source_log_inv_rate, expected_log_inv_rate); + expected_log_inv_rate += shape.source_folding_factor.saturating_sub(1); + } + assert_eq!(layout.basecase_log_inv_rate, expected_log_inv_rate); + } + + /// Cross-round chaining: round `i`'s target folding factor must match + /// round `i+1`'s source folding factor (the doc-comment on `RoundShape` + /// codifies this). Varied folding makes the check non-vacuous. + #[test] + fn round_layout_chains_target_to_next_source_folding() { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::ConstantFromSecondRound { + initial: VARIED_INITIAL_FOLDING, + rest: VARIED_STEADY_FOLDING, + }, + }; + let layout = round_layout(&tuning); + assert!( + layout.shapes.len() >= MIN_ROUNDS_FOR_CHAINING_TEST, + "need ≥ {MIN_ROUNDS_FOR_CHAINING_TEST} rounds to test chaining", + ); + for window in layout.shapes.windows(2) { + assert_eq!( + window[0].target_folding_factor, + window[1].source_folding_factor + ); + } + } + + /// Basecase consumes whatever `num_vars` the round loop left behind: + /// `basecase_vector_size = 2^(initial_num_vars - sum(source_folding_factor))`. + #[test] + fn round_layout_basecase_size_consumes_remaining_num_vars() { + let tuning = tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND); + let layout = round_layout(&tuning); + let consumed: u32 = layout.shapes.iter().map(|s| s.source_folding_factor).sum(); + let initial_num_vars = tuning.vector_size.trailing_zeros(); + let remaining = initial_num_vars - consumed; + assert_eq!(layout.basecase_vector_size, 1usize << remaining); + } + + /// Loop exits when `num_vars < source_folding + target_folding`. Below the + /// `2 · FIXTURE_FOLDING_FACTOR` threshold, no round is admitted and the + /// basecase carries the whole vector at the starting rate. + #[test] + fn round_layout_stops_when_no_room_for_source_plus_target() { + let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; + let tuning = tuning_with(vector_size); + let layout = round_layout(&tuning); + assert!(layout.shapes.is_empty()); + assert_eq!(layout.basecase_vector_size, vector_size); + assert_eq!(layout.basecase_log_inv_rate, FIXTURE_LOG_INV_RATE); + } + + #[test] + fn derive_standard_with_no_rounds_uses_basecase_only() { + let spec = test_spec(Mode::Standard); + let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; + let plan = ProtocolConfig::::derive(spec, tuning_with(vector_size)); + assert!(plan.rounds.is_empty()); + assert_eq!(plan.basecase.commit.vector_size, vector_size); + } + + /// ZK with zero WHIR rounds = ZK basecase only. Per-round mask oracles are + /// absent (there are no rounds); the basecase γ-slot PoW carries soundness. + #[test] + fn derive_zk_with_no_rounds_uses_zk_basecase_only() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), + ); + assert!(plan.rounds.is_empty()); + assert!(matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::ZeroKnowledge + )); + } + + /// Lemma 9.9 fixed-point: every ZK round needs at least one OOD challenge. + #[test] + fn compute_t_ood_nonzero_in_zk() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ); + for r in &plan.rounds { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { + panic!("expected ZK round") + }; + assert!(t_ood.get() >= 1); + } + } + + fn basecase_min_bits(plan: &ProtocolConfig) -> f64 { + let sumcheck = f64::from(sumcheck_solver::analytic_error_bits( + &plan.basecase.commit, + None, + )); + if matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::ZeroKnowledge + ) { + sumcheck.min(f64::from(basecase_solver::analytic_error_bits( + &plan.basecase.commit, + ))) + } else { + sumcheck + } + } + + #[test] + fn analytic_bits_finite_and_positive_standard() { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ); + let bits: f64 = plan.analytic_bits().into(); + assert!(bits.is_finite() && bits > 0.0, "bits = {bits}"); + let min_round = plan + .rounds + .iter() + .map(|r| f64::from(r.analytic_bits())) + .fold(f64::INFINITY, f64::min); + let expected = min_round.min(basecase_min_bits(&plan)); + assert!((bits - expected).abs() < 1e-9, "{bits} vs {expected}"); + } + + #[test] + fn analytic_bits_includes_mask_oracle_in_zk() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ); + let plan_bits: f64 = plan.analytic_bits().into(); + let mo_floor = plan + .rounds + .iter() + .filter_map(|r| { + r.mask_oracle + .as_ref() + .map(|mo| f64::from(mo.analytic_bits())) + }) + .fold(f64::INFINITY, f64::min); + assert!( + mo_floor.is_finite(), + "ZK plan must contribute mask-oracle bits" + ); + let min_round = plan + .rounds + .iter() + .map(|r| f64::from(r.analytic_bits())) + .fold(f64::INFINITY, f64::min); + let expected = mo_floor.min(min_round).min(basecase_min_bits(&plan)); + assert!( + (plan_bits - expected).abs() < 1e-9, + "{plan_bits} vs {expected}" + ); + } + + #[test] + fn derive_plans_basecase() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ); + assert!(matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::ZeroKnowledge + )); + assert_eq!(plan.basecase.commit.interleaving_depth, 1); + // Sumcheck folds basecase to size 1. + assert_eq!(plan.basecase.sumcheck.final_size(), 1); + } + + /// Matches `proof_of_work::threshold`'s 60-bit cap. + const LOOSE_POW_BUDGET_BITS: u32 = 60; + /// Below any realistic analytic gap; forces `check_pow_bits` to reject + /// the injected slot in the negative test. + const TIGHT_POW_BUDGET_BITS: u32 = 10; + /// Comfortably above `TIGHT_POW_BUDGET_BITS`. + const OVER_BUDGET_INJECTED_BITS: f64 = 50.0; + + /// Derived plans must satisfy their own `max_pow_bits` budget. + #[test] + fn check_pow_bits_passes_on_derived_plan() { + let spec = SecuritySpec { + mode: Mode::ZeroKnowledge, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + max_pow_bits: Some(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ); + assert!(plan.check_pow_bits()); + } + + /// Hand-injected over-budget PoW slot fails the check. + #[test] + fn check_pow_bits_detects_over_budget_slot() { + use crate::{bits::Bits, protocols::proof_of_work::Config as PowConfig}; + let spec = SecuritySpec { + mode: Mode::ZeroKnowledge, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + max_pow_bits: Some(TIGHT_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ); + plan.basecase.pow = PowConfig::from_difficulty(Bits::new(OVER_BUDGET_INJECTED_BITS)); + assert!(!plan.check_pow_bits()); + } + + /// `analytic_error + pow ≥ target` for every PoW slot in the plan. + fn assert_plan_meets_target_per_slot( + spec: &SecuritySpec, + plan: &ProtocolConfig, + ) { + for r in &plan.rounds { + let mask_info = r.mode.mask_oracle(); + assert_pow_closes_gap( + spec, + sumcheck_solver::analytic_error_bits(&r.code_switch.source, mask_info), + &r.sumcheck.round_pow, + ); + assert_pow_closes_gap( + spec, + code_switch_solver::analytic_error_bits( + &r.code_switch.source, + &r.code_switch.target, + r.code_switch.out_domain_samples, + mask_info, + ), + &r.code_switch.pow, + ); + if let Some(mo) = &r.mask_oracle { + assert_pow_closes_gap( + spec, + mask_proximity_solver::analytic_error_bits( + &mo.mask_proximity.c_zk_commit, + mo.mask_proximity.num_masks, + ), + &mo.mask_proximity.pow, + ); + } + } + assert_pow_closes_gap( + spec, + sumcheck_solver::analytic_error_bits(&plan.basecase.commit, None), + &plan.basecase.sumcheck.round_pow, + ); + // γ-slot is ZK-only. + if matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::ZeroKnowledge + ) { + assert_pow_closes_gap( + spec, + basecase_solver::analytic_error_bits(&plan.basecase.commit), + &plan.basecase.pow, + ); + } + } + + proptest! { + /// End-to-end soundness (Standard): every PoW slot in the derived plan + /// closes the gap `analytic + pow ≥ target` against the spec target. + #[test] + fn derived_plan_meets_target_per_slot_standard(tuning in arb_tuning()) { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive(spec.clone(), tuning); + assert_plan_meets_target_per_slot(&spec, &plan); + } + + /// End-to-end soundness (ZK): same as above, plus the per-round + /// mask-proximity slot and the basecase γ-slot. + #[test] + fn derived_plan_meets_target_per_slot_zk(tuning in arb_tuning()) { + let log_threshold = + tuning.folding_factor.at_round(0) + tuning.folding_factor.at_round(1); + prop_assume!(tuning.vector_size.trailing_zeros() as usize >= log_threshold); + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive(spec.clone(), tuning); + assert_plan_meets_target_per_slot(&spec, &plan); + } + + /// Standard mode: derive succeeds for any tuning shape, no per-round + /// mask oracle, and basecase covers the post-fold tail. + #[test] + fn derive_standard_succeeds_over_tunings(tuning in arb_tuning()) { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive(spec, tuning); + for r in &plan.rounds { + prop_assert!(matches!(r.mode, RoundMode::Standard)); + prop_assert!(r.mask_oracle.is_none()); + } + prop_assert!(matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::Standard + )); + prop_assert_eq!(plan.basecase.commit.interleaving_depth, 1); + } + + /// ZK mode: each round has its own mask oracle sized for `k + 1` + /// masks; basecase is ZK-flagged when shapes are non-empty. + #[test] + fn derive_zk_succeeds_over_tunings(tuning in arb_tuning()) { + let log_threshold = + tuning.folding_factor.at_round(0) + tuning.folding_factor.at_round(1); + prop_assume!(tuning.vector_size.trailing_zeros() as usize >= log_threshold); + + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive(spec, tuning); + for r in &plan.rounds { + let mask_oracle = r + .mask_oracle + .as_ref() + .expect("ZK round must have a mask oracle"); + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { + panic!("expected ZK round"); + }; + let k = r.code_switch.source.interleaving_depth.trailing_zeros() as usize; + let num_masks = k + 1; + prop_assert_eq!(mask_oracle.c_zk.num_vectors, 2 * num_masks); + prop_assert_eq!(mask_oracle.mask_proximity.num_masks, num_masks); + // Bound 3 (Lemma 9.3): ℓ_zk ≥ r + t_ood for this round. + let source_mask = r.code_switch.source.mask_length(); + prop_assert!(mask_oracle.l_zk.get() >= source_mask + t_ood.get()); + } + prop_assert!(matches!( + plan.basecase.mode, + crate::protocols::basecase::Mode::ZeroKnowledge + )); + } + + /// `analytic_bits` is finite and non-negative for any tuning the + /// planner accepts in Standard mode. + #[test] + fn analytic_bits_finite_and_non_negative_standard(tuning in arb_tuning()) { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive(spec, tuning); + let analytic = f64::from(plan.analytic_bits()); + prop_assert!(analytic.is_finite()); + prop_assert!(analytic >= 0.0); + } + } +} diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 0f82ea31..9dc18055 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -6,20 +6,23 @@ use std::num::NonZeroUsize; use crate::{ algebra::embedding::Embedding, protocols::{ - irs_commit::{self, num_in_domain_queries, IrsMode}, - params::spec::{ - LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + irs_commit::{num_in_domain_queries, Config as IrsConfig, IrsMode}, + params::{ + bounds::rate, + spec::{ + LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + }, }, }, }; pub fn solve( - spec: &SecuritySpec, + spec: &SecuritySpec, ctx: &RoundContext, out_domain_samples: OodSampleBudget, -) -> irs_commit::Config { +) -> IrsConfig { let security_target = spec.protocol_security_target_bits(); - let rate = 2_f64.powf(-f64::from(ctx.log_inv_rate)); + let rate = rate(f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; // Construction 9.7 is Johnson-only — `Mode` cannot express unique-decoding. let unique_decoding = false; @@ -31,19 +34,19 @@ pub fn solve( let min_mask = num_in_domain_queries(unique_decoding, security_target, rate) .checked_add(out_domain_samples.get()) .expect("usize overflow"); - // Lemma 9.5 is `≥`, so pow2 padding is safe. - let mask_length = message_length + // Lemma 9.5: mask covers in-domain + OOD queries. + // Pad masked length to a pow2 for NTT (the lemma is `≥`, so padding is safe). + let masked_message_length = message_length .checked_add(min_mask.get()) - .expect("usize overflow") - .next_power_of_two() - .checked_sub(message_length) - .and_then(NonZeroUsize::new) - .expect("mask_length non-zero in ZK"); + .expect("masked_message_length overflow") + .next_power_of_two(); + let mask_length = NonZeroUsize::new(masked_message_length - message_length) + .expect("min_mask ≥ 1 (NonZeroUsize) ⇒ next_pow2(ℓ + min_mask) > ℓ"); IrsMode::ZeroKnowledge { mask_length } } }; - irs_commit::Config::new( + IrsConfig::new( security_target, unique_decoding, spec.hash_id, @@ -61,12 +64,12 @@ pub fn solve( /// - `source_mask_length`: `r` from Theorem 9.6. /// - `num_vectors`: `2 * num_masks` (Construction 7.2: originals + fresh). pub fn solve_mask_code( - spec: &SecuritySpec, + spec: &SecuritySpec, l_zk: MaskCodeMessageLen, source_mask_length: usize, log_inv_rate: LogInvRate, num_vectors: usize, -) -> irs_commit::Config { +) -> IrsConfig { let l_zk = l_zk.get(); assert!( matches!(spec.mode, Mode::ZeroKnowledge), @@ -83,9 +86,9 @@ pub fn solve_mask_code( ); let security_target = spec.protocol_security_target_bits(); - let rate = 2_f64.powf(-f64::from(log_inv_rate.get())); + let rate = rate(f64::from(log_inv_rate.get())); - irs_commit::Config::new( + IrsConfig::new( security_target, false, // ZK ⇒ Johnson regime spec.hash_id, @@ -104,6 +107,7 @@ mod tests { use super::*; use crate::protocols::params::test_utils::{ arb_round_ctx, arb_spec, arb_zk_spec, deterministic_spec, TestEmbedding, + TestNonIdentityEmbedding, }; type M = TestEmbedding; @@ -111,36 +115,36 @@ mod tests { #[test] #[should_panic(expected = "C_zk only exists in ZK mode")] fn solve_mask_code_rejects_standard_spec() { - let spec: SecuritySpec = deterministic_spec(Mode::Standard); - let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 2); + let spec: SecuritySpec = deterministic_spec(Mode::Standard); + let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 2); } #[test] #[should_panic(expected = "must be a power of 2")] fn solve_mask_code_rejects_non_pow2_l_zk() { - let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); - let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(3), 0, LogInvRate::new(1), 2); + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(3), 0, LogInvRate::new(1), 2); } #[test] #[should_panic(expected = "Theorem 9.6")] fn solve_mask_code_rejects_l_zk_below_source_mask_length() { - let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); - let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(2), 4, LogInvRate::new(1), 2); + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(2), 4, LogInvRate::new(1), 2); } #[test] #[should_panic(expected = "must be even")] fn solve_mask_code_rejects_odd_num_vectors() { - let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); - let _ = solve_mask_code(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 3); + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 3); } - fn arb_zk_spec_default() -> impl Strategy> { + fn arb_zk_spec_default() -> impl Strategy { arb_zk_spec(80..=128) } - fn arb_standard_spec() -> impl Strategy> { + fn arb_standard_spec() -> impl Strategy { arb_spec(Mode::Standard, 80..=128) } @@ -152,7 +156,7 @@ mod tests { ctx in arb_round_ctx(), out_domain in 0usize..16, ) { - let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); + let config = solve::(&spec, &ctx, OodSampleBudget::new(out_domain)); prop_assert!( config.mask_length() >= config.in_domain_samples + out_domain, "mask {} < in_domain {} + out_domain {}", @@ -166,8 +170,25 @@ mod tests { ctx in arb_round_ctx(), out_domain in 0usize..8, ) { - let config = solve(&spec, &ctx, OodSampleBudget::new(out_domain)); + let config = solve::(&spec, &ctx, OodSampleBudget::new(out_domain)); prop_assert_eq!(config.mask_length(), 0); } } + + /// Smoke test: `M::Source ≠ M::Target`, ZK path. Mask sizing depends only + /// on the target field (via `field_size_bits`), but the generic embedding + /// still flows through the Config and must compile + execute. + #[test] + fn solve_works_with_basefield_embedding_zk() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = RoundContext { + round_index: 0, + vector_size: 64, + log_inv_rate: 1, + folding_factor: 2, + }; + let config: IrsConfig = + solve(&spec, &ctx, OodSampleBudget::new(2)); + assert!(config.mask_length() > 0); + } } diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index 48f671f9..9c7236a2 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -7,21 +7,22 @@ use crate::{ algebra::{embedding::Identity, fields::FieldWithSize}, bits::Bits, protocols::{ - irs_commit::Config as IrsConfig, mask_proximity, params::spec::SecuritySpec, proof_of_work, + irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, + params::spec::SecuritySpec, proof_of_work::Config as PowConfig, }, }; /// `c_zk.num_vectors` must equal `2 * num_masks` (originals + fresh). /// PoW closes the Lemma 7.4 γ-combination gap to `spec.target_security_bits`. pub fn solve( - spec: &SecuritySpec>, + spec: &SecuritySpec, c_zk: IrsConfig>, num_masks: usize, -) -> mask_proximity::Config { +) -> MaskProximityConfig { let target_bits = Bits::new(f64::from(spec.target_security_bits)); let analytic = analytic_error_bits(&c_zk, num_masks); - let pow = proof_of_work::Config::grind_to(target_bits, analytic, spec.hash_id); - mask_proximity::Config::new(c_zk, num_masks, pow) + let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id); + MaskProximityConfig::new(c_zk, num_masks, pow) } /// γ-combination soundness (Lemma 7.4): @@ -38,6 +39,7 @@ pub fn analytic_error_bits(c_zk: &IrsConfig>, num_masks: u } #[cfg(test)] +#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; @@ -48,15 +50,59 @@ mod tests { protocols::{ irs_commit::IrsMode, params::{ - irs_commit as params_irs, + irs_commit as irs_solver, spec::{LogInvRate, MaskCodeMessageLen, Mode}, - test_utils::{arb_zk_spec, deterministic_spec, TestEmbedding}, + test_utils::{ + arb_zk_spec, assert_pow_closes_gap, deterministic_spec, TestEmbedding, + TEST_TARGET_RANGE, + }, }, }, }; - // Keeps `target − error ≤ 60`, the cap `proof_of_work::threshold` enforces. - const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; + /// γ-combination (Lemma 7.4): `log|F| − log(num_masks · (deg − 1))`, + /// `deg = c_zk.masked_message_length()`. With `num_masks = 0` or `deg ≤ 1` + /// the bound saturates to `field_bits`. + #[test] + fn analytic_error_formula() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let num_masks = 3_usize; + let c_zk = irs_solver::solve_mask_code::( + &spec, + MaskCodeMessageLen::new(8), + 0, + LogInvRate::new(1), + 2 * num_masks, + ); + + let got = f64::from(analytic_error_bits(&c_zk, num_masks)); + + let field_bits = ::field_size_bits(); + let deg = c_zk.masked_message_length(); + let log_combined = ((num_masks * (deg - 1)) as f64).log2(); + let expected = (field_bits - log_combined).max(0.0); + + assert!( + (got - expected).abs() < 1e-9, + "got {got} vs expected {expected}", + ); + } + + /// Degenerate inputs (`num_masks == 0` or `deg ≤ 1`) saturate to `field_bits`. + #[test] + fn analytic_error_saturates_when_no_masks() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let c_zk = irs_solver::solve_mask_code::( + &spec, + MaskCodeMessageLen::new(2), + 0, + LogInvRate::new(1), + 2, + ); + let bits = f64::from(analytic_error_bits(&c_zk, 0)); + let field_bits = ::field_size_bits(); + assert_eq!(bits, field_bits.max(0.0)); + } proptest! { #[test] @@ -67,7 +113,7 @@ mod tests { l_zk_log in 1u32..=5, ) { let l_zk = MaskCodeMessageLen::new(1usize << l_zk_log); - let c_zk = params_irs::solve_mask_code( + let c_zk = irs_solver::solve_mask_code::( &spec, l_zk, 0, @@ -89,29 +135,24 @@ mod tests { l_zk_log in 1u32..=5, ) { let l_zk = MaskCodeMessageLen::new(1usize << l_zk_log); - let c_zk = params_irs::solve_mask_code( + let c_zk = irs_solver::solve_mask_code::( &spec, l_zk, 0, LogInvRate::new(log_inv_rate), 2 * num_masks, ); - let analytic = f64::from(analytic_error_bits(&c_zk, num_masks)); + let analytic = analytic_error_bits(&c_zk, num_masks); let config = solve(&spec, c_zk, num_masks); - let pow_bits = f64::from(config.pow.difficulty()); - prop_assert!( - analytic + pow_bits >= f64::from(spec.target_security_bits) - 1e-3, - "analytic {} + pow {} < target {}", - analytic, pow_bits, spec.target_security_bits, - ); + assert_pow_closes_gap(&spec, analytic, &config.pow); } } #[test] #[should_panic(expected = "c_zk.num_vectors must be 2 * num_masks")] fn solve_rejects_mismatched_num_vectors() { - let spec = deterministic_spec::(Mode::ZeroKnowledge); - let c_zk = params_irs::solve_mask_code( + let spec = deterministic_spec(Mode::ZeroKnowledge); + let c_zk = irs_solver::solve_mask_code::( &spec, MaskCodeMessageLen::new(2), 0, @@ -124,17 +165,27 @@ mod tests { #[test] #[should_panic(expected = "interleaving_depth = 1")] fn solve_rejects_non_unit_interleaving() { - let spec = deterministic_spec::(Mode::ZeroKnowledge); - let c_zk = crate::protocols::irs_commit::Config::>::new( - 80.0, - false, + // All values except `NON_UNIT_INTERLEAVING_DEPTH` are chosen to satisfy + // `Config::new`'s divisibility/pow2 constraints. + const SECURITY_TARGET_BITS: f64 = 80.0; + const UNIQUE_DECODING: bool = false; + const NUM_VECTORS: usize = 2; + const VECTOR_SIZE: usize = 8; + const NON_UNIT_INTERLEAVING_DEPTH: usize = 2; + const RATE: f64 = 0.5; + const NUM_MASKS: usize = 1; + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let c_zk = IrsConfig::>::new( + SECURITY_TARGET_BITS, + UNIQUE_DECODING, hash::BLAKE3, - 2, - 8, - 2, // interleaving_depth ≠ 1 — triggers the panic - 0.5, + NUM_VECTORS, + VECTOR_SIZE, + NON_UNIT_INTERLEAVING_DEPTH, + RATE, IrsMode::Standard, ); - let _ = solve(&spec, c_zk, 1); + let _ = solve(&spec, c_zk, NUM_MASKS); } } diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index 4a56c6e6..26ecd371 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -1,10 +1,10 @@ pub mod basecase; pub(crate) mod bounds; pub mod code_switch; +pub mod derive; pub mod irs_commit; pub mod mask_proximity; -pub mod plan; -pub mod planner; +pub mod protocol_config; pub mod spec; pub mod sumcheck; diff --git a/src/protocols/params/planner.rs b/src/protocols/params/planner.rs deleted file mode 100644 index c765ba6e..00000000 --- a/src/protocols/params/planner.rs +++ /dev/null @@ -1,719 +0,0 @@ -//! Derives a [`ParameterPlan`] from a spec + tuning. All cross-protocol -//! coordination — per-round loop, `t_ood ↔ r` and `ℓ_zk ↔ c_zk` fixed-points, -//! shared C_zk + mask-proximity — lives here. - -use std::marker::PhantomData; - -use crate::{ - algebra::{ - embedding::{Embedding, Identity}, - fields::FieldWithSize, - }, - protocols::{ - irs_commit::{self, Config as IrsConfig}, - params::{ - basecase as bc_solver, code_switch as cs_solver, irs_commit as irs_solver, - mask_proximity as mp_solver, - plan::{ - MaskOracleInfo, MaskOraclePlan, ParameterPlan, RoundMode, RoundPlan, SharedPlan, - }, - spec::{ - LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, - TuningSpec, - }, - sumcheck as sc_solver, - }, - }, -}; - -const L_ZK_MAX_ITER: usize = 16; -/// Smallest pow2 ≥ 1 — satisfies `solve_mask_code`'s pow2 assertion. -const L_ZK_BOOTSTRAP: usize = 2; - -impl ParameterPlan { - /// ZK mode runs a global ℓ_zk fixed-point so one C_zk covers every round. - pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Self { - let RoundLayout { - shapes, - basecase_vector_size, - basecase_log_inv_rate, - } = round_layout(&tuning); - let target_spec = transfer_spec_to_target(&spec); - - let (rounds, mask_oracle) = match spec.mode { - Mode::Standard => { - let rounds = shapes - .iter() - .map(|shape| build_round(&spec, shape, None)) - .collect(); - (rounds, None) - } - Mode::ZeroKnowledge => { - let SharedMaskOracleData { - info, - round_data, - plan, - } = build_shared_mask_oracle(&spec, &target_spec, &tuning, &shapes); - let rounds = shapes - .iter() - .zip(round_data) - .map(|(shape, data)| finalize_zk_round(&spec, shape, data, info)) - .collect(); - (rounds, Some(plan)) - } - }; - - let basecase = bc_solver::solve(&target_spec, basecase_vector_size, basecase_log_inv_rate); - - Self { - security: spec, - tuning, - shared: SharedPlan { mask_oracle }, - rounds, - basecase, - } - } -} - -// Round layout -// --------------------------------------------------------------------------- - -/// `target_folding_factor` is the next round's source folding — uniform -/// `tuning.folding_factor` — so `target_r → source_{r+1}` has matching -/// interleaving. -#[derive(Debug, Clone, Copy)] -struct RoundShape { - round_index: usize, - source_vector_size: usize, - source_log_inv_rate: u32, - source_folding_factor: u32, - target_folding_factor: u32, -} - -/// Shapes plus the basecase tail (size + rate of the message after every fold). -struct RoundLayout { - shapes: Vec, - basecase_vector_size: usize, - basecase_log_inv_rate: u32, -} - -struct RoundData { - source: IrsConfig, - target: IrsConfig>, - t_ood: usize, -} - -/// Output of the ZK global ℓ_zk ↔ C_zk fixed-point: the slim `info` view used -/// by per-round builders, the materialised per-round IRS/t_ood, and the full -/// shared `MaskOraclePlan` to embed in the final plan. -struct SharedMaskOracleData { - info: MaskOracleInfo, - round_data: Vec>, - plan: MaskOraclePlan, -} - -/// Stops when there's no room for both a valid source and a valid target IRS. -fn round_layout(tuning: &TuningSpec) -> RoundLayout { - assert!(tuning.vector_size.is_power_of_two()); - assert!(tuning.folding_factor.min() >= 1); - - let mut num_vars = tuning.vector_size.trailing_zeros() as usize; - let mut log_inv_rate = tuning.starting_log_inv_rate; - let mut shapes = Vec::new(); - - loop { - let round = shapes.len(); - let source_folding = tuning.folding_factor.at_round(round); - let target_folding = tuning.folding_factor.at_round(round + 1); - if num_vars < source_folding + target_folding { - break; - } - #[allow(clippy::cast_possible_truncation)] - shapes.push(RoundShape { - round_index: round, - source_vector_size: 1usize << num_vars, - source_log_inv_rate: log_inv_rate, - source_folding_factor: source_folding as u32, - target_folding_factor: target_folding as u32, - }); - num_vars -= source_folding; - #[allow(clippy::cast_possible_truncation)] - { - log_inv_rate += (source_folding as u32).saturating_sub(1); - } - } - - RoundLayout { - shapes, - basecase_vector_size: 1usize << num_vars, - basecase_log_inv_rate: log_inv_rate, - } -} - -const fn round_context(shape: &RoundShape) -> RoundContext { - RoundContext { - round_index: shape.round_index, - vector_size: shape.source_vector_size, - log_inv_rate: shape.source_log_inv_rate, - folding_factor: shape.source_folding_factor, - } -} - -fn target_context(shape: &RoundShape, source: &IrsConfig) -> RoundContext { - RoundContext { - round_index: shape.round_index, - vector_size: source.message_length(), - log_inv_rate: shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1), - folding_factor: shape.target_folding_factor, - } -} - -// Zero-knowledge fixed-point — shared C_zk + global ℓ_zk -// --------------------------------------------------------------------------- - -/// Run the global ℓ_zk ↔ C_zk fixed-point. `ℓ_zk = next_pow2(max_round(r + t_ood))` -/// (Lemma 9.3), `C_zk.list_size` feeds back into per-round `t_ood` (Lemma 9.9 -/// term 1). The shared C_zk holds `2 · total_masks` columns (originals + fresh, -/// one mask per sumcheck round per Lemma 6.4). -fn build_shared_mask_oracle( - spec: &SecuritySpec, - target_spec: &SecuritySpec>, - tuning: &TuningSpec, - shapes: &[RoundShape], -) -> SharedMaskOracleData { - let c_zk_log_inv_rate = LogInvRate::new(tuning.starting_log_inv_rate); - - let total_masks: usize = shapes - .iter() - .map(|s| s.source_folding_factor as usize) - .sum(); - assert!(total_masks > 0, "ZK requires ≥ 1 mask polynomial"); - let c_zk_num_vectors = 2 * total_masks; - - let mut l_zk = MaskCodeMessageLen::new(L_ZK_BOOTSTRAP); - let mut c_zk = - irs_solver::solve_mask_code(target_spec, l_zk, 0, c_zk_log_inv_rate, c_zk_num_vectors); - - let mut last_round_data: Vec> = Vec::new(); - - for _ in 0..L_ZK_MAX_ITER { - let round_data: Vec> = shapes - .iter() - .map(|shape| build_zk_round_data(spec, shape, c_zk.list_size())) - .collect(); - - let max_r_plus_t_ood = round_data - .iter() - .map(|r| r.source.mask_length() + r.t_ood) - .max() - .expect("non-empty rounds"); - let new_l_zk = MaskCodeMessageLen::new(max_r_plus_t_ood.next_power_of_two()); - - if new_l_zk.get() == l_zk.get() { - last_round_data = round_data; - break; - } - - l_zk = new_l_zk; - // Solve_mask_code asserts `ℓ_zk ≥ r`; pass the max so it always holds. - let max_source_mask = round_data - .iter() - .map(|r| r.source.mask_length()) - .max() - .unwrap_or(0); - c_zk = irs_solver::solve_mask_code( - target_spec, - l_zk, - max_source_mask, - c_zk_log_inv_rate, - c_zk_num_vectors, - ); - last_round_data = round_data; - } - - let info = MaskOracleInfo { - c_zk_list_size: c_zk.list_size(), - l_zk, - }; - let mask_proximity = mp_solver::solve(target_spec, c_zk.clone(), total_masks); - let plan = MaskOraclePlan { - c_zk, - l_zk, - mask_proximity, - }; - - SharedMaskOracleData { - info, - round_data: last_round_data, - plan, - } -} - -/// Local fixed point: `source.mask_length` covers `t_ood` queries; `t_ood` is -/// sized against `source.message + source.mask`. -fn build_zk_round_data( - spec: &SecuritySpec, - shape: &RoundShape, - c_zk_list_size: f64, -) -> RoundData { - const LOCAL_MAX_ITER: usize = 16; - - let src_ctx = round_context(shape); - let mut source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(0)); - let mut t_ood = 0; - let mut target = irs_solver::solve( - &transfer_spec_to_target(spec), - &target_context(shape, &source), - OodSampleBudget::new(0), - ); - - for _ in 0..LOCAL_MAX_ITER { - let new_t_ood = compute_t_ood(spec, &source, target.list_size(), Some(c_zk_list_size)); - let new_source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(new_t_ood)); - let new_target = irs_solver::solve( - &transfer_spec_to_target(spec), - &target_context(shape, &new_source), - OodSampleBudget::new(new_t_ood), - ); - - if new_t_ood == t_ood - && new_source.codeword_length == source.codeword_length - && new_target.codeword_length == target.codeword_length - { - return RoundData { - source: new_source, - target: new_target, - t_ood: new_t_ood, - }; - } - - source = new_source; - target = new_target; - t_ood = new_t_ood; - } - - panic!("per-round ZK fixed-point did not converge"); -} - -fn finalize_zk_round( - spec: &SecuritySpec, - shape: &RoundShape, - data: RoundData, - mask_oracle: MaskOracleInfo, -) -> RoundPlan { - let RoundData { - source, - target, - t_ood, - } = data; - let src_ctx = round_context(shape); - let sumcheck = sc_solver::solve(spec, &src_ctx, &source, Some(mask_oracle)); - let code_switch = cs_solver::solve(spec, source, target, t_ood, Some(mask_oracle)); - RoundPlan { - round_index: shape.round_index, - sumcheck, - code_switch, - mode: RoundMode::ZeroKnowledge { - t_ood: OodSampleBudget::new(t_ood), - mask_oracle, - }, - } -} - -// Standard mode per-round builder -// --------------------------------------------------------------------------- - -fn build_round( - spec: &SecuritySpec, - shape: &RoundShape, - mask_oracle: Option, -) -> RoundPlan { - debug_assert!(mask_oracle.is_none(), "ZK path uses finalize_zk_round"); - - let target_spec = transfer_spec_to_target(spec); - let src_ctx = round_context(shape); - let source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(0)); - - let mut target = irs_solver::solve( - &target_spec, - &target_context(shape, &source), - OodSampleBudget::new(0), - ); - let mut t_ood = compute_t_ood(spec, &source, target.list_size(), None); - for _ in 0..8 { - let new_target = irs_solver::solve( - &target_spec, - &target_context(shape, &source), - OodSampleBudget::new(t_ood), - ); - let new_t_ood = compute_t_ood(spec, &source, new_target.list_size(), None); - if new_target.codeword_length == target.codeword_length && new_t_ood == t_ood { - target = new_target; - t_ood = new_t_ood; - break; - } - target = new_target; - t_ood = new_t_ood; - } - - let sumcheck = sc_solver::solve(spec, &src_ctx, &source, None); - let code_switch = cs_solver::solve(spec, source, target, t_ood, None); - RoundPlan { - round_index: shape.round_index, - sumcheck, - code_switch, - mode: RoundMode::Standard, - } -} - -// Cross-protocol bound helpers -// --------------------------------------------------------------------------- - -/// Per-round `ℓ_zk = next_power_of_two(r + t_ood)` (Lemma 9.3). The global -/// ℓ_zk in [`derive_zk`] is the max-then-pad over all rounds, computed inline. -#[allow(dead_code)] -pub(super) const fn compute_l_zk( - source: &IrsConfig, - t_ood: usize, -) -> MaskCodeMessageLen { - MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()) -} - -/// Solves Lemma 9.9 term 1 for `t_ood`. In ZK, `degree = ℓ + r + t_ood` -/// couples back to `t_ood`, so iterate. -pub(super) fn compute_t_ood( - spec: &SecuritySpec, - source: &IrsConfig, - target_list_size: f64, - c_zk_list_size: Option, -) -> usize { - const MAX_ITER: usize = 32; - - let security_target = spec.protocol_security_target_bits(); - let field_bits = M::Target::field_size_bits(); - // Construction 9.7 is Johnson-only — `Mode` cannot express unique-decoding. - let unique_decoding = false; - let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); - let message_length = source.message_length(); - let source_mask_length = source.mask_length(); - - let solve_for_degree = |degree: usize| { - irs_commit::num_ood_samples( - unique_decoding, - security_target, - field_bits, - combined_list_size, - degree, - ) - }; - - if matches!(spec.mode, Mode::Standard) { - return solve_for_degree(message_length); - } - - let mut t_ood = 0; - for _ in 0..MAX_ITER { - let new_t_ood = solve_for_degree(message_length + source_mask_length + t_ood); - if new_t_ood == t_ood { - return t_ood; - } - t_ood = new_t_ood; - } - panic!("compute_t_ood did not converge in {MAX_ITER} iterations"); -} - -// SecuritySpec helpers -// --------------------------------------------------------------------------- - -/// C_zk lives in `Identity`; copy the rest of the spec across. -const fn transfer_spec_to_target( - spec: &SecuritySpec, -) -> SecuritySpec> { - SecuritySpec { - mode: spec.mode, - target_security_bits: spec.target_security_bits, - max_pow_bits: spec.max_pow_bits, - hash_id: spec.hash_id, - _embedding: PhantomData, - } -} - -#[cfg(test)] -#[allow(clippy::float_cmp)] -mod tests { - use proptest::prelude::*; - - use super::*; - use crate::{ - hash, - protocols::params::{ - bounds::SoundnessBounded, spec::FoldingFactor, test_utils::TestEmbedding, - }, - }; - - /// Varied tuning space for proptests. Exercises both `FoldingFactor` - /// variants. Bounds keep PoW under the 60-bit cap and the IRS solver - /// inside Field64's reachable range. - fn arb_tuning() -> impl Strategy { - let folding = prop_oneof![ - (1usize..=3).prop_map(FoldingFactor::Constant), - (1usize..=3, 1usize..=3).prop_map(|(initial, rest)| { - FoldingFactor::ConstantFromSecondRound { initial, rest } - }), - ]; - (4u32..=8, 1u32..=3, folding).prop_map(|(log_size, log_inv_rate, folding_factor)| { - TuningSpec { - vector_size: 1usize << log_size, - starting_log_inv_rate: log_inv_rate, - folding_factor, - } - }) - } - - fn tuning_with(vector_size: usize) -> TuningSpec { - TuningSpec { - vector_size, - starting_log_inv_rate: 1, - folding_factor: FoldingFactor::Constant(2), - } - } - - /// Keeps PoW below the 60-bit cap for small test tunings. - fn test_spec(mode: Mode) -> SecuritySpec { - SecuritySpec { - mode, - target_security_bits: 40, - max_pow_bits: None, - hash_id: hash::BLAKE3, - _embedding: PhantomData, - } - } - - #[test] - fn round_shapes_match_old_whir_loop() { - let tuning = tuning_with(1 << 10); - let layout = round_layout(&tuning); - assert!(!layout.shapes.is_empty()); - assert_eq!(layout.shapes[0].source_vector_size, 1 << 10); - assert_eq!(layout.shapes[0].source_folding_factor, 2); - } - - #[test] - fn derive_standard_with_no_rounds_uses_basecase_only() { - let spec: SecuritySpec = test_spec(Mode::Standard); - // tuning_with sets initial=2, folding=2 → threshold = 4, so num_vars=3 (size=8) gives 0 rounds. - let plan = ParameterPlan::derive(spec, tuning_with(1 << 3)); - assert!(plan.rounds.is_empty()); - assert_eq!(plan.basecase.commit.vector_size, 1 << 3); - } - - #[test] - #[should_panic(expected = "ZK requires ≥ 1 mask polynomial")] - fn derive_zk_panics_with_no_rounds() { - let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); - let _ = ParameterPlan::derive(spec, tuning_with(1 << 3)); - } - - /// Lemma 9.9 fixed-point: every ZK round needs at least one OOD challenge. - #[test] - fn compute_t_ood_nonzero_in_zk() { - let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); - let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); - for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { - panic!("expected ZK round") - }; - assert!(t_ood.get() >= 1); - } - } - - #[test] - fn derive_zk_produces_shared_mask_oracle() { - let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); - let tuning = tuning_with(1 << 8); - let plan = ParameterPlan::derive(spec, tuning); - let mask_oracle = plan - .shared - .mask_oracle - .as_ref() - .expect("ZK plan must produce a mask oracle"); - - // Bound 3: ℓ_zk dominates every round's r + t_ood. - for r in &plan.rounds { - let RoundMode::ZeroKnowledge { - t_ood, - mask_oracle: round_oracle, - } = r.mode - else { - panic!("expected ZK round"); - }; - let source_mask = r.code_switch.source.mask_length(); - assert!(mask_oracle.l_zk.get() >= source_mask + t_ood.get()); - assert_eq!(round_oracle.l_zk.get(), mask_oracle.l_zk.get()); - assert_eq!(round_oracle.c_zk_list_size, mask_oracle.c_zk.list_size()); - } - - let total_masks: usize = plan.rounds.iter().map(|r| r.sumcheck.num_rounds).sum(); - assert_eq!(mask_oracle.c_zk.num_vectors, 2 * total_masks); - assert_eq!(mask_oracle.mask_proximity.num_masks, total_masks); - } - - fn basecase_min_bits(plan: &ParameterPlan) -> f64 { - let sumcheck = f64::from(sc_solver::analytic_error_bits(&plan.basecase.commit, None)); - if matches!( - plan.basecase.mode, - crate::protocols::basecase::Mode::ZeroKnowledge - ) { - sumcheck.min(f64::from(bc_solver::analytic_error_bits( - &plan.basecase.commit, - ))) - } else { - sumcheck - } - } - - #[test] - fn analytic_bits_finite_and_positive_standard() { - let spec: SecuritySpec = test_spec(Mode::Standard); - let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); - let bits: f64 = plan.analytic_bits().into(); - assert!(bits.is_finite() && bits > 0.0, "bits = {bits}"); - let min_round = plan - .rounds - .iter() - .map(|r| f64::from(r.analytic_bits())) - .fold(f64::INFINITY, f64::min); - let expected = min_round.min(basecase_min_bits(&plan)); - assert!((bits - expected).abs() < 1e-9, "{bits} vs {expected}"); - } - - #[test] - fn analytic_bits_includes_mask_oracle_in_zk() { - let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); - let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); - let plan_bits: f64 = plan.analytic_bits().into(); - let mo_bits: f64 = plan - .shared - .mask_oracle - .as_ref() - .expect("ZK has mask oracle") - .analytic_bits() - .into(); - let min_round = plan - .rounds - .iter() - .map(|r| f64::from(r.analytic_bits())) - .fold(f64::INFINITY, f64::min); - let expected = mo_bits.min(min_round).min(basecase_min_bits(&plan)); - assert!( - (plan_bits - expected).abs() < 1e-9, - "{plan_bits} vs {expected}" - ); - } - - #[test] - fn derive_plans_basecase() { - let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); - let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); - assert!(matches!( - plan.basecase.mode, - crate::protocols::basecase::Mode::ZeroKnowledge - )); - assert_eq!(plan.basecase.commit.interleaving_depth, 1); - // Sumcheck folds basecase to size 1. - assert_eq!(plan.basecase.sumcheck.final_size(), 1); - } - - /// Derived plans must satisfy their own `max_pow_bits` budget. - #[test] - fn check_pow_bits_passes_on_derived_plan() { - let spec: SecuritySpec = SecuritySpec { - mode: Mode::ZeroKnowledge, - target_security_bits: 40, - max_pow_bits: Some(60), - hash_id: hash::BLAKE3, - _embedding: PhantomData, - }; - let plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); - assert!(plan.check_pow_bits()); - } - - /// Hand-injected over-budget PoW slot fails the check. - #[test] - fn check_pow_bits_detects_over_budget_slot() { - use crate::{bits::Bits, protocols::proof_of_work}; - let spec: SecuritySpec = SecuritySpec { - mode: Mode::ZeroKnowledge, - target_security_bits: 40, - max_pow_bits: Some(10), - hash_id: hash::BLAKE3, - _embedding: PhantomData, - }; - let mut plan = ParameterPlan::derive(spec, tuning_with(1 << 8)); - plan.basecase.pow = proof_of_work::Config::from_difficulty(Bits::new(50.0)); - assert!(!plan.check_pow_bits()); - } - - proptest! { - /// Standard mode: derive succeeds for any tuning shape, mask oracle is - /// absent, and basecase covers the post-fold tail. - #[test] - fn derive_standard_succeeds_over_tunings(tuning in arb_tuning()) { - let spec: SecuritySpec = test_spec(Mode::Standard); - let plan = ParameterPlan::derive(spec, tuning); - prop_assert!(plan.shared.mask_oracle.is_none()); - for r in &plan.rounds { - prop_assert!(matches!(r.mode, RoundMode::Standard)); - } - prop_assert!(matches!( - plan.basecase.mode, - crate::protocols::basecase::Mode::Standard - )); - prop_assert_eq!(plan.basecase.commit.interleaving_depth, 1); - } - - /// ZK mode: derive succeeds when shapes are non-empty; total masks - /// matches the sum of source folding factors; basecase is ZK-flagged - /// when shapes are non-empty. - #[test] - fn derive_zk_succeeds_over_tunings(tuning in arb_tuning()) { - let log_threshold = - tuning.folding_factor.at_round(0) + tuning.folding_factor.at_round(1); - prop_assume!(tuning.vector_size.trailing_zeros() as usize >= log_threshold); - - let spec: SecuritySpec = test_spec(Mode::ZeroKnowledge); - let plan = ParameterPlan::derive(spec, tuning); - let mask_oracle = plan - .shared - .mask_oracle - .as_ref() - .expect("ZK plan must have a mask oracle"); - - let total_source_folds: usize = plan - .rounds - .iter() - .map(|r| r.code_switch.source.interleaving_depth.trailing_zeros() as usize) - .sum(); - prop_assert_eq!(mask_oracle.c_zk.num_vectors, 2 * total_source_folds); - prop_assert!(matches!( - plan.basecase.mode, - crate::protocols::basecase::Mode::ZeroKnowledge - )); - } - - /// `analytic_bits + max_per_slot_pow ≥ target` for any tuning the - /// planner accepts (Standard mode: no mask-oracle floor). - #[test] - fn analytic_plus_pow_meets_target_standard(tuning in arb_tuning()) { - let spec: SecuritySpec = test_spec(Mode::Standard); - let plan = ParameterPlan::derive(spec.clone(), tuning); - let analytic = f64::from(plan.analytic_bits()); - // Reading the dominant per-slot PoW: each sub-protocol grinds to - // `target_security_bits`. We assert the analytic floor is non-zero - // and that `analytic + 60` covers any plausible target. - prop_assert!(analytic.is_finite()); - prop_assert!(analytic >= 0.0); - prop_assert!(analytic + 60.0 >= f64::from(spec.target_security_bits) - 1e-3); - } - } -} diff --git a/src/protocols/params/plan.rs b/src/protocols/params/protocol_config.rs similarity index 58% rename from src/protocols/params/plan.rs rename to src/protocols/params/protocol_config.rs index 9bfaf68a..a8e492f5 100644 --- a/src/protocols/params/plan.rs +++ b/src/protocols/params/protocol_config.rs @@ -1,9 +1,9 @@ -//! Output shape of the planner. +//! Output of [`super::derive`]: the assembled per-round and basecase configs. //! -//! C_zk and ℓ_zk are protocol-global (one shared Merkle tree across all -//! rounds) and live in [`SharedPlan`]; per-round sumcheck + code-switch live -//! in [`RoundPlan`]. Source/target IRS configs are accessed via -//! `round.code_switch` — not duplicated at the round level. +//! Each ZK round owns its mask oracle: a per-round C_zk codeword (sized for +//! `2·(k+1)` columns — `k` sumcheck masks + 1 code-switch `(r ‖ s)` mask, all +//! doubled by Construction 7.2's originals + fresh pairs) plus a per-round +//! mask-proximity check. Standard rounds carry no mask oracle. use ark_ff::Field; @@ -11,7 +11,10 @@ use crate::{ algebra::embedding::{Embedding, Identity}, bits::Bits, protocols::{ - basecase, code_switch, irs_commit, mask_proximity, + basecase::{self, Config as BasecaseConfig}, + code_switch::Config as CodeSwitchConfig, + irs_commit::Config as IrsConfig, + mask_proximity::Config as MaskProximityConfig, params::{ basecase as basecase_solver, bounds::SoundnessBounded, @@ -19,51 +22,47 @@ use crate::{ spec::{MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, sumcheck as sumcheck_solver, }, - proof_of_work, sumcheck, + proof_of_work::Config as PowConfig, + sumcheck::Config as SumcheckConfig, }, }; #[derive(Clone, Debug)] -pub struct ParameterPlan { - pub security: SecuritySpec, +pub struct ProtocolConfig { + pub security: SecuritySpec, pub tuning: TuningSpec, - pub shared: SharedPlan, - pub rounds: Vec>, - pub basecase: basecase::Config, + pub rounds: Vec>, + pub basecase: BasecaseConfig, } -impl ParameterPlan { - /// Returns `true` iff every PoW slot's difficulty fits within +impl ProtocolConfig { + /// Returns `true` if every PoW slot's difficulty fits within /// `security.max_pow_bits`. Cheap pre-flight check that fails before the /// 60-bit cap assertion inside `proof_of_work::threshold`. pub fn check_pow_bits(&self) -> bool { let max = Bits::new(f64::from(self.security.max_pow_bits.unwrap_or(0))); - let within = |pow: &proof_of_work::Config| pow.difficulty() <= max; - - if !self - .rounds - .iter() - .all(|r| within(&r.sumcheck.round_pow) && within(&r.code_switch.pow)) - { + let within = |pow: &PowConfig| pow.difficulty() <= max; + if !self.rounds.iter().all(|r| { + within(&r.sumcheck.round_pow) + && within(&r.code_switch.pow) + && r.mask_oracle + .as_ref() + .is_none_or(|mo| within(&mo.mask_proximity.pow)) + }) { return false; } - if let Some(mo) = &self.shared.mask_oracle { - if !within(&mo.mask_proximity.pow) { - return false; - } - } within(&self.basecase.sumcheck.round_pow) && within(&self.basecase.pow) } } -impl SoundnessBounded for ParameterPlan { +impl SoundnessBounded for ProtocolConfig { fn analytic_bits(&self) -> Bits { let mut min_bits = f64::INFINITY; for round in &self.rounds { min_bits = min_bits.min(f64::from(round.analytic_bits())); - } - if let Some(mo) = &self.shared.mask_oracle { - min_bits = min_bits.min(f64::from(mo.analytic_bits())); + if let Some(mo) = &round.mask_oracle { + min_bits = min_bits.min(f64::from(mo.analytic_bits())); + } } // Basecase sumcheck per-round bound applies in both modes; the γ-slot // only contributes in ZK. @@ -84,11 +83,14 @@ impl SoundnessBounded for ParameterPlan { } #[derive(Clone, Debug)] -pub struct RoundPlan { +pub struct RoundConfig { pub round_index: usize, - pub sumcheck: sumcheck::Config, - pub code_switch: code_switch::Config, + pub sumcheck: SumcheckConfig, + pub code_switch: CodeSwitchConfig, pub mode: RoundMode, + /// `Some` iff this is a ZK round. Sized for this round's `k + 1` masks + /// (k sumcheck + 1 code-switch). + pub mask_oracle: Option>, } #[derive(Clone, Copy, Debug, PartialEq)] @@ -97,8 +99,9 @@ pub enum RoundMode { ZeroKnowledge { /// Bound 2 / Lemma 9.9. t_ood: OodSampleBudget, - /// Cached view of the shared mask oracle (denormalized from - /// [`MaskOraclePlan`]) so each round is self-contained for soundness. + /// Slim view of this round's [`MaskOracleConfig`] (C_zk's list size + + /// ℓ_zk) — denormalized so soundness routines can read it without + /// chasing through `mask_oracle`. mask_oracle: MaskOracleInfo, }, } @@ -116,7 +119,7 @@ impl RoundMode { } } -impl SoundnessBounded for RoundPlan { +impl SoundnessBounded for RoundConfig { fn analytic_bits(&self) -> Bits { let source = &self.code_switch.source; let target = &self.code_switch.target; @@ -138,32 +141,25 @@ impl SoundnessBounded for RoundPlan { } } +/// One round's mask oracle: a C_zk codeword + ℓ_zk + mask-proximity check +/// covering `k + 1` masks (sumcheck + code-switch) for this round. #[derive(Clone, Debug)] -pub struct SharedPlan { - /// `Some` iff `Mode::ZeroKnowledge`. - pub mask_oracle: Option>, -} - -/// One C_zk codeword + one shared Merkle tree + one mask-proximity check, -/// covering every mask committed across all rounds. -#[derive(Clone, Debug)] -pub struct MaskOraclePlan { - /// `num_vectors = 2 * total_masks` (Construction 7.2: originals + fresh). - pub c_zk: irs_commit::Config>, - /// Dominates every round's `r + t_ood` (Lemma 9.3). +pub struct MaskOracleConfig { + /// `num_vectors = 2 · (k + 1)` (Construction 7.2: originals + fresh). + pub c_zk: IrsConfig>, + /// `next_pow2(r + t_ood)` for this round (Lemma 9.3). pub l_zk: MaskCodeMessageLen, - pub mask_proximity: mask_proximity::Config, + pub mask_proximity: MaskProximityConfig, } -/// Slim mask-oracle view (C_zk's list size + ℓ_zk) for builders that don't -/// need the full config. +/// Slim mask-oracle view (C_zk's list size + ℓ_zk). #[derive(Clone, Copy, Debug, PartialEq)] pub struct MaskOracleInfo { pub c_zk_list_size: f64, pub l_zk: MaskCodeMessageLen, } -impl MaskOraclePlan { +impl MaskOracleConfig { pub fn info(&self) -> MaskOracleInfo { MaskOracleInfo { c_zk_list_size: self.c_zk.list_size(), @@ -172,7 +168,7 @@ impl MaskOraclePlan { } } -impl SoundnessBounded for MaskOraclePlan { +impl SoundnessBounded for MaskOracleConfig { fn analytic_bits(&self) -> Bits { mask_proximity_solver::analytic_error_bits( &self.mask_proximity.c_zk_commit, diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 5dd15d7d..03177efc 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -1,6 +1,6 @@ use core::marker::PhantomData; -use crate::{algebra::embedding::Embedding, engines::EngineId}; +use crate::engines::EngineId; /// Phantom-typed newtype — `Tagged` and `Tagged` are distinct types. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -17,15 +17,14 @@ impl Tagged { } #[derive(Debug, Clone)] -pub struct SecuritySpec { +pub struct SecuritySpec { pub mode: Mode, pub target_security_bits: u32, pub max_pow_bits: Option, pub hash_id: EngineId, - pub _embedding: PhantomData, } -impl SecuritySpec { +impl SecuritySpec { pub fn protocol_security_target_bits(&self) -> f64 { let pow = self.max_pow_bits.unwrap_or(0); f64::from(self.target_security_bits.saturating_sub(pow)) @@ -55,7 +54,7 @@ impl FoldingFactor { } } - /// Smallest factor across rounds; used by `TuningSpec` validation. + /// Smallest factor across rounds. pub const fn min(&self) -> usize { match self { Self::Constant(f) => *f, @@ -116,24 +115,27 @@ pub type LogInvRate = Tagged; #[allow(clippy::float_cmp)] mod tests { use super::*; - use crate::{ - algebra::{embedding::Identity, fields::Field64}, - hash, - }; + use crate::hash; - fn spec(max_pow_bits: Option) -> SecuritySpec> { + /// Fixture target. 100 is chosen so the expected `target − pow` values in + /// the tests below are round numbers (80, 40, 0) for readability. + const TARGET_BITS: u32 = 100; + + fn spec(max_pow_bits: Option) -> SecuritySpec { SecuritySpec { mode: Mode::ZeroKnowledge, - target_security_bits: 100, + target_security_bits: TARGET_BITS, max_pow_bits, hash_id: hash::BLAKE3, - _embedding: PhantomData, } } #[test] fn none_means_no_pow_credit() { - assert_eq!(spec(None).protocol_security_target_bits(), 100.0); + assert_eq!( + spec(None).protocol_security_target_bits(), + f64::from(TARGET_BITS), + ); } #[test] @@ -146,12 +148,15 @@ mod tests { #[test] fn pow_credit_shifts_analytic_floor() { + // Two below-target PoW budgets: `target − pow` shifts down 1:1. assert_eq!(spec(Some(20)).protocol_security_target_bits(), 80.0); assert_eq!(spec(Some(60)).protocol_security_target_bits(), 40.0); } #[test] fn pow_exceeding_target_saturates_to_zero() { - assert_eq!(spec(Some(200)).protocol_security_target_bits(), 0.0); + // `pow > target` saturates rather than going negative. + let pow_over_target = TARGET_BITS + 100; + assert_eq!(spec(Some(pow_over_target)).protocol_security_target_bits(), 0.0); } } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index a2503150..58e8dbda 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -5,24 +5,25 @@ use crate::{ algebra::{embedding::Embedding, fields::FieldWithSize}, bits::Bits, protocols::{ - irs_commit, + irs_commit::Config as IrsConfig, params::{ - plan::MaskOracleInfo, + protocol_config::MaskOracleInfo, spec::{RoundContext, SecuritySpec}, }, - proof_of_work, sumcheck, + proof_of_work::Config as PowConfig, + sumcheck::{self, Config as SumcheckConfig}, }, }; /// `mask_oracle` is `Some` iff ZK; only C_zk's list size + ℓ_zk are read here. pub fn solve( - spec: &SecuritySpec, + spec: &SecuritySpec, ctx: &RoundContext, - source_irs: &irs_commit::Config, + source_irs: &IrsConfig, mask_oracle: Option, -) -> sumcheck::Config { +) -> SumcheckConfig { let num_rounds = num_sumcheck_rounds(ctx); - let round_pow = proof_of_work::Config::grind_to( + let round_pow = PowConfig::grind_to( Bits::new(f64::from(spec.target_security_bits)), analytic_error_bits(source_irs, mask_oracle), spec.hash_id, @@ -33,7 +34,7 @@ pub fn solve( mask_length: zk_mask_length(), }, }; - sumcheck::Config::new(ctx.vector_size, round_pow, num_rounds, mode) + SumcheckConfig::new(ctx.vector_size, round_pow, num_rounds, mode) } /// Per-sumcheck-round soundness in bits: `min(ε_mca, poly_identity_term)`. @@ -41,7 +42,7 @@ pub fn solve( /// - Standard (degree-2): `log|F| − log|Λ(C)| − 1`. /// - ZK (Lemma 6.5, p.40): `log|F| − log|Λ(C)| − log|Λ(C_zk)| − log ℓ_zk`. pub fn analytic_error_bits( - source_irs: &irs_commit::Config, + source_irs: &IrsConfig, mask_oracle: Option, ) -> Bits { let field_bits = M::Target::field_size_bits(); @@ -58,12 +59,10 @@ pub fn analytic_error_bits( Bits::new(prox_gaps.min(poly_id).max(0.0)) } -pub const fn masks_required(is_zk: bool, ctx: &RoundContext) -> usize { - if is_zk { - num_sumcheck_rounds(ctx) - } else { - 0 - } +/// Number of degree-2 round-polynomial masks sumcheck contributes to C_zk +/// per round (Lemma 6.4): one per sumcheck round. +pub const fn masks_required(ctx: &RoundContext) -> usize { + num_sumcheck_rounds(ctx) } const fn num_sumcheck_rounds(ctx: &RoundContext) -> usize { @@ -76,27 +75,126 @@ const fn zk_mask_length() -> usize { } #[cfg(test)] +#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; use super::*; use crate::protocols::params::{ - irs_commit as params_irs, - spec::OodSampleBudget, + irs_commit as irs_solver, + spec::{MaskCodeMessageLen, Mode, OodSampleBudget}, test_utils::{ - arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, build_minimal_mask_oracle, - TestEmbedding, + arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, assert_pow_closes_gap, + build_minimal_mask_oracle, deterministic_spec, TestEmbedding, TestField, + TestNonIdentityEmbedding, TEST_TARGET_RANGE, }, }; - // Keeps `target - error ≤ 60`, the upper bound `proof_of_work::threshold` enforces. - const TEST_TARGET_RANGE: std::ops::RangeInclusive = 30..=50; - fn build_source_irs( - spec: &SecuritySpec, + spec: &SecuritySpec, ctx: &RoundContext, - ) -> irs_commit::Config { - params_irs::solve(spec, ctx, OodSampleBudget::new(0)) + ) -> IrsConfig { + irs_solver::solve(spec, ctx, OodSampleBudget::new(0)) + } + + /// Smallest pow2 shape that still produces a non-degenerate IRS. + const FIXTURE_LOG_VECTOR_SIZE: u32 = 4; + const FIXTURE_LOG_INV_RATE: u32 = 1; + const FIXTURE_FOLDING_FACTOR: u32 = 2; + + fn fixture_ctx() -> RoundContext { + RoundContext { + round_index: 0, + vector_size: 1 << FIXTURE_LOG_VECTOR_SIZE, + log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FIXTURE_FOLDING_FACTOR, + } + } + + /// Lemma 6.4: ZK round polynomial has 3 coefficients. + #[test] + fn zk_mode_has_three_mask_coefficients() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = fixture_ctx(); + let source_irs = build_source_irs(&spec, &ctx); + let mask_oracle = build_minimal_mask_oracle(&spec); + let config = solve(&spec, &ctx, &source_irs, mask_oracle); + match config.mode { + sumcheck::SumcheckMode::ZeroKnowledge { mask_length } => { + assert_eq!(mask_length, 3); + } + sumcheck::SumcheckMode::Standard => panic!("expected ZK"), + } + } + + /// Standard branch: `min(prox_gaps, log|F| − log|Λ(C)| − 1).max(0)`. + #[test] + fn analytic_error_standard_formula() { + let spec = deterministic_spec(Mode::Standard); + let ctx = fixture_ctx(); + let irs = build_source_irs(&spec, &ctx); + + let got = f64::from(analytic_error_bits::(&irs, None)); + + let field_bits = TestField::field_size_bits(); + let log_list = irs.list_size().log2(); + let prox = irs.rbr_soundness_fold_prox_gaps(); + let expected = prox.min(field_bits - log_list - 1.0).max(0.0); + + assert!( + (got - expected).abs() < 1e-9, + "got {got} vs expected {expected}" + ); + } + + /// ZK branch (Lemma 6.5): `min(prox_gaps, log|F| − log|Λ(C)| − log|Λ(C_zk)| − log ℓ_zk).max(0)`. + #[test] + fn analytic_error_zk_formula() { + // Pow2 values so `log2` is exact. + const C_ZK_LIST_SIZE: f64 = 4.0; + const L_ZK_USIZE: usize = 8; + let log_c_zk_list = C_ZK_LIST_SIZE.log2(); + let log_l_zk = (L_ZK_USIZE as f64).log2(); + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = fixture_ctx(); + let irs = build_source_irs(&spec, &ctx); + let info = MaskOracleInfo { + c_zk_list_size: C_ZK_LIST_SIZE, + l_zk: MaskCodeMessageLen::new(L_ZK_USIZE), + }; + + let got = f64::from(analytic_error_bits::(&irs, Some(info))); + + let field_bits = TestField::field_size_bits(); + let log_list = irs.list_size().log2(); + let prox = irs.rbr_soundness_fold_prox_gaps(); + let expected = prox + .min(field_bits - log_list - log_c_zk_list - log_l_zk) + .max(0.0); + + assert!( + (got - expected).abs() < 1e-9, + "got {got} vs expected {expected}" + ); + } + + /// Oracle large enough to drive `poly_id` strongly negative → clamped to 0. + #[test] + fn analytic_error_clamps_to_zero() { + // `log2(c_zk_list_size) + log2(l_zk) > field_bits` on `Field64`. + const OVERSIZED_LOG_C_ZK_LIST: i32 = 60; + const OVERSIZED_LOG_L_ZK: u32 = 30; + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = fixture_ctx(); + let irs = build_source_irs(&spec, &ctx); + let huge = MaskOracleInfo { + c_zk_list_size: 2_f64.powi(OVERSIZED_LOG_C_ZK_LIST), + l_zk: MaskCodeMessageLen::new(1 << OVERSIZED_LOG_L_ZK), + }; + let bits = f64::from(analytic_error_bits::(&irs, Some(huge))); + assert_eq!(bits, 0.0); } proptest! { @@ -111,23 +209,6 @@ mod tests { prop_assert!(matches!(config.mode, sumcheck::SumcheckMode::Standard)); } - /// Lemma 6.4: ZK round polynomial mask_length = 3. - #[test] - fn zk_mode_has_three_mask_coefficients( - spec in arb_zk_spec(TEST_TARGET_RANGE), - ctx in arb_round_ctx(), - ) { - let source_irs = build_source_irs(&spec, &ctx); - let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle); - match config.mode { - sumcheck::SumcheckMode::ZeroKnowledge { mask_length } => { - prop_assert_eq!(mask_length, 3); - } - sumcheck::SumcheckMode::Standard => prop_assert!(false, "expected ZK"), - } - } - #[test] fn num_rounds_matches_folding_factor( spec in prop_oneof![ @@ -142,18 +223,18 @@ mod tests { prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); } + /// ZK subtracts two non-negative log terms beyond Standard, so the ZK + /// error term cannot exceed the Standard one for any source IRS. #[test] - fn masks_required_matches_mode( - spec in prop_oneof![ - arb_standard_johnson_spec(TEST_TARGET_RANGE), - arb_zk_spec(TEST_TARGET_RANGE), - ], + fn zk_error_le_standard_error( + spec in arb_zk_spec(TEST_TARGET_RANGE), ctx in arb_round_ctx(), ) { - let mask_oracle = build_minimal_mask_oracle(&spec); - let required = masks_required(mask_oracle.is_some(), &ctx); - let expected = if mask_oracle.is_some() { ctx.folding_factor as usize } else { 0 }; - prop_assert_eq!(required, expected); + let irs = build_source_irs(&spec, &ctx); + let mo = build_minimal_mask_oracle(&spec); + let zk = f64::from(analytic_error_bits::(&irs, mo)); + let standard = f64::from(analytic_error_bits::(&irs, None)); + prop_assert!(zk <= standard + 1e-9, "zk {} > standard {}", zk, standard); } /// `analytic_error + pow ≥ target`. @@ -167,15 +248,24 @@ mod tests { ) { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); + let error = analytic_error_bits(&source_irs, mask_oracle); let config = solve(&spec, &ctx, &source_irs, mask_oracle); - let error = f64::from(analytic_error_bits(&source_irs, mask_oracle)); - let pow_bits = f64::from(config.round_pow.difficulty()); - // Tolerance for `proof_of_work::threshold`'s ceil quantization. - prop_assert!( - error + pow_bits >= f64::from(spec.target_security_bits) - 1e-3, - "error {} + pow {} < target {}", - error, pow_bits, spec.target_security_bits, - ); + assert_pow_closes_gap(&spec, error, &config.round_pow); } } + + /// Smoke test: `M::Source ≠ M::Target`, ZK mode. + #[test] + fn solve_works_with_basefield_embedding_zk() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = fixture_ctx(); + let source_irs: IrsConfig = + irs_solver::solve(&spec, &ctx, OodSampleBudget::new(0)); + let info = MaskOracleInfo { + c_zk_list_size: 4.0, + l_zk: MaskCodeMessageLen::new(8), + }; + let config = solve(&spec, &ctx, &source_irs, Some(info)); + assert!(matches!(config.mode, sumcheck::SumcheckMode::ZeroKnowledge { .. })); + } } diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 6186dc96..ab2625da 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -1,6 +1,6 @@ //! Shared test fixtures. -use std::{marker::PhantomData, ops::RangeInclusive}; +use std::ops::RangeInclusive; use proptest::prelude::*; @@ -9,11 +9,17 @@ use crate::{ embedding::{Basefield, Embedding, Identity}, fields::{Field64, Field64_2}, }, + bits::Bits, hash, - protocols::params::{ - irs_commit as params_irs, - plan::MaskOracleInfo, - spec::{LogInvRate, MaskCodeMessageLen, Mode, RoundContext, SecuritySpec}, + protocols::{ + irs_commit::Config as IrsConfig, + params::{ + irs_commit as irs_solver, + derive::compute_t_ood, + protocol_config::MaskOracleInfo, + spec::{LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec}, + }, + proof_of_work::Config as PowConfig, }, }; @@ -23,13 +29,22 @@ pub type TestExtensionField = Field64_2; /// `Source = Field64, Target = Field64_2`. pub type TestNonIdentityEmbedding = Basefield; -pub fn deterministic_spec(mode: Mode) -> SecuritySpec { +/// `target_security_bits` range used by every solver-level proptest. +/// Upper bound keeps `target − analytic_error ≤ 60`, matching the cap in +/// `proof_of_work::threshold`. Lower bound keeps the analytic floor away from 0. +pub const TEST_TARGET_RANGE: RangeInclusive = 30..=50; + +/// Default `target_security_bits` for `deterministic_spec` fixtures. +/// 80 leaves enough analytic headroom on `Field64` (~64-bit) that every +/// sub-protocol solver has a closable gap to target. +pub const FIXTURE_TARGET_BITS: u32 = 80; + +pub fn deterministic_spec(mode: Mode) -> SecuritySpec { SecuritySpec { mode, - target_security_bits: 80, + target_security_bits: FIXTURE_TARGET_BITS, max_pow_bits: None, hash_id: hash::BLAKE3, - _embedding: PhantomData, } } @@ -39,29 +54,28 @@ pub fn deterministic_spec(mode: Mode) -> SecuritySpec { pub fn arb_spec( mode: Mode, target_range: RangeInclusive, -) -> impl Strategy> { +) -> impl Strategy { let pow_strategy = prop_oneof![Just(None), (0u32..=16).prop_map(Some)]; (target_range, pow_strategy).prop_map(move |(target, max_pow)| SecuritySpec { mode, target_security_bits: target, max_pow_bits: max_pow, hash_id: hash::BLAKE3, - _embedding: PhantomData, }) } -pub fn arb_zk_spec( - target_range: RangeInclusive, -) -> impl Strategy> { +pub fn arb_zk_spec(target_range: RangeInclusive) -> impl Strategy { arb_spec(Mode::ZeroKnowledge, target_range) } pub fn arb_standard_johnson_spec( target_range: RangeInclusive, -) -> impl Strategy> { +) -> impl Strategy { arb_spec(Mode::Standard, target_range) } +/// `log_size ∈ 4..=8` (vector_size 16..256) leaves room for ≥ 2·folding_factor +/// post-folding while capping proptest time. pub fn arb_round_ctx() -> impl Strategy { (0usize..=3, 4u32..=8, 1u32..=4, 1u32..=3).prop_map( |(round_index, log_size, log_inv_rate, folding_factor)| RoundContext { @@ -74,14 +88,72 @@ pub fn arb_round_ctx() -> impl Strategy { } /// `None` in Standard; `Some(ℓ_zk=2, c_zk rate 1/2)` in ZK. -pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option { +pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option { if !matches!(spec.mode, Mode::ZeroKnowledge) { return None; } let l_zk = MaskCodeMessageLen::new(2); - let c_zk = params_irs::solve_mask_code(spec, l_zk, 0, LogInvRate::new(1), 2); + let c_zk: IrsConfig = + irs_solver::solve_mask_code(spec, l_zk, 0, LogInvRate::new(1), 2); Some(MaskOracleInfo { c_zk_list_size: c_zk.list_size(), l_zk, }) } + +/// Shared check used by every sub-protocol's `pow_closes_gap_to_target*` test: +/// `analytic_error_bits + pow.difficulty() ≥ target_security_bits` (the `1e-3` +/// tolerance absorbs `proof_of_work::threshold`'s ceil quantization). +pub fn assert_pow_closes_gap( + spec: &SecuritySpec, + analytic: Bits, + pow: &PowConfig, +) { + let error = f64::from(analytic); + let pow_bits = f64::from(pow.difficulty()); + let target = f64::from(spec.target_security_bits); + assert!( + error + pow_bits >= target - 1e-3, + "error {error} + pow {pow_bits} < target {target}", + ); +} + +/// Safety net for the `target_irs ↔ t_ood` loop in [`build_round_io`]. +/// Steady state converges in ≤ 2 iterations (`target.list_size()` is rate-only). +const TARGET_STABILIZATION_MAX_ITER: usize = 8; + +/// Builds a self-consistent `(source, target, t_ood)` triplet matching the +/// per-round shape that `code_switch::solve` expects. +pub fn build_round_io( + spec: &SecuritySpec, + log_inv_rate: u32, + folding_factor: u32, + num_vars: u32, + c_zk_list_size: Option, +) -> (IrsConfig, IrsConfig>, usize) { + let source_ctx = RoundContext { + round_index: 0, + vector_size: 1usize << num_vars, + log_inv_rate, + folding_factor, + }; + let source = irs_solver::solve(spec, &source_ctx, OodSampleBudget::new(0)); + + let target_ctx = RoundContext { + round_index: 1, + vector_size: source.message_length(), + log_inv_rate: log_inv_rate + folding_factor - 1, + folding_factor, + }; + + let mut target = irs_solver::solve(spec, &target_ctx, OodSampleBudget::new(0)); + for _ in 0..TARGET_STABILIZATION_MAX_ITER { + let t_ood = compute_t_ood(spec, &source, target.list_size(), c_zk_list_size); + let new_target = irs_solver::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); + if new_target.codeword_length == target.codeword_length { + return (source, new_target, t_ood); + } + target = new_target; + } + panic!("target IRS did not stabilize"); +} From a75034adcbc135df035a9ba935e9ae090245dbc5 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 15 May 2026 19:28:10 +0530 Subject: [PATCH 11/31] fix: param selection for code switch --- src/protocols/basecase.rs | 18 +++--- src/protocols/code_switch.rs | 26 ++++---- src/protocols/params/basecase.rs | 31 ++++----- src/protocols/params/bounds.rs | 84 ++++++++++++++++--------- src/protocols/params/code_switch.rs | 57 +++++++++-------- src/protocols/params/derive.rs | 64 +++++++++++++++---- src/protocols/params/irs_commit.rs | 26 ++++++-- src/protocols/params/mask_proximity.rs | 70 +++++++-------------- src/protocols/params/protocol_config.rs | 30 ++++++++- src/protocols/params/spec.rs | 5 +- src/protocols/params/sumcheck.rs | 48 +++++++------- src/protocols/params/test_utils.rs | 43 ++++++++++--- 12 files changed, 308 insertions(+), 194 deletions(-) diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 589caba3..22f677e8 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -28,7 +28,7 @@ pub struct Opening { /// Standard / ZeroKnowledge selector for basecase. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum Mode { +pub enum BasecaseMode { Standard, ZeroKnowledge, } @@ -39,7 +39,7 @@ pub enum Mode { pub struct Config { pub commit: irs_commit::Config>, pub sumcheck: sumcheck::Config, - pub mode: Mode, + pub mode: BasecaseMode, pub pow: proof_of_work::Config, } @@ -49,7 +49,7 @@ impl Config { } pub const fn is_zk(&self) -> bool { - matches!(self.mode, Mode::ZeroKnowledge) + matches!(self.mode, BasecaseMode::ZeroKnowledge) } pub fn prove( @@ -126,12 +126,12 @@ impl Config { Standard: Distribution, { match self.mode { - Mode::Standard => { + BasecaseMode::Standard => { prover_state.prover_messages(vector); prover_state.prover_messages(&witness.masks); None } - Mode::ZeroKnowledge => { + BasecaseMode::ZeroKnowledge => { let blinding_vector = random_vector(prover_state.rng(), vector.len()); let blinding_witness = self.commit.commit(prover_state, &[&blinding_vector]); let blinding_inner_product = dot(&blinding_vector, covector); @@ -236,8 +236,8 @@ impl Config { Hash: ProverMessage<[H::U]>, { match self.mode { - Mode::Standard => Ok(None), - Mode::ZeroKnowledge => { + BasecaseMode::Standard => Ok(None), + BasecaseMode::ZeroKnowledge => { let blinding_commitment = self.commit.receive_commitment(verifier_state)?; let blinding_inner_product: F = verifier_state.prover_message()?; // Grind the Theorem 7.1 γ-combination gap before γ is sampled. @@ -274,9 +274,9 @@ mod tests { sumcheck::SumcheckMode::Standard, ), mode: if is_zk { - Mode::ZeroKnowledge + BasecaseMode::ZeroKnowledge } else { - Mode::Standard + BasecaseMode::Standard }, pow: proof_of_work::Config::none(), }) diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 3843b6be..c30fe9a2 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -32,7 +32,7 @@ use crate::{ /// Standard / ZeroKnowledge selector for code-switch. #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] -pub enum Mode { +pub enum CodeSwitchMode { Standard, ZeroKnowledge { message_mask_length: NonZeroUsize }, } @@ -44,7 +44,7 @@ pub enum Mode { pub struct Config { pub source: IrsConfig, pub target: IrsConfig>, - pub mode: Mode, + pub mode: CodeSwitchMode, pub out_domain_samples: usize, pub pow: proof_of_work::Config, } @@ -66,7 +66,7 @@ impl Config { source_config: IrsConfig, target_config: IrsConfig>, out_domain_samples: usize, - mode: Mode, + mode: CodeSwitchMode, pow: proof_of_work::Config, ) -> Self { assert_eq!( @@ -99,7 +99,7 @@ impl Config { source_config.interleaving_depth.is_power_of_two(), "source.interleaving_depth must be a power of 2" ); - if let Mode::ZeroKnowledge { + if let CodeSwitchMode::ZeroKnowledge { message_mask_length, } = &mode { @@ -140,8 +140,8 @@ impl Config { /// Mask oracle length `ℓ_zk`. Returns 0 in Standard mode. pub const fn message_mask_length(&self) -> usize { match &self.mode { - Mode::Standard => 0, - Mode::ZeroKnowledge { + CodeSwitchMode::Standard => 0, + CodeSwitchMode::ZeroKnowledge { message_mask_length, } => message_mask_length.get(), } @@ -149,7 +149,7 @@ impl Config { /// `true` iff the protocol is configured for ZK. pub const fn is_zk(&self) -> bool { - matches!(&self.mode, Mode::ZeroKnowledge { .. }) + matches!(&self.mode, CodeSwitchMode::ZeroKnowledge { .. }) } /// Length of the covector for this code-switch. @@ -263,8 +263,8 @@ impl Config { for &point in ood_points { let f_eval = univariate_evaluate(message, point); let answer = match &self.mode { - Mode::Standard => f_eval, - Mode::ZeroKnowledge { .. } => { + CodeSwitchMode::Standard => f_eval, + CodeSwitchMode::ZeroKnowledge { .. } => { let mask_eval = univariate_evaluate(mask, point); let shift = point.pow([msg_len as u64]); f_eval + shift * mask_eval @@ -286,7 +286,7 @@ impl Config { in_domain_points: &[M::Target], ) { match &self.mode { - Mode::Standard => { + CodeSwitchMode::Standard => { let all_points: Vec<_> = ood_points.iter().chain(in_domain_points).copied().collect(); let pows: Vec<_> = ood_rlc_coeffs @@ -296,7 +296,7 @@ impl Config { .collect(); geometric_accumulate(covector, pows, &all_points); } - Mode::ZeroKnowledge { .. } => { + CodeSwitchMode::ZeroKnowledge { .. } => { geometric_accumulate(covector, ood_rlc_coeffs.to_vec(), ood_points); geometric_accumulate( &mut covector[..self.source.masked_message_length()], @@ -486,12 +486,12 @@ mod tests { // masks fold to a single length-mask_length chunk). let r = source.mask_length(); let mode = if zk { - Mode::ZeroKnowledge { + CodeSwitchMode::ZeroKnowledge { message_mask_length: NonZeroUsize::new(r + fresh_s_len) .expect("ZK ⇒ r + fresh_s_len > 0"), } } else { - Mode::Standard + CodeSwitchMode::Standard }; Self::new( source.clone(), diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index fb978720..2124a95f 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -49,13 +49,13 @@ pub fn solve( ); let mode = match spec.mode { - SpecMode::Standard => basecase::Mode::Standard, - SpecMode::ZeroKnowledge => basecase::Mode::ZeroKnowledge, + SpecMode::Standard => basecase::BasecaseMode::Standard, + SpecMode::ZeroKnowledge => basecase::BasecaseMode::ZeroKnowledge, }; let pow = match mode { - basecase::Mode::Standard => PowConfig::none(), - basecase::Mode::ZeroKnowledge => { + basecase::BasecaseMode::Standard => PowConfig::none(), + basecase::BasecaseMode::ZeroKnowledge => { PowConfig::grind_to(target_bits, analytic_error_bits(&commit), spec.hash_id) } }; @@ -82,10 +82,16 @@ mod tests { use super::*; use crate::protocols::params::test_utils::{ - arb_standard_johnson_spec, arb_zk_spec, assert_pow_closes_gap, deterministic_spec, - TestField, TEST_TARGET_RANGE, + arb_standard_johnson_spec, arb_zk_spec, assert_close, assert_pow_closes_gap, + deterministic_spec, TestField, TEST_TARGET_RANGE, }; + /// `vector_size = 16` (2^4) and `log_inv_rate = 2` give a small but + /// non-degenerate basecase IRS. `folding_factor = 0` is the basecase + /// invariant (no folding, message_length = vector_size). + const FIXTURE_VECTOR_SIZE: usize = 16; + const FIXTURE_LOG_INV_RATE: u32 = 2; + fn arb_dims() -> impl Strategy { (1u32..=4, 1u32..=3) } @@ -103,8 +109,8 @@ mod tests { let spec = deterministic_spec(Mode::ZeroKnowledge); let ctx = RoundContext { round_index: 0, - vector_size: 16, - log_inv_rate: 2, + vector_size: FIXTURE_VECTOR_SIZE, + log_inv_rate: FIXTURE_LOG_INV_RATE, folding_factor: 0, }; let commit: IrsConfig> = @@ -115,10 +121,7 @@ mod tests { let log_list = commit.list_size().log2(); let expected = (field_bits - log_list).max(0.0); - assert!( - (got - expected).abs() < 1e-9, - "got {got} vs expected {expected}", - ); + assert_close(got, expected); } proptest! { @@ -128,7 +131,7 @@ mod tests { (log_size, log_inv_rate) in arb_dims(), ) { let config = solve::(&spec, 1usize << log_size, log_inv_rate); - prop_assert!(matches!(config.mode, basecase::Mode::Standard)); + prop_assert!(matches!(config.mode, basecase::BasecaseMode::Standard)); prop_assert_eq!(config.commit.interleaving_depth, 1); prop_assert_eq!(config.commit.num_vectors, 1); prop_assert_eq!(config.commit.vector_size, config.sumcheck.initial_size); @@ -140,7 +143,7 @@ mod tests { (log_size, log_inv_rate) in arb_dims(), ) { let config = solve::(&spec, 1usize << log_size, log_inv_rate); - prop_assert!(matches!(config.mode, basecase::Mode::ZeroKnowledge)); + prop_assert!(matches!(config.mode, basecase::BasecaseMode::ZeroKnowledge)); prop_assert!(config.commit.mask_length() > 0); } diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index 5fb8aa2e..9a897161 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -102,8 +102,11 @@ pub fn pow_bits_to_close_gap(target_security_bits: f64, achieved_security_bits: #[allow(clippy::float_cmp)] mod tests { use super::*; + use crate::protocols::params::test_utils::assert_close; - const EPS: f64 = 1e-9; + /// Tighter tolerance for tests doing relative-error checks (`(got - exp).abs() / exp`) + /// against an alternative-derived expected value with the same operations. + const TIGHT_EPS: f64 = 1e-12; /// Johnson list size: `|Λ| = 1 / (2η√ρ)`, log₂ form. Hand-evaluated at /// `log_inv_rate = 2`, `η = 0.1`: `−1 − log₂(0.1) + 1 ≈ 3.3219`. @@ -111,7 +114,7 @@ mod tests { fn list_size_log2_johnson_formula() { let got = list_size_log2(2.0, 0.1); let expected = -1.0 - 0.1_f64.log2() + 0.5 * 2.0; - assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + assert_close(got, expected); } /// Unique-decoding regime (`η = 0`) gives `|Λ| = 1`, i.e. log = 0. @@ -128,7 +131,7 @@ mod tests { let got = johnson_list_size(b); let expected = 10.0 * 2_f64.powf(b); assert!( - (got - expected).abs() / expected < 1e-12, + (got - expected).abs() / expected < TIGHT_EPS, "log_inv_rate={b}: got {got} vs {expected}", ); } @@ -144,21 +147,28 @@ mod tests { hash, protocols::irs_commit::{Config, IrsMode}, }; - let log_inv_rate = 2; + // All shape values except rate are placeholders — `list_size()` depends + // only on `johnson_slack`, which is itself a function of rate. + const PLACEHOLDER_SECURITY_TARGET_BITS: f64 = 80.0; + const PLACEHOLDER_NUM_VECTORS: usize = 2; + const PLACEHOLDER_VECTOR_SIZE: usize = 8; + const PLACEHOLDER_INTERLEAVING_DEPTH: usize = 1; + const LOG_INV_RATE: u32 = 2; + let config: Config> = Config::new( - 80.0, - false, + PLACEHOLDER_SECURITY_TARGET_BITS, + false, // unique_decoding hash::BLAKE3, - 2, - 8, - 1, - 2_f64.powf(-f64::from(log_inv_rate)), + PLACEHOLDER_NUM_VECTORS, + PLACEHOLDER_VECTOR_SIZE, + PLACEHOLDER_INTERLEAVING_DEPTH, + 2_f64.powf(-f64::from(LOG_INV_RATE)), IrsMode::Standard, ); - let got = johnson_list_size(f64::from(log_inv_rate)); + let got = johnson_list_size(f64::from(LOG_INV_RATE)); let expected = config.list_size(); assert!( - (got - expected).abs() / expected < 1e-12, + (got - expected).abs() / expected < TIGHT_EPS, "bounds helper ({got}) vs Config::list_size ({expected})", ); } @@ -166,9 +176,13 @@ mod tests { /// OOD per-sample Schwartz–Zippel: `log₂((k−1) / |F|) = log₂(k−1) − field_bits`. #[test] fn ood_per_sample_log2_formula() { - let got = ood_per_sample_log2(129, 64.0); - let expected = 128_f64.log2() - 64.0; - assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + // `k = 129` so `k − 1 = 128 = 2^7` for exact `log2`. + const K: usize = 129; + const FIELD_BITS: f64 = 64.0; + + let got = ood_per_sample_log2(K, FIELD_BITS); + let expected = ((K - 1) as f64).log2() - FIELD_BITS; + assert_close(got, expected); // (k−1)/|F| < 1 for sane parameters ⇒ log is negative. assert!(got < 0.0); } @@ -180,7 +194,7 @@ mod tests { let got = one_minus_distance_log2(log_inv_rate, 0.0); let rho = 2_f64.powf(-log_inv_rate); let expected = f64::midpoint(1.0, rho).log2(); - assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + assert_close(got, expected); } /// `1 − δ` in Johnson regime: `√ρ + η`. @@ -191,37 +205,47 @@ mod tests { let got = one_minus_distance_log2(log_inv_rate, eta); let rho = 2_f64.powf(-log_inv_rate); let expected = (rho.sqrt() + eta).log2(); - assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + assert_close(got, expected); } + /// MCA fixture — `message_length = 16 = 2^4` and `log_inv_rate = 2` give + /// exact `log2(k) = 4`. `field_bits = 64.0` for Field64. + const MCA_MESSAGE_LENGTH: usize = 16; + const MCA_LOG_INV_RATE: f64 = 2.0; + const MCA_FIELD_BITS: f64 = 64.0; + /// MCA error, unique-decoding branch: `log k + log_inv_rate − field_bits`. #[test] fn eps_mca_log2_unique_decoding_formula() { let p = CodeParams { - log_inv_rate: 2.0, + log_inv_rate: MCA_LOG_INV_RATE, johnson_slack: 0.0, - message_length: 16, - field_bits: 64.0, + message_length: MCA_MESSAGE_LENGTH, + field_bits: MCA_FIELD_BITS, }; let got = eps_mca_log2(&p); - let expected = 16_f64.log2() + 2.0 - 64.0; - assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + let expected = (MCA_MESSAGE_LENGTH as f64).log2() + MCA_LOG_INV_RATE - MCA_FIELD_BITS; + assert_close(got, expected); } /// MCA error, Johnson branch: `7·log₂10 + 3.5·log_inv_rate + 2·log k − field_bits`. #[test] fn eps_mca_log2_johnson_formula() { + // `η = 0.1` stays within the debug assertion's slack range: + // `η.log2() ≥ −(0.5·log_inv_rate + log₂10 + 1) ≈ −5.32`. + const JOHNSON_SLACK: f64 = 0.1; + let p = CodeParams { - log_inv_rate: 2.0, - // Stay within the debug assertion's slack range: johnson_slack.log2() ≥ - // -(0.5·log_inv_rate + log₂10 + 1) ≈ -5.32. - johnson_slack: 0.1, - message_length: 16, - field_bits: 64.0, + log_inv_rate: MCA_LOG_INV_RATE, + johnson_slack: JOHNSON_SLACK, + message_length: MCA_MESSAGE_LENGTH, + field_bits: MCA_FIELD_BITS, }; let got = eps_mca_log2(&p); - let expected = 7.0 * LOG2_10 + 3.5 * 2.0 + 2.0 * 16_f64.log2() - 64.0; - assert!((got - expected).abs() < EPS, "got {got} vs {expected}"); + let expected = + 7.0 * LOG2_10 + 3.5 * MCA_LOG_INV_RATE + 2.0 * (MCA_MESSAGE_LENGTH as f64).log2() + - MCA_FIELD_BITS; + assert_close(got, expected); } /// `pow_bits_to_close_gap` clamps negative gaps to zero (no anti-grind). diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index cf270dc3..85108eba 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -29,7 +29,7 @@ pub fn solve( t_ood: usize, mask_oracle: Option, ) -> CodeSwitchConfig { - let mode = mask_oracle.map_or(code_switch::Mode::Standard, |info| { + let mode = mask_oracle.map_or(code_switch::CodeSwitchMode::Standard, |info| { let l_zk = info.l_zk.get(); assert!( l_zk >= source.mask_length() + t_ood, @@ -37,7 +37,7 @@ pub fn solve( source.mask_length(), t_ood, ); - code_switch::Mode::ZeroKnowledge { + code_switch::CodeSwitchMode::ZeroKnowledge { message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), } }); @@ -69,10 +69,13 @@ pub fn analytic_error_bits( let field_bits = M::Target::field_size_bits(); let combined_list = target.list_size() * mask_oracle.map_or(1.0, |info| info.c_zk_list_size); - let degree = match mask_oracle { - Some(_) => source.masked_message_length() + t_ood, - None => source.message_length(), - }; + // OOD polynomial is over witness `[f; r_C; s]` of length `ℓ + ℓ_zk` (ZK) or + // `ℓ` (Standard). The `s`-tail is sampled at full length `ℓ_zk − r` (not + // just `t_ood`), so degree must use the realized `ℓ_zk`, not `r + t_ood`. + let degree = mask_oracle.map_or_else( + || source.message_length(), + |info| source.message_length() + info.l_zk.get(), + ); #[allow(clippy::cast_precision_loss)] let t_ood_f = t_ood as f64; @@ -105,12 +108,12 @@ mod tests { use super::*; use crate::protocols::params::{ - irs_commit as irs_solver, derive::{compute_l_zk, compute_t_ood}, + irs_commit as irs_solver, spec::{LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec}, test_utils::{ arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, - assert_pow_closes_gap, build_round_io, deterministic_spec, TestEmbedding, + assert_close, assert_pow_closes_gap, build_round_io, deterministic_spec, TestEmbedding, TestExtensionField, TestField, TestNonIdentityEmbedding, TEST_TARGET_RANGE, }, }; @@ -169,10 +172,7 @@ mod tests { let comb = field_bits - (count as f64).log2() - target_list.log2(); let expected = ood.min(comb).max(0.0); - assert!( - (got - expected).abs() < 1e-9, - "got {got} vs expected {expected}", - ); + assert_close(got, expected); } /// ZK OOD bound: combined list `L = target × c_zk`, masked degree `ℓ + r + t_ood`, @@ -206,7 +206,7 @@ mod tests { let field_bits = ::field_size_bits(); let target_list = target.list_size(); let combined_list = target_list * C_ZK_LIST_SIZE; - let degree = source.masked_message_length() + t_ood; + let degree = source.message_length() + L_ZK_USIZE; let log_deg_m1 = ((degree - 1) as f64).log2(); let l_choose_2 = combined_list * (combined_list - 1.0) / 2.0; let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); @@ -214,10 +214,7 @@ mod tests { let comb = field_bits - (count as f64).log2() - target_list.log2() - C_ZK_LIST_SIZE.log2(); let expected = ood.min(comb).max(0.0); - assert!( - (got - expected).abs() < 1e-9, - "got {got} vs expected {expected}", - ); + assert_close(got, expected); } proptest! { @@ -229,7 +226,7 @@ mod tests { let (source, target, t_ood) = build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); let config = solve(&spec, source, target, t_ood, None); - prop_assert!(matches!(config.mode, code_switch::Mode::Standard)); + prop_assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); prop_assert!(config.out_domain_samples >= 1); } @@ -337,17 +334,22 @@ mod tests { let t_ood = compute_t_ood(&spec, &source, target.list_size(), None); let config = solve(&spec, source, target, t_ood, None); - assert!(matches!(config.mode, code_switch::Mode::Standard)); + assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); } + /// Placeholder mask-oracle list size for the smoke test. Pow2 keeps + /// `log2` exact and matches `analytic_error_zk_formula`'s fixture. + const SMOKE_C_ZK_LIST_SIZE: f64 = 4.0; + /// Cap on the smoke-test `t_ood ↔ (source, target)` fixed-point. Matches the + /// loop bound used in `build_round_io`; in practice converges in 1–3 iters. + const SMOKE_FIXED_POINT_MAX_ITER: usize = 8; + /// Smoke test: `M::Source ≠ M::Target`, ZK mode. #[test] fn solve_works_with_basefield_embedding_zk() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); - let c_zk_list_size = 4.0; - // `build_round_io` for the non-identity embedding. let mut t_ood = 0; let mut source = irs_solver::solve::( &spec, @@ -359,8 +361,13 @@ mod tests { &target_ctx, OodSampleBudget::new(0), ); - for _ in 0..8 { - let new_t_ood = compute_t_ood(&spec, &source, target.list_size(), Some(c_zk_list_size)); + for _ in 0..SMOKE_FIXED_POINT_MAX_ITER { + let new_t_ood = compute_t_ood( + &spec, + &source, + target.list_size(), + Some(SMOKE_C_ZK_LIST_SIZE), + ); if new_t_ood == t_ood { break; } @@ -370,13 +377,13 @@ mod tests { } let mask_oracle = MaskOracleInfo { - c_zk_list_size, + c_zk_list_size: SMOKE_C_ZK_LIST_SIZE, l_zk: MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()), }; let config = solve(&spec, source, target, t_ood, Some(mask_oracle)); assert!(matches!( config.mode, - code_switch::Mode::ZeroKnowledge { .. } + code_switch::CodeSwitchMode::ZeroKnowledge { .. } )); } } diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index ddadc311..5fe88d5e 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -297,7 +297,10 @@ pub(super) fn compute_t_ood( let r = source.mask_length(); for _ in 0..MAX_ITER { - let new_t_ood = solve_for_degree(message_length + r + t_ood); + // Polynomial degree = `ℓ + ℓ_zk` where `ℓ_zk = next_pow2(r + t_ood)` + // (Lemma 9.3). Using `r + t_ood` would under-count when not pow2. + let l_zk = (r + t_ood).next_power_of_two(); + let new_t_ood = solve_for_degree(message_length + l_zk); if new_t_ood == t_ood { return t_ood; } @@ -317,7 +320,7 @@ mod tests { protocols::params::{ bounds::SoundnessBounded, spec::FoldingFactor, - test_utils::{assert_pow_closes_gap, TestEmbedding}, + test_utils::{assert_close, assert_pow_closes_gap, TestEmbedding}, }, }; @@ -482,7 +485,7 @@ mod tests { assert!(plan.rounds.is_empty()); assert!(matches!( plan.basecase.mode, - crate::protocols::basecase::Mode::ZeroKnowledge + crate::protocols::basecase::BasecaseMode::ZeroKnowledge )); } @@ -509,7 +512,7 @@ mod tests { )); if matches!( plan.basecase.mode, - crate::protocols::basecase::Mode::ZeroKnowledge + crate::protocols::basecase::BasecaseMode::ZeroKnowledge ) { sumcheck.min(f64::from(basecase_solver::analytic_error_bits( &plan.basecase.commit, @@ -534,7 +537,7 @@ mod tests { .map(|r| f64::from(r.analytic_bits())) .fold(f64::INFINITY, f64::min); let expected = min_round.min(basecase_min_bits(&plan)); - assert!((bits - expected).abs() < 1e-9, "{bits} vs {expected}"); + assert_close(bits, expected); } #[test] @@ -564,10 +567,7 @@ mod tests { .map(|r| f64::from(r.analytic_bits())) .fold(f64::INFINITY, f64::min); let expected = mo_floor.min(min_round).min(basecase_min_bits(&plan)); - assert!( - (plan_bits - expected).abs() < 1e-9, - "{plan_bits} vs {expected}" - ); + assert_close(plan_bits, expected); } #[test] @@ -579,7 +579,7 @@ mod tests { ); assert!(matches!( plan.basecase.mode, - crate::protocols::basecase::Mode::ZeroKnowledge + crate::protocols::basecase::BasecaseMode::ZeroKnowledge )); assert_eq!(plan.basecase.commit.interleaving_depth, 1); // Sumcheck folds basecase to size 1. @@ -594,6 +594,44 @@ mod tests { /// Comfortably above `TIGHT_POW_BUDGET_BITS`. const OVER_BUDGET_INJECTED_BITS: f64 = 50.0; + /// Bound 3 + Bound 7: HVZK privacy error in bits matches the closed-form + /// `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` over ZK rounds. + #[test] + fn privacy_error_bits_matches_bound_3_sum() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ); + let field_bits = ::field_size_bits(); + let mut expected_total = 0.0_f64; + for r in &plan.rounds { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { + panic!("expected ZK round"); + }; + let t = t_ood.get() as f64; + expected_total += 2_f64.powf(f64::midpoint(t * t, t).log2() - field_bits); + } + let expected_bits = -expected_total.log2(); + let got = f64::from(plan.privacy_error_bits()); + assert_close(got, expected_bits); + } + + /// Standard-mode plans have no HVZK claim — `privacy_error_bits` returns + /// the spec's `target_security_bits` as a sentinel. + #[test] + fn privacy_error_bits_standard_returns_target_sentinel() { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ); + assert_eq!( + f64::from(plan.privacy_error_bits()), + f64::from(PLAN_FIXTURE_TARGET_BITS), + ); + } + /// Derived plans must satisfy their own `max_pow_bits` budget. #[test] fn check_pow_bits_passes_on_derived_plan() { @@ -669,7 +707,7 @@ mod tests { // γ-slot is ZK-only. if matches!( plan.basecase.mode, - crate::protocols::basecase::Mode::ZeroKnowledge + crate::protocols::basecase::BasecaseMode::ZeroKnowledge ) { assert_pow_closes_gap( spec, @@ -713,7 +751,7 @@ mod tests { } prop_assert!(matches!( plan.basecase.mode, - crate::protocols::basecase::Mode::Standard + crate::protocols::basecase::BasecaseMode::Standard )); prop_assert_eq!(plan.basecase.commit.interleaving_depth, 1); } @@ -746,7 +784,7 @@ mod tests { } prop_assert!(matches!( plan.basecase.mode, - crate::protocols::basecase::Mode::ZeroKnowledge + crate::protocols::basecase::BasecaseMode::ZeroKnowledge )); } diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 9dc18055..51ebd0d6 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -140,12 +140,18 @@ mod tests { let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 3); } + /// `irs_commit::solve` doesn't grind PoW, so this range can sit higher than + /// the shared `TEST_TARGET_RANGE` (which is capped at 50 to keep the PoW + /// gap below the 60-bit threshold). 80..=128 covers production-realistic + /// target sizes. + const IRS_TARGET_RANGE: std::ops::RangeInclusive = 80..=128; + fn arb_zk_spec_default() -> impl Strategy { - arb_zk_spec(80..=128) + arb_zk_spec(IRS_TARGET_RANGE) } fn arb_standard_spec() -> impl Strategy { - arb_spec(Mode::Standard, 80..=128) + arb_spec(Mode::Standard, IRS_TARGET_RANGE) } proptest! { @@ -175,6 +181,14 @@ mod tests { } } + /// Smoke-test fixture: 64-element vector folded by 2 at rate 1/2 — small + /// but produces a non-degenerate IRS for the non-identity embedding. + const SMOKE_VECTOR_SIZE: usize = 64; + const SMOKE_LOG_INV_RATE: u32 = 1; + const SMOKE_FOLDING_FACTOR: u32 = 2; + /// Arbitrary > 0 so the ZK mask sizing exercises the OOD path. + const SMOKE_OOD_BUDGET: usize = 2; + /// Smoke test: `M::Source ≠ M::Target`, ZK path. Mask sizing depends only /// on the target field (via `field_size_bits`), but the generic embedding /// still flows through the Config and must compile + execute. @@ -183,12 +197,12 @@ mod tests { let spec = deterministic_spec(Mode::ZeroKnowledge); let ctx = RoundContext { round_index: 0, - vector_size: 64, - log_inv_rate: 1, - folding_factor: 2, + vector_size: SMOKE_VECTOR_SIZE, + log_inv_rate: SMOKE_LOG_INV_RATE, + folding_factor: SMOKE_FOLDING_FACTOR, }; let config: IrsConfig = - solve(&spec, &ctx, OodSampleBudget::new(2)); + solve(&spec, &ctx, OodSampleBudget::new(SMOKE_OOD_BUDGET)); assert!(config.mask_length() > 0); } } diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index 9c7236a2..47070f2d 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -50,11 +50,10 @@ mod tests { protocols::{ irs_commit::IrsMode, params::{ - irs_commit as irs_solver, - spec::{LogInvRate, MaskCodeMessageLen, Mode}, + spec::Mode, test_utils::{ - arb_zk_spec, assert_pow_closes_gap, deterministic_spec, TestEmbedding, - TEST_TARGET_RANGE, + arb_zk_spec, assert_close, assert_pow_closes_gap, build_test_c_zk, + deterministic_spec, TEST_TARGET_RANGE, }, }, }, @@ -63,42 +62,34 @@ mod tests { /// γ-combination (Lemma 7.4): `log|F| − log(num_masks · (deg − 1))`, /// `deg = c_zk.masked_message_length()`. With `num_masks = 0` or `deg ≤ 1` /// the bound saturates to `field_bits`. + /// Pow2 `l_zk = 8` gives exact `log2(deg − 1) = log2(7) ≈ 2.81`. + /// `num_masks = 3` is the smallest count > 1 (so `num_masks · (deg − 1) > 1` + /// and the formula doesn't saturate). `log_inv_rate = 1` is the minimum + /// rate the C_zk solver accepts. + const FIXTURE_L_ZK: usize = 8; + const FIXTURE_NUM_MASKS: usize = 3; + const FIXTURE_LOG_INV_RATE: u32 = 1; + #[test] fn analytic_error_formula() { let spec = deterministic_spec(Mode::ZeroKnowledge); - let num_masks = 3_usize; - let c_zk = irs_solver::solve_mask_code::( - &spec, - MaskCodeMessageLen::new(8), - 0, - LogInvRate::new(1), - 2 * num_masks, - ); + let c_zk = build_test_c_zk(&spec, FIXTURE_L_ZK, FIXTURE_LOG_INV_RATE, FIXTURE_NUM_MASKS); - let got = f64::from(analytic_error_bits(&c_zk, num_masks)); + let got = f64::from(analytic_error_bits(&c_zk, FIXTURE_NUM_MASKS)); let field_bits = ::field_size_bits(); let deg = c_zk.masked_message_length(); - let log_combined = ((num_masks * (deg - 1)) as f64).log2(); + let log_combined = ((FIXTURE_NUM_MASKS * (deg - 1)) as f64).log2(); let expected = (field_bits - log_combined).max(0.0); - assert!( - (got - expected).abs() < 1e-9, - "got {got} vs expected {expected}", - ); + assert_close(got, expected); } /// Degenerate inputs (`num_masks == 0` or `deg ≤ 1`) saturate to `field_bits`. #[test] fn analytic_error_saturates_when_no_masks() { let spec = deterministic_spec(Mode::ZeroKnowledge); - let c_zk = irs_solver::solve_mask_code::( - &spec, - MaskCodeMessageLen::new(2), - 0, - LogInvRate::new(1), - 2, - ); + let c_zk = build_test_c_zk(&spec, 2, 1, 1); let bits = f64::from(analytic_error_bits(&c_zk, 0)); let field_bits = ::field_size_bits(); assert_eq!(bits, field_bits.max(0.0)); @@ -112,14 +103,7 @@ mod tests { num_masks in 1usize..=8, l_zk_log in 1u32..=5, ) { - let l_zk = MaskCodeMessageLen::new(1usize << l_zk_log); - let c_zk = irs_solver::solve_mask_code::( - &spec, - l_zk, - 0, - LogInvRate::new(log_inv_rate), - 2 * num_masks, - ); + let c_zk = build_test_c_zk(&spec, 1usize << l_zk_log, log_inv_rate, num_masks); let config = solve(&spec, c_zk, num_masks); prop_assert_eq!(config.num_masks, num_masks); prop_assert_eq!(config.c_zk_commit.num_vectors, 2 * num_masks); @@ -134,31 +118,21 @@ mod tests { num_masks in 1usize..=8, l_zk_log in 1u32..=5, ) { - let l_zk = MaskCodeMessageLen::new(1usize << l_zk_log); - let c_zk = irs_solver::solve_mask_code::( - &spec, - l_zk, - 0, - LogInvRate::new(log_inv_rate), - 2 * num_masks, - ); + let c_zk = build_test_c_zk(&spec, 1usize << l_zk_log, log_inv_rate, num_masks); let analytic = analytic_error_bits(&c_zk, num_masks); let config = solve(&spec, c_zk, num_masks); assert_pow_closes_gap(&spec, analytic, &config.pow); } } + /// `mask_proximity::solve` requires `c_zk.num_vectors == 2 · num_masks`. + /// Builds C_zk for `num_masks = 2` (so `num_vectors = 4`), then calls + /// `solve` with `num_masks = 3` to trip the assertion. #[test] #[should_panic(expected = "c_zk.num_vectors must be 2 * num_masks")] fn solve_rejects_mismatched_num_vectors() { let spec = deterministic_spec(Mode::ZeroKnowledge); - let c_zk = irs_solver::solve_mask_code::( - &spec, - MaskCodeMessageLen::new(2), - 0, - LogInvRate::new(1), - 4, - ); + let c_zk = build_test_c_zk(&spec, 2, 1, 2); let _ = solve(&spec, c_zk, 3); } diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index a8e492f5..6959c798 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -8,7 +8,10 @@ use ark_ff::Field; use crate::{ - algebra::embedding::{Embedding, Identity}, + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, bits::Bits, protocols::{ basecase::{self, Config as BasecaseConfig}, @@ -53,6 +56,29 @@ impl ProtocolConfig { } within(&self.basecase.sumcheck.round_pow) && within(&self.basecase.pow) } + + /// HVZK privacy error in bits, summed across ZK rounds: + /// `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` (Bound 3 + Bound 7). + /// Standard-mode plans return `target_security_bits` as a sentinel — + /// HVZK isn't claimed when there are no ZK rounds. + pub fn privacy_error_bits(&self) -> Bits { + let field_bits = ::field_size_bits(); + let mut total_error = 0.0_f64; + for r in &self.rounds { + if let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode { + #[allow(clippy::cast_precision_loss)] + let t = t_ood.get() as f64; + // ζ_ze ≤ (t_ood² + t_ood) / (2|F|). Compute in log space to + // stay numerically stable for large field_bits. + let log_err = f64::midpoint(t * t, t).log2() - field_bits; + total_error += 2_f64.powf(log_err); + } + } + if total_error == 0.0 { + return Bits::new(f64::from(self.security.target_security_bits)); + } + Bits::new((-total_error.log2()).max(0.0)) + } } impl SoundnessBounded for ProtocolConfig { @@ -70,7 +96,7 @@ impl SoundnessBounded for ProtocolConfig { &self.basecase.commit, None, ))); - if matches!(self.basecase.mode, basecase::Mode::ZeroKnowledge) { + if matches!(self.basecase.mode, basecase::BasecaseMode::ZeroKnowledge) { min_bits = min_bits.min(f64::from(basecase_solver::analytic_error_bits( &self.basecase.commit, ))); diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 03177efc..504fb29d 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -157,6 +157,9 @@ mod tests { fn pow_exceeding_target_saturates_to_zero() { // `pow > target` saturates rather than going negative. let pow_over_target = TARGET_BITS + 100; - assert_eq!(spec(Some(pow_over_target)).protocol_security_target_bits(), 0.0); + assert_eq!( + spec(Some(pow_over_target)).protocol_security_target_bits(), + 0.0 + ); } } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 58e8dbda..009713b9 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -84,16 +84,18 @@ mod tests { irs_commit as irs_solver, spec::{MaskCodeMessageLen, Mode, OodSampleBudget}, test_utils::{ - arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, assert_pow_closes_gap, - build_minimal_mask_oracle, deterministic_spec, TestEmbedding, TestField, - TestNonIdentityEmbedding, TEST_TARGET_RANGE, + arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, assert_close, + assert_pow_closes_gap, build_minimal_mask_oracle, deterministic_spec, TestEmbedding, + TestField, TestNonIdentityEmbedding, EPS, TEST_TARGET_RANGE, }, }; - fn build_source_irs( - spec: &SecuritySpec, - ctx: &RoundContext, - ) -> IrsConfig { + /// Mask-oracle fixture used by the formula tests + the ZK smoke test. + /// Both values are pow2 so `log2` is exact (no f64 drift in expected-vs-got). + const FIXTURE_C_ZK_LIST_SIZE: f64 = 4.0; + const FIXTURE_L_ZK: usize = 8; + + fn build_source_irs(spec: &SecuritySpec, ctx: &RoundContext) -> IrsConfig { irs_solver::solve(spec, ctx, OodSampleBudget::new(0)) } @@ -141,27 +143,21 @@ mod tests { let prox = irs.rbr_soundness_fold_prox_gaps(); let expected = prox.min(field_bits - log_list - 1.0).max(0.0); - assert!( - (got - expected).abs() < 1e-9, - "got {got} vs expected {expected}" - ); + assert_close(got, expected); } /// ZK branch (Lemma 6.5): `min(prox_gaps, log|F| − log|Λ(C)| − log|Λ(C_zk)| − log ℓ_zk).max(0)`. #[test] fn analytic_error_zk_formula() { - // Pow2 values so `log2` is exact. - const C_ZK_LIST_SIZE: f64 = 4.0; - const L_ZK_USIZE: usize = 8; - let log_c_zk_list = C_ZK_LIST_SIZE.log2(); - let log_l_zk = (L_ZK_USIZE as f64).log2(); + let log_c_zk_list = FIXTURE_C_ZK_LIST_SIZE.log2(); + let log_l_zk = (FIXTURE_L_ZK as f64).log2(); let spec = deterministic_spec(Mode::ZeroKnowledge); let ctx = fixture_ctx(); let irs = build_source_irs(&spec, &ctx); let info = MaskOracleInfo { - c_zk_list_size: C_ZK_LIST_SIZE, - l_zk: MaskCodeMessageLen::new(L_ZK_USIZE), + c_zk_list_size: FIXTURE_C_ZK_LIST_SIZE, + l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; let got = f64::from(analytic_error_bits::(&irs, Some(info))); @@ -173,10 +169,7 @@ mod tests { .min(field_bits - log_list - log_c_zk_list - log_l_zk) .max(0.0); - assert!( - (got - expected).abs() < 1e-9, - "got {got} vs expected {expected}" - ); + assert_close(got, expected); } /// Oracle large enough to drive `poly_id` strongly negative → clamped to 0. @@ -234,7 +227,7 @@ mod tests { let mo = build_minimal_mask_oracle(&spec); let zk = f64::from(analytic_error_bits::(&irs, mo)); let standard = f64::from(analytic_error_bits::(&irs, None)); - prop_assert!(zk <= standard + 1e-9, "zk {} > standard {}", zk, standard); + prop_assert!(zk <= standard + EPS, "zk {} > standard {}", zk, standard); } /// `analytic_error + pow ≥ target`. @@ -262,10 +255,13 @@ mod tests { let source_irs: IrsConfig = irs_solver::solve(&spec, &ctx, OodSampleBudget::new(0)); let info = MaskOracleInfo { - c_zk_list_size: 4.0, - l_zk: MaskCodeMessageLen::new(8), + c_zk_list_size: FIXTURE_C_ZK_LIST_SIZE, + l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; let config = solve(&spec, &ctx, &source_irs, Some(info)); - assert!(matches!(config.mode, sumcheck::SumcheckMode::ZeroKnowledge { .. })); + assert!(matches!( + config.mode, + sumcheck::SumcheckMode::ZeroKnowledge { .. } + )); } } diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index ab2625da..d2a818a0 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -14,10 +14,12 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ - irs_commit as irs_solver, derive::compute_t_ood, + irs_commit as irs_solver, protocol_config::MaskOracleInfo, - spec::{LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec}, + spec::{ + LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + }, }, proof_of_work::Config as PowConfig, }, @@ -39,6 +41,11 @@ pub const TEST_TARGET_RANGE: RangeInclusive = 30..=50; /// sub-protocol solver has a closable gap to target. pub const FIXTURE_TARGET_BITS: u32 = 80; +/// Tolerance for `(got - expected).abs() < EPS` checks on formula-reconstruction +/// tests. `1e-9` is well above the `f64` rounding noise on log/sum expressions +/// used in the analytic-error formulas. +pub const EPS: f64 = 1e-9; + pub fn deterministic_spec(mode: Mode) -> SecuritySpec { SecuritySpec { mode, @@ -104,11 +111,7 @@ pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option /// Shared check used by every sub-protocol's `pow_closes_gap_to_target*` test: /// `analytic_error_bits + pow.difficulty() ≥ target_security_bits` (the `1e-3` /// tolerance absorbs `proof_of_work::threshold`'s ceil quantization). -pub fn assert_pow_closes_gap( - spec: &SecuritySpec, - analytic: Bits, - pow: &PowConfig, -) { +pub fn assert_pow_closes_gap(spec: &SecuritySpec, analytic: Bits, pow: &PowConfig) { let error = f64::from(analytic); let pow_bits = f64::from(pow.difficulty()); let target = f64::from(spec.target_security_bits); @@ -118,6 +121,32 @@ pub fn assert_pow_closes_gap( ); } +/// `|got − expected| < EPS` with a uniform error message. Shared by every +/// `analytic_error_*_formula` test. +pub fn assert_close(got: f64, expected: f64) { + assert!( + (got - expected).abs() < EPS, + "got {got} vs expected {expected}", + ); +} + +/// C_zk fixture used by every `mask_proximity` test: source mask length 0, +/// `num_vectors = 2 · num_masks` (Construction 7.2 originals + fresh pairs). +pub fn build_test_c_zk( + spec: &SecuritySpec, + l_zk: usize, + log_inv_rate: u32, + num_masks: usize, +) -> IrsConfig { + irs_solver::solve_mask_code( + spec, + MaskCodeMessageLen::new(l_zk), + 0, + LogInvRate::new(log_inv_rate), + 2 * num_masks, + ) +} + /// Safety net for the `target_irs ↔ t_ood` loop in [`build_round_io`]. /// Steady state converges in ≤ 2 iterations (`target.list_size()` is rate-only). const TARGET_STABILIZATION_MAX_ITER: usize = 8; From a22445b7b1dcce4d485acfae833fe5782bf547e6 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Mon, 18 May 2026 14:34:02 +0530 Subject: [PATCH 12/31] refactor push minimal --- src/protocols/params/basecase.rs | 19 ++++++++++++++++-- src/protocols/params/bounds.rs | 16 +-------------- src/protocols/params/code_switch.rs | 3 --- src/protocols/params/derive.rs | 25 ++++-------------------- src/protocols/params/irs_commit.rs | 1 - src/protocols/params/mask_proximity.rs | 12 ++++++++++-- src/protocols/params/protocol_config.rs | 26 ++++--------------------- src/protocols/params/spec.rs | 4 +++- src/protocols/params/sumcheck.rs | 1 - src/protocols/params/test_utils.rs | 7 ++----- 10 files changed, 41 insertions(+), 73 deletions(-) diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 2124a95f..02348cde 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -9,6 +9,7 @@ use crate::{ basecase::{self, Config as BasecaseConfig}, irs_commit::Config as IrsConfig, params::{ + bounds::SoundnessBounded, irs_commit as irs_solver, spec::{Mode as SpecMode, OodSampleBudget, RoundContext, SecuritySpec}, sumcheck as sumcheck_solver, @@ -28,7 +29,6 @@ pub fn solve( assert!(vector_size > 0, "basecase requires vector_size ≥ 1"); let ctx = RoundContext { - round_index: 0, vector_size, log_inv_rate, folding_factor: 0, @@ -75,6 +75,22 @@ pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { Bits::new((field_bits - log_list).max(0.0)) } +impl SoundnessBounded for BasecaseConfig { + /// `min(sumcheck round error, γ-slot error)`. The γ-slot only contributes + /// in ZK mode; Standard collapses to the sumcheck term. + fn analytic_bits(&self) -> Bits { + let sumcheck_term = + f64::from(sumcheck_solver::analytic_error_bits(&self.commit, None)); + let min_bits = match self.mode { + basecase::BasecaseMode::Standard => sumcheck_term, + basecase::BasecaseMode::ZeroKnowledge => { + sumcheck_term.min(f64::from(analytic_error_bits(&self.commit))) + } + }; + Bits::new(min_bits.max(0.0)) + } +} + #[cfg(test)] #[allow(clippy::float_cmp)] mod tests { @@ -108,7 +124,6 @@ mod tests { let spec = deterministic_spec(Mode::ZeroKnowledge); let ctx = RoundContext { - round_index: 0, vector_size: FIXTURE_VECTOR_SIZE, log_inv_rate: FIXTURE_LOG_INV_RATE, folding_factor: 0, diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index 9a897161..a036bddf 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -10,6 +10,7 @@ use crate::{ /// Analytic soundness bits (excluding PoW) delivered by a protocol-level unit. /// Sub-protocol `Config` types lack the cross-protocol context to self-report. +// Library-side callers land with protocol wiring; until then only tests use it. #[allow(dead_code)] pub trait SoundnessBounded { fn analytic_bits(&self) -> Bits; @@ -91,13 +92,6 @@ pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { ((message_length - 1) as f64).log2() - field_bits } -/// PoW difficulty to close a soundness gap: `max(0, target − achieved)`. -// TODO(phase-6): re-wire from the cross-protocol PoW pass. -#[allow(dead_code)] -pub fn pow_bits_to_close_gap(target_security_bits: f64, achieved_security_bits: f64) -> Bits { - Bits::new((target_security_bits - achieved_security_bits).max(0.0)) -} - #[cfg(test)] #[allow(clippy::float_cmp)] mod tests { @@ -247,12 +241,4 @@ mod tests { - MCA_FIELD_BITS; assert_close(got, expected); } - - /// `pow_bits_to_close_gap` clamps negative gaps to zero (no anti-grind). - #[test] - fn pow_bits_to_close_gap_saturates_at_zero() { - assert_eq!(f64::from(pow_bits_to_close_gap(100.0, 120.0)), 0.0); - assert_eq!(f64::from(pow_bits_to_close_gap(100.0, 100.0)), 0.0); - assert_eq!(f64::from(pow_bits_to_close_gap(100.0, 60.0)), 40.0); - } } diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 85108eba..793173ae 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -238,7 +238,6 @@ mod tests { ) { // Break the t_ood ↔ c_zk.list_size cycle with a placeholder C_zk. let placeholder_source_ctx = RoundContext { - round_index: 0, vector_size: 1usize << num_vars, log_inv_rate, folding_factor, @@ -300,13 +299,11 @@ mod tests { const FOLDING_FACTOR: u32 = 2; let source_ctx = RoundContext { - round_index: 0, vector_size: SOURCE_VECTOR_SIZE, log_inv_rate: SOURCE_LOG_INV_RATE, folding_factor: FOLDING_FACTOR, }; let target_ctx = RoundContext { - round_index: 1, vector_size: source_ctx.vector_size / (1 << source_ctx.folding_factor), log_inv_rate: source_ctx.log_inv_rate + source_ctx.folding_factor - 1, folding_factor: source_ctx.folding_factor, diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 5fe88d5e..33c5ef0d 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -127,7 +127,6 @@ fn round_layout(tuning: &TuningSpec) -> RoundLayout { const fn round_context(shape: &RoundShape) -> RoundContext { RoundContext { - round_index: shape.round_index, vector_size: shape.source_vector_size, log_inv_rate: shape.source_log_inv_rate, folding_factor: shape.source_folding_factor, @@ -136,7 +135,6 @@ const fn round_context(shape: &RoundShape) -> RoundContext { fn target_context(shape: &RoundShape, source: &IrsConfig) -> RoundContext { RoundContext { - round_index: shape.round_index, vector_size: source.message_length(), log_inv_rate: shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1), folding_factor: shape.target_folding_factor, @@ -505,23 +503,6 @@ mod tests { } } - fn basecase_min_bits(plan: &ProtocolConfig) -> f64 { - let sumcheck = f64::from(sumcheck_solver::analytic_error_bits( - &plan.basecase.commit, - None, - )); - if matches!( - plan.basecase.mode, - crate::protocols::basecase::BasecaseMode::ZeroKnowledge - ) { - sumcheck.min(f64::from(basecase_solver::analytic_error_bits( - &plan.basecase.commit, - ))) - } else { - sumcheck - } - } - #[test] fn analytic_bits_finite_and_positive_standard() { let spec = test_spec(Mode::Standard); @@ -536,7 +517,7 @@ mod tests { .iter() .map(|r| f64::from(r.analytic_bits())) .fold(f64::INFINITY, f64::min); - let expected = min_round.min(basecase_min_bits(&plan)); + let expected = min_round.min(f64::from(plan.basecase.analytic_bits())); assert_close(bits, expected); } @@ -566,7 +547,9 @@ mod tests { .iter() .map(|r| f64::from(r.analytic_bits())) .fold(f64::INFINITY, f64::min); - let expected = mo_floor.min(min_round).min(basecase_min_bits(&plan)); + let expected = mo_floor + .min(min_round) + .min(f64::from(plan.basecase.analytic_bits())); assert_close(plan_bits, expected); } diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 51ebd0d6..549f0908 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -196,7 +196,6 @@ mod tests { fn solve_works_with_basefield_embedding_zk() { let spec = deterministic_spec(Mode::ZeroKnowledge); let ctx = RoundContext { - round_index: 0, vector_size: SMOKE_VECTOR_SIZE, log_inv_rate: SMOKE_LOG_INV_RATE, folding_factor: SMOKE_FOLDING_FACTOR, diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index 47070f2d..197bf994 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -7,8 +7,10 @@ use crate::{ algebra::{embedding::Identity, fields::FieldWithSize}, bits::Bits, protocols::{ - irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, - params::spec::SecuritySpec, proof_of_work::Config as PowConfig, + irs_commit::Config as IrsConfig, + mask_proximity::Config as MaskProximityConfig, + params::{bounds::SoundnessBounded, spec::SecuritySpec}, + proof_of_work::Config as PowConfig, }, }; @@ -38,6 +40,12 @@ pub fn analytic_error_bits(c_zk: &IrsConfig>, num_masks: u Bits::new((field_bits - log_combined).max(0.0)) } +impl SoundnessBounded for MaskProximityConfig { + fn analytic_bits(&self) -> Bits { + analytic_error_bits(&self.c_zk_commit, self.num_masks) + } +} + #[cfg(test)] #[allow(clippy::float_cmp)] mod tests { diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 6959c798..73a18b2d 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -14,14 +14,13 @@ use crate::{ }, bits::Bits, protocols::{ - basecase::{self, Config as BasecaseConfig}, + basecase::Config as BasecaseConfig, code_switch::Config as CodeSwitchConfig, irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, params::{ - basecase as basecase_solver, bounds::SoundnessBounded, - code_switch as code_switch_solver, mask_proximity as mask_proximity_solver, + code_switch as code_switch_solver, spec::{MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, sumcheck as sumcheck_solver, }, @@ -83,27 +82,13 @@ impl ProtocolConfig { impl SoundnessBounded for ProtocolConfig { fn analytic_bits(&self) -> Bits { - let mut min_bits = f64::INFINITY; + let mut min_bits = f64::from(self.basecase.analytic_bits()); for round in &self.rounds { min_bits = min_bits.min(f64::from(round.analytic_bits())); if let Some(mo) = &round.mask_oracle { min_bits = min_bits.min(f64::from(mo.analytic_bits())); } } - // Basecase sumcheck per-round bound applies in both modes; the γ-slot - // only contributes in ZK. - min_bits = min_bits.min(f64::from(sumcheck_solver::analytic_error_bits( - &self.basecase.commit, - None, - ))); - if matches!(self.basecase.mode, basecase::BasecaseMode::ZeroKnowledge) { - min_bits = min_bits.min(f64::from(basecase_solver::analytic_error_bits( - &self.basecase.commit, - ))); - } - if min_bits.is_infinite() { - return Bits::new(f64::from(self.security.target_security_bits)); - } Bits::new(min_bits.max(0.0)) } } @@ -196,9 +181,6 @@ impl MaskOracleConfig { impl SoundnessBounded for MaskOracleConfig { fn analytic_bits(&self) -> Bits { - mask_proximity_solver::analytic_error_bits( - &self.mask_proximity.c_zk_commit, - self.mask_proximity.num_masks, - ) + self.mask_proximity.analytic_bits() } } diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 504fb29d..2dd0ed4a 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -20,6 +20,9 @@ impl Tagged { pub struct SecuritySpec { pub mode: Mode, pub target_security_bits: u32, + /// Per-slot PoW budget — every grinding slot may close at most this many + /// bits of gap to `target_security_bits`. Not a cumulative budget across + /// slots; `check_pow_bits` enforces it per-slot. `None` ⇒ `Some(0)`. pub max_pow_bits: Option, pub hash_id: EngineId, } @@ -80,7 +83,6 @@ pub struct TuningSpec { /// Per-round context handed to a sub-protocol builder. #[derive(Debug, Clone)] pub struct RoundContext { - pub round_index: usize, pub vector_size: usize, pub log_inv_rate: u32, pub folding_factor: u32, diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 009713b9..63833422 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -106,7 +106,6 @@ mod tests { fn fixture_ctx() -> RoundContext { RoundContext { - round_index: 0, vector_size: 1 << FIXTURE_LOG_VECTOR_SIZE, log_inv_rate: FIXTURE_LOG_INV_RATE, folding_factor: FIXTURE_FOLDING_FACTOR, diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index d2a818a0..545189ba 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -84,9 +84,8 @@ pub fn arb_standard_johnson_spec( /// `log_size ∈ 4..=8` (vector_size 16..256) leaves room for ≥ 2·folding_factor /// post-folding while capping proptest time. pub fn arb_round_ctx() -> impl Strategy { - (0usize..=3, 4u32..=8, 1u32..=4, 1u32..=3).prop_map( - |(round_index, log_size, log_inv_rate, folding_factor)| RoundContext { - round_index, + (4u32..=8, 1u32..=4, 1u32..=3).prop_map( + |(log_size, log_inv_rate, folding_factor)| RoundContext { vector_size: 1usize << log_size, log_inv_rate, folding_factor, @@ -161,7 +160,6 @@ pub fn build_round_io( c_zk_list_size: Option, ) -> (IrsConfig, IrsConfig>, usize) { let source_ctx = RoundContext { - round_index: 0, vector_size: 1usize << num_vars, log_inv_rate, folding_factor, @@ -169,7 +167,6 @@ pub fn build_round_io( let source = irs_solver::solve(spec, &source_ctx, OodSampleBudget::new(0)); let target_ctx = RoundContext { - round_index: 1, vector_size: source.message_length(), log_inv_rate: log_inv_rate + folding_factor - 1, folding_factor, From 69d4a45d036bb7716363e080593b26baf3ce1c7c Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Mon, 18 May 2026 14:36:07 +0530 Subject: [PATCH 13/31] lint --- src/protocols/params/basecase.rs | 3 +-- src/protocols/params/test_utils.rs | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 02348cde..1a0026dc 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -79,8 +79,7 @@ impl SoundnessBounded for BasecaseConfig { /// `min(sumcheck round error, γ-slot error)`. The γ-slot only contributes /// in ZK mode; Standard collapses to the sumcheck term. fn analytic_bits(&self) -> Bits { - let sumcheck_term = - f64::from(sumcheck_solver::analytic_error_bits(&self.commit, None)); + let sumcheck_term = f64::from(sumcheck_solver::analytic_error_bits(&self.commit, None)); let min_bits = match self.mode { basecase::BasecaseMode::Standard => sumcheck_term, basecase::BasecaseMode::ZeroKnowledge => { diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 545189ba..51a9053c 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -84,13 +84,13 @@ pub fn arb_standard_johnson_spec( /// `log_size ∈ 4..=8` (vector_size 16..256) leaves room for ≥ 2·folding_factor /// post-folding while capping proptest time. pub fn arb_round_ctx() -> impl Strategy { - (4u32..=8, 1u32..=4, 1u32..=3).prop_map( - |(log_size, log_inv_rate, folding_factor)| RoundContext { + (4u32..=8, 1u32..=4, 1u32..=3).prop_map(|(log_size, log_inv_rate, folding_factor)| { + RoundContext { vector_size: 1usize << log_size, log_inv_rate, folding_factor, - }, - ) + } + }) } /// `None` in Standard; `Some(ℓ_zk=2, c_zk rate 1/2)` in ZK. From 669b144b634f4c0be6758659f86086d79872b94b Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 20 May 2026 10:14:26 +0530 Subject: [PATCH 14/31] fix : code switch mask length over estimation --- .../protocols/params/code_switch.txt | 1 + src/protocols/irs_commit.rs | 7 ++++- src/protocols/params/code_switch.rs | 10 ++++++- src/protocols/params/irs_commit.rs | 22 ++++++--------- src/protocols/params/test_utils.rs | 28 +++++++++---------- 5 files changed, 37 insertions(+), 31 deletions(-) diff --git a/proptest-regressions/protocols/params/code_switch.txt b/proptest-regressions/protocols/params/code_switch.txt index 20850e33..13a64288 100644 --- a/proptest-regressions/protocols/params/code_switch.txt +++ b/proptest-regressions/protocols/params/code_switch.txt @@ -8,3 +8,4 @@ cc 7a7df094ea650db7a295d162b75dd9da9b52d1fc36947d2b07df8150cd9d906f # shrinks to cc b42c982074a04c7110df07cf00f45156607be547e176b1ddd5f9d994ad491ddb # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 cc eaf09a2b6bdffa86026264679f008326498ca800260dd2f17d4370df9fb3f801 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 cc 3887a5fa698c99109e8262e843dbd24ea94b9c9d420791e4520b5c9211a3eca0 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 100, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 7) +cc b3e128084f721e6f43e263e05acf2e2de6fcd05dccf3811f063eeb0b63d78f8e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 47, max_pow_bits: Some(15), hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 4) diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index a0999c35..f2c113f8 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -143,8 +143,13 @@ impl Config { assert!(vector_size.is_multiple_of(interleaving_depth)); assert!(rate > 0. && rate <= 1.); let masked_message_length = vector_size / interleaving_depth + mode.mask_length(); + // `interleaved_encode` requires `codeword_length` to divide the NTT root + // order. `masked_message_length` is allowed to be arbitrary (the coset + // NTT zero-extends internally), so we only round the codeword side here. #[allow(clippy::cast_sign_loss)] - let codeword_length = (masked_message_length as f64 / rate).ceil() as usize; + let raw_codeword_length = (masked_message_length as f64 / rate).ceil() as usize; + let codeword_length = ntt::next_order::(raw_codeword_length) + .expect("codeword length exceeds NTT engine support"); let rate = masked_message_length as f64 / codeword_length as f64; // η = slack to Johnson bound. We pick η = √ρ / 20. diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 793173ae..7a0f316d 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -108,6 +108,7 @@ mod tests { use super::*; use crate::protocols::params::{ + bounds::johnson_list_size, derive::{compute_l_zk, compute_t_ood}, irs_commit as irs_solver, spec::{LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec}, @@ -266,8 +267,15 @@ mod tests { LogInvRate::new(log_inv_rate), 2, ); + // Use the same rate-only Johnson list as the planner / `build_round_io`. + // `target.list_size()` here would read the *effective* rate after + // `next_order` rounding, which (post Lemma-9.5 tight masking) differs + // from the requested rate and would spuriously shift `t_ood`. The + // assertion isolates the c_zk fixed-point, not the rate-drift artifact. + let target_log_inv_rate = f64::from(log_inv_rate + folding_factor - 1); + let target_list_size = johnson_list_size(target_log_inv_rate); let recomputed_t_ood = - compute_t_ood(&spec, &source, target.list_size(), Some(c_zk.list_size())); + compute_t_ood(&spec, &source, target_list_size, Some(c_zk.list_size())); prop_assert_eq!(t_ood, recomputed_t_ood, "placeholder ⇒ final C_zk fixed-point"); let mask_oracle = MaskOracleInfo { c_zk_list_size: c_zk.list_size(), diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 549f0908..b4567c88 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -1,7 +1,8 @@ -//! IRS-commit parameter selection. ZK mask sized per Lemma 9.5, padded so -//! `message + mask` is a pow2 (NTT-valid codeword length). - -use std::num::NonZeroUsize; +//! IRS-commit parameter selection. +//! +//! ZK mask is sized per Lemma 9.5 (paper p.53) at the tight bound +//! `in-domain + OOD` queries. Codeword NTT-smoothness is enforced inside +//! [`IrsConfig::new`] on `codeword_length`, not by inflating the mask. use crate::{ algebra::embedding::Embedding, @@ -26,22 +27,15 @@ pub fn solve( let interleaving_depth = 1_usize << ctx.folding_factor; // Construction 9.7 is Johnson-only — `Mode` cannot express unique-decoding. let unique_decoding = false; - let message_length = ctx.vector_size / interleaving_depth; let mode = match spec.mode { Mode::Standard => IrsMode::Standard, Mode::ZeroKnowledge => { - let min_mask = num_in_domain_queries(unique_decoding, security_target, rate) + // Lemma 9.5 (part ii): r-query perfect-ZK encoding requires + // `r ≥ in-domain + OOD`. Use the tight bound; do not pow2-pad here. + let mask_length = num_in_domain_queries(unique_decoding, security_target, rate) .checked_add(out_domain_samples.get()) .expect("usize overflow"); - // Lemma 9.5: mask covers in-domain + OOD queries. - // Pad masked length to a pow2 for NTT (the lemma is `≥`, so padding is safe). - let masked_message_length = message_length - .checked_add(min_mask.get()) - .expect("masked_message_length overflow") - .next_power_of_two(); - let mask_length = NonZeroUsize::new(masked_message_length - message_length) - .expect("min_mask ≥ 1 (NonZeroUsize) ⇒ next_pow2(ℓ + min_mask) > ℓ"); IrsMode::ZeroKnowledge { mask_length } } }; diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 51a9053c..c5386f0f 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -14,6 +14,7 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ + bounds::johnson_list_size, derive::compute_t_ood, irs_commit as irs_solver, protocol_config::MaskOracleInfo, @@ -146,12 +147,14 @@ pub fn build_test_c_zk( ) } -/// Safety net for the `target_irs ↔ t_ood` loop in [`build_round_io`]. -/// Steady state converges in ≤ 2 iterations (`target.list_size()` is rate-only). -const TARGET_STABILIZATION_MAX_ITER: usize = 8; - /// Builds a self-consistent `(source, target, t_ood)` triplet matching the /// per-round shape that `code_switch::solve` expects. +/// +/// `t_ood` is solved against the rate-only `johnson_list_size(target_log_inv_rate)`, +/// mirroring `derive::build_zk_round_data`. Using `target.list_size()` here +/// instead would couple `t_ood` to the target's effective rate (which itself +/// depends on `t_ood` via the mask), producing a non-monotone oscillation +/// once the mask is tight (Lemma 9.5) rather than pow2-padded. pub fn build_round_io( spec: &SecuritySpec, log_inv_rate: u32, @@ -166,20 +169,15 @@ pub fn build_round_io( }; let source = irs_solver::solve(spec, &source_ctx, OodSampleBudget::new(0)); + let target_log_inv_rate = log_inv_rate + folding_factor - 1; let target_ctx = RoundContext { vector_size: source.message_length(), - log_inv_rate: log_inv_rate + folding_factor - 1, + log_inv_rate: target_log_inv_rate, folding_factor, }; - let mut target = irs_solver::solve(spec, &target_ctx, OodSampleBudget::new(0)); - for _ in 0..TARGET_STABILIZATION_MAX_ITER { - let t_ood = compute_t_ood(spec, &source, target.list_size(), c_zk_list_size); - let new_target = irs_solver::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); - if new_target.codeword_length == target.codeword_length { - return (source, new_target, t_ood); - } - target = new_target; - } - panic!("target IRS did not stabilize"); + let target_list_size = johnson_list_size(f64::from(target_log_inv_rate)); + let t_ood = compute_t_ood(spec, &source, target_list_size, c_zk_list_size); + let target = irs_solver::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); + (source, target, t_ood) } From 4ca24b5ccba216ed6b59ae4323e1706e2111ee49 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 20 May 2026 13:12:34 +0530 Subject: [PATCH 15/31] feat : added missing terms in analytical bits (soundness fix) --- src/protocols/params/basecase.rs | 43 ++++++++++++++-- src/protocols/params/code_switch.rs | 77 +++++++++++++++++++++-------- 2 files changed, 96 insertions(+), 24 deletions(-) diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 1a0026dc..4c8768c6 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -68,11 +68,14 @@ pub fn solve( } } -/// γ-combination soundness (Theorem 7.1, n=0): `log|F| − log|Λ(C^≡2, δ)|`. +/// γ-combination soundness (Lemma 7.4 combination-randomness slot, paper p.45). +/// At `n = 0` the `C_zk` factors vanish; `ε_mca(C, δ)` does not. pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { let field_bits = F::field_size_bits(); let log_list = commit.list_size().log2(); - Bits::new((field_bits - log_list).max(0.0)) + let prox_gaps = commit.rbr_soundness_fold_prox_gaps(); + let poly_id = field_bits - log_list; + Bits::new(prox_gaps.min(poly_id).max(0.0)) } impl SoundnessBounded for BasecaseConfig { @@ -111,7 +114,6 @@ mod tests { (1u32..=4, 1u32..=3) } - /// γ-combination soundness (Theorem 7.1, n=0): `log|F| − log|Λ(C^≡2, δ)|`. /// Builds the commit directly via the IRS solver to bypass `solve`'s PoW /// grind (which would assert against the cap for default test targets). #[test] @@ -133,11 +135,44 @@ mod tests { let got = f64::from(analytic_error_bits(&commit)); let field_bits = TestField::field_size_bits(); let log_list = commit.list_size().log2(); - let expected = (field_bits - log_list).max(0.0); + let prox_gaps = commit.rbr_soundness_fold_prox_gaps(); + let poly_id = field_bits - log_list; + let expected = prox_gaps.min(poly_id).max(0.0); assert_close(got, expected); } + /// At `log_inv_rate = 1` on `Field64`, `ε_mca` is below the poly-identity + /// term — pins the `min` to the arm that earlier returned `poly_id` alone. + #[test] + fn analytic_error_uses_eps_mca_when_limiting() { + use crate::protocols::params::{ + irs_commit as irs_solver, + spec::{Mode, OodSampleBudget, RoundContext}, + }; + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = RoundContext { + vector_size: FIXTURE_VECTOR_SIZE, + log_inv_rate: 1, + folding_factor: 0, + }; + let commit: IrsConfig> = + irs_solver::solve(&spec, &ctx, OodSampleBudget::new(0)); + + let field_bits = TestField::field_size_bits(); + let log_list = commit.list_size().log2(); + let prox_gaps = commit.rbr_soundness_fold_prox_gaps(); + let poly_id = field_bits - log_list; + assert!( + prox_gaps < poly_id, + "fixture wants prox_gaps to bind: prox_gaps {prox_gaps} ≥ poly_id {poly_id}", + ); + + let got = f64::from(analytic_error_bits(&commit)); + assert_close(got, prox_gaps.max(0.0)); + } + proptest! { #[test] fn solve_standard_assembles( diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 7a0f316d..82ef6ebc 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -49,16 +49,9 @@ pub fn solve( CodeSwitchConfig::new(source, target, t_ood, mode, pow) } -/// Per-round code-switch soundness in bits: `min(ood_term, combination_term)`. -/// -/// - OOD (Lemma 9.9, term 1): `t_ood · (log|F| − log(deg − 1)) − log(L choose 2)` -/// - Combination (Bound 1, γ-RLC): `log|F| − log(count) − log L` -/// -/// `L = |Λ(target)| · |Λ(C_zk)|` (mask-oracle absent ⇒ `|Λ(C_zk)| = 1`). -/// `deg = ℓ + r + t_ood` in ZK, `ℓ` in Standard. -/// `count = t_ood + t · ι` (OOD samples + in-domain ι-interleaved source queries). -/// -/// `t_ood ≥ 1` per [`code_switch::Config::new`]. +/// Per-round code-switch soundness in bits: `min` over Lemma 9.9's three RBR +/// error slots (OOD, in-domain, combination). `t_ood ≥ 1` per +/// [`code_switch::Config::new`]. pub fn analytic_error_bits( source: &IrsConfig, target: &IrsConfig>, @@ -85,14 +78,15 @@ pub fn analytic_error_bits( let log_l_choose_2 = (combined_list * (combined_list - 1.0) / 2.0).log2(); let ood_term = t_ood_f * (field_bits - log_degree_minus_1) - log_l_choose_2; - // Combination term — Bound 1 (γ-RLC): `t_ood` OOD samples plus the - // in-domain batch of `t · ι` source columns, all RLC'd into one target - // codeword. + // In-domain term — Lemma 9.9, term 2. + let in_domain_term = source.rbr_queries(); + + // Combination term — Lemma 9.9, term 3 (γ-RLC, bounds.md §5.1). #[allow(clippy::cast_precision_loss)] let log_count = ((t_ood + source.in_domain_samples * source.interleaving_depth) as f64).log2(); let combination_term = field_bits - log_count - combined_list.log2(); - Bits::new(ood_term.min(combination_term).max(0.0)) + Bits::new(ood_term.min(in_domain_term).min(combination_term).max(0.0)) } /// Number of `(r ‖ s)` mask polynomials code-switch contributes to C_zk per @@ -148,9 +142,8 @@ mod tests { const FORMULA_FOLDING_FACTOR: u32 = 2; const FORMULA_NUM_VARS: u32 = 6; - /// Standard OOD bound (Lemma 9.9 first term, no mask): - /// `min(t_ood · (log|F| − log(ℓ−1)) − log(L choose 2), combination)`. - /// `L = target.list_size()`, `combination = log|F| − log(t_ood + t·ι) − log L`. + /// Standard `min(ood, in_domain, comb)` from Lemma 9.9's three RBR error + /// slots; `L = target.list_size()`. #[test] fn analytic_error_standard_formula() { let spec: SecuritySpec = deterministic_spec(Mode::Standard); @@ -169,14 +162,15 @@ mod tests { let log_deg_m1 = ((degree - 1) as f64).log2(); let l_choose_2 = target_list * (target_list - 1.0) / 2.0; let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); + let in_domain = source.rbr_queries(); let count = t_ood + source.in_domain_samples * source.interleaving_depth; let comb = field_bits - (count as f64).log2() - target_list.log2(); - let expected = ood.min(comb).max(0.0); + let expected = ood.min(in_domain).min(comb).max(0.0); assert_close(got, expected); } - /// ZK OOD bound: combined list `L = target × c_zk`, masked degree `ℓ + r + t_ood`, + /// ZK bound: combined list `L = target × c_zk`, masked degree `ℓ + ℓ_zk`, /// combination term also subtracts `log|Λ(C_zk)|`. #[test] fn analytic_error_zk_formula() { @@ -211,13 +205,56 @@ mod tests { let log_deg_m1 = ((degree - 1) as f64).log2(); let l_choose_2 = combined_list * (combined_list - 1.0) / 2.0; let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); + let in_domain = source.rbr_queries(); let count = t_ood + source.in_domain_samples * source.interleaving_depth; let comb = field_bits - (count as f64).log2() - target_list.log2() - C_ZK_LIST_SIZE.log2(); - let expected = ood.min(comb).max(0.0); + let expected = ood.min(in_domain).min(comb).max(0.0); assert_close(got, expected); } + /// Low security target (16 bits) pins `source.rbr_queries()` below the + /// natural OOD and combination floors on `Field64`, forcing the `min` to + /// the arm + #[test] + fn analytic_error_uses_in_domain_when_limiting() { + const LIMITING_TARGET_BITS: u32 = 16; + const LIMITING_LOG_INV_RATE: u32 = 1; + const LIMITING_FOLDING_FACTOR: u32 = 1; + const LIMITING_NUM_VARS: u32 = 4; + + let spec = SecuritySpec { + mode: Mode::Standard, + target_security_bits: LIMITING_TARGET_BITS, + max_pow_bits: None, + hash_id: crate::hash::BLAKE3, + }; + let (source, target, t_ood) = build_round_io::( + &spec, + LIMITING_LOG_INV_RATE, + LIMITING_FOLDING_FACTOR, + LIMITING_NUM_VARS, + None, + ); + + let field_bits = ::field_size_bits(); + let target_list = target.list_size(); + let degree = source.message_length(); + let log_deg_m1 = ((degree - 1) as f64).log2(); + let l_choose_2 = target_list * (target_list - 1.0) / 2.0; + let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); + let in_domain = source.rbr_queries(); + let count = t_ood + source.in_domain_samples * source.interleaving_depth; + let comb = field_bits - (count as f64).log2() - target_list.log2(); + assert!( + in_domain < ood && in_domain < comb, + "fixture wants in_domain to bind: in_domain {in_domain}, ood {ood}, comb {comb}", + ); + + let got = f64::from(analytic_error_bits(&source, &target, t_ood, None)); + assert_close(got, in_domain.max(0.0)); + } + proptest! { #[test] fn solve_standard_assembles( From 360efb84f94d87bfe94d1b1b7ed43f8285eeeaf2 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 20 May 2026 14:32:22 +0530 Subject: [PATCH 16/31] doc : fixed refs --- src/protocols/params/code_switch.rs | 4 ++-- src/protocols/params/derive.rs | 17 ++++++++++------- src/protocols/params/irs_commit.rs | 2 +- src/protocols/params/mod.rs | 6 ++++++ src/protocols/params/protocol_config.rs | 7 ++++--- src/protocols/params/spec.rs | 2 +- src/protocols/params/sumcheck.rs | 4 +++- src/protocols/params/test_utils.rs | 2 +- 8 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 82ef6ebc..a2975fe4 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -33,7 +33,7 @@ pub fn solve( let l_zk = info.l_zk.get(); assert!( l_zk >= source.mask_length() + t_ood, - "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Bound 3", + "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Theorem 9.6 witness sizing", source.mask_length(), t_ood, ); @@ -81,7 +81,7 @@ pub fn analytic_error_bits( // In-domain term — Lemma 9.9, term 2. let in_domain_term = source.rbr_queries(); - // Combination term — Lemma 9.9, term 3 (γ-RLC, bounds.md §5.1). + // Combination term — Lemma 9.9, term 3 (γ-RLC, bounds doc §5.1). #[allow(clippy::cast_precision_loss)] let log_count = ((t_ood + source.in_domain_samples * source.interleaving_depth) as f64).log2(); let combination_term = field_bits - log_count - combined_list.log2(); diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 33c5ef0d..ae5ab76c 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -143,8 +143,9 @@ fn target_context(shape: &RoundShape, source: &IrsConfig) -> Ro /// Per-round ZK builder. C_zk holds `2 · (k + 1)` columns (Construction 7.2 /// originals + fresh): `k` sumcheck masks (Lemma 6.4) + one `(r ‖ s)` -/// code-switch mask (Construction 9.7). `ℓ_zk = next_pow2(r + t_ood)` per -/// Lemma 9.3; `t_ood` solves Lemma 9.9 term 1. +/// code-switch mask (Construction 9.7). `ℓ_zk = next_pow2(r + t_ood)` from +/// Theorem 9.6's witness layout + Lemma 9.3's `r ≥ t` privacy precondition; +/// `t_ood` solves Lemma 9.9 term 1. fn build_zk_round_config( spec: &SecuritySpec, shape: &RoundShape, @@ -255,7 +256,8 @@ fn build_round_config( } } -/// `ℓ_zk = next_pow2(r + t_ood)` (Lemma 9.3). +/// `ℓ_zk = next_pow2(r + t_ood)`: Theorem 9.6 witness layout `0^{ℓ_zk − r}` +/// combined with Lemma 9.3's `r ≥ t` privacy precondition. pub(super) const fn compute_l_zk( source: &IrsConfig, t_ood: usize, @@ -296,7 +298,8 @@ pub(super) fn compute_t_ood( let r = source.mask_length(); for _ in 0..MAX_ITER { // Polynomial degree = `ℓ + ℓ_zk` where `ℓ_zk = next_pow2(r + t_ood)` - // (Lemma 9.3). Using `r + t_ood` would under-count when not pow2. + // (Theorem 9.6 / Lemma 9.3). Using `r + t_ood` would under-count when + // not pow2. let l_zk = (r + t_ood).next_power_of_two(); let new_t_ood = solve_for_degree(message_length + l_zk); if new_t_ood == t_ood { @@ -577,8 +580,8 @@ mod tests { /// Comfortably above `TIGHT_POW_BUDGET_BITS`. const OVER_BUDGET_INJECTED_BITS: f64 = 50.0; - /// Bound 3 + Bound 7: HVZK privacy error in bits matches the closed-form - /// `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` over ZK rounds. + /// Bounds doc §5.3 + §5.7: HVZK privacy error in bits matches the closed + /// form `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` over ZK rounds. #[test] fn privacy_error_bits_matches_bound_3_sum() { let spec = test_spec(Mode::ZeroKnowledge); @@ -761,7 +764,7 @@ mod tests { let num_masks = k + 1; prop_assert_eq!(mask_oracle.c_zk.num_vectors, 2 * num_masks); prop_assert_eq!(mask_oracle.mask_proximity.num_masks, num_masks); - // Bound 3 (Lemma 9.3): ℓ_zk ≥ r + t_ood for this round. + // Theorem 9.6 / Lemma 9.3: ℓ_zk ≥ r + t_ood for this round. let source_mask = r.code_switch.source.mask_length(); prop_assert!(mask_oracle.l_zk.get() >= source_mask + t_ood.get()); } diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index b4567c88..0eae6fbf 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -149,7 +149,7 @@ mod tests { } proptest! { - /// Lemma 9.5: mask covers all revealed evaluations. + /// Lemma 9.5 (part ii): mask covers all revealed evaluations. #[test] fn zk_mask_covers_lemma_9_5( spec in arb_zk_spec_default(), diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index 26ecd371..c69f3fec 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -1,3 +1,9 @@ +//! Parameter selection for HVZK-WHIR. +//! +//! Soundness and ZK bound derivations (referred to in submodule comments as +//! "the bounds doc, §N") live at +//! . + pub mod basecase; pub(crate) mod bounds; pub mod code_switch; diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 73a18b2d..edea01d9 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -57,7 +57,7 @@ impl ProtocolConfig { } /// HVZK privacy error in bits, summed across ZK rounds: - /// `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` (Bound 3 + Bound 7). + /// `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` (bounds doc, §5.3 + §5.7). /// Standard-mode plans return `target_security_bits` as a sentinel — /// HVZK isn't claimed when there are no ZK rounds. pub fn privacy_error_bits(&self) -> Bits { @@ -108,7 +108,7 @@ pub struct RoundConfig { pub enum RoundMode { Standard, ZeroKnowledge { - /// Bound 2 / Lemma 9.9. + /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). t_ood: OodSampleBudget, /// Slim view of this round's [`MaskOracleConfig`] (C_zk's list size + /// ℓ_zk) — denormalized so soundness routines can read it without @@ -158,7 +158,8 @@ impl SoundnessBounded for RoundConfig { pub struct MaskOracleConfig { /// `num_vectors = 2 · (k + 1)` (Construction 7.2: originals + fresh). pub c_zk: IrsConfig>, - /// `next_pow2(r + t_ood)` for this round (Lemma 9.3). + /// `next_pow2(r + t_ood)` for this round: Theorem 9.6 witness layout + /// (`0^{ℓ_zk − r}` padding) + Lemma 9.3 `(ℓ_zk − r, 0)`-privacy precondition. pub l_zk: MaskCodeMessageLen, pub mask_proximity: MaskProximityConfig, } diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 2dd0ed4a..aed80b5a 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -104,7 +104,7 @@ pub enum MaskCodeMessageLenTag {} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum LogInvRateTag {} -/// Bound 2 OOD-sample budget. +/// OOD-sample budget (Lemma 9.9 / bounds doc §5.2). pub type OodSampleBudget = Tagged; /// C_zk message length (Theorem 9.6: `ℓ_zk ≥ source mask length`). diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 63833422..22bb867d 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -69,7 +69,9 @@ const fn num_sumcheck_rounds(ctx: &RoundContext) -> usize { ctx.folding_factor as usize } -/// Lemma 6.4, p.38: 3 coefficients suffice for a degree-2 round polynomial. +/// Construction 6.3 step 4(a) sends `h_j ∈ F^{ usize { 3 } diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index c5386f0f..b8d2cc4c 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -154,7 +154,7 @@ pub fn build_test_c_zk( /// mirroring `derive::build_zk_round_data`. Using `target.list_size()` here /// instead would couple `t_ood` to the target's effective rate (which itself /// depends on `t_ood` via the mask), producing a non-monotone oscillation -/// once the mask is tight (Lemma 9.5) rather than pow2-padded. +/// once the mask is tight (Lemma 9.5 part ii) rather than pow2-padded. pub fn build_round_io( spec: &SecuritySpec, log_inv_rate: u32, From 939c8b319c984d1ed799313ac240b1ac508ebbda Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Wed, 20 May 2026 17:41:01 +0530 Subject: [PATCH 17/31] refactor : resolved refactoring comments --- .../protocols/params/derive.txt | 7 +++ .../protocols/params/irs_commit.txt | 1 + src/protocols/basecase.rs | 22 +++++++ src/protocols/params/basecase.rs | 13 ++-- src/protocols/params/bounds.rs | 7 +++ src/protocols/params/code_switch.rs | 36 +++++------ src/protocols/params/derive.rs | 13 ++-- src/protocols/params/irs_commit.rs | 4 +- src/protocols/params/mask_proximity.rs | 8 ++- src/protocols/params/mod.rs | 22 +++++-- src/protocols/params/protocol_config.rs | 41 +++++++------ src/protocols/params/spec.rs | 44 ++++++++++--- src/protocols/params/sumcheck.rs | 26 ++++---- src/protocols/params/test_utils.rs | 7 ++- src/protocols/sumcheck.rs | 61 +++++++++++++++---- 15 files changed, 214 insertions(+), 98 deletions(-) create mode 100644 proptest-regressions/protocols/params/derive.txt diff --git a/proptest-regressions/protocols/params/derive.txt b/proptest-regressions/protocols/params/derive.txt new file mode 100644 index 00000000..ca83c1b3 --- /dev/null +++ b/proptest-regressions/protocols/params/derive.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 104921a4117ed8255308c1ea5d3e12c72356ef72ef0d93fc0f24ed29f93fdd3a # shrinks to tuning = TuningSpec { vector_size: 32, starting_log_inv_rate: 3, folding_factor: Constant(1) } diff --git a/proptest-regressions/protocols/params/irs_commit.txt b/proptest-regressions/protocols/params/irs_commit.txt index 2ca7e863..4d35f3df 100644 --- a/proptest-regressions/protocols/params/irs_commit.txt +++ b/proptest-regressions/protocols/params/irs_commit.txt @@ -5,3 +5,4 @@ # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. cc 0b6dd03179c9a4e38b29b34b241b88fba69348a2c8938af7253314b7035bea82 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 }, out_domain = 0, seed = 0 +cc 7e49f7a2d53f55cfa2f09114d17ab4123678b45ddf69e0cfbc646b246de2f042 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, ctx = RoundContext { vector_size: 128, log_inv_rate: 2, folding_factor: 2 }, out_domain = 11 diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 22f677e8..bdd37731 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -44,6 +44,28 @@ pub struct Config { } impl Config { + /// `mode == ZeroKnowledge` iff `pow != PowConfig::none()`: ZK basecase has + /// a γ-combination PoW slot (Lemma 7.4); Standard has no γ challenge. + pub fn new( + commit: irs_commit::Config>, + sumcheck: sumcheck::Config, + mode: BasecaseMode, + pow: proof_of_work::Config, + ) -> Self { + let has_pow = pow != proof_of_work::Config::none(); + debug_assert_eq!( + matches!(mode, BasecaseMode::ZeroKnowledge), + has_pow, + "ZK basecase needs PoW; Standard basecase must have none", + ); + Self { + commit, + sumcheck, + mode, + pow, + } + } + pub const fn size(&self) -> usize { self.sumcheck.initial_size } diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 4c8768c6..5c64b60e 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -33,7 +33,7 @@ pub fn solve( log_inv_rate, folding_factor: 0, }; - let commit = irs_solver::solve(spec, &ctx, OodSampleBudget::new(0)); + let commit = irs_solver::solve(spec, &ctx, OodSampleBudget::ZERO); let target_bits = Bits::new(f64::from(spec.target_security_bits)); let sumcheck_pow = PowConfig::grind_to( @@ -60,12 +60,7 @@ pub fn solve( } }; - BasecaseConfig { - commit, - sumcheck, - mode, - pow, - } + BasecaseConfig::new(commit, sumcheck, mode, pow) } /// γ-combination soundness (Lemma 7.4 combination-randomness slot, paper p.45). @@ -130,7 +125,7 @@ mod tests { folding_factor: 0, }; let commit: IrsConfig> = - irs_solver::solve(&spec, &ctx, OodSampleBudget::new(0)); + irs_solver::solve(&spec, &ctx, OodSampleBudget::ZERO); let got = f64::from(analytic_error_bits(&commit)); let field_bits = TestField::field_size_bits(); @@ -158,7 +153,7 @@ mod tests { folding_factor: 0, }; let commit: IrsConfig> = - irs_solver::solve(&spec, &ctx, OodSampleBudget::new(0)); + irs_solver::solve(&spec, &ctx, OodSampleBudget::ZERO); let field_bits = TestField::field_size_bits(); let log_list = commit.list_size().log2(); diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index a036bddf..db3a3321 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -41,6 +41,13 @@ pub(super) fn rate(log_inv_rate: f64) -> f64 { 2_f64.powf(-log_inv_rate) } +/// Lossy `usize → f64` for analytic-error formulas. Single allow-site for +/// `clippy::cast_precision_loss` so individual call sites can stay terse. +#[allow(clippy::cast_precision_loss)] +pub(super) const fn usize_to_f64(x: usize) -> f64 { + x as f64 +} + fn unique_decoding(johnson_slack: f64) -> bool { johnson_slack == 0.0 } diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index a2975fe4..e017be91 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -12,7 +12,7 @@ use crate::{ protocols::{ code_switch::{self, Config as CodeSwitchConfig}, irs_commit::Config as IrsConfig, - params::{protocol_config::MaskOracleInfo, spec::SecuritySpec}, + params::{bounds::usize_to_f64, protocol_config::MaskOracleInfo, spec::SecuritySpec}, proof_of_work::Config as PowConfig, }, }; @@ -61,7 +61,8 @@ pub fn analytic_error_bits( assert!(t_ood > 0, "code-switch requires t_ood ≥ 1"); let field_bits = M::Target::field_size_bits(); - let combined_list = target.list_size() * mask_oracle.map_or(1.0, |info| info.c_zk_list_size); + let combined_list = + target.list_size() * mask_oracle.map_or(1.0, |info| info.c_zk_list_size.get()); // OOD polynomial is over witness `[f; r_C; s]` of length `ℓ + ℓ_zk` (ZK) or // `ℓ` (Standard). The `s`-tail is sampled at full length `ℓ_zk − r` (not // just `t_ood`), so degree must use the realized `ℓ_zk`, not `r + t_ood`. @@ -69,12 +70,10 @@ pub fn analytic_error_bits( || source.message_length(), |info| source.message_length() + info.l_zk.get(), ); - #[allow(clippy::cast_precision_loss)] - let t_ood_f = t_ood as f64; + let t_ood_f = usize_to_f64(t_ood); // OOD term — Lemma 9.9, term 1. - #[allow(clippy::cast_precision_loss)] - let log_degree_minus_1 = ((degree - 1) as f64).log2(); + let log_degree_minus_1 = usize_to_f64(degree - 1).log2(); let log_l_choose_2 = (combined_list * (combined_list - 1.0) / 2.0).log2(); let ood_term = t_ood_f * (field_bits - log_degree_minus_1) - log_l_choose_2; @@ -82,8 +81,8 @@ pub fn analytic_error_bits( let in_domain_term = source.rbr_queries(); // Combination term — Lemma 9.9, term 3 (γ-RLC, bounds doc §5.1). - #[allow(clippy::cast_precision_loss)] - let log_count = ((t_ood + source.in_domain_samples * source.interleaving_depth) as f64).log2(); + let log_count = + usize_to_f64(t_ood + source.in_domain_samples * source.interleaving_depth).log2(); let combination_term = field_bits - log_count - combined_list.log2(); Bits::new(ood_term.min(in_domain_term).min(combination_term).max(0.0)) @@ -105,7 +104,10 @@ mod tests { bounds::johnson_list_size, derive::{compute_l_zk, compute_t_ood}, irs_commit as irs_solver, - spec::{LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec}, + spec::{ + ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, + SecuritySpec, + }, test_utils::{ arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, assert_close, assert_pow_closes_gap, build_round_io, deterministic_spec, TestEmbedding, @@ -181,7 +183,7 @@ mod tests { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); let mask_oracle = MaskOracleInfo { - c_zk_list_size: C_ZK_LIST_SIZE, + c_zk_list_size: ListSize::new(C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new(L_ZK_USIZE), }; let (source, target, t_ood) = build_round_io::( @@ -283,7 +285,7 @@ mod tests { let placeholder_source = irs_solver::solve::( &spec, &placeholder_source_ctx, - OodSampleBudget::new(0), + OodSampleBudget::ZERO, ); let c_zk_placeholder = irs_solver::solve_mask_code::( &spec, @@ -315,7 +317,7 @@ mod tests { compute_t_ood(&spec, &source, target_list_size, Some(c_zk.list_size())); prop_assert_eq!(t_ood, recomputed_t_ood, "placeholder ⇒ final C_zk fixed-point"); let mask_oracle = MaskOracleInfo { - c_zk_list_size: c_zk.list_size(), + c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, }; let config = solve(&spec, source, target, t_ood, Some(mask_oracle)); @@ -365,13 +367,13 @@ mod tests { let source = irs_solver::solve::( &spec, &source_ctx, - OodSampleBudget::new(0), + OodSampleBudget::ZERO, ); // Standard target: codeword_length is t_ood-independent (mask = 0). let target = irs_solver::solve::>( &spec, &target_ctx, - OodSampleBudget::new(0), + OodSampleBudget::ZERO, ); let t_ood = compute_t_ood(&spec, &source, target.list_size(), None); @@ -396,12 +398,12 @@ mod tests { let mut source = irs_solver::solve::( &spec, &source_ctx, - OodSampleBudget::new(0), + OodSampleBudget::ZERO, ); let mut target = irs_solver::solve::>( &spec, &target_ctx, - OodSampleBudget::new(0), + OodSampleBudget::ZERO, ); for _ in 0..SMOKE_FIXED_POINT_MAX_ITER { let new_t_ood = compute_t_ood( @@ -419,7 +421,7 @@ mod tests { } let mask_oracle = MaskOracleInfo { - c_zk_list_size: SMOKE_C_ZK_LIST_SIZE, + c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()), }; let config = solve(&spec, source, target, t_ood, Some(mask_oracle)); diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index ae5ab76c..4a03b764 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -207,7 +207,7 @@ fn build_zk_round_data( let target_list_size = johnson_list_size(target_log_inv_rate); let mut t_ood = 0; - let mut source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(0)); + let mut source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::ZERO); for _ in 0..LOCAL_MAX_ITER { let new_t_ood = compute_t_ood(spec, &source, target_list_size, Some(c_zk_list_size)); if new_t_ood == t_ood { @@ -237,12 +237,9 @@ fn build_round_config( debug_assert!(mask_oracle.is_none(), "ZK path uses build_zk_round_config"); let src_ctx = round_context(shape); - let source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(0)); - let target: IrsConfig> = irs_solver::solve( - spec, - &target_context(shape, &source), - OodSampleBudget::new(0), - ); + let source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::ZERO); + let target: IrsConfig> = + irs_solver::solve(spec, &target_context(shape, &source), OodSampleBudget::ZERO); let t_ood = compute_t_ood(spec, &source, target.list_size(), None); let sumcheck = sumcheck_solver::solve(spec, &src_ctx, &source, None); @@ -275,7 +272,7 @@ pub(super) fn compute_t_ood( ) -> usize { const MAX_ITER: usize = 32; - let security_target = spec.protocol_security_target_bits(); + let security_target = f64::from(spec.protocol_security_target_bits()); let field_bits = M::Target::field_size_bits(); let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); let message_length = source.message_length(); diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 0eae6fbf..47027e25 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -22,7 +22,7 @@ pub fn solve( ctx: &RoundContext, out_domain_samples: OodSampleBudget, ) -> IrsConfig { - let security_target = spec.protocol_security_target_bits(); + let security_target = f64::from(spec.protocol_security_target_bits()); let rate = rate(f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; // Construction 9.7 is Johnson-only — `Mode` cannot express unique-decoding. @@ -79,7 +79,7 @@ pub fn solve_mask_code( "num_vectors ({num_vectors}) must be even (mask-proximity original/fresh pairs)", ); - let security_target = spec.protocol_security_target_bits(); + let security_target = f64::from(spec.protocol_security_target_bits()); let rate = rate(f64::from(log_inv_rate.get())); IrsConfig::new( diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index 197bf994..f43473b1 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -9,7 +9,10 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, - params::{bounds::SoundnessBounded, spec::SecuritySpec}, + params::{ + bounds::{usize_to_f64, SoundnessBounded}, + spec::SecuritySpec, + }, proof_of_work::Config as PowConfig, }, }; @@ -35,8 +38,7 @@ pub fn analytic_error_bits(c_zk: &IrsConfig>, num_masks: u if deg <= 1 || num_masks == 0 { return Bits::new(field_bits.max(0.0)); } - #[allow(clippy::cast_precision_loss)] - let log_combined = ((num_masks * (deg - 1)) as f64).log2(); + let log_combined = usize_to_f64(num_masks * (deg - 1)).log2(); Bits::new((field_bits - log_combined).max(0.0)) } diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index c69f3fec..b0f9a1d6 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -3,16 +3,28 @@ //! Soundness and ZK bound derivations (referred to in submodule comments as //! "the bounds doc, §N") live at //! . +//! +//! `derive` is the public entry point; the sub-protocol solvers (`basecase`, +//! `code_switch`, `irs_commit`, `mask_proximity`, `sumcheck`) are crate-local +//! and reached only via `derive`. Output and spec types are re-exported below. -pub mod basecase; +pub(crate) mod basecase; pub(crate) mod bounds; -pub mod code_switch; +pub(crate) mod code_switch; pub mod derive; -pub mod irs_commit; -pub mod mask_proximity; +pub(crate) mod irs_commit; +pub(crate) mod mask_proximity; pub mod protocol_config; pub mod spec; -pub mod sumcheck; +pub(crate) mod sumcheck; #[cfg(test)] pub(crate) mod test_utils; + +pub use protocol_config::{ + MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, +}; +pub use spec::{ + FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, + SecuritySpec, TuningSpec, +}; diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index edea01d9..3c96c377 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -19,9 +19,9 @@ use crate::{ irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, params::{ - bounds::SoundnessBounded, + bounds::{usize_to_f64, SoundnessBounded}, code_switch as code_switch_solver, - spec::{MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, + spec::{ListSize, MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, sumcheck as sumcheck_solver, }, proof_of_work::Config as PowConfig, @@ -65,8 +65,7 @@ impl ProtocolConfig { let mut total_error = 0.0_f64; for r in &self.rounds { if let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode { - #[allow(clippy::cast_precision_loss)] - let t = t_ood.get() as f64; + let t = usize_to_f64(t_ood.get()); // ζ_ze ≤ (t_ood² + t_ood) / (2|F|). Compute in log space to // stay numerically stable for large field_bits. let log_err = f64::midpoint(t * t, t).log2() - field_bits; @@ -85,9 +84,6 @@ impl SoundnessBounded for ProtocolConfig { let mut min_bits = f64::from(self.basecase.analytic_bits()); for round in &self.rounds { min_bits = min_bits.min(f64::from(round.analytic_bits())); - if let Some(mo) = &round.mask_oracle { - min_bits = min_bits.min(f64::from(mo.analytic_bits())); - } } Bits::new(min_bits.max(0.0)) } @@ -104,7 +100,7 @@ pub struct RoundConfig { pub mask_oracle: Option>, } -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum RoundMode { Standard, ZeroKnowledge { @@ -131,24 +127,33 @@ impl RoundMode { } impl SoundnessBounded for RoundConfig { + /// Round-level analytic floor: the smallest of `sumcheck`, `code_switch`, + /// and (when present) the per-round mask-oracle proximity check. Folding + /// the mask-oracle term in here keeps `ProtocolConfig::analytic_bits` + /// a pure `min` over rounds + basecase. fn analytic_bits(&self) -> Bits { let source = &self.code_switch.source; let target = &self.code_switch.target; let mask_oracle = self.mode.mask_oracle(); - let sumcheck_term = sumcheck_solver::analytic_error_bits(source, mask_oracle); - let code_switch_term = code_switch_solver::analytic_error_bits( + let sumcheck_term = f64::from(sumcheck_solver::analytic_error_bits(source, mask_oracle)); + let code_switch_term = f64::from(code_switch_solver::analytic_error_bits( source, target, self.code_switch.out_domain_samples, mask_oracle, - ); + )); + let mask_oracle_term = self + .mask_oracle + .as_ref() + .map_or(f64::INFINITY, |mo| f64::from(mo.analytic_bits())); - if f64::from(code_switch_term) < f64::from(sumcheck_term) { - code_switch_term - } else { + Bits::new( sumcheck_term - } + .min(code_switch_term) + .min(mask_oracle_term) + .max(0.0), + ) } } @@ -165,16 +170,16 @@ pub struct MaskOracleConfig { } /// Slim mask-oracle view (C_zk's list size + ℓ_zk). -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct MaskOracleInfo { - pub c_zk_list_size: f64, + pub c_zk_list_size: ListSize, pub l_zk: MaskCodeMessageLen, } impl MaskOracleConfig { pub fn info(&self) -> MaskOracleInfo { MaskOracleInfo { - c_zk_list_size: self.c_zk.list_size(), + c_zk_list_size: ListSize::new(self.c_zk.list_size()), l_zk: self.l_zk, } } diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index aed80b5a..f42bc0ac 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -1,6 +1,8 @@ use core::marker::PhantomData; -use crate::engines::EngineId; +use ordered_float::OrderedFloat; + +use crate::{bits::Bits, engines::EngineId}; /// Phantom-typed newtype — `Tagged` and `Tagged` are distinct types. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -28,9 +30,9 @@ pub struct SecuritySpec { } impl SecuritySpec { - pub fn protocol_security_target_bits(&self) -> f64 { + pub fn protocol_security_target_bits(&self) -> Bits { let pow = self.max_pow_bits.unwrap_or(0); - f64::from(self.target_security_bits.saturating_sub(pow)) + Bits::new(f64::from(self.target_security_bits.saturating_sub(pow))) } } @@ -107,12 +109,34 @@ pub enum LogInvRateTag {} /// OOD-sample budget (Lemma 9.9 / bounds doc §5.2). pub type OodSampleBudget = Tagged; +impl Tagged { + /// Sentinel for "no OOD samples". Used by sub-protocols that don't + /// require an OOD challenge round (e.g. Standard mode, basecase). + pub const ZERO: Self = Self::new(0); +} + /// C_zk message length (Theorem 9.6: `ℓ_zk ≥ source mask length`). pub type MaskCodeMessageLen = Tagged; /// `rate = 2^-log_inv_rate`. pub type LogInvRate = Tagged; +/// Reed–Solomon list-decoding ball size `|Λ(C, δ)|`. Wraps `OrderedFloat` +/// so it can be stored alongside the `Tagged` integer newtypes without losing +/// `Eq`/`Hash`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ListSize(OrderedFloat); + +impl ListSize { + pub const fn new(v: f64) -> Self { + Self(OrderedFloat(v)) + } + + pub const fn get(self) -> f64 { + self.0 .0 + } +} + #[cfg(test)] #[allow(clippy::float_cmp)] mod tests { @@ -136,7 +160,7 @@ mod tests { fn none_means_no_pow_credit() { assert_eq!( spec(None).protocol_security_target_bits(), - f64::from(TARGET_BITS), + Bits::new(f64::from(TARGET_BITS)), ); } @@ -151,8 +175,14 @@ mod tests { #[test] fn pow_credit_shifts_analytic_floor() { // Two below-target PoW budgets: `target − pow` shifts down 1:1. - assert_eq!(spec(Some(20)).protocol_security_target_bits(), 80.0); - assert_eq!(spec(Some(60)).protocol_security_target_bits(), 40.0); + assert_eq!( + spec(Some(20)).protocol_security_target_bits(), + Bits::new(80.0), + ); + assert_eq!( + spec(Some(60)).protocol_security_target_bits(), + Bits::new(40.0), + ); } #[test] @@ -161,7 +191,7 @@ mod tests { let pow_over_target = TARGET_BITS + 100; assert_eq!( spec(Some(pow_over_target)).protocol_security_target_bits(), - 0.0 + Bits::new(0.0), ); } } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 22bb867d..fe0149dd 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -7,11 +7,12 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ + bounds::usize_to_f64, protocol_config::MaskOracleInfo, spec::{RoundContext, SecuritySpec}, }, proof_of_work::Config as PowConfig, - sumcheck::{self, Config as SumcheckConfig}, + sumcheck::{self, Config as SumcheckConfig, SumcheckMaskLen}, }, }; @@ -50,9 +51,8 @@ pub fn analytic_error_bits( let prox_gaps = source_irs.rbr_soundness_fold_prox_gaps(); let poly_id = mask_oracle.map_or(field_bits - log_list_size - 1.0, |info| { - let log_list_size_c_zk = info.c_zk_list_size.log2(); - #[allow(clippy::cast_precision_loss)] - let log_l_zk = (info.l_zk.get() as f64).log2(); + let log_list_size_c_zk = info.c_zk_list_size.get().log2(); + let log_l_zk = usize_to_f64(info.l_zk.get()).log2(); field_bits - log_list_size - log_list_size_c_zk - log_l_zk }); @@ -72,8 +72,8 @@ const fn num_sumcheck_rounds(ctx: &RoundContext) -> usize { /// Construction 6.3 step 4(a) sends `h_j ∈ F^{ usize { - 3 +const fn zk_mask_length() -> SumcheckMaskLen { + SumcheckMaskLen::new(3) } #[cfg(test)] @@ -84,7 +84,7 @@ mod tests { use super::*; use crate::protocols::params::{ irs_commit as irs_solver, - spec::{MaskCodeMessageLen, Mode, OodSampleBudget}, + spec::{ListSize, MaskCodeMessageLen, Mode, OodSampleBudget}, test_utils::{ arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, assert_close, assert_pow_closes_gap, build_minimal_mask_oracle, deterministic_spec, TestEmbedding, @@ -98,7 +98,7 @@ mod tests { const FIXTURE_L_ZK: usize = 8; fn build_source_irs(spec: &SecuritySpec, ctx: &RoundContext) -> IrsConfig { - irs_solver::solve(spec, ctx, OodSampleBudget::new(0)) + irs_solver::solve(spec, ctx, OodSampleBudget::ZERO) } /// Smallest pow2 shape that still produces a non-degenerate IRS. @@ -124,7 +124,7 @@ mod tests { let config = solve(&spec, &ctx, &source_irs, mask_oracle); match config.mode { sumcheck::SumcheckMode::ZeroKnowledge { mask_length } => { - assert_eq!(mask_length, 3); + assert_eq!(mask_length.get(), 3); } sumcheck::SumcheckMode::Standard => panic!("expected ZK"), } @@ -157,7 +157,7 @@ mod tests { let ctx = fixture_ctx(); let irs = build_source_irs(&spec, &ctx); let info = MaskOracleInfo { - c_zk_list_size: FIXTURE_C_ZK_LIST_SIZE, + c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; @@ -184,7 +184,7 @@ mod tests { let ctx = fixture_ctx(); let irs = build_source_irs(&spec, &ctx); let huge = MaskOracleInfo { - c_zk_list_size: 2_f64.powi(OVERSIZED_LOG_C_ZK_LIST), + c_zk_list_size: ListSize::new(2_f64.powi(OVERSIZED_LOG_C_ZK_LIST)), l_zk: MaskCodeMessageLen::new(1 << OVERSIZED_LOG_L_ZK), }; let bits = f64::from(analytic_error_bits::(&irs, Some(huge))); @@ -254,9 +254,9 @@ mod tests { let spec = deterministic_spec(Mode::ZeroKnowledge); let ctx = fixture_ctx(); let source_irs: IrsConfig = - irs_solver::solve(&spec, &ctx, OodSampleBudget::new(0)); + irs_solver::solve(&spec, &ctx, OodSampleBudget::ZERO); let info = MaskOracleInfo { - c_zk_list_size: FIXTURE_C_ZK_LIST_SIZE, + c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; let config = solve(&spec, &ctx, &source_irs, Some(info)); diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index b8d2cc4c..4f671adc 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -19,7 +19,8 @@ use crate::{ irs_commit as irs_solver, protocol_config::MaskOracleInfo, spec::{ - LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, + SecuritySpec, }, }, proof_of_work::Config as PowConfig, @@ -103,7 +104,7 @@ pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option let c_zk: IrsConfig = irs_solver::solve_mask_code(spec, l_zk, 0, LogInvRate::new(1), 2); Some(MaskOracleInfo { - c_zk_list_size: c_zk.list_size(), + c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, }) } @@ -167,7 +168,7 @@ pub fn build_round_io( log_inv_rate, folding_factor, }; - let source = irs_solver::solve(spec, &source_ctx, OodSampleBudget::new(0)); + let source = irs_solver::solve(spec, &source_ctx, OodSampleBudget::ZERO); let target_log_inv_rate = log_inv_rate + folding_factor - 1; let target_ctx = RoundContext { diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index bccd316f..626e040f 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -30,10 +30,33 @@ pub struct SumcheckOpening { pub mask_rlc: F, } +/// ZK sumcheck mask polynomial dimension. +/// +/// Validated at construction to be at least `MIN = 3` — the round polynomial +/// has 3 coefficients (degree-2), so the mask must have at least as many to +/// hide it. Lemma 6.4 itself only requires `ℓ_zk ≥ 2`; the `3` floor is a +/// WHIR design choice tied to the degree-2 round polynomial (see +/// `params::sumcheck::zk_mask_length`). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct SumcheckMaskLen(usize); + +impl SumcheckMaskLen { + pub const MIN: usize = 3; + + pub const fn new(n: usize) -> Self { + assert!(n >= Self::MIN); + Self(n) + } + + pub const fn get(self) -> usize { + self.0 + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum SumcheckMode { Standard, - ZeroKnowledge { mask_length: usize }, + ZeroKnowledge { mask_length: SumcheckMaskLen }, } #[must_use] @@ -58,10 +81,9 @@ impl Config { mode: SumcheckMode, ) -> Self { assert!(num_rounds == 0 || initial_size.next_power_of_two() >= 1 << num_rounds); - if let SumcheckMode::ZeroKnowledge { mask_length } = &mode { - // Mask must cover all 3 sumcheck polynomial coefficients (c0, c1, c2). - assert!(*mask_length >= 3); - // Lemma 6.4 prerequisite. + // `SumcheckMaskLen::new` already enforces the ≥ 3 floor at construction; + // here we only need the field-characteristic precondition from Lemma 6.4. + if matches!(mode, SumcheckMode::ZeroKnowledge { .. }) { assert!( !F::ONE.double().is_zero(), "ZK sumcheck requires char(F) ≠ 2" @@ -79,7 +101,7 @@ impl Config { const fn mask_length(&self) -> usize { match &self.mode { SumcheckMode::Standard => 0, - SumcheckMode::ZeroKnowledge { mask_length } => *mask_length, + SumcheckMode::ZeroKnowledge { mask_length } => mask_length.get(), } } @@ -202,8 +224,11 @@ impl Config { return (F::ZERO, F::ONE); } let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(1)); - let mask_sum = - masks.chunks_exact(*mask_length).map(eval_01).sum::() * sum_multiple; + let mask_sum = masks + .chunks_exact(mask_length.get()) + .map(eval_01) + .sum::() + * sum_multiple; prover_state.prover_message(&mask_sum); let mask_rlc = prover_state.verifier_message(); (mask_sum, mask_rlc) @@ -285,7 +310,9 @@ impl fmt::Display for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mode_str = match &self.mode { SumcheckMode::Standard => "standard".to_string(), - SumcheckMode::ZeroKnowledge { mask_length } => format!("zk ℓ_zk={mask_length}"), + SumcheckMode::ZeroKnowledge { mask_length } => { + format!("zk ℓ_zk={}", mask_length.get()) + } }; write!( f, @@ -333,7 +360,9 @@ mod tests { pub fn arbitrary() -> impl Strategy { let mode_strategy = prop_oneof![ 3 => Just(SumcheckMode::Standard), - 7 => (3_usize..20).prop_map(|mask_length| SumcheckMode::ZeroKnowledge { mask_length }), + 7 => (3_usize..20).prop_map(|n| SumcheckMode::ZeroKnowledge { + mask_length: SumcheckMaskLen::new(n), + }), ]; (0_usize..(1 << 12), 0_usize..12, mode_strategy).prop_map( |(initial_size, num_rounds, mode)| { @@ -438,7 +467,9 @@ mod tests { 2, proof_of_work::Config::none(), 1, - SumcheckMode::ZeroKnowledge { mask_length: 3 }, + SumcheckMode::ZeroKnowledge { + mask_length: SumcheckMaskLen::new(3), + }, ), ); } @@ -451,7 +482,9 @@ mod tests { 3, proof_of_work::Config::none(), 2, - SumcheckMode::ZeroKnowledge { mask_length: 3 }, + SumcheckMode::ZeroKnowledge { + mask_length: SumcheckMaskLen::new(3), + }, ), ); } @@ -464,7 +497,9 @@ mod tests { 5, proof_of_work::Config::none(), 3, - SumcheckMode::ZeroKnowledge { mask_length: 3 }, + SumcheckMode::ZeroKnowledge { + mask_length: SumcheckMaskLen::new(3), + }, ), ); } From 71fd7580072dab044002e9cc3b3d8a5013741ffd Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Thu, 21 May 2026 16:09:53 +0530 Subject: [PATCH 18/31] feat : added error handling for params module --- Cargo.lock | 21 +++ Cargo.toml | 1 + src/protocols/params/basecase.rs | 17 +- src/protocols/params/code_switch.rs | 36 +++-- src/protocols/params/derive.rs | 207 ++++++++++++++++++------ src/protocols/params/error.rs | 98 +++++++++++ src/protocols/params/mask_proximity.rs | 19 ++- src/protocols/params/mod.rs | 2 + src/protocols/params/protocol_config.rs | 68 ++++++-- src/protocols/params/sumcheck.rs | 46 +++++- src/protocols/params/test_utils.rs | 3 +- src/protocols/proof_of_work.rs | 33 +++- 12 files changed, 448 insertions(+), 103 deletions(-) create mode 100644 src/protocols/params/error.rs diff --git a/Cargo.lock b/Cargo.lock index acc6ca37..62e5a0d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1360,6 +1360,26 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -1556,6 +1576,7 @@ dependencies = [ "sha3", "spongefish", "static_assertions", + "thiserror", "tracing", "tracing-subscriber", "zerocopy", diff --git a/Cargo.toml b/Cargo.toml index 515ee2b8..52fc40e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ const-oid = "0.9.6" arrayvec = "0.7.6" derive-where = { version = "1.6.0", features = ["safe"] } ordered-float = { version = "5.1.0", features = ["serde"] } +thiserror = "2.0" [dev-dependencies] proptest = "1.0" diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 5c64b60e..262b1596 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -10,6 +10,7 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::SoundnessBounded, + error::{BasecaseSlot, DeriveError, PowResultExt, PowSlot}, irs_commit as irs_solver, spec::{Mode as SpecMode, OodSampleBudget, RoundContext, SecuritySpec}, sumcheck as sumcheck_solver, @@ -25,7 +26,7 @@ pub fn solve( spec: &SecuritySpec, vector_size: usize, log_inv_rate: u32, -) -> BasecaseConfig { +) -> Result, DeriveError> { assert!(vector_size > 0, "basecase requires vector_size ≥ 1"); let ctx = RoundContext { @@ -40,7 +41,8 @@ pub fn solve( target_bits, sumcheck_solver::analytic_error_bits(&commit, None), spec.hash_id, - ); + ) + .at_slot(PowSlot::Basecase(BasecaseSlot::Sumcheck))?; let sumcheck = SumcheckConfig::new( vector_size, sumcheck_pow, @@ -57,10 +59,11 @@ pub fn solve( basecase::BasecaseMode::Standard => PowConfig::none(), basecase::BasecaseMode::ZeroKnowledge => { PowConfig::grind_to(target_bits, analytic_error_bits(&commit), spec.hash_id) + .at_slot(PowSlot::Basecase(BasecaseSlot::GammaCombination))? } }; - BasecaseConfig::new(commit, sumcheck, mode, pow) + Ok(BasecaseConfig::new(commit, sumcheck, mode, pow)) } /// γ-combination soundness (Lemma 7.4 combination-randomness slot, paper p.45). @@ -174,7 +177,7 @@ mod tests { spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { - let config = solve::(&spec, 1usize << log_size, log_inv_rate); + let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); prop_assert!(matches!(config.mode, basecase::BasecaseMode::Standard)); prop_assert_eq!(config.commit.interleaving_depth, 1); prop_assert_eq!(config.commit.num_vectors, 1); @@ -186,7 +189,7 @@ mod tests { spec in arb_zk_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { - let config = solve::(&spec, 1usize << log_size, log_inv_rate); + let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); prop_assert!(matches!(config.mode, basecase::BasecaseMode::ZeroKnowledge)); prop_assert!(config.commit.mask_length() > 0); } @@ -196,7 +199,7 @@ mod tests { spec in arb_zk_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { - let config = solve::(&spec, 1usize << log_size, log_inv_rate); + let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); assert_pow_closes_gap(&spec, analytic_error_bits(&config.commit), &config.pow); } @@ -205,7 +208,7 @@ mod tests { spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { - let config = solve::(&spec, 1usize << log_size, log_inv_rate); + let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); prop_assert_eq!(config.pow, PowConfig::none()); } } diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index e017be91..74d1ee5d 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -12,7 +12,12 @@ use crate::{ protocols::{ code_switch::{self, Config as CodeSwitchConfig}, irs_commit::Config as IrsConfig, - params::{bounds::usize_to_f64, protocol_config::MaskOracleInfo, spec::SecuritySpec}, + params::{ + bounds::usize_to_f64, + error::{DeriveError, PowResultExt, PowSlot, RoundSlot}, + protocol_config::MaskOracleInfo, + spec::SecuritySpec, + }, proof_of_work::Config as PowConfig, }, }; @@ -28,7 +33,8 @@ pub fn solve( target: IrsConfig>, t_ood: usize, mask_oracle: Option, -) -> CodeSwitchConfig { + round_index: usize, +) -> Result, DeriveError> { let mode = mask_oracle.map_or(code_switch::CodeSwitchMode::Standard, |info| { let l_zk = info.l_zk.get(); assert!( @@ -44,9 +50,12 @@ pub fn solve( let target_bits = Bits::new(f64::from(spec.target_security_bits)); let analytic = analytic_error_bits(&source, &target, t_ood, mask_oracle); - let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id); + let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id).at_slot(PowSlot::Round { + index: round_index, + kind: RoundSlot::CodeSwitch, + })?; - CodeSwitchConfig::new(source, target, t_ood, mode, pow) + Ok(CodeSwitchConfig::new(source, target, t_ood, mode, pow)) } /// Per-round code-switch soundness in bits: `min` over Lemma 9.9's three RBR @@ -265,7 +274,7 @@ mod tests { ) { let (source, target, t_ood) = build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); - let config = solve(&spec, source, target, t_ood, None); + let config = solve(&spec, source, target, t_ood, None, 0).unwrap(); prop_assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); prop_assert!(config.out_domain_samples >= 1); } @@ -314,13 +323,14 @@ mod tests { let target_log_inv_rate = f64::from(log_inv_rate + folding_factor - 1); let target_list_size = johnson_list_size(target_log_inv_rate); let recomputed_t_ood = - compute_t_ood(&spec, &source, target_list_size, Some(c_zk.list_size())); + compute_t_ood(&spec, &source, target_list_size, Some(c_zk.list_size()), 0) + .unwrap(); prop_assert_eq!(t_ood, recomputed_t_ood, "placeholder ⇒ final C_zk fixed-point"); let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, }; - let config = solve(&spec, source, target, t_ood, Some(mask_oracle)); + let config = solve(&spec, source, target, t_ood, Some(mask_oracle), 0).unwrap(); prop_assert_eq!(config.message_mask_length(), (r + t_ood).next_power_of_two()); } @@ -333,7 +343,7 @@ mod tests { let (source, target, t_ood) = build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); let error = analytic_error_bits(&source, &target, t_ood, None); - let config = solve(&spec, source, target, t_ood, None); + let config = solve(&spec, source, target, t_ood, None, 0).unwrap(); assert_pow_closes_gap(&spec, error, &config.pow); } } @@ -375,9 +385,9 @@ mod tests { &target_ctx, OodSampleBudget::ZERO, ); - let t_ood = compute_t_ood(&spec, &source, target.list_size(), None); + let t_ood = compute_t_ood(&spec, &source, target.list_size(), None, 0).unwrap(); - let config = solve(&spec, source, target, t_ood, None); + let config = solve(&spec, source, target, t_ood, None, 0).unwrap(); assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); } @@ -411,7 +421,9 @@ mod tests { &source, target.list_size(), Some(SMOKE_C_ZK_LIST_SIZE), - ); + 0, + ) + .unwrap(); if new_t_ood == t_ood { break; } @@ -424,7 +436,7 @@ mod tests { c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()), }; - let config = solve(&spec, source, target, t_ood, Some(mask_oracle)); + let config = solve(&spec, source, target, t_ood, Some(mask_oracle), 0).unwrap(); assert!(matches!( config.mode, code_switch::CodeSwitchMode::ZeroKnowledge { .. } diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 4a03b764..64af6fb0 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -14,8 +14,9 @@ use crate::{ params::{ basecase as basecase_solver, bounds::johnson_list_size, - code_switch as code_switch_solver, irs_commit as irs_solver, - mask_proximity as mask_proximity_solver, + code_switch as code_switch_solver, + error::{DeriveError, PowSlot, RoundSlot}, + irs_commit as irs_solver, mask_proximity as mask_proximity_solver, protocol_config::{ MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, }, @@ -31,7 +32,11 @@ use crate::{ impl ProtocolConfig { /// In ZK each round owns its mask oracle; the `ℓ_zk ↔ c_zk ↔ t_ood` /// fixed-point runs independently per round. - pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Self { + /// + /// Fails with [`DeriveError`] when the spec/tuning combination is + /// infeasible: a PoW slot exceeds the grind cap, a fixed point diverges, + /// or any slot exceeds `spec.max_pow_bits` (post-derivation validation). + pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Result { let RoundLayout { shapes, basecase_vector_size, @@ -42,24 +47,26 @@ impl ProtocolConfig { Mode::Standard => shapes .iter() .map(|shape| build_round_config::(&spec, shape, None)) - .collect(), + .collect::>()?, Mode::ZeroKnowledge => { let c_zk_log_inv_rate = LogInvRate::new(tuning.starting_log_inv_rate); shapes .iter() .map(|shape| build_zk_round_config::(&spec, shape, c_zk_log_inv_rate)) - .collect() + .collect::>()? } }; - let basecase = basecase_solver::solve(&spec, basecase_vector_size, basecase_log_inv_rate); + let basecase = basecase_solver::solve(&spec, basecase_vector_size, basecase_log_inv_rate)?; - Self { + let plan = Self { security: spec, tuning, rounds, basecase, - } + }; + plan.validate_pow_budget()?; + Ok(plan) } } @@ -150,7 +157,7 @@ fn build_zk_round_config( spec: &SecuritySpec, shape: &RoundShape, c_zk_log_inv_rate: LogInvRate, -) -> RoundConfig { +) -> Result, DeriveError> { let ctx = round_context(shape); let num_masks = sumcheck_solver::masks_required(&ctx) + code_switch_solver::masks_required(); // C_zk.list_size depends only on rate — no IRS build needed for it. @@ -160,7 +167,7 @@ fn build_zk_round_config( source, target, t_ood, - } = build_zk_round_data::(spec, shape, c_zk_list_size); + } = build_zk_round_data::(spec, shape, c_zk_list_size)?; let l_zk = compute_l_zk(&source, t_ood); let c_zk: IrsConfig> = irs_solver::solve_mask_code( @@ -171,15 +178,30 @@ fn build_zk_round_config( 2 * num_masks, ); let mask_oracle = MaskOracleConfig { - mask_proximity: mask_proximity_solver::solve(spec, c_zk.clone(), num_masks), + mask_proximity: mask_proximity_solver::solve( + spec, + c_zk.clone(), + num_masks, + shape.round_index, + )?, c_zk, l_zk, }; let info = mask_oracle.info(); - let sumcheck = sumcheck_solver::solve(spec, &ctx, &source, Some(info)); - let code_switch = code_switch_solver::solve(spec, source, target, t_ood, Some(info)); - RoundConfig { + let sumcheck = sumcheck_solver::solve( + spec, + &ctx, + &source, + Some(info), + PowSlot::Round { + index: shape.round_index, + kind: RoundSlot::Sumcheck, + }, + )?; + let code_switch = + code_switch_solver::solve(spec, source, target, t_ood, Some(info), shape.round_index)?; + Ok(RoundConfig { round_index: shape.round_index, sumcheck, code_switch, @@ -188,7 +210,7 @@ fn build_zk_round_config( mask_oracle: info, }, mask_oracle: Some(mask_oracle), - } + }) } /// Local `t_ood ↔ r` fixed-point. `r = source.mask_length()` is a step function @@ -198,7 +220,7 @@ fn build_zk_round_data( spec: &SecuritySpec, shape: &RoundShape, c_zk_list_size: f64, -) -> RoundData { +) -> Result, DeriveError> { const LOCAL_MAX_ITER: usize = 16; let src_ctx = round_context(shape); @@ -209,48 +231,66 @@ fn build_zk_round_data( let mut t_ood = 0; let mut source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::ZERO); for _ in 0..LOCAL_MAX_ITER { - let new_t_ood = compute_t_ood(spec, &source, target_list_size, Some(c_zk_list_size)); + let new_t_ood = compute_t_ood( + spec, + &source, + target_list_size, + Some(c_zk_list_size), + shape.round_index, + )?; if new_t_ood == t_ood { let target: IrsConfig> = irs_solver::solve( spec, &target_context(shape, &source), OodSampleBudget::new(t_ood), ); - return RoundData { + return Ok(RoundData { source, target, t_ood, - }; + }); } t_ood = new_t_ood; source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(t_ood)); } - panic!("per-round ZK fixed-point did not converge"); + Err(DeriveError::PerRoundFixedPointDidNotConverge { + round_index: shape.round_index, + }) } fn build_round_config( spec: &SecuritySpec, shape: &RoundShape, mask_oracle: Option, -) -> RoundConfig { +) -> Result, DeriveError> { debug_assert!(mask_oracle.is_none(), "ZK path uses build_zk_round_config"); let src_ctx = round_context(shape); let source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::ZERO); let target: IrsConfig> = irs_solver::solve(spec, &target_context(shape, &source), OodSampleBudget::ZERO); - let t_ood = compute_t_ood(spec, &source, target.list_size(), None); + let t_ood = compute_t_ood(spec, &source, target.list_size(), None, shape.round_index)?; - let sumcheck = sumcheck_solver::solve(spec, &src_ctx, &source, None); - let code_switch = code_switch_solver::solve(spec, source, target, t_ood, None); - RoundConfig { + let sumcheck = sumcheck_solver::solve( + spec, + &src_ctx, + &source, + None, + PowSlot::Round { + index: shape.round_index, + kind: RoundSlot::Sumcheck, + }, + )?; + let code_switch = + code_switch_solver::solve(spec, source, target, t_ood, None, shape.round_index)?; + Ok(RoundConfig { round_index: shape.round_index, sumcheck, code_switch, mode: RoundMode::Standard, mask_oracle: None, - } + }) } /// `ℓ_zk = next_pow2(r + t_ood)`: Theorem 9.6 witness layout `0^{ℓ_zk − r}` @@ -269,7 +309,8 @@ pub(super) fn compute_t_ood( source: &IrsConfig, target_list_size: f64, c_zk_list_size: Option, -) -> usize { + round_index: usize, +) -> Result { const MAX_ITER: usize = 32; let security_target = f64::from(spec.protocol_security_target_bits()); @@ -289,7 +330,7 @@ pub(super) fn compute_t_ood( let mut t_ood = solve_for_degree(message_length); if matches!(spec.mode, Mode::Standard) { - return t_ood; + return Ok(t_ood); } let r = source.mask_length(); @@ -300,11 +341,11 @@ pub(super) fn compute_t_ood( let l_zk = (r + t_ood).next_power_of_two(); let new_t_ood = solve_for_degree(message_length + l_zk); if new_t_ood == t_ood { - return t_ood; + return Ok(t_ood); } t_ood = new_t_ood; } - panic!("compute_t_ood did not converge in {MAX_ITER} iterations"); + Err(DeriveError::TOodFixedPointDidNotConverge { round_index }) } #[cfg(test)] @@ -378,7 +419,9 @@ mod tests { SecuritySpec { mode, target_security_bits: PLAN_FIXTURE_TARGET_BITS, - max_pow_bits: None, + // Allow up to the grind cap; derive() auto-validates the budget + // and would reject configs that need any PoW under `None ⇒ 0`. + max_pow_bits: Some(LOOSE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, } } @@ -466,7 +509,7 @@ mod tests { fn derive_standard_with_no_rounds_uses_basecase_only() { let spec = test_spec(Mode::Standard); let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; - let plan = ProtocolConfig::::derive(spec, tuning_with(vector_size)); + let plan = ProtocolConfig::::derive(spec, tuning_with(vector_size)).unwrap(); assert!(plan.rounds.is_empty()); assert_eq!(plan.basecase.commit.vector_size, vector_size); } @@ -479,7 +522,8 @@ mod tests { let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), - ); + ) + .unwrap(); assert!(plan.rounds.is_empty()); assert!(matches!( plan.basecase.mode, @@ -494,7 +538,8 @@ mod tests { let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ); + ) + .unwrap(); for r in &plan.rounds { let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { panic!("expected ZK round") @@ -509,7 +554,8 @@ mod tests { let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ); + ) + .unwrap(); let bits: f64 = plan.analytic_bits().into(); assert!(bits.is_finite() && bits > 0.0, "bits = {bits}"); let min_round = plan @@ -527,7 +573,8 @@ mod tests { let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ); + ) + .unwrap(); let plan_bits: f64 = plan.analytic_bits().into(); let mo_floor = plan .rounds @@ -559,7 +606,8 @@ mod tests { let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ); + ) + .unwrap(); assert!(matches!( plan.basecase.mode, crate::protocols::basecase::BasecaseMode::ZeroKnowledge @@ -571,10 +619,9 @@ mod tests { /// Matches `proof_of_work::threshold`'s 60-bit cap. const LOOSE_POW_BUDGET_BITS: u32 = 60; - /// Below any realistic analytic gap; forces `check_pow_bits` to reject - /// the injected slot in the negative test. - const TIGHT_POW_BUDGET_BITS: u32 = 10; - /// Comfortably above `TIGHT_POW_BUDGET_BITS`. + /// Sits between a moderate budget (30) and the grind cap (60) — used by + /// `check_pow_bits_detects_over_budget_slot` to inject a slot that fits + /// the cap but exceeds the test's `max_pow_bits`. const OVER_BUDGET_INJECTED_BITS: f64 = 50.0; /// Bounds doc §5.3 + §5.7: HVZK privacy error in bits matches the closed @@ -585,7 +632,8 @@ mod tests { let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ); + ) + .unwrap(); let field_bits = ::field_size_bits(); let mut expected_total = 0.0_f64; for r in &plan.rounds { @@ -608,7 +656,8 @@ mod tests { let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ); + ) + .unwrap(); assert_eq!( f64::from(plan.privacy_error_bits()), f64::from(PLAN_FIXTURE_TARGET_BITS), @@ -627,28 +676,82 @@ mod tests { let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ); + ) + .unwrap(); assert!(plan.check_pow_bits()); } - /// Hand-injected over-budget PoW slot fails the check. + /// Hand-injected over-budget PoW slot fails `check_pow_bits()`. + /// + /// Derive with a moderately tight budget (passes auto-validation because + /// the natural slot pow stays well below it), then mutate the basecase + /// pow to a value above that budget but still within the grind cap, and + /// verify the boolean check trips. #[test] fn check_pow_bits_detects_over_budget_slot() { use crate::{bits::Bits, protocols::proof_of_work::Config as PowConfig}; + const MODERATE_POW_BUDGET_BITS: u32 = 30; let spec = SecuritySpec { mode: Mode::ZeroKnowledge, target_security_bits: PLAN_FIXTURE_TARGET_BITS, - max_pow_bits: Some(TIGHT_POW_BUDGET_BITS), + max_pow_bits: Some(MODERATE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, }; let mut plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), - ); + ) + .unwrap(); plan.basecase.pow = PowConfig::from_difficulty(Bits::new(OVER_BUDGET_INJECTED_BITS)); assert!(!plan.check_pow_bits()); } + /// `derive()` reports `PowUngrindable` when the spec demands a per-slot + /// difficulty above the grind cap. `target_security_bits = 200` against + /// `analytic ≈ 64` on `Field64` gives `required ≈ 136` ≫ 60. + #[test] + fn derive_reports_pow_ungrindable() { + const UNREACHABLE_TARGET_BITS: u32 = 200; + let spec = SecuritySpec { + mode: Mode::Standard, + target_security_bits: UNREACHABLE_TARGET_BITS, + max_pow_bits: Some(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let err = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .expect_err("target above grind cap must fail"); + assert!( + matches!(err, DeriveError::PowUngrindable { .. }), + "got {err:?}", + ); + } + + /// `derive()` reports `PowBudgetExceeded` when a slot's required PoW + /// fits the grind cap but exceeds `spec.max_pow_bits`. `target = 40` + /// with `max_pow_bits = Some(5)` forces this on `Field64`. + #[test] + fn derive_reports_pow_budget_exceeded() { + const TIGHT_MAX_POW: u32 = 5; + let spec = SecuritySpec { + mode: Mode::ZeroKnowledge, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + max_pow_bits: Some(TIGHT_MAX_POW), + hash_id: hash::BLAKE3, + }; + let err = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .expect_err("tight max_pow_bits must trip auto-validation"); + assert!( + matches!(err, DeriveError::PowBudgetExceeded { .. }), + "got {err:?}", + ); + } + /// `analytic_error + pow ≥ target` for every PoW slot in the plan. fn assert_plan_meets_target_per_slot( spec: &SecuritySpec, @@ -706,7 +809,7 @@ mod tests { #[test] fn derived_plan_meets_target_per_slot_standard(tuning in arb_tuning()) { let spec = test_spec(Mode::Standard); - let plan = ProtocolConfig::::derive(spec.clone(), tuning); + let plan = ProtocolConfig::::derive(spec.clone(), tuning).unwrap(); assert_plan_meets_target_per_slot(&spec, &plan); } @@ -718,7 +821,7 @@ mod tests { tuning.folding_factor.at_round(0) + tuning.folding_factor.at_round(1); prop_assume!(tuning.vector_size.trailing_zeros() as usize >= log_threshold); let spec = test_spec(Mode::ZeroKnowledge); - let plan = ProtocolConfig::::derive(spec.clone(), tuning); + let plan = ProtocolConfig::::derive(spec.clone(), tuning).unwrap(); assert_plan_meets_target_per_slot(&spec, &plan); } @@ -727,7 +830,7 @@ mod tests { #[test] fn derive_standard_succeeds_over_tunings(tuning in arb_tuning()) { let spec = test_spec(Mode::Standard); - let plan = ProtocolConfig::::derive(spec, tuning); + let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); for r in &plan.rounds { prop_assert!(matches!(r.mode, RoundMode::Standard)); prop_assert!(r.mask_oracle.is_none()); @@ -748,7 +851,7 @@ mod tests { prop_assume!(tuning.vector_size.trailing_zeros() as usize >= log_threshold); let spec = test_spec(Mode::ZeroKnowledge); - let plan = ProtocolConfig::::derive(spec, tuning); + let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); for r in &plan.rounds { let mask_oracle = r .mask_oracle @@ -776,7 +879,7 @@ mod tests { #[test] fn analytic_bits_finite_and_non_negative_standard(tuning in arb_tuning()) { let spec = test_spec(Mode::Standard); - let plan = ProtocolConfig::::derive(spec, tuning); + let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); let analytic = f64::from(plan.analytic_bits()); prop_assert!(analytic.is_finite()); prop_assert!(analytic >= 0.0); diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs new file mode 100644 index 00000000..cfb173b2 --- /dev/null +++ b/src/protocols/params/error.rs @@ -0,0 +1,98 @@ +//! Errors raised by [`super::derive::ProtocolConfig::derive`] and the +//! sub-protocol solvers. +//! +//! Two layers: [`super::super::proof_of_work::PowError`] for grinding-cap +//! failures, [`DeriveError`] for everything `derive()` can surface. The latter +//! wraps the former via [`DeriveError::PowUngrindable::source`] so callers can +//! walk the `std::error::Error::source()` chain. + +use thiserror::Error; + +use crate::{bits::Bits, protocols::proof_of_work::PowError}; + +/// Coordinate of a PoW slot in the derived protocol. Two axes: where the +/// slot lives (basecase vs. a numbered round) and which sub-protocol owns it. +/// Only valid combinations are representable. +/// +/// `Error` is derived for the Display propagation it provides in +/// [`DeriveError`]'s `#[error("...")]` attributes; this type isn't itself a +/// failure (`source()` is always `None`). +#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] +pub enum PowSlot { + #[error("basecase {0}")] + Basecase(BasecaseSlot), + #[error("round {index} {kind}")] + Round { index: usize, kind: RoundSlot }, +} + +/// Sub-protocols whose PoW lives in the basecase. `GammaCombination` is the +/// Lemma 7.4 γ-RLC slot, present only in ZK mode. +#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] +pub enum BasecaseSlot { + #[error("γ-combination")] + GammaCombination, + #[error("sumcheck")] + Sumcheck, +} + +/// Sub-protocols whose PoW lives in a per-round shape. +#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] +pub enum RoundSlot { + #[error("sumcheck")] + Sumcheck, + #[error("code-switch")] + CodeSwitch, + #[error("mask-proximity")] + MaskProximity, +} + +/// Failure modes for [`super::derive::ProtocolConfig::derive`] and the +/// sub-protocol solvers it calls. +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum DeriveError { + /// `compute_t_ood` failed to reach a fixed point in `MAX_ITER` iterations. + /// Indicates a pathological spec/tuning combo; should not happen under + /// realistic security targets on supported fields. + #[error("t_ood fixed-point did not converge for round {round_index}")] + TOodFixedPointDidNotConverge { round_index: usize }, + + /// The ZK per-round `t_ood ↔ source.mask_length()` loop failed to reach a + /// fixed point. Same caveat as `TOodFixedPointDidNotConverge`. + #[error("ZK per-round fixed-point did not converge for round {round_index}")] + PerRoundFixedPointDidNotConverge { round_index: usize }, + + /// A PoW slot cannot be ground at the chosen analytic floor — the spec is + /// too tight for any single grind slot to close the gap. + #[error("{slot} cannot be ground: {source}")] + PowUngrindable { + slot: PowSlot, + #[source] + source: PowError, + }, + + /// A PoW slot fits the grind cap but exceeds the per-slot budget set by + /// [`super::spec::SecuritySpec::max_pow_bits`]. + #[error("{slot} requires {required} bits, exceeds spec.max_pow_bits = {max}")] + PowBudgetExceeded { + slot: PowSlot, + required: Bits, + max: Bits, + }, + + /// Computed codeword length exceeds the NTT engine's supported order. + #[error("codeword length {length} exceeds the NTT engine's supported order")] + CodewordExceedsNtt { length: usize }, +} + +/// Lift `Result` into `Result` by attaching a +/// [`PowSlot`] label. Lets call sites stay single-line — no manual +/// `.map_err(|e| DeriveError::PowUngrindable { slot, source: e })` boilerplate. +pub(crate) trait PowResultExt { + fn at_slot(self, slot: PowSlot) -> Result; +} + +impl PowResultExt for Result { + fn at_slot(self, slot: PowSlot) -> Result { + self.map_err(|source| DeriveError::PowUngrindable { slot, source }) + } +} diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index f43473b1..068c9e2e 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -11,6 +11,7 @@ use crate::{ mask_proximity::Config as MaskProximityConfig, params::{ bounds::{usize_to_f64, SoundnessBounded}, + error::{DeriveError, PowResultExt, PowSlot, RoundSlot}, spec::SecuritySpec, }, proof_of_work::Config as PowConfig, @@ -23,11 +24,15 @@ pub fn solve( spec: &SecuritySpec, c_zk: IrsConfig>, num_masks: usize, -) -> MaskProximityConfig { + round_index: usize, +) -> Result, DeriveError> { let target_bits = Bits::new(f64::from(spec.target_security_bits)); let analytic = analytic_error_bits(&c_zk, num_masks); - let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id); - MaskProximityConfig::new(c_zk, num_masks, pow) + let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id).at_slot(PowSlot::Round { + index: round_index, + kind: RoundSlot::MaskProximity, + })?; + Ok(MaskProximityConfig::new(c_zk, num_masks, pow)) } /// γ-combination soundness (Lemma 7.4): @@ -114,7 +119,7 @@ mod tests { l_zk_log in 1u32..=5, ) { let c_zk = build_test_c_zk(&spec, 1usize << l_zk_log, log_inv_rate, num_masks); - let config = solve(&spec, c_zk, num_masks); + let config = solve(&spec, c_zk, num_masks, 0).unwrap(); prop_assert_eq!(config.num_masks, num_masks); prop_assert_eq!(config.c_zk_commit.num_vectors, 2 * num_masks); prop_assert_eq!(config.c_zk_commit.interleaving_depth, 1); @@ -130,7 +135,7 @@ mod tests { ) { let c_zk = build_test_c_zk(&spec, 1usize << l_zk_log, log_inv_rate, num_masks); let analytic = analytic_error_bits(&c_zk, num_masks); - let config = solve(&spec, c_zk, num_masks); + let config = solve(&spec, c_zk, num_masks, 0).unwrap(); assert_pow_closes_gap(&spec, analytic, &config.pow); } } @@ -143,7 +148,7 @@ mod tests { fn solve_rejects_mismatched_num_vectors() { let spec = deterministic_spec(Mode::ZeroKnowledge); let c_zk = build_test_c_zk(&spec, 2, 1, 2); - let _ = solve(&spec, c_zk, 3); + let _ = solve(&spec, c_zk, 3, 0); } #[test] @@ -170,6 +175,6 @@ mod tests { RATE, IrsMode::Standard, ); - let _ = solve(&spec, c_zk, NUM_MASKS); + let _ = solve(&spec, c_zk, NUM_MASKS, 0); } } diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index b0f9a1d6..34f113b9 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -12,6 +12,7 @@ pub(crate) mod basecase; pub(crate) mod bounds; pub(crate) mod code_switch; pub mod derive; +pub mod error; pub(crate) mod irs_commit; pub(crate) mod mask_proximity; pub mod protocol_config; @@ -21,6 +22,7 @@ pub(crate) mod sumcheck; #[cfg(test)] pub(crate) mod test_utils; +pub use error::{BasecaseSlot, DeriveError, PowSlot, RoundSlot}; pub use protocol_config::{ MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, }; diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 3c96c377..1face937 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -21,6 +21,7 @@ use crate::{ params::{ bounds::{usize_to_f64, SoundnessBounded}, code_switch as code_switch_solver, + error::{BasecaseSlot, DeriveError, PowSlot, RoundSlot}, spec::{ListSize, MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, sumcheck as sumcheck_solver, }, @@ -39,21 +40,64 @@ pub struct ProtocolConfig { impl ProtocolConfig { /// Returns `true` if every PoW slot's difficulty fits within - /// `security.max_pow_bits`. Cheap pre-flight check that fails before the - /// 60-bit cap assertion inside `proof_of_work::threshold`. + /// `security.max_pow_bits`. Boolean predicate kept for callers that want + /// to re-check after manual inspection; [`Self::validate_pow_budget`] is + /// the typed version used internally by [`super::derive::ProtocolConfig::derive`]. pub fn check_pow_bits(&self) -> bool { + self.validate_pow_budget().is_ok() + } + + /// Same check as [`Self::check_pow_bits`] but returns the specific slot + /// and required-vs-max difficulties on failure. Auto-invoked by + /// `derive()`; callers don't normally need to call this directly. + pub fn validate_pow_budget(&self) -> Result<(), DeriveError> { let max = Bits::new(f64::from(self.security.max_pow_bits.unwrap_or(0))); - let within = |pow: &PowConfig| pow.difficulty() <= max; - if !self.rounds.iter().all(|r| { - within(&r.sumcheck.round_pow) - && within(&r.code_switch.pow) - && r.mask_oracle - .as_ref() - .is_none_or(|mo| within(&mo.mask_proximity.pow)) - }) { - return false; + let check = |slot: PowSlot, pow: &PowConfig| -> Result<(), DeriveError> { + let required = pow.difficulty(); + if required > max { + Err(DeriveError::PowBudgetExceeded { + slot, + required, + max, + }) + } else { + Ok(()) + } + }; + for r in &self.rounds { + check( + PowSlot::Round { + index: r.round_index, + kind: RoundSlot::Sumcheck, + }, + &r.sumcheck.round_pow, + )?; + check( + PowSlot::Round { + index: r.round_index, + kind: RoundSlot::CodeSwitch, + }, + &r.code_switch.pow, + )?; + if let Some(mo) = &r.mask_oracle { + check( + PowSlot::Round { + index: r.round_index, + kind: RoundSlot::MaskProximity, + }, + &mo.mask_proximity.pow, + )?; + } } - within(&self.basecase.sumcheck.round_pow) && within(&self.basecase.pow) + check( + PowSlot::Basecase(BasecaseSlot::Sumcheck), + &self.basecase.sumcheck.round_pow, + )?; + check( + PowSlot::Basecase(BasecaseSlot::GammaCombination), + &self.basecase.pow, + )?; + Ok(()) } /// HVZK privacy error in bits, summed across ZK rounds: diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index fe0149dd..9e563c3f 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -8,6 +8,7 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::usize_to_f64, + error::{DeriveError, PowResultExt, PowSlot}, protocol_config::MaskOracleInfo, spec::{RoundContext, SecuritySpec}, }, @@ -17,25 +18,33 @@ use crate::{ }; /// `mask_oracle` is `Some` iff ZK; only C_zk's list size + ℓ_zk are read here. +/// `slot` is the [`PowSlot`] that labels grinding failures (basecase or per-round). pub fn solve( spec: &SecuritySpec, ctx: &RoundContext, source_irs: &IrsConfig, mask_oracle: Option, -) -> SumcheckConfig { + slot: PowSlot, +) -> Result, DeriveError> { let num_rounds = num_sumcheck_rounds(ctx); let round_pow = PowConfig::grind_to( Bits::new(f64::from(spec.target_security_bits)), analytic_error_bits(source_irs, mask_oracle), spec.hash_id, - ); + ) + .at_slot(slot)?; let mode = match mask_oracle { None => sumcheck::SumcheckMode::Standard, Some(_) => sumcheck::SumcheckMode::ZeroKnowledge { mask_length: zk_mask_length(), }, }; - SumcheckConfig::new(ctx.vector_size, round_pow, num_rounds, mode) + Ok(SumcheckConfig::new( + ctx.vector_size, + round_pow, + num_rounds, + mode, + )) } /// Per-sumcheck-round soundness in bits: `min(ε_mca, poly_identity_term)`. @@ -83,6 +92,7 @@ mod tests { use super::*; use crate::protocols::params::{ + error::RoundSlot, irs_commit as irs_solver, spec::{ListSize, MaskCodeMessageLen, Mode, OodSampleBudget}, test_utils::{ @@ -121,7 +131,17 @@ mod tests { let ctx = fixture_ctx(); let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle); + let config = solve( + &spec, + &ctx, + &source_irs, + mask_oracle, + PowSlot::Round { + index: 0, + kind: RoundSlot::Sumcheck, + }, + ) + .unwrap(); match config.mode { sumcheck::SumcheckMode::ZeroKnowledge { mask_length } => { assert_eq!(mask_length.get(), 3); @@ -199,7 +219,7 @@ mod tests { ) { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle); + let config = solve(&spec, &ctx, &source_irs, mask_oracle, PowSlot::Round { index: 0, kind: RoundSlot::Sumcheck }).unwrap(); prop_assert!(matches!(config.mode, sumcheck::SumcheckMode::Standard)); } @@ -213,7 +233,7 @@ mod tests { ) { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle); + let config = solve(&spec, &ctx, &source_irs, mask_oracle, PowSlot::Round { index: 0, kind: RoundSlot::Sumcheck }).unwrap(); prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); } @@ -243,7 +263,7 @@ mod tests { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); let error = analytic_error_bits(&source_irs, mask_oracle); - let config = solve(&spec, &ctx, &source_irs, mask_oracle); + let config = solve(&spec, &ctx, &source_irs, mask_oracle, PowSlot::Round { index: 0, kind: RoundSlot::Sumcheck }).unwrap(); assert_pow_closes_gap(&spec, error, &config.round_pow); } } @@ -259,7 +279,17 @@ mod tests { c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; - let config = solve(&spec, &ctx, &source_irs, Some(info)); + let config = solve( + &spec, + &ctx, + &source_irs, + Some(info), + PowSlot::Round { + index: 0, + kind: RoundSlot::Sumcheck, + }, + ) + .unwrap(); assert!(matches!( config.mode, sumcheck::SumcheckMode::ZeroKnowledge { .. } diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 4f671adc..a2409016 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -178,7 +178,8 @@ pub fn build_round_io( }; let target_list_size = johnson_list_size(f64::from(target_log_inv_rate)); - let t_ood = compute_t_ood(spec, &source, target_list_size, c_zk_list_size); + let t_ood = compute_t_ood(spec, &source, target_list_size, c_zk_list_size, 0) + .expect("compute_t_ood diverged in test fixture"); let target = irs_solver::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); (source, target, t_ood) } diff --git a/src/protocols/proof_of_work.rs b/src/protocols/proof_of_work.rs index f316d650..ccdc3a60 100644 --- a/src/protocols/proof_of_work.rs +++ b/src/protocols/proof_of_work.rs @@ -26,8 +26,20 @@ pub struct Config { pub threshold: u64, } +/// Failure modes for [`Config::grind_to`]. +#[derive(Debug, thiserror::Error, Clone, Copy, PartialEq, Eq)] +pub enum PowError { + /// `target − analytic_error` exceeds what a single grind slot can deliver + /// (`MAX_DIFFICULTY = 60` bits, matching the grinding engine's capacity). + #[error("required {required} bits exceeds the {max} grind cap")] + GapExceedsGrindCap { required: Bits, max: Bits }, +} + +/// Largest gap a single grind slot can close (in bits). +pub const MAX_DIFFICULTY: f64 = 60.0; + pub fn threshold(difficulty: Bits) -> u64 { - assert!((0.0..=60.0).contains(&difficulty.into())); + assert!((0.0..=MAX_DIFFICULTY).contains(&difficulty.into())); let threshold = (64.0 - f64::from(difficulty)).exp2().ceil(); #[allow(clippy::cast_sign_loss)] @@ -71,12 +83,25 @@ impl Config { /// soundness up to `target`. The caller is responsible for ensuring /// `analytic_error` is computed from the local protocol step (see e.g. /// `params::sumcheck`). - pub fn grind_to(target: Bits, analytic_error: Bits, hash_id: EngineId) -> Self { + /// + /// Returns [`PowError::GapExceedsGrindCap`] if the required difficulty + /// exceeds [`MAX_DIFFICULTY`] — the spec is too tight for any single slot. + pub fn grind_to( + target: Bits, + analytic_error: Bits, + hash_id: EngineId, + ) -> Result { let gap = (f64::from(target) - f64::from(analytic_error)).max(0.0); - Self { + if gap > MAX_DIFFICULTY { + return Err(PowError::GapExceedsGrindCap { + required: Bits::new(gap), + max: Bits::new(MAX_DIFFICULTY), + }); + } + Ok(Self { hash_id, threshold: threshold(Bits::new(gap)), - } + }) } #[cfg_attr(feature = "tracing", instrument(skip_all, fields(engine)))] From 7b8ec9c3367a379b81b98c8768f969f7c3f3109d Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Fri, 22 May 2026 09:41:06 +0530 Subject: [PATCH 19/31] refactor : added enums differentiating branches --- src/protocols/basecase.rs | 13 +-- src/protocols/params/code_switch.rs | 14 +-- src/protocols/params/derive.rs | 84 ++++++++++++--- src/protocols/params/error.rs | 4 +- src/protocols/params/irs_commit.rs | 27 +++-- src/protocols/params/mod.rs | 4 +- src/protocols/params/protocol_config.rs | 4 +- src/protocols/params/spec.rs | 138 ++++++++++++++++++++---- src/protocols/params/test_utils.rs | 35 +++--- 9 files changed, 238 insertions(+), 85 deletions(-) diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index bdd37731..8c8ef5f5 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -44,8 +44,10 @@ pub struct Config { } impl Config { - /// `mode == ZeroKnowledge` iff `pow != PowConfig::none()`: ZK basecase has - /// a γ-combination PoW slot (Lemma 7.4); Standard has no γ challenge. + /// Standard basecase has no γ challenge — PoW must be `none()`. ZK + /// basecase has a γ-combination slot (Lemma 7.4) and may or may not need + /// PoW depending on whether the analytic floor already clears the target + /// (under unique decoding it often does). pub fn new( commit: irs_commit::Config>, sumcheck: sumcheck::Config, @@ -53,10 +55,9 @@ impl Config { pow: proof_of_work::Config, ) -> Self { let has_pow = pow != proof_of_work::Config::none(); - debug_assert_eq!( - matches!(mode, BasecaseMode::ZeroKnowledge), - has_pow, - "ZK basecase needs PoW; Standard basecase must have none", + debug_assert!( + !matches!(mode, BasecaseMode::Standard) || !has_pow, + "Standard basecase has no γ challenge — pow must be none()", ); Self { commit, diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 74d1ee5d..b0f0dd82 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -113,9 +113,9 @@ mod tests { bounds::johnson_list_size, derive::{compute_l_zk, compute_t_ood}, irs_commit as irs_solver, - spec::{ - ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, - SecuritySpec, + spec::{DecodingRegime, + ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, + RoundContext, SecuritySpec, ZkSpec, }, test_utils::{ arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, @@ -236,8 +236,9 @@ mod tests { let spec = SecuritySpec { mode: Mode::Standard, + decoding_regime: DecodingRegime::Johnson, target_security_bits: LIMITING_TARGET_BITS, - max_pow_bits: None, + pow_budget: PowBudget::Forbidden, hash_id: crate::hash::BLAKE3, }; let (source, target, t_ood) = build_round_io::( @@ -296,8 +297,9 @@ mod tests { &placeholder_source_ctx, OodSampleBudget::ZERO, ); + let zk_spec = ZkSpec::try_new(&spec).expect("arb_zk_spec"); let c_zk_placeholder = irs_solver::solve_mask_code::( - &spec, + zk_spec, compute_l_zk(&placeholder_source, 1), placeholder_source.mask_length(), LogInvRate::new(log_inv_rate), @@ -309,7 +311,7 @@ mod tests { let r = source.mask_length(); let l_zk = compute_l_zk(&source, t_ood); let c_zk = irs_solver::solve_mask_code::( - &spec, + zk_spec, l_zk, r, LogInvRate::new(log_inv_rate), diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 64af6fb0..11a5c8ba 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -22,7 +22,7 @@ use crate::{ }, spec::{ LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, - TuningSpec, + TuningSpec, ZkSpec, }, sumcheck as sumcheck_solver, }, @@ -35,7 +35,7 @@ impl ProtocolConfig { /// /// Fails with [`DeriveError`] when the spec/tuning combination is /// infeasible: a PoW slot exceeds the grind cap, a fixed point diverges, - /// or any slot exceeds `spec.max_pow_bits` (post-derivation validation). + /// or any slot exceeds `spec.pow_budget` (post-derivation validation). pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Result { let RoundLayout { shapes, @@ -49,10 +49,12 @@ impl ProtocolConfig { .map(|shape| build_round_config::(&spec, shape, None)) .collect::>()?, Mode::ZeroKnowledge => { + let zk_spec = + ZkSpec::try_new(&spec).expect("matched Mode::ZeroKnowledge above"); let c_zk_log_inv_rate = LogInvRate::new(tuning.starting_log_inv_rate); shapes .iter() - .map(|shape| build_zk_round_config::(&spec, shape, c_zk_log_inv_rate)) + .map(|shape| build_zk_round_config::(zk_spec, shape, c_zk_log_inv_rate)) .collect::>()? } }; @@ -154,10 +156,11 @@ fn target_context(shape: &RoundShape, source: &IrsConfig) -> Ro /// Theorem 9.6's witness layout + Lemma 9.3's `r ≥ t` privacy precondition; /// `t_ood` solves Lemma 9.9 term 1. fn build_zk_round_config( - spec: &SecuritySpec, + zk_spec: ZkSpec<'_>, shape: &RoundShape, c_zk_log_inv_rate: LogInvRate, ) -> Result, DeriveError> { + let spec = zk_spec.get(); let ctx = round_context(shape); let num_masks = sumcheck_solver::masks_required(&ctx) + code_switch_solver::masks_required(); // C_zk.list_size depends only on rate — no IRS build needed for it. @@ -171,7 +174,7 @@ fn build_zk_round_config( let l_zk = compute_l_zk(&source, t_ood); let c_zk: IrsConfig> = irs_solver::solve_mask_code( - spec, + zk_spec, l_zk, source.mask_length(), c_zk_log_inv_rate, @@ -358,7 +361,7 @@ mod tests { hash, protocols::params::{ bounds::SoundnessBounded, - spec::FoldingFactor, + spec::{DecodingRegime, FoldingFactor, PowBudget}, test_utils::{assert_close, assert_pow_closes_gap, TestEmbedding}, }, }; @@ -418,10 +421,11 @@ mod tests { fn test_spec(mode: Mode) -> SecuritySpec { SecuritySpec { mode, + decoding_regime: DecodingRegime::Johnson, target_security_bits: PLAN_FIXTURE_TARGET_BITS, // Allow up to the grind cap; derive() auto-validates the budget - // and would reject configs that need any PoW under `None ⇒ 0`. - max_pow_bits: Some(LOOSE_POW_BUDGET_BITS), + // and would reject configs that need any PoW under `Forbidden`. + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, } } @@ -621,7 +625,7 @@ mod tests { const LOOSE_POW_BUDGET_BITS: u32 = 60; /// Sits between a moderate budget (30) and the grind cap (60) — used by /// `check_pow_bits_detects_over_budget_slot` to inject a slot that fits - /// the cap but exceeds the test's `max_pow_bits`. + /// the cap but exceeds the test's `pow_budget`. const OVER_BUDGET_INJECTED_BITS: f64 = 50.0; /// Bounds doc §5.3 + §5.7: HVZK privacy error in bits matches the closed @@ -664,13 +668,14 @@ mod tests { ); } - /// Derived plans must satisfy their own `max_pow_bits` budget. + /// Derived plans must satisfy their own `pow_budget`. #[test] fn check_pow_bits_passes_on_derived_plan() { let spec = SecuritySpec { mode: Mode::ZeroKnowledge, + decoding_regime: DecodingRegime::Johnson, target_security_bits: PLAN_FIXTURE_TARGET_BITS, - max_pow_bits: Some(LOOSE_POW_BUDGET_BITS), + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, }; let plan = ProtocolConfig::::derive( @@ -693,8 +698,9 @@ mod tests { const MODERATE_POW_BUDGET_BITS: u32 = 30; let spec = SecuritySpec { mode: Mode::ZeroKnowledge, + decoding_regime: DecodingRegime::Johnson, target_security_bits: PLAN_FIXTURE_TARGET_BITS, - max_pow_bits: Some(MODERATE_POW_BUDGET_BITS), + pow_budget: PowBudget::per_slot(MODERATE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, }; let mut plan = ProtocolConfig::::derive( @@ -714,8 +720,9 @@ mod tests { const UNREACHABLE_TARGET_BITS: u32 = 200; let spec = SecuritySpec { mode: Mode::Standard, + decoding_regime: DecodingRegime::Johnson, target_security_bits: UNREACHABLE_TARGET_BITS, - max_pow_bits: Some(LOOSE_POW_BUDGET_BITS), + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, }; let err = ProtocolConfig::::derive( @@ -730,28 +737,71 @@ mod tests { } /// `derive()` reports `PowBudgetExceeded` when a slot's required PoW - /// fits the grind cap but exceeds `spec.max_pow_bits`. `target = 40` - /// with `max_pow_bits = Some(5)` forces this on `Field64`. + /// fits the grind cap but exceeds `spec.pow_budget`. `target = 40` + /// with `pow_budget = PerSlot { bits: 5 }` forces this on `Field64`. #[test] fn derive_reports_pow_budget_exceeded() { const TIGHT_MAX_POW: u32 = 5; let spec = SecuritySpec { mode: Mode::ZeroKnowledge, + decoding_regime: DecodingRegime::Johnson, target_security_bits: PLAN_FIXTURE_TARGET_BITS, - max_pow_bits: Some(TIGHT_MAX_POW), + pow_budget: PowBudget::per_slot(TIGHT_MAX_POW), hash_id: hash::BLAKE3, }; let err = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) - .expect_err("tight max_pow_bits must trip auto-validation"); + .expect_err("tight pow_budget must trip auto-validation"); assert!( matches!(err, DeriveError::PowBudgetExceeded { .. }), "got {err:?}", ); } + /// Unique decoding threads through to the basecase IRS in Standard mode. + /// Uses a basecase-only tuning so the regime is unambiguous (no rate + /// stepping across rounds). + #[test] + fn derive_threads_unique_decoding_standard() { + let spec = SecuritySpec { + mode: Mode::Standard, + decoding_regime: DecodingRegime::Unique, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), + ) + .unwrap(); + assert!(plan.rounds.is_empty()); + assert!(plan.basecase.commit.unique_decoding()); + } + + /// Same threading check under ZK mode. Basecase-only avoids the per-round + /// code-switch (which still requires `t_ood ≥ 1` until Stage 2 of the + /// regime work lands). + #[test] + fn derive_threads_unique_decoding_zk() { + let spec = SecuritySpec { + mode: Mode::ZeroKnowledge, + decoding_regime: DecodingRegime::Unique, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), + ) + .unwrap(); + assert!(plan.rounds.is_empty()); + assert!(plan.basecase.commit.unique_decoding()); + } + /// `analytic_error + pow ≥ target` for every PoW slot in the plan. fn assert_plan_meets_target_per_slot( spec: &SecuritySpec, diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index cfb173b2..0e9b4bbb 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -71,8 +71,8 @@ pub enum DeriveError { }, /// A PoW slot fits the grind cap but exceeds the per-slot budget set by - /// [`super::spec::SecuritySpec::max_pow_bits`]. - #[error("{slot} requires {required} bits, exceeds spec.max_pow_bits = {max}")] + /// [`super::spec::SecuritySpec::pow_budget`]. + #[error("{slot} requires {required} bits, exceeds spec.pow_budget = {max}")] PowBudgetExceeded { slot: PowSlot, required: Bits, diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 47027e25..7e68683e 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -12,6 +12,7 @@ use crate::{ bounds::rate, spec::{ LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + ZkSpec, }, }, }, @@ -25,8 +26,7 @@ pub fn solve( let security_target = f64::from(spec.protocol_security_target_bits()); let rate = rate(f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; - // Construction 9.7 is Johnson-only — `Mode` cannot express unique-decoding. - let unique_decoding = false; + let unique_decoding = spec.decoding_regime.unique_decoding(); let mode = match spec.mode { Mode::Standard => IrsMode::Standard, @@ -58,17 +58,13 @@ pub fn solve( /// - `source_mask_length`: `r` from Theorem 9.6. /// - `num_vectors`: `2 * num_masks` (Construction 7.2: originals + fresh). pub fn solve_mask_code( - spec: &SecuritySpec, + spec: ZkSpec<'_>, l_zk: MaskCodeMessageLen, source_mask_length: usize, log_inv_rate: LogInvRate, num_vectors: usize, ) -> IrsConfig { let l_zk = l_zk.get(); - assert!( - matches!(spec.mode, Mode::ZeroKnowledge), - "C_zk only exists in ZK mode" - ); assert!( l_zk >= source_mask_length, "Theorem 9.6: ℓ_zk ({l_zk}) ≥ source mask length ({source_mask_length})", @@ -79,12 +75,13 @@ pub fn solve_mask_code( "num_vectors ({num_vectors}) must be even (mask-proximity original/fresh pairs)", ); + let spec = spec.get(); let security_target = f64::from(spec.protocol_security_target_bits()); let rate = rate(f64::from(log_inv_rate.get())); IrsConfig::new( security_target, - false, // ZK ⇒ Johnson regime + spec.decoding_regime.unique_decoding(), spec.hash_id, num_vectors, l_zk, @@ -107,31 +104,33 @@ mod tests { type M = TestEmbedding; #[test] - #[should_panic(expected = "C_zk only exists in ZK mode")] - fn solve_mask_code_rejects_standard_spec() { + fn zk_spec_rejects_standard_mode() { let spec: SecuritySpec = deterministic_spec(Mode::Standard); - let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 2); + assert!(ZkSpec::try_new(&spec).is_none()); } #[test] #[should_panic(expected = "must be a power of 2")] fn solve_mask_code_rejects_non_pow2_l_zk() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); - let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(3), 0, LogInvRate::new(1), 2); + let zk_spec = ZkSpec::try_new(&spec).unwrap(); + let _ = solve_mask_code::(zk_spec, MaskCodeMessageLen::new(3), 0, LogInvRate::new(1), 2); } #[test] #[should_panic(expected = "Theorem 9.6")] fn solve_mask_code_rejects_l_zk_below_source_mask_length() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); - let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(2), 4, LogInvRate::new(1), 2); + let zk_spec = ZkSpec::try_new(&spec).unwrap(); + let _ = solve_mask_code::(zk_spec, MaskCodeMessageLen::new(2), 4, LogInvRate::new(1), 2); } #[test] #[should_panic(expected = "must be even")] fn solve_mask_code_rejects_odd_num_vectors() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); - let _ = solve_mask_code::(&spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 3); + let zk_spec = ZkSpec::try_new(&spec).unwrap(); + let _ = solve_mask_code::(zk_spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 3); } /// `irs_commit::solve` doesn't grind PoW, so this range can sit higher than diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index 34f113b9..d5a9f530 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -27,6 +27,6 @@ pub use protocol_config::{ MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, }; pub use spec::{ - FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, - SecuritySpec, TuningSpec, + DecodingRegime, FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, + OodSampleBudget, PowBudget, RoundContext, SecuritySpec, TuningSpec, ZkSpec, }; diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 1face937..5ec5c9b5 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -40,7 +40,7 @@ pub struct ProtocolConfig { impl ProtocolConfig { /// Returns `true` if every PoW slot's difficulty fits within - /// `security.max_pow_bits`. Boolean predicate kept for callers that want + /// `security.pow_budget`. Boolean predicate kept for callers that want /// to re-check after manual inspection; [`Self::validate_pow_budget`] is /// the typed version used internally by [`super::derive::ProtocolConfig::derive`]. pub fn check_pow_bits(&self) -> bool { @@ -51,7 +51,7 @@ impl ProtocolConfig { /// and required-vs-max difficulties on failure. Auto-invoked by /// `derive()`; callers don't normally need to call this directly. pub fn validate_pow_budget(&self) -> Result<(), DeriveError> { - let max = Bits::new(f64::from(self.security.max_pow_bits.unwrap_or(0))); + let max = Bits::new(f64::from(self.security.pow_budget.bits())); let check = |slot: PowSlot, pow: &PowConfig| -> Result<(), DeriveError> { let required = pow.difficulty(); if required > max { diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index f42bc0ac..a753f3f4 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -1,9 +1,50 @@ -use core::marker::PhantomData; +use core::{marker::PhantomData, num::NonZeroU32}; use ordered_float::OrderedFloat; use crate::{bits::Bits, engines::EngineId}; +/// Per-slot proof-of-work policy. +/// +/// The same `bits` value plays two roles, deliberately coupled: +/// - **Planning credit**: [`SecuritySpec::protocol_security_target_bits`] +/// subtracts `bits` from `target_security_bits` so solvers know the +/// analytic floor they must reach. +/// - **Validation cap**: [`super::protocol_config::ProtocolConfig::validate_pow_budget`] +/// rejects any per-slot PoW that exceeds `bits`. +/// +/// `Forbidden` is *not* `PerSlot { bits: 0 }`: the latter is unrepresentable +/// (the variant takes a [`NonZeroU32`]). Use [`PowBudget::per_slot`] when +/// converting from an arbitrary `u32` — it collapses `0` to `Forbidden`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum PowBudget { + /// Per-slot grinding forbidden. Solvers still plan against the full + /// `target_security_bits`; any nonzero per-slot PoW the planner emits + /// is rejected by validation. + Forbidden, + /// Per-slot grinding allowed up to `bits`. Planning relaxes the + /// analytic target by `bits`; validation caps every slot at `bits`. + PerSlot { bits: NonZeroU32 }, +} + +impl PowBudget { + /// `Forbidden` when `bits == 0`, else `PerSlot { bits }`. + pub const fn per_slot(bits: u32) -> Self { + match NonZeroU32::new(bits) { + Some(bits) => Self::PerSlot { bits }, + None => Self::Forbidden, + } + } + + /// Bits of grinding allowed per slot. `0` for [`PowBudget::Forbidden`]. + pub const fn bits(self) -> u32 { + match self { + Self::Forbidden => 0, + Self::PerSlot { bits } => bits.get(), + } + } +} + /// Phantom-typed newtype — `Tagged` and `Tagged` are distinct types. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Tagged(T, PhantomData); @@ -21,17 +62,21 @@ impl Tagged { #[derive(Debug, Clone)] pub struct SecuritySpec { pub mode: Mode, + /// Reed–Solomon decoding regime — selects the proximity radius `δ` and + /// slack policy. See [`DecodingRegime`]. + pub decoding_regime: DecodingRegime, pub target_security_bits: u32, - /// Per-slot PoW budget — every grinding slot may close at most this many - /// bits of gap to `target_security_bits`. Not a cumulative budget across - /// slots; `check_pow_bits` enforces it per-slot. `None` ⇒ `Some(0)`. - pub max_pow_bits: Option, + /// Per-slot PoW policy — both the planning credit subtracted from + /// `target_security_bits` and the per-slot cap enforced by + /// [`super::protocol_config::ProtocolConfig::validate_pow_budget`]. + /// See [`PowBudget`] for the dual role. + pub pow_budget: PowBudget, pub hash_id: EngineId, } impl SecuritySpec { pub fn protocol_security_target_bits(&self) -> Bits { - let pow = self.max_pow_bits.unwrap_or(0); + let pow = self.pow_budget.bits(); Bits::new(f64::from(self.target_security_bits.saturating_sub(pow))) } } @@ -90,15 +135,59 @@ pub struct RoundContext { pub folding_factor: u32, } -/// Both variants run in the Johnson regime — Construction 9.7's OOD-query -/// requirement makes unique-decoding incompatible with code-switch, so it is -/// not representable here. +/// Standard vs. zero-knowledge selection. Orthogonal to [`DecodingRegime`]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Mode { Standard, ZeroKnowledge, } +/// A `SecuritySpec` borrow proven to be in [`Mode::ZeroKnowledge`]. +/// +/// Constructed only via [`ZkSpec::try_new`], which performs the mode check +/// once at the boundary. ZK-only solvers accept `ZkSpec` to make +/// "ZK mode required" a compile-time precondition instead of a runtime assert. +#[derive(Debug, Clone, Copy)] +pub struct ZkSpec<'a>(&'a SecuritySpec); + +impl<'a> ZkSpec<'a> { + /// Returns `Some` iff `spec.mode == Mode::ZeroKnowledge`. + pub fn try_new(spec: &'a SecuritySpec) -> Option { + matches!(spec.mode, Mode::ZeroKnowledge).then_some(Self(spec)) + } + + pub const fn get(self) -> &'a SecuritySpec { + self.0 + } +} + +/// Reed–Solomon decoding regime selection. +/// +/// Picks the proximity radius `δ` and slack policy used by the IRS and +/// downstream sub-protocols. `Johnson` uses the codebase's slack policy +/// `η = √ρ / 20`; the list-decoding ball can hold `~10/ρ` codewords. +/// `Unique` operates strictly inside the unique-decoding radius `(1 − ρ)/2`; +/// the ball holds at most one. +/// +/// WHIR's rate stepping (each round bumps `log_inv_rate` by +/// `folding_factor − 1`) pushes ρ → 1, shrinking the unique-decoding +/// radius. At high security targets or deep folding, `Unique` may exceed +/// the grind cap on per-round PoW and [`super::derive::ProtocolConfig::derive`] +/// will return `PowUngrindable`. Pick `Johnson` for those cases. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DecodingRegime { + Unique, + Johnson, +} + +impl DecodingRegime { + /// Bridge to [`super::super::irs_commit::Config::new`]'s `unique_decoding` + /// parameter. + pub const fn unique_decoding(self) -> bool { + matches!(self, Self::Unique) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum OodSampleBudgetTag {} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -147,40 +236,47 @@ mod tests { /// the tests below are round numbers (80, 40, 0) for readability. const TARGET_BITS: u32 = 100; - fn spec(max_pow_bits: Option) -> SecuritySpec { + fn spec(pow_budget: PowBudget) -> SecuritySpec { SecuritySpec { mode: Mode::ZeroKnowledge, + decoding_regime: DecodingRegime::Johnson, target_security_bits: TARGET_BITS, - max_pow_bits, + pow_budget, hash_id: hash::BLAKE3, } } #[test] - fn none_means_no_pow_credit() { + fn forbidden_means_no_pow_credit() { assert_eq!( - spec(None).protocol_security_target_bits(), + spec(PowBudget::Forbidden).protocol_security_target_bits(), Bits::new(f64::from(TARGET_BITS)), ); } #[test] - fn some_zero_matches_none() { - assert_eq!( - spec(Some(0)).protocol_security_target_bits(), - spec(None).protocol_security_target_bits(), - ); + fn per_slot_zero_collapses_to_forbidden() { + // `per_slot(0)` is the only documented way to ask for "no grinding" + // from a `u32`; it must produce the `Forbidden` variant, not a + // `PerSlot { bits: 0 }` (which is unrepresentable). + assert_eq!(PowBudget::per_slot(0), PowBudget::Forbidden); + } + + #[test] + fn per_slot_bits_round_trip() { + assert_eq!(PowBudget::per_slot(20).bits(), 20); + assert_eq!(PowBudget::Forbidden.bits(), 0); } #[test] fn pow_credit_shifts_analytic_floor() { // Two below-target PoW budgets: `target − pow` shifts down 1:1. assert_eq!( - spec(Some(20)).protocol_security_target_bits(), + spec(PowBudget::per_slot(20)).protocol_security_target_bits(), Bits::new(80.0), ); assert_eq!( - spec(Some(60)).protocol_security_target_bits(), + spec(PowBudget::per_slot(60)).protocol_security_target_bits(), Bits::new(40.0), ); } @@ -190,7 +286,7 @@ mod tests { // `pow > target` saturates rather than going negative. let pow_over_target = TARGET_BITS + 100; assert_eq!( - spec(Some(pow_over_target)).protocol_security_target_bits(), + spec(PowBudget::per_slot(pow_over_target)).protocol_security_target_bits(), Bits::new(0.0), ); } diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index a2409016..23f2fa34 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -18,9 +18,9 @@ use crate::{ derive::compute_t_ood, irs_commit as irs_solver, protocol_config::MaskOracleInfo, - spec::{ - ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, - SecuritySpec, + spec::{DecodingRegime, + ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, + RoundContext, SecuritySpec, ZkSpec, }, }, proof_of_work::Config as PowConfig, @@ -51,24 +51,30 @@ pub const EPS: f64 = 1e-9; pub fn deterministic_spec(mode: Mode) -> SecuritySpec { SecuritySpec { mode, + decoding_regime: DecodingRegime::Johnson, target_security_bits: FIXTURE_TARGET_BITS, - max_pow_bits: None, + pow_budget: PowBudget::Forbidden, hash_id: hash::BLAKE3, } } -/// `max_pow_bits` ∈ `{None, Some(0..=16)}`; bounded so the analytic floor -/// stays positive for the lowest test targets and the PoW gap stays under the -/// 60-bit cap. +/// `pow_budget` ∈ `{Forbidden, PerSlot{1..=16}}`; bounded so the analytic +/// floor stays positive for the lowest test targets and the PoW gap stays +/// under the 60-bit cap. `PerSlot { bits: 0 }` is unrepresentable, so we +/// generate `Forbidden` for the "no grinding" case directly. pub fn arb_spec( mode: Mode, target_range: RangeInclusive, ) -> impl Strategy { - let pow_strategy = prop_oneof![Just(None), (0u32..=16).prop_map(Some)]; - (target_range, pow_strategy).prop_map(move |(target, max_pow)| SecuritySpec { + let pow_strategy = prop_oneof![ + Just(PowBudget::Forbidden), + (1u32..=16).prop_map(PowBudget::per_slot), + ]; + (target_range, pow_strategy).prop_map(move |(target, pow_budget)| SecuritySpec { mode, + decoding_regime: DecodingRegime::Johnson, target_security_bits: target, - max_pow_bits: max_pow, + pow_budget, hash_id: hash::BLAKE3, }) } @@ -97,12 +103,10 @@ pub fn arb_round_ctx() -> impl Strategy { /// `None` in Standard; `Some(ℓ_zk=2, c_zk rate 1/2)` in ZK. pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option { - if !matches!(spec.mode, Mode::ZeroKnowledge) { - return None; - } + let zk_spec = ZkSpec::try_new(spec)?; let l_zk = MaskCodeMessageLen::new(2); let c_zk: IrsConfig = - irs_solver::solve_mask_code(spec, l_zk, 0, LogInvRate::new(1), 2); + irs_solver::solve_mask_code(zk_spec, l_zk, 0, LogInvRate::new(1), 2); Some(MaskOracleInfo { c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, @@ -139,8 +143,9 @@ pub fn build_test_c_zk( log_inv_rate: u32, num_masks: usize, ) -> IrsConfig { + let zk_spec = ZkSpec::try_new(spec).expect("build_test_c_zk requires a ZK spec"); irs_solver::solve_mask_code( - spec, + zk_spec, MaskCodeMessageLen::new(l_zk), 0, LogInvRate::new(log_inv_rate), From cf774f8aeadfa1c8418166f4d09139d6c04fabeb Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 23 May 2026 00:27:45 +0530 Subject: [PATCH 20/31] feat : added proper error types and refactor of params --- src/bin/benchmark.rs | 13 +- src/bin/main.rs | 13 +- src/parameters.rs | 14 +- src/protocols/irs_commit.rs | 62 ++--- src/protocols/params/basecase.rs | 18 +- src/protocols/params/bounds.rs | 219 +--------------- src/protocols/params/code_switch.rs | 16 +- src/protocols/params/derive.rs | 240 +++++++++-------- src/protocols/params/error.rs | 167 ++++++++---- src/protocols/params/irs_commit.rs | 32 ++- src/protocols/params/mask_proximity.rs | 20 +- src/protocols/params/mod.rs | 7 +- src/protocols/params/protocol_config.rs | 334 ++++++++++++++++++------ src/protocols/params/regime.rs | 247 ++++++++++++++++++ src/protocols/params/spec.rs | 74 +++++- src/protocols/params/sumcheck.rs | 25 +- src/protocols/params/test_utils.rs | 12 +- src/protocols/whir/config.rs | 18 +- src/protocols/whir/mod.rs | 44 ++-- src/protocols/whir_zk/mod.rs | 7 +- 20 files changed, 955 insertions(+), 627 deletions(-) create mode 100644 src/protocols/params/regime.rs diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index d9e79ce6..77574c42 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -18,6 +18,7 @@ use whir::{ cmdline_utils::{AvailableFields, AvailableHash}, hash::HASH_COUNTER, parameters::ProtocolParameters, + protocols::params::DecodingRegime, transcript::{codecs::Empty, Codec, DomainSeparator, ProverState, VerifierState}, }; @@ -48,8 +49,8 @@ struct Args { #[arg(short = 'k', long = "fold", default_value = "4")] folding_factor: usize, - #[arg(long = "unique-decoding", default_value_t = false)] - unique_decoding: bool, + #[arg(long = "decoding-regime", default_value = "Johnson")] + decoding_regime: DecodingRegime, #[arg(short = 'f', long = "field", default_value = "Goldilocks3")] field: AvailableFields, @@ -67,7 +68,7 @@ struct BenchmarkOutput { repetitions: usize, initial_folding_factor: usize, folding_factor: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, field: AvailableFields, hash: AvailableHash, @@ -117,7 +118,7 @@ where let reps = args.verifier_repetitions; let folding_factor = args.folding_factor; let first_round_folding_factor = args.first_round_folding_factor; - let unique_decoding = args.unique_decoding; + let decoding_regime = args.decoding_regime; std::fs::create_dir_all("outputs").unwrap(); @@ -128,7 +129,7 @@ where pow_bits, initial_folding_factor: first_round_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: starting_rate, batch_size: 1, hash_id: args.hash.hash_id(), @@ -298,7 +299,7 @@ where repetitions: reps, initial_folding_factor: first_round_folding_factor, folding_factor, - unique_decoding, + decoding_regime, field: args.field, hash: args.hash, diff --git a/src/bin/main.rs b/src/bin/main.rs index bafe16f5..0faf4b36 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -14,6 +14,7 @@ use whir::{ cmdline_utils::{AvailableFields, AvailableHash}, hash::HASH_COUNTER, parameters::ProtocolParameters, + protocols::params::DecodingRegime, transcript::{codecs::Empty, Codec, DomainSeparator, ProverState, VerifierState}, }; @@ -48,9 +49,9 @@ struct Args { #[arg(short = 'k', long = "fold", default_value = "4")] folding_factor: usize, - /// Restrict PCS to the Unique Decoding regime. LDT is always UD. - #[arg(long = "unique-decoding", default_value_t = false)] - unique_decoding: bool, + /// Reed–Solomon decoding regime: Unique or Johnson (list-decoding). + #[arg(long = "decoding-regime", default_value = "Johnson")] + decoding_regime: DecodingRegime, #[arg(short = 'f', long = "field", default_value = "Goldilocks3")] field: AvailableFields, @@ -109,7 +110,7 @@ where let reps = args.verifier_repetitions; let first_round_folding_factor = args.first_round_folding_factor; let folding_factor = args.folding_factor; - let unique_decoding = args.unique_decoding; + let decoding_regime = args.decoding_regime; let num_evaluations = args.num_evaluations; let num_linear_constraints = args.num_linear_constraints; let hash_id = args.hash.hash_id(); @@ -125,7 +126,7 @@ where pow_bits, initial_folding_factor: first_round_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: starting_rate, batch_size: 1, hash_id, @@ -254,7 +255,7 @@ where let num_coeffs = 1 << num_variables; let whir_params = ProtocolParameters { - unique_decoding: args.unique_decoding, + decoding_regime: args.decoding_regime, security_level, pow_bits, initial_folding_factor: first_round_folding_factor, diff --git a/src/parameters.rs b/src/parameters.rs index 242b9749..8e257c79 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -2,13 +2,13 @@ use std::fmt::{Debug, Display}; use serde::{Deserialize, Serialize}; -use crate::engines::EngineId; +use crate::{engines::EngineId, protocols::params::DecodingRegime}; /// Configuration parameters for WHIR proofs. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct ProtocolParameters { - /// Whether to require unique decoding. - pub unique_decoding: bool, + /// Reed–Solomon decoding regime: `Unique` or `Johnson` (list-decoding). + pub decoding_regime: DecodingRegime, /// The logarithmic inverse rate for sampling. pub starting_log_inv_rate: usize, /// Folding factor for the initial round. @@ -30,13 +30,7 @@ impl Display for ProtocolParameters { writeln!( f, "Targeting {}-bits of security with {}-bits of PoW using {} decoding", - self.security_level, - self.pow_bits, - if self.unique_decoding { - "unique" - } else { - "list" - } + self.security_level, self.pow_bits, self.decoding_regime, )?; writeln!( f, diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index f2c113f8..4f1d2540 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -15,14 +15,13 @@ use std::{f64, fmt, num::NonZeroUsize}; use ark_ff::{AdditiveGroup, Field}; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; -use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; #[cfg(feature = "tracing")] use tracing::instrument; use crate::{ algebra::{ - dot, embedding::Embedding, lift, linear_form::UnivariateEvaluation, + dot, embedding::Embedding, fields::FieldWithSize, lift, linear_form::UnivariateEvaluation, mixed_univariate_evaluate, ntt, random_vector, }, engines::EngineId, @@ -30,9 +29,7 @@ use crate::{ protocols::{ challenge_indices::challenge_indices, matrix_commit, - params::bounds::{ - eps_mca_log2, list_size_log2, one_minus_distance_log2, ood_per_sample_log2, CodeParams, - }, + params::{bounds::ood_per_sample_log2, regime::DecodingRegimeParams, spec::DecodingRegime}, }, transcript::{ Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, @@ -81,9 +78,8 @@ pub struct Config { /// The matrix commitment configuration. pub matrix_commit: matrix_commit::Config, - /// Slack to the Jonhnson bound in list decoding. - /// Zero indicates unique decoding. - pub johnson_slack: OrderedFloat, + /// Materialized Reed–Solomon decoding regime (Unique / Johnson w/ slack). + pub regime: DecodingRegimeParams, /// The number of in-domain samples. pub in_domain_samples: usize, @@ -129,7 +125,7 @@ impl Config { #[allow(clippy::too_many_arguments)] pub fn new( security_target: f64, - unique_decoding: bool, + decoding_regime: DecodingRegime, hash_id: EngineId, num_vectors: usize, vector_size: usize, @@ -152,14 +148,8 @@ impl Config { .expect("codeword length exceeds NTT engine support"); let rate = masked_message_length as f64 / codeword_length as f64; - // η = slack to Johnson bound. We pick η = √ρ / 20. - // TODO: Optimize picking η. - let johnson_slack = if unique_decoding { - 0.0 - } else { - rate.sqrt() / 20. - }; - let in_domain_samples = num_in_domain_queries(unique_decoding, security_target, rate).get(); + let regime = DecodingRegimeParams::from_policy(decoding_regime, rate); + let in_domain_samples = num_in_domain_queries(decoding_regime, security_target, rate).get(); Self { embedding: Typed::::default(), @@ -172,7 +162,7 @@ impl Config { codeword_length, interleaving_depth * num_vectors, ), - johnson_slack: OrderedFloat(johnson_slack), + regime, in_domain_samples, deduplicate_in_domain: false, mode, @@ -222,8 +212,8 @@ impl Config { self.masked_message_length() as f64 / self.codeword_length as f64 } - pub fn unique_decoding(&self) -> bool { - self.johnson_slack == 0.0 + pub const fn unique_decoding(&self) -> bool { + self.regime.is_unique() } fn log_inv_rate(&self) -> f64 { @@ -232,23 +222,23 @@ impl Config { /// Compute a list size bound. pub fn list_size(&self) -> f64 { - 2_f64.powf(list_size_log2( - self.log_inv_rate(), - self.johnson_slack.into_inner(), - )) + self.regime.list_size(self.log_inv_rate()) } /// Round-by-round soundness of the in-domain queries in bits. pub fn rbr_queries(&self) -> f64 { // Query error is (1 - δ)^q in bits = -q · log2(1 - δ). - -(self.in_domain_samples as f64) - * one_minus_distance_log2(self.log_inv_rate(), self.johnson_slack.into_inner()) + -(self.in_domain_samples as f64) * self.regime.one_minus_distance_log2(self.log_inv_rate()) } /// Round-by-round soundness of the proximity-gaps fold in bits. /// See WHIR Theorem 4.8. pub fn rbr_soundness_fold_prox_gaps(&self) -> f64 { - -eps_mca_log2(&CodeParams::from_irs(self)) + -self.regime.eps_mca_log2( + self.log_inv_rate(), + self.masked_message_length(), + M::Target::field_size_bits(), + ) } /// Commit to one or more vectors. @@ -533,13 +523,13 @@ impl fmt::Display for Config { /// See [STIR] Lemma 4.5. #[allow(clippy::cast_sign_loss)] pub fn num_ood_samples( - unique_decoding: bool, + decoding_regime: DecodingRegime, security_target: f64, field_size_bits: f64, list_size: f64, degree: usize, ) -> usize { - if unique_decoding { + if matches!(decoding_regime, DecodingRegime::Unique) { return 0; } let log_per_sample = -ood_per_sample_log2(degree, field_size_bits); @@ -559,19 +549,13 @@ pub fn num_ood_samples( // TODO: A method with cleaner abstraction. #[allow(clippy::cast_sign_loss)] pub(crate) fn num_in_domain_queries( - unique_decoding: bool, + decoding_regime: DecodingRegime, security_target: f64, rate: f64, ) -> NonZeroUsize { - // η = slack to Johnson bound. We pick η = √ρ / 20. - // TODO: Optimize picking η. - let johnson_slack = if unique_decoding { - 0.0 - } else { - rate.sqrt() / 20. - }; + let regime = DecodingRegimeParams::from_policy(decoding_regime, rate); // Query error is (1 - δ)^q in bits = -q · log2(1 - δ). - let log_one_minus_delta = one_minus_distance_log2(-rate.log2(), johnson_slack); + let log_one_minus_delta = regime.one_minus_distance_log2(-rate.log2()); let q = (security_target / -log_one_minus_delta).ceil() as usize; NonZeroUsize::new(q).unwrap_or(NonZeroUsize::MIN) } @@ -646,7 +630,7 @@ pub(crate) mod tests { codeword_length, interleaving_depth, matrix_commit, - johnson_slack: OrderedFloat::default(), + regime: DecodingRegimeParams::Unique, in_domain_samples, deduplicate_in_domain, mode, diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 262b1596..97a65f5d 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -9,8 +9,7 @@ use crate::{ basecase::{self, Config as BasecaseConfig}, irs_commit::Config as IrsConfig, params::{ - bounds::SoundnessBounded, - error::{BasecaseSlot, DeriveError, PowResultExt, PowSlot}, + error::{DeriveError, Pow, PowResultExt}, irs_commit as irs_solver, spec::{Mode as SpecMode, OodSampleBudget, RoundContext, SecuritySpec}, sumcheck as sumcheck_solver, @@ -42,7 +41,7 @@ pub fn solve( sumcheck_solver::analytic_error_bits(&commit, None), spec.hash_id, ) - .at_slot(PowSlot::Basecase(BasecaseSlot::Sumcheck))?; + .at(Pow::BasecaseSumcheck)?; let sumcheck = SumcheckConfig::new( vector_size, sumcheck_pow, @@ -59,7 +58,7 @@ pub fn solve( basecase::BasecaseMode::Standard => PowConfig::none(), basecase::BasecaseMode::ZeroKnowledge => { PowConfig::grind_to(target_bits, analytic_error_bits(&commit), spec.hash_id) - .at_slot(PowSlot::Basecase(BasecaseSlot::GammaCombination))? + .at(Pow::BasecaseGammaCombination)? } }; @@ -76,10 +75,11 @@ pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { Bits::new(prox_gaps.min(poly_id).max(0.0)) } -impl SoundnessBounded for BasecaseConfig { - /// `min(sumcheck round error, γ-slot error)`. The γ-slot only contributes - /// in ZK mode; Standard collapses to the sumcheck term. - fn analytic_bits(&self) -> Bits { +impl BasecaseConfig { + /// Analytic soundness bits (excluding PoW): `min(sumcheck round error, γ-slot error)`. + /// The γ-slot only contributes in ZK mode; Standard collapses to the + /// sumcheck term. + pub fn analytic_bits(&self) -> Bits { let sumcheck_term = f64::from(sumcheck_solver::analytic_error_bits(&self.commit, None)); let min_bits = match self.mode { basecase::BasecaseMode::Standard => sumcheck_term, @@ -141,7 +141,7 @@ mod tests { } /// At `log_inv_rate = 1` on `Field64`, `ε_mca` is below the poly-identity - /// term — pins the `min` to the arm that earlier returned `poly_id` alone. + /// term — pins the `min` to the prox-gaps arm rather than `poly_id`. #[test] fn analytic_error_uses_eps_mca_when_limiting() { use crate::protocols::params::{ diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index db3a3321..524b0955 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -1,40 +1,7 @@ -//! Shared RS-code primitives + the [`SoundnessBounded`] abstraction. - -use std::{f64::consts::LOG2_10, ops::Neg}; - -use crate::{ - algebra::{embedding::Embedding, fields::FieldWithSize}, - bits::Bits, - protocols::irs_commit::Config as IrsConfig, -}; - -/// Analytic soundness bits (excluding PoW) delivered by a protocol-level unit. -/// Sub-protocol `Config` types lack the cross-protocol context to self-report. -// Library-side callers land with protocol wiring; until then only tests use it. -#[allow(dead_code)] -pub trait SoundnessBounded { - fn analytic_bits(&self) -> Bits; -} - -/// `johnson_slack == 0.0` selects the unique-decoding regime. -#[derive(Debug, Clone, Copy)] -pub struct CodeParams { - pub log_inv_rate: f64, - pub johnson_slack: f64, - pub message_length: usize, - pub field_bits: f64, -} - -impl CodeParams { - pub fn from_irs(irs: &IrsConfig) -> Self { - Self { - log_inv_rate: irs.rate().log2().neg(), - johnson_slack: irs.johnson_slack.into_inner(), - message_length: irs.masked_message_length(), - field_bits: M::Target::field_size_bits(), - } - } -} +//! Regime-agnostic analytic primitives shared across the params solvers. +//! +//! Regime-specific math (Johnson / Unique branches) lives on +//! [`super::regime::DecodingRegimeParams`]. /// `ρ = 2^-log_inv_rate`. Centralized so the rate formula lives in one place. pub(super) fn rate(log_inv_rate: f64) -> f64 { @@ -48,53 +15,7 @@ pub(super) const fn usize_to_f64(x: usize) -> f64 { x as f64 } -fn unique_decoding(johnson_slack: f64) -> bool { - johnson_slack == 0.0 -} - -/// log2 |Λ(C, δ)|. -pub fn list_size_log2(log_inv_rate: f64, johnson_slack: f64) -> f64 { - if unique_decoding(johnson_slack) { - 0.0 - } else { - // Johnson: |Λ| = 1 / (2 η √ρ). - -1.0 - johnson_slack.log2() + 0.5 * log_inv_rate - } -} - -/// `|Λ(C)|` for a Johnson-regime code derived purely from the rate, using the -/// canonical `η = √ρ / 20` slack. -pub fn johnson_list_size(log_inv_rate: f64) -> f64 { - let rate = 2_f64.powf(-log_inv_rate); - let johnson_slack = rate.sqrt() / 20.0; - 2_f64.powf(list_size_log2(log_inv_rate, johnson_slack)) -} - -/// log2 ε_mca(C, δ). -pub fn eps_mca_log2(p: &CodeParams) -> f64 { - let log_k = (p.message_length as f64).log2(); - - let error = if unique_decoding(p.johnson_slack) { - log_k + p.log_inv_rate - } else { - debug_assert!(p.johnson_slack.log2() >= -(0.5 * p.log_inv_rate + LOG2_10 + 1.0) - 1e-6); - 7.0 * LOG2_10 + 3.5 * p.log_inv_rate + 2.0 * log_k - }; - - error - p.field_bits -} - -/// log2(1 - δ). -pub fn one_minus_distance_log2(log_inv_rate: f64, johnson_slack: f64) -> f64 { - let one_minus_delta = if unique_decoding(johnson_slack) { - f64::midpoint(1.0, rate(log_inv_rate)) - } else { - rate(log_inv_rate).sqrt() + johnson_slack - }; - one_minus_delta.log2() -} - -/// log2 of the per-OOD-sample Schwartz-Zippel error: (k-1)/|F|. +/// log2 of the per-OOD-sample Schwartz–Zippel error: `(k-1)/|F|`. pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { ((message_length - 1) as f64).log2() - field_bits } @@ -105,75 +26,6 @@ mod tests { use super::*; use crate::protocols::params::test_utils::assert_close; - /// Tighter tolerance for tests doing relative-error checks (`(got - exp).abs() / exp`) - /// against an alternative-derived expected value with the same operations. - const TIGHT_EPS: f64 = 1e-12; - - /// Johnson list size: `|Λ| = 1 / (2η√ρ)`, log₂ form. Hand-evaluated at - /// `log_inv_rate = 2`, `η = 0.1`: `−1 − log₂(0.1) + 1 ≈ 3.3219`. - #[test] - fn list_size_log2_johnson_formula() { - let got = list_size_log2(2.0, 0.1); - let expected = -1.0 - 0.1_f64.log2() + 0.5 * 2.0; - assert_close(got, expected); - } - - /// Unique-decoding regime (`η = 0`) gives `|Λ| = 1`, i.e. log = 0. - #[test] - fn list_size_log2_unique_decoding_is_zero() { - assert_eq!(list_size_log2(2.0, 0.0), 0.0); - } - - /// `η = √ρ / 20` substituted into `|Λ| = 1/(2η√ρ)` simplifies to `10/ρ`. - /// So `johnson_list_size(b) = 10 · 2^b`. - #[test] - fn johnson_list_size_closed_form() { - for b in [1.0, 2.0, 3.0, 5.0] { - let got = johnson_list_size(b); - let expected = 10.0 * 2_f64.powf(b); - assert!( - (got - expected).abs() / expected < TIGHT_EPS, - "log_inv_rate={b}: got {got} vs {expected}", - ); - } - } - - /// `johnson_list_size(b) = 2^list_size_log2(b, √ρ/20)` must match `Config::list_size` - /// once a config is built at the same rate. Keeps the bounds helper in sync with - /// `irs_commit::Config::new`'s `johnson_slack = √ρ / 20` policy. - #[test] - fn johnson_list_size_matches_config_list_size() { - use crate::{ - algebra::{embedding::Identity, fields::Field64}, - hash, - protocols::irs_commit::{Config, IrsMode}, - }; - // All shape values except rate are placeholders — `list_size()` depends - // only on `johnson_slack`, which is itself a function of rate. - const PLACEHOLDER_SECURITY_TARGET_BITS: f64 = 80.0; - const PLACEHOLDER_NUM_VECTORS: usize = 2; - const PLACEHOLDER_VECTOR_SIZE: usize = 8; - const PLACEHOLDER_INTERLEAVING_DEPTH: usize = 1; - const LOG_INV_RATE: u32 = 2; - - let config: Config> = Config::new( - PLACEHOLDER_SECURITY_TARGET_BITS, - false, // unique_decoding - hash::BLAKE3, - PLACEHOLDER_NUM_VECTORS, - PLACEHOLDER_VECTOR_SIZE, - PLACEHOLDER_INTERLEAVING_DEPTH, - 2_f64.powf(-f64::from(LOG_INV_RATE)), - IrsMode::Standard, - ); - let got = johnson_list_size(f64::from(LOG_INV_RATE)); - let expected = config.list_size(); - assert!( - (got - expected).abs() / expected < TIGHT_EPS, - "bounds helper ({got}) vs Config::list_size ({expected})", - ); - } - /// OOD per-sample Schwartz–Zippel: `log₂((k−1) / |F|) = log₂(k−1) − field_bits`. #[test] fn ood_per_sample_log2_formula() { @@ -187,65 +39,4 @@ mod tests { // (k−1)/|F| < 1 for sane parameters ⇒ log is negative. assert!(got < 0.0); } - - /// `1 − δ` in unique-decoding mode: midpoint of 1 and ρ. - #[test] - fn one_minus_distance_log2_unique() { - let log_inv_rate = 2.0; - let got = one_minus_distance_log2(log_inv_rate, 0.0); - let rho = 2_f64.powf(-log_inv_rate); - let expected = f64::midpoint(1.0, rho).log2(); - assert_close(got, expected); - } - - /// `1 − δ` in Johnson regime: `√ρ + η`. - #[test] - fn one_minus_distance_log2_johnson() { - let log_inv_rate = 2.0; - let eta = 0.1; - let got = one_minus_distance_log2(log_inv_rate, eta); - let rho = 2_f64.powf(-log_inv_rate); - let expected = (rho.sqrt() + eta).log2(); - assert_close(got, expected); - } - - /// MCA fixture — `message_length = 16 = 2^4` and `log_inv_rate = 2` give - /// exact `log2(k) = 4`. `field_bits = 64.0` for Field64. - const MCA_MESSAGE_LENGTH: usize = 16; - const MCA_LOG_INV_RATE: f64 = 2.0; - const MCA_FIELD_BITS: f64 = 64.0; - - /// MCA error, unique-decoding branch: `log k + log_inv_rate − field_bits`. - #[test] - fn eps_mca_log2_unique_decoding_formula() { - let p = CodeParams { - log_inv_rate: MCA_LOG_INV_RATE, - johnson_slack: 0.0, - message_length: MCA_MESSAGE_LENGTH, - field_bits: MCA_FIELD_BITS, - }; - let got = eps_mca_log2(&p); - let expected = (MCA_MESSAGE_LENGTH as f64).log2() + MCA_LOG_INV_RATE - MCA_FIELD_BITS; - assert_close(got, expected); - } - - /// MCA error, Johnson branch: `7·log₂10 + 3.5·log_inv_rate + 2·log k − field_bits`. - #[test] - fn eps_mca_log2_johnson_formula() { - // `η = 0.1` stays within the debug assertion's slack range: - // `η.log2() ≥ −(0.5·log_inv_rate + log₂10 + 1) ≈ −5.32`. - const JOHNSON_SLACK: f64 = 0.1; - - let p = CodeParams { - log_inv_rate: MCA_LOG_INV_RATE, - johnson_slack: JOHNSON_SLACK, - message_length: MCA_MESSAGE_LENGTH, - field_bits: MCA_FIELD_BITS, - }; - let got = eps_mca_log2(&p); - let expected = - 7.0 * LOG2_10 + 3.5 * MCA_LOG_INV_RATE + 2.0 * (MCA_MESSAGE_LENGTH as f64).log2() - - MCA_FIELD_BITS; - assert_close(got, expected); - } } diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index b0f0dd82..7730518a 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -14,7 +14,7 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::usize_to_f64, - error::{DeriveError, PowResultExt, PowSlot, RoundSlot}, + error::{DeriveError, Pow, PowResultExt}, protocol_config::MaskOracleInfo, spec::SecuritySpec, }, @@ -50,10 +50,8 @@ pub fn solve( let target_bits = Bits::new(f64::from(spec.target_security_bits)); let analytic = analytic_error_bits(&source, &target, t_ood, mask_oracle); - let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id).at_slot(PowSlot::Round { - index: round_index, - kind: RoundSlot::CodeSwitch, - })?; + let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id) + .at(Pow::RoundCodeSwitch { index: round_index })?; Ok(CodeSwitchConfig::new(source, target, t_ood, mode, pow)) } @@ -110,12 +108,12 @@ mod tests { use super::*; use crate::protocols::params::{ - bounds::johnson_list_size, derive::{compute_l_zk, compute_t_ood}, irs_commit as irs_solver, - spec::{DecodingRegime, - ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, - RoundContext, SecuritySpec, ZkSpec, + regime::johnson_list_size, + spec::{ + DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + PowBudget, RoundContext, SecuritySpec, ZkSpec, }, test_utils::{ arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 11a5c8ba..5785e37b 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -12,17 +12,14 @@ use crate::{ protocols::{ irs_commit::{self, Config as IrsConfig}, params::{ - basecase as basecase_solver, - bounds::johnson_list_size, - code_switch as code_switch_solver, - error::{DeriveError, PowSlot, RoundSlot}, + basecase as basecase_solver, code_switch as code_switch_solver, + error::{DeriveError, FixedPointLoop, Pow}, irs_commit as irs_solver, mask_proximity as mask_proximity_solver, - protocol_config::{ - MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, - }, + protocol_config::{MaskOracleConfig, ProtocolConfig, RoundConfig, RoundMode}, + regime::johnson_list_size, spec::{ - LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, - TuningSpec, ZkSpec, + DecodingRegime, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + RoundContext, SecuritySpec, TuningSpec, ZkSpec, }, sumcheck as sumcheck_solver, }, @@ -46,11 +43,10 @@ impl ProtocolConfig { let rounds: Vec> = match spec.mode { Mode::Standard => shapes .iter() - .map(|shape| build_round_config::(&spec, shape, None)) + .map(|shape| build_round_config::(&spec, shape)) .collect::>()?, Mode::ZeroKnowledge => { - let zk_spec = - ZkSpec::try_new(&spec).expect("matched Mode::ZeroKnowledge above"); + let zk_spec = ZkSpec::try_new(&spec).expect("matched Mode::ZeroKnowledge above"); let c_zk_log_inv_rate = LogInvRate::new(tuning.starting_log_inv_rate); shapes .iter() @@ -61,13 +57,8 @@ impl ProtocolConfig { let basecase = basecase_solver::solve(&spec, basecase_vector_size, basecase_log_inv_rate)?; - let plan = Self { - security: spec, - tuning, - rounds, - basecase, - }; - plan.validate_pow_budget()?; + let plan = Self::new(spec, tuning, rounds, basecase); + plan.validate()?; Ok(plan) } } @@ -160,7 +151,7 @@ fn build_zk_round_config( shape: &RoundShape, c_zk_log_inv_rate: LogInvRate, ) -> Result, DeriveError> { - let spec = zk_spec.get(); + let spec = zk_spec.as_inner(); let ctx = round_context(shape); let num_masks = sumcheck_solver::masks_required(&ctx) + code_switch_solver::masks_required(); // C_zk.list_size depends only on rate — no IRS build needed for it. @@ -180,16 +171,9 @@ fn build_zk_round_config( c_zk_log_inv_rate, 2 * num_masks, ); - let mask_oracle = MaskOracleConfig { - mask_proximity: mask_proximity_solver::solve( - spec, - c_zk.clone(), - num_masks, - shape.round_index, - )?, - c_zk, - l_zk, - }; + let mask_proximity = + mask_proximity_solver::solve(spec, c_zk.clone(), num_masks, shape.round_index)?; + let mask_oracle = MaskOracleConfig::new(c_zk, l_zk, mask_proximity); let info = mask_oracle.info(); let sumcheck = sumcheck_solver::solve( @@ -197,23 +181,21 @@ fn build_zk_round_config( &ctx, &source, Some(info), - PowSlot::Round { + Pow::RoundSumcheck { index: shape.round_index, - kind: RoundSlot::Sumcheck, }, )?; let code_switch = code_switch_solver::solve(spec, source, target, t_ood, Some(info), shape.round_index)?; - Ok(RoundConfig { - round_index: shape.round_index, + Ok(RoundConfig::new( + shape.round_index, sumcheck, code_switch, - mode: RoundMode::ZeroKnowledge { + RoundMode::ZeroKnowledge { t_ood: OodSampleBudget::new(t_ood), - mask_oracle: info, + mask_oracle, }, - mask_oracle: Some(mask_oracle), - }) + )) } /// Local `t_ood ↔ r` fixed-point. `r = source.mask_length()` is a step function @@ -257,18 +239,16 @@ fn build_zk_round_data( source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(t_ood)); } - Err(DeriveError::PerRoundFixedPointDidNotConverge { + Err(DeriveError::FixedPointDidNotConverge { round_index: shape.round_index, + loop_kind: FixedPointLoop::ZkRound, }) } fn build_round_config( spec: &SecuritySpec, shape: &RoundShape, - mask_oracle: Option, ) -> Result, DeriveError> { - debug_assert!(mask_oracle.is_none(), "ZK path uses build_zk_round_config"); - let src_ctx = round_context(shape); let source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::ZERO); let target: IrsConfig> = @@ -280,20 +260,18 @@ fn build_round_config( &src_ctx, &source, None, - PowSlot::Round { + Pow::RoundSumcheck { index: shape.round_index, - kind: RoundSlot::Sumcheck, }, )?; let code_switch = code_switch_solver::solve(spec, source, target, t_ood, None, shape.round_index)?; - Ok(RoundConfig { - round_index: shape.round_index, + Ok(RoundConfig::new( + shape.round_index, sumcheck, code_switch, - mode: RoundMode::Standard, - mask_oracle: None, - }) + RoundMode::Standard, + )) } /// `ℓ_zk = next_pow2(r + t_ood)`: Theorem 9.6 witness layout `0^{ℓ_zk − r}` @@ -321,9 +299,13 @@ pub(super) fn compute_t_ood( let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); let message_length = source.message_length(); + // `Johnson` force-computes OOD samples even when `spec.decoding_regime` + // is `Unique`. `compute_t_ood` is called from the ZK fixed-point loop, + // which needs the count for sizing the mask; bypassing the early-return + // in `num_ood_samples` is intentional. let solve_for_degree = |degree: usize| { irs_commit::num_ood_samples( - false, + DecodingRegime::Johnson, security_target, field_bits, combined_list_size, @@ -348,7 +330,10 @@ pub(super) fn compute_t_ood( } t_ood = new_t_ood; } - Err(DeriveError::TOodFixedPointDidNotConverge { round_index }) + Err(DeriveError::FixedPointDidNotConverge { + round_index, + loop_kind: FixedPointLoop::TOod, + }) } #[cfg(test)] @@ -360,7 +345,6 @@ mod tests { use crate::{ hash, protocols::params::{ - bounds::SoundnessBounded, spec::{DecodingRegime, FoldingFactor, PowBudget}, test_utils::{assert_close, assert_pow_closes_gap, TestEmbedding}, }, @@ -514,8 +498,8 @@ mod tests { let spec = test_spec(Mode::Standard); let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; let plan = ProtocolConfig::::derive(spec, tuning_with(vector_size)).unwrap(); - assert!(plan.rounds.is_empty()); - assert_eq!(plan.basecase.commit.vector_size, vector_size); + assert!(plan.rounds().is_empty()); + assert_eq!(plan.basecase().commit.vector_size, vector_size); } /// ZK with zero WHIR rounds = ZK basecase only. Per-round mask oracles are @@ -528,9 +512,9 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), ) .unwrap(); - assert!(plan.rounds.is_empty()); + assert!(plan.rounds().is_empty()); assert!(matches!( - plan.basecase.mode, + plan.basecase().mode, crate::protocols::basecase::BasecaseMode::ZeroKnowledge )); } @@ -544,8 +528,8 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { + for r in plan.rounds() { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { panic!("expected ZK round") }; assert!(t_ood.get() >= 1); @@ -563,11 +547,11 @@ mod tests { let bits: f64 = plan.analytic_bits().into(); assert!(bits.is_finite() && bits > 0.0, "bits = {bits}"); let min_round = plan - .rounds + .rounds() .iter() .map(|r| f64::from(r.analytic_bits())) .fold(f64::INFINITY, f64::min); - let expected = min_round.min(f64::from(plan.basecase.analytic_bits())); + let expected = min_round.min(f64::from(plan.basecase().analytic_bits())); assert_close(bits, expected); } @@ -581,26 +565,22 @@ mod tests { .unwrap(); let plan_bits: f64 = plan.analytic_bits().into(); let mo_floor = plan - .rounds + .rounds() .iter() - .filter_map(|r| { - r.mask_oracle - .as_ref() - .map(|mo| f64::from(mo.analytic_bits())) - }) + .filter_map(|r| r.mask_oracle().map(|mo| f64::from(mo.analytic_bits()))) .fold(f64::INFINITY, f64::min); assert!( mo_floor.is_finite(), "ZK plan must contribute mask-oracle bits" ); let min_round = plan - .rounds + .rounds() .iter() .map(|r| f64::from(r.analytic_bits())) .fold(f64::INFINITY, f64::min); let expected = mo_floor .min(min_round) - .min(f64::from(plan.basecase.analytic_bits())); + .min(f64::from(plan.basecase().analytic_bits())); assert_close(plan_bits, expected); } @@ -613,12 +593,12 @@ mod tests { ) .unwrap(); assert!(matches!( - plan.basecase.mode, + plan.basecase().mode, crate::protocols::basecase::BasecaseMode::ZeroKnowledge )); - assert_eq!(plan.basecase.commit.interleaving_depth, 1); + assert_eq!(plan.basecase().commit.interleaving_depth, 1); // Sumcheck folds basecase to size 1. - assert_eq!(plan.basecase.sumcheck.final_size(), 1); + assert_eq!(plan.basecase().sumcheck.final_size(), 1); } /// Matches `proof_of_work::threshold`'s 60-bit cap. @@ -640,8 +620,8 @@ mod tests { .unwrap(); let field_bits = ::field_size_bits(); let mut expected_total = 0.0_f64; - for r in &plan.rounds { - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { + for r in plan.rounds() { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { panic!("expected ZK round"); }; let t = t_ood.get() as f64; @@ -708,10 +688,44 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - plan.basecase.pow = PowConfig::from_difficulty(Bits::new(OVER_BUDGET_INJECTED_BITS)); + plan.override_basecase_pow_for_test(PowConfig::from_difficulty(Bits::new( + OVER_BUDGET_INJECTED_BITS, + ))); assert!(!plan.check_pow_bits()); } + /// `validate_round_chaining` trips when the basecase no longer chains + /// to the (new) last round after the tail is dropped. Multi-round plan + /// is required so dropping the last leaves at least one round behind. + #[test] + fn validate_round_chaining_detects_basecase_mismatch() { + let spec = test_spec(Mode::ZeroKnowledge); + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + let n = plan.rounds().len(); + assert!(n >= 2, "need ≥ 2 rounds to break the chain by truncation"); + assert!(plan.check_all_invariants(), "fresh plan must validate"); + + plan.truncate_rounds_for_test(n - 1); + let err = plan + .validate_round_chaining() + .expect_err("truncated tail breaks basecase chaining"); + assert!( + matches!( + err, + DeriveError::RoundChainBroken { + to: crate::protocols::params::error::ChainTarget::Basecase, + .. + } + ), + "got {err:?}", + ); + assert!(!plan.check_all_invariants()); + } + /// `derive()` reports `PowUngrindable` when the spec demands a per-slot /// difficulty above the grind cap. `target_security_bits = 200` against /// `analytic ≈ 64` on `Field64` gives `required ≈ 136` ≫ 60. @@ -777,13 +791,12 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), ) .unwrap(); - assert!(plan.rounds.is_empty()); - assert!(plan.basecase.commit.unique_decoding()); + assert!(plan.rounds().is_empty()); + assert!(plan.basecase().commit.unique_decoding()); } /// Same threading check under ZK mode. Basecase-only avoids the per-round - /// code-switch (which still requires `t_ood ≥ 1` until Stage 2 of the - /// regime work lands). + /// code-switch (which requires `t_ood ≥ 1`). #[test] fn derive_threads_unique_decoding_zk() { let spec = SecuritySpec { @@ -798,8 +811,8 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), ) .unwrap(); - assert!(plan.rounds.is_empty()); - assert!(plan.basecase.commit.unique_decoding()); + assert!(plan.rounds().is_empty()); + assert!(plan.basecase().commit.unique_decoding()); } /// `analytic_error + pow ≥ target` for every PoW slot in the plan. @@ -807,48 +820,47 @@ mod tests { spec: &SecuritySpec, plan: &ProtocolConfig, ) { - for r in &plan.rounds { - let mask_info = r.mode.mask_oracle(); + for r in plan.rounds() { + let mask_info = r.mask_oracle_info(); + let cs = r.code_switch(); assert_pow_closes_gap( spec, - sumcheck_solver::analytic_error_bits(&r.code_switch.source, mask_info), - &r.sumcheck.round_pow, + sumcheck_solver::analytic_error_bits(&cs.source, mask_info), + &r.sumcheck().round_pow, ); assert_pow_closes_gap( spec, code_switch_solver::analytic_error_bits( - &r.code_switch.source, - &r.code_switch.target, - r.code_switch.out_domain_samples, + &cs.source, + &cs.target, + cs.out_domain_samples, mask_info, ), - &r.code_switch.pow, + &cs.pow, ); - if let Some(mo) = &r.mask_oracle { + if let Some(mo) = r.mask_oracle() { + let mp = mo.mask_proximity(); assert_pow_closes_gap( spec, - mask_proximity_solver::analytic_error_bits( - &mo.mask_proximity.c_zk_commit, - mo.mask_proximity.num_masks, - ), - &mo.mask_proximity.pow, + mask_proximity_solver::analytic_error_bits(&mp.c_zk_commit, mp.num_masks), + &mp.pow, ); } } assert_pow_closes_gap( spec, - sumcheck_solver::analytic_error_bits(&plan.basecase.commit, None), - &plan.basecase.sumcheck.round_pow, + sumcheck_solver::analytic_error_bits(&plan.basecase().commit, None), + &plan.basecase().sumcheck.round_pow, ); // γ-slot is ZK-only. if matches!( - plan.basecase.mode, + plan.basecase().mode, crate::protocols::basecase::BasecaseMode::ZeroKnowledge ) { assert_pow_closes_gap( spec, - basecase_solver::analytic_error_bits(&plan.basecase.commit), - &plan.basecase.pow, + basecase_solver::analytic_error_bits(&plan.basecase().commit), + &plan.basecase().pow, ); } } @@ -881,15 +893,15 @@ mod tests { fn derive_standard_succeeds_over_tunings(tuning in arb_tuning()) { let spec = test_spec(Mode::Standard); let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); - for r in &plan.rounds { - prop_assert!(matches!(r.mode, RoundMode::Standard)); - prop_assert!(r.mask_oracle.is_none()); + for r in plan.rounds() { + prop_assert!(matches!(r.mode(), RoundMode::Standard)); + prop_assert!(r.mask_oracle().is_none()); } prop_assert!(matches!( - plan.basecase.mode, + plan.basecase().mode, crate::protocols::basecase::BasecaseMode::Standard )); - prop_assert_eq!(plan.basecase.commit.interleaving_depth, 1); + prop_assert_eq!(plan.basecase().commit.interleaving_depth, 1); } /// ZK mode: each round has its own mask oracle sized for `k + 1` @@ -902,24 +914,24 @@ mod tests { let spec = test_spec(Mode::ZeroKnowledge); let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); - for r in &plan.rounds { + for r in plan.rounds() { let mask_oracle = r - .mask_oracle - .as_ref() + .mask_oracle() .expect("ZK round must have a mask oracle"); - let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode else { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { panic!("expected ZK round"); }; - let k = r.code_switch.source.interleaving_depth.trailing_zeros() as usize; + let cs = r.code_switch(); + let k = cs.source.interleaving_depth.trailing_zeros() as usize; let num_masks = k + 1; - prop_assert_eq!(mask_oracle.c_zk.num_vectors, 2 * num_masks); - prop_assert_eq!(mask_oracle.mask_proximity.num_masks, num_masks); + prop_assert_eq!(mask_oracle.c_zk().num_vectors, 2 * num_masks); + prop_assert_eq!(mask_oracle.mask_proximity().num_masks, num_masks); // Theorem 9.6 / Lemma 9.3: ℓ_zk ≥ r + t_ood for this round. - let source_mask = r.code_switch.source.mask_length(); - prop_assert!(mask_oracle.l_zk.get() >= source_mask + t_ood.get()); + let source_mask = cs.source.mask_length(); + prop_assert!(mask_oracle.l_zk().get() >= source_mask + t_ood.get()); } prop_assert!(matches!( - plan.basecase.mode, + plan.basecase().mode, crate::protocols::basecase::BasecaseMode::ZeroKnowledge )); } diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index 0e9b4bbb..ac68bc5a 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -6,93 +6,148 @@ //! wraps the former via [`DeriveError::PowUngrindable::source`] so callers can //! walk the `std::error::Error::source()` chain. +use std::fmt::{self, Display, Formatter}; + use thiserror::Error; use crate::{bits::Bits, protocols::proof_of_work::PowError}; -/// Coordinate of a PoW slot in the derived protocol. Two axes: where the -/// slot lives (basecase vs. a numbered round) and which sub-protocol owns it. -/// Only valid combinations are representable. +/// Identifies a single PoW grind in the derived protocol — basecase +/// sub-protocol or a per-round sub-protocol at a specific round index. Used +/// to label grinding-cap and budget failures. /// -/// `Error` is derived for the Display propagation it provides in -/// [`DeriveError`]'s `#[error("...")]` attributes; this type isn't itself a -/// failure (`source()` is always `None`). -#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] -pub enum PowSlot { - #[error("basecase {0}")] - Basecase(BasecaseSlot), - #[error("round {index} {kind}")] - Round { index: usize, kind: RoundSlot }, +/// Flat by design: each variant is one valid (where, sub-protocol) pair. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Pow { + /// Basecase γ-RLC grind (Lemma 7.4) — ZK mode only. + BasecaseGammaCombination, + /// Basecase sumcheck grind. + BasecaseSumcheck, + /// Per-round sumcheck grind at `index`. + RoundSumcheck { index: usize }, + /// Per-round code-switch grind at `index`. + RoundCodeSwitch { index: usize }, + /// Per-round mask-proximity grind at `index` — ZK mode only. + RoundMaskProximity { index: usize }, +} + +impl Display for Pow { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::BasecaseGammaCombination => f.write_str("basecase γ-combination"), + Self::BasecaseSumcheck => f.write_str("basecase sumcheck"), + Self::RoundSumcheck { index } => write!(f, "round {index} sumcheck"), + Self::RoundCodeSwitch { index } => write!(f, "round {index} code-switch"), + Self::RoundMaskProximity { index } => write!(f, "round {index} mask-proximity"), + } + } +} + +/// Origin side of a [`DeriveError::RoundChainBroken`]: either a numbered round +/// or the pre-round `tuning` shape (for plans with no rounds at all). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChainSource { + Tuning, + Round(usize), +} + +impl Display for ChainSource { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Tuning => f.write_str("tuning"), + Self::Round(i) => write!(f, "round {i}"), + } + } +} + +/// Destination side of a [`DeriveError::RoundChainBroken`]: the next round in +/// sequence, or the basecase. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChainTarget { + NextRound(usize), + Basecase, } -/// Sub-protocols whose PoW lives in the basecase. `GammaCombination` is the -/// Lemma 7.4 γ-RLC slot, present only in ZK mode. -#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] -pub enum BasecaseSlot { - #[error("γ-combination")] - GammaCombination, - #[error("sumcheck")] - Sumcheck, +impl Display for ChainTarget { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::NextRound(i) => write!(f, "round {i}"), + Self::Basecase => f.write_str("basecase"), + } + } } -/// Sub-protocols whose PoW lives in a per-round shape. -#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)] -pub enum RoundSlot { - #[error("sumcheck")] - Sumcheck, - #[error("code-switch")] - CodeSwitch, - #[error("mask-proximity")] - MaskProximity, +/// Which fixed-point loop failed to converge. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FixedPointLoop { + /// `compute_t_ood`'s scalar iteration. + TOod, + /// `build_zk_round_data`'s outer `t_ood ↔ source.mask_length()` iteration. + ZkRound, +} + +impl Display for FixedPointLoop { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::TOod => f.write_str("t_ood"), + Self::ZkRound => f.write_str("ZK per-round"), + } + } } /// Failure modes for [`super::derive::ProtocolConfig::derive`] and the /// sub-protocol solvers it calls. #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum DeriveError { - /// `compute_t_ood` failed to reach a fixed point in `MAX_ITER` iterations. - /// Indicates a pathological spec/tuning combo; should not happen under - /// realistic security targets on supported fields. - #[error("t_ood fixed-point did not converge for round {round_index}")] - TOodFixedPointDidNotConverge { round_index: usize }, - - /// The ZK per-round `t_ood ↔ source.mask_length()` loop failed to reach a - /// fixed point. Same caveat as `TOodFixedPointDidNotConverge`. - #[error("ZK per-round fixed-point did not converge for round {round_index}")] - PerRoundFixedPointDidNotConverge { round_index: usize }, - - /// A PoW slot cannot be ground at the chosen analytic floor — the spec is - /// too tight for any single grind slot to close the gap. - #[error("{slot} cannot be ground: {source}")] + /// A fixed-point loop ran out of iterations. Indicates a pathological + /// spec/tuning combo; should not happen under realistic security targets + /// on supported fields. + #[error("{loop_kind} fixed-point did not converge for round {round_index}")] + FixedPointDidNotConverge { + round_index: usize, + loop_kind: FixedPointLoop, + }, + + /// A PoW grind cannot close the analytic-to-target gap — the spec is too + /// tight for any single grind to reach `target_security_bits`. + #[error("{pow} cannot be ground: {source}")] PowUngrindable { - slot: PowSlot, + pow: Pow, #[source] source: PowError, }, - /// A PoW slot fits the grind cap but exceeds the per-slot budget set by + /// A PoW grind fits the grind cap but exceeds the per-slot budget set by /// [`super::spec::SecuritySpec::pow_budget`]. - #[error("{slot} requires {required} bits, exceeds spec.pow_budget = {max}")] - PowBudgetExceeded { - slot: PowSlot, - required: Bits, - max: Bits, - }, + #[error("{pow} requires {required} bits, exceeds spec.pow_budget = {max}")] + PowBudgetExceeded { pow: Pow, required: Bits, max: Bits }, /// Computed codeword length exceeds the NTT engine's supported order. #[error("codeword length {length} exceeds the NTT engine's supported order")] CodewordExceedsNtt { length: usize }, + + /// Cross-round (or round → basecase) shape chain broken: the next + /// component's source `vector_size` does not match the previous + /// component's target `vector_size`. Surfaced by + /// [`super::protocol_config::ProtocolConfig::validate_round_chaining`]. + #[error("chain broken: {from} → {to} expected vector_size {expected}, found {found}")] + RoundChainBroken { + from: ChainSource, + to: ChainTarget, + expected: usize, + found: usize, + }, } /// Lift `Result` into `Result` by attaching a -/// [`PowSlot`] label. Lets call sites stay single-line — no manual -/// `.map_err(|e| DeriveError::PowUngrindable { slot, source: e })` boilerplate. +/// [`Pow`] label. Lets call sites stay single-line — no manual +/// `.map_err(|e| DeriveError::PowUngrindable { pow, source: e })` boilerplate. pub(crate) trait PowResultExt { - fn at_slot(self, slot: PowSlot) -> Result; + fn at(self, pow: Pow) -> Result; } impl PowResultExt for Result { - fn at_slot(self, slot: PowSlot) -> Result { - self.map_err(|source| DeriveError::PowUngrindable { slot, source }) + fn at(self, pow: Pow) -> Result { + self.map_err(|source| DeriveError::PowUngrindable { pow, source }) } } diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 7e68683e..83756957 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -26,14 +26,13 @@ pub fn solve( let security_target = f64::from(spec.protocol_security_target_bits()); let rate = rate(f64::from(ctx.log_inv_rate)); let interleaving_depth = 1_usize << ctx.folding_factor; - let unique_decoding = spec.decoding_regime.unique_decoding(); let mode = match spec.mode { Mode::Standard => IrsMode::Standard, Mode::ZeroKnowledge => { // Lemma 9.5 (part ii): r-query perfect-ZK encoding requires // `r ≥ in-domain + OOD`. Use the tight bound; do not pow2-pad here. - let mask_length = num_in_domain_queries(unique_decoding, security_target, rate) + let mask_length = num_in_domain_queries(spec.decoding_regime, security_target, rate) .checked_add(out_domain_samples.get()) .expect("usize overflow"); IrsMode::ZeroKnowledge { mask_length } @@ -42,7 +41,7 @@ pub fn solve( IrsConfig::new( security_target, - unique_decoding, + spec.decoding_regime, spec.hash_id, 1, // one vector committed per round ctx.vector_size, @@ -75,13 +74,12 @@ pub fn solve_mask_code( "num_vectors ({num_vectors}) must be even (mask-proximity original/fresh pairs)", ); - let spec = spec.get(); let security_target = f64::from(spec.protocol_security_target_bits()); let rate = rate(f64::from(log_inv_rate.get())); IrsConfig::new( security_target, - spec.decoding_regime.unique_decoding(), + spec.decoding_regime, spec.hash_id, num_vectors, l_zk, @@ -114,7 +112,13 @@ mod tests { fn solve_mask_code_rejects_non_pow2_l_zk() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); let zk_spec = ZkSpec::try_new(&spec).unwrap(); - let _ = solve_mask_code::(zk_spec, MaskCodeMessageLen::new(3), 0, LogInvRate::new(1), 2); + let _ = solve_mask_code::( + zk_spec, + MaskCodeMessageLen::new(3), + 0, + LogInvRate::new(1), + 2, + ); } #[test] @@ -122,7 +126,13 @@ mod tests { fn solve_mask_code_rejects_l_zk_below_source_mask_length() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); let zk_spec = ZkSpec::try_new(&spec).unwrap(); - let _ = solve_mask_code::(zk_spec, MaskCodeMessageLen::new(2), 4, LogInvRate::new(1), 2); + let _ = solve_mask_code::( + zk_spec, + MaskCodeMessageLen::new(2), + 4, + LogInvRate::new(1), + 2, + ); } #[test] @@ -130,7 +140,13 @@ mod tests { fn solve_mask_code_rejects_odd_num_vectors() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); let zk_spec = ZkSpec::try_new(&spec).unwrap(); - let _ = solve_mask_code::(zk_spec, MaskCodeMessageLen::new(2), 0, LogInvRate::new(1), 3); + let _ = solve_mask_code::( + zk_spec, + MaskCodeMessageLen::new(2), + 0, + LogInvRate::new(1), + 3, + ); } /// `irs_commit::solve` doesn't grind PoW, so this range can sit higher than diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index 068c9e2e..b6481e5e 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -10,8 +10,8 @@ use crate::{ irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, params::{ - bounds::{usize_to_f64, SoundnessBounded}, - error::{DeriveError, PowResultExt, PowSlot, RoundSlot}, + bounds::usize_to_f64, + error::{DeriveError, Pow, PowResultExt}, spec::SecuritySpec, }, proof_of_work::Config as PowConfig, @@ -28,10 +28,8 @@ pub fn solve( ) -> Result, DeriveError> { let target_bits = Bits::new(f64::from(spec.target_security_bits)); let analytic = analytic_error_bits(&c_zk, num_masks); - let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id).at_slot(PowSlot::Round { - index: round_index, - kind: RoundSlot::MaskProximity, - })?; + let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id) + .at(Pow::RoundMaskProximity { index: round_index })?; Ok(MaskProximityConfig::new(c_zk, num_masks, pow)) } @@ -47,8 +45,9 @@ pub fn analytic_error_bits(c_zk: &IrsConfig>, num_masks: u Bits::new((field_bits - log_combined).max(0.0)) } -impl SoundnessBounded for MaskProximityConfig { - fn analytic_bits(&self) -> Bits { +impl MaskProximityConfig { + /// Analytic soundness bits (excluding PoW) for the Lemma 7.4 γ-combination. + pub fn analytic_bits(&self) -> Bits { analytic_error_bits(&self.c_zk_commit, self.num_masks) } } @@ -65,7 +64,7 @@ mod tests { protocols::{ irs_commit::IrsMode, params::{ - spec::Mode, + spec::{DecodingRegime, Mode}, test_utils::{ arb_zk_spec, assert_close, assert_pow_closes_gap, build_test_c_zk, deterministic_spec, TEST_TARGET_RANGE, @@ -157,7 +156,6 @@ mod tests { // All values except `NON_UNIT_INTERLEAVING_DEPTH` are chosen to satisfy // `Config::new`'s divisibility/pow2 constraints. const SECURITY_TARGET_BITS: f64 = 80.0; - const UNIQUE_DECODING: bool = false; const NUM_VECTORS: usize = 2; const VECTOR_SIZE: usize = 8; const NON_UNIT_INTERLEAVING_DEPTH: usize = 2; @@ -167,7 +165,7 @@ mod tests { let spec = deterministic_spec(Mode::ZeroKnowledge); let c_zk = IrsConfig::>::new( SECURITY_TARGET_BITS, - UNIQUE_DECODING, + DecodingRegime::Johnson, hash::BLAKE3, NUM_VECTORS, VECTOR_SIZE, diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index d5a9f530..d257df72 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -16,17 +16,18 @@ pub mod error; pub(crate) mod irs_commit; pub(crate) mod mask_proximity; pub mod protocol_config; +pub(crate) mod regime; pub mod spec; pub(crate) mod sumcheck; #[cfg(test)] pub(crate) mod test_utils; -pub use error::{BasecaseSlot, DeriveError, PowSlot, RoundSlot}; +pub use error::{ChainSource, ChainTarget, DeriveError, FixedPointLoop, Pow}; pub use protocol_config::{ MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, }; pub use spec::{ - DecodingRegime, FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, - OodSampleBudget, PowBudget, RoundContext, SecuritySpec, TuningSpec, ZkSpec, + DecodingRegime, FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + PowBudget, RoundContext, SecuritySpec, TuningSpec, ZkSpec, }; diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 5ec5c9b5..9d299fad 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -4,6 +4,11 @@ //! `2·(k+1)` columns — `k` sumcheck masks + 1 code-switch `(r ‖ s)` mask, all //! doubled by Construction 7.2's originals + fresh pairs) plus a per-round //! mask-proximity check. Standard rounds carry no mask oracle. +//! +//! The post-construction structures (`ProtocolConfig`, `RoundConfig`, +//! `MaskOracleConfig`) expose only read accessors externally — invariants +//! validated by [`ProtocolConfig::validate`] survive past the call site +//! because there is no public mutation surface. use ark_ff::Field; @@ -19,9 +24,9 @@ use crate::{ irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, params::{ - bounds::{usize_to_f64, SoundnessBounded}, + bounds::usize_to_f64, code_switch as code_switch_solver, - error::{BasecaseSlot, DeriveError, PowSlot, RoundSlot}, + error::{ChainSource, ChainTarget, DeriveError, Pow}, spec::{ListSize, MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, sumcheck as sumcheck_solver, }, @@ -32,71 +37,161 @@ use crate::{ #[derive(Clone, Debug)] pub struct ProtocolConfig { - pub security: SecuritySpec, - pub tuning: TuningSpec, - pub rounds: Vec>, - pub basecase: BasecaseConfig, + security: SecuritySpec, + tuning: TuningSpec, + rounds: Vec>, + basecase: BasecaseConfig, } impl ProtocolConfig { - /// Returns `true` if every PoW slot's difficulty fits within - /// `security.pow_budget`. Boolean predicate kept for callers that want - /// to re-check after manual inspection; [`Self::validate_pow_budget`] is - /// the typed version used internally by [`super::derive::ProtocolConfig::derive`]. + pub(crate) const fn new( + security: SecuritySpec, + tuning: TuningSpec, + rounds: Vec>, + basecase: BasecaseConfig, + ) -> Self { + Self { + security, + tuning, + rounds, + basecase, + } + } + + pub const fn security(&self) -> &SecuritySpec { + &self.security + } + + pub const fn tuning(&self) -> &TuningSpec { + &self.tuning + } + + pub fn rounds(&self) -> &[RoundConfig] { + &self.rounds + } + + pub const fn basecase(&self) -> &BasecaseConfig { + &self.basecase + } + + /// `#[cfg(test)]` escape hatch: lets the negative test in + /// `derive::tests` inject an over-budget basecase PoW slot so that + /// `validate_pow_budget` can be exercised on a corrupted plan. + /// Not for production use — there is no equivalent on the public API. + #[cfg(test)] + pub(crate) const fn override_basecase_pow_for_test(&mut self, pow: PowConfig) { + self.basecase.pow = pow; + } + + /// `#[cfg(test)]` escape hatch: lets chain-broken tests in + /// `derive::tests` drop the tail of `rounds` so the basecase's chained + /// `vector_size` no longer matches the new last round. + #[cfg(test)] + pub(crate) fn truncate_rounds_for_test(&mut self, len: usize) { + self.rounds.truncate(len); + } + + /// `true` if every PoW slot's difficulty fits within `security.pow_budget`. + /// Boolean form of [`Self::validate_pow_budget`]. pub fn check_pow_bits(&self) -> bool { self.validate_pow_budget().is_ok() } - /// Same check as [`Self::check_pow_bits`] but returns the specific slot - /// and required-vs-max difficulties on failure. Auto-invoked by - /// `derive()`; callers don't normally need to call this directly. + /// Returns `true` if every post-construction invariant holds: PoW + /// budget, mask-oracle coherence, and cross-round shape chaining. + pub fn check_all_invariants(&self) -> bool { + self.validate().is_ok() + } + + /// Run every post-construction invariant check. Auto-invoked by + /// `derive()`; callers only need this after manual inspection (and only + /// then through the `pub(crate)` test shim, since fields are private). + /// + /// Mask-oracle coherence is *not* a separate check: the per-round + /// `mask_oracle` lives inside `RoundMode::ZeroKnowledge`, so its + /// presence ↔ ZK-ness equivalence is enforced by the type system. + pub fn validate(&self) -> Result<(), DeriveError> { + self.validate_pow_budget()?; + self.validate_round_chaining()?; + Ok(()) + } + + /// PoW slot difficulty ≤ `security.pow_budget` for every slot. Auto-invoked + /// by `derive()` via [`Self::validate`]. pub fn validate_pow_budget(&self) -> Result<(), DeriveError> { let max = Bits::new(f64::from(self.security.pow_budget.bits())); - let check = |slot: PowSlot, pow: &PowConfig| -> Result<(), DeriveError> { - let required = pow.difficulty(); + let check = |pow: Pow, cfg: &PowConfig| -> Result<(), DeriveError> { + let required = cfg.difficulty(); if required > max { - Err(DeriveError::PowBudgetExceeded { - slot, - required, - max, - }) + Err(DeriveError::PowBudgetExceeded { pow, required, max }) } else { Ok(()) } }; for r in &self.rounds { check( - PowSlot::Round { + Pow::RoundSumcheck { index: r.round_index, - kind: RoundSlot::Sumcheck, }, &r.sumcheck.round_pow, )?; check( - PowSlot::Round { + Pow::RoundCodeSwitch { index: r.round_index, - kind: RoundSlot::CodeSwitch, }, &r.code_switch.pow, )?; - if let Some(mo) = &r.mask_oracle { + if let Some(mo) = r.mask_oracle() { check( - PowSlot::Round { + Pow::RoundMaskProximity { index: r.round_index, - kind: RoundSlot::MaskProximity, }, &mo.mask_proximity.pow, )?; } } - check( - PowSlot::Basecase(BasecaseSlot::Sumcheck), - &self.basecase.sumcheck.round_pow, - )?; - check( - PowSlot::Basecase(BasecaseSlot::GammaCombination), - &self.basecase.pow, - )?; + check(Pow::BasecaseSumcheck, &self.basecase.sumcheck.round_pow)?; + check(Pow::BasecaseGammaCombination, &self.basecase.pow)?; + Ok(()) + } + + /// Cross-round shape chaining: + /// - adjacent rounds: `round[i+1].source.vector_size == round[i].target.vector_size` + /// - last round → basecase: `basecase.commit.vector_size == last.target.vector_size` + /// - no rounds: `basecase.commit.vector_size == tuning.vector_size` + pub fn validate_round_chaining(&self) -> Result<(), DeriveError> { + for window in self.rounds.windows(2) { + let prev = &window[0]; + let next = &window[1]; + let expected = prev.code_switch.target.vector_size; + let found = next.code_switch.source.vector_size; + if expected != found { + return Err(DeriveError::RoundChainBroken { + from: ChainSource::Round(prev.round_index), + to: ChainTarget::NextRound(next.round_index), + expected, + found, + }); + } + } + + let basecase_vector_size = self.basecase.commit.vector_size; + let expected = self.rounds.last().map_or(self.tuning.vector_size, |last| { + last.code_switch.target.vector_size + }); + if expected != basecase_vector_size { + let from = self + .rounds + .last() + .map_or(ChainSource::Tuning, |r| ChainSource::Round(r.round_index)); + return Err(DeriveError::RoundChainBroken { + from, + to: ChainTarget::Basecase, + expected, + found: basecase_vector_size, + }); + } + Ok(()) } @@ -108,7 +203,7 @@ impl ProtocolConfig { let field_bits = ::field_size_bits(); let mut total_error = 0.0_f64; for r in &self.rounds { - if let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode { + if let RoundMode::ZeroKnowledge { t_ood, .. } = &r.mode { let t = usize_to_f64(t_ood.get()); // ζ_ze ≤ (t_ood² + t_ood) / (2|F|). Compute in log space to // stay numerically stable for large field_bits. @@ -123,8 +218,10 @@ impl ProtocolConfig { } } -impl SoundnessBounded for ProtocolConfig { - fn analytic_bits(&self) -> Bits { +impl ProtocolConfig { + /// Analytic soundness bits (excluding PoW): minimum over basecase and + /// every round. + pub fn analytic_bits(&self) -> Bits { let mut min_bits = f64::from(self.basecase.analytic_bits()); for round in &self.rounds { min_bits = min_bits.min(f64::from(round.analytic_bits())); @@ -135,61 +232,112 @@ impl SoundnessBounded for ProtocolConfig { #[derive(Clone, Debug)] pub struct RoundConfig { - pub round_index: usize, - pub sumcheck: SumcheckConfig, - pub code_switch: CodeSwitchConfig, - pub mode: RoundMode, - /// `Some` iff this is a ZK round. Sized for this round's `k + 1` masks - /// (k sumcheck + 1 code-switch). - pub mask_oracle: Option>, + round_index: usize, + sumcheck: SumcheckConfig, + code_switch: CodeSwitchConfig, + /// Standard vs. ZK — and in ZK mode, owns the round's full mask oracle + /// directly. No separate `mask_oracle` field on `RoundConfig`: the + /// variant tag is the single source of truth for both ZK-ness and the + /// oracle's presence/contents. + mode: RoundMode, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub enum RoundMode { +impl RoundConfig { + pub(crate) const fn new( + round_index: usize, + sumcheck: SumcheckConfig, + code_switch: CodeSwitchConfig, + mode: RoundMode, + ) -> Self { + Self { + round_index, + sumcheck, + code_switch, + mode, + } + } + + pub const fn round_index(&self) -> usize { + self.round_index + } + + pub const fn sumcheck(&self) -> &SumcheckConfig { + &self.sumcheck + } + + pub const fn code_switch(&self) -> &CodeSwitchConfig { + &self.code_switch + } + + pub const fn mode(&self) -> &RoundMode { + &self.mode + } + + /// Convenience: borrow the round's mask oracle if this is a ZK round. + /// Equivalent to pattern-matching on `mode()`. + pub const fn mask_oracle(&self) -> Option<&MaskOracleConfig> { + match &self.mode { + RoundMode::Standard => None, + RoundMode::ZeroKnowledge { mask_oracle, .. } => Some(mask_oracle), + } + } + + /// Slim mask-oracle view derived from `mask_oracle()`. Produced on + /// demand — there is no stored copy. + pub fn mask_oracle_info(&self) -> Option { + self.mask_oracle().map(MaskOracleConfig::info) + } +} + +/// Standard vs. ZK round. ZK variant carries the full per-round mask oracle +/// — there is no longer a separate `MaskOracleInfo` slim view stored +/// alongside it (which would duplicate `mask_oracle.info()`). +/// +/// Not `Copy`: `MaskOracleConfig` owns a `MaskProximityConfig` and an +/// `IrsConfig`, neither of which is `Copy`. +/// +/// `large_enum_variant` allowed: the ZK variant carries `MaskOracleConfig` +/// (~330B) while `Standard` is 0B, but a proof holds O(rounds) RoundModes +/// (single-digit count) so the absolute overhead is a few KB. Boxing the +/// payload would add per-access indirection without measurable savings. +#[derive(Clone, Debug)] +#[allow(clippy::large_enum_variant)] +pub enum RoundMode { Standard, ZeroKnowledge { /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). t_ood: OodSampleBudget, - /// Slim view of this round's [`MaskOracleConfig`] (C_zk's list size + - /// ℓ_zk) — denormalized so soundness routines can read it without - /// chasing through `mask_oracle`. - mask_oracle: MaskOracleInfo, + /// Per-round mask oracle: C_zk codeword (sized for `2·(k+1)` + /// columns) + ℓ_zk + mask-proximity check for `k+1` masks. + mask_oracle: MaskOracleConfig, }, } -impl RoundMode { +impl RoundMode { pub const fn is_zk(&self) -> bool { matches!(self, Self::ZeroKnowledge { .. }) } - - pub const fn mask_oracle(&self) -> Option { - match self { - Self::Standard => None, - Self::ZeroKnowledge { mask_oracle, .. } => Some(*mask_oracle), - } - } } -impl SoundnessBounded for RoundConfig { +impl RoundConfig { /// Round-level analytic floor: the smallest of `sumcheck`, `code_switch`, /// and (when present) the per-round mask-oracle proximity check. Folding /// the mask-oracle term in here keeps `ProtocolConfig::analytic_bits` /// a pure `min` over rounds + basecase. - fn analytic_bits(&self) -> Bits { + pub fn analytic_bits(&self) -> Bits { let source = &self.code_switch.source; let target = &self.code_switch.target; - let mask_oracle = self.mode.mask_oracle(); + let mask_info = self.mask_oracle_info(); - let sumcheck_term = f64::from(sumcheck_solver::analytic_error_bits(source, mask_oracle)); + let sumcheck_term = f64::from(sumcheck_solver::analytic_error_bits(source, mask_info)); let code_switch_term = f64::from(code_switch_solver::analytic_error_bits( source, target, self.code_switch.out_domain_samples, - mask_oracle, + mask_info, )); let mask_oracle_term = self - .mask_oracle - .as_ref() + .mask_oracle() .map_or(f64::INFINITY, |mo| f64::from(mo.analytic_bits())); Bits::new( @@ -206,21 +354,38 @@ impl SoundnessBounded for RoundConfig { #[derive(Clone, Debug)] pub struct MaskOracleConfig { /// `num_vectors = 2 · (k + 1)` (Construction 7.2: originals + fresh). - pub c_zk: IrsConfig>, + c_zk: IrsConfig>, /// `next_pow2(r + t_ood)` for this round: Theorem 9.6 witness layout /// (`0^{ℓ_zk − r}` padding) + Lemma 9.3 `(ℓ_zk − r, 0)`-privacy precondition. - pub l_zk: MaskCodeMessageLen, - pub mask_proximity: MaskProximityConfig, -} - -/// Slim mask-oracle view (C_zk's list size + ℓ_zk). -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct MaskOracleInfo { - pub c_zk_list_size: ListSize, - pub l_zk: MaskCodeMessageLen, + l_zk: MaskCodeMessageLen, + mask_proximity: MaskProximityConfig, } impl MaskOracleConfig { + pub(crate) const fn new( + c_zk: IrsConfig>, + l_zk: MaskCodeMessageLen, + mask_proximity: MaskProximityConfig, + ) -> Self { + Self { + c_zk, + l_zk, + mask_proximity, + } + } + + pub const fn c_zk(&self) -> &IrsConfig> { + &self.c_zk + } + + pub const fn l_zk(&self) -> MaskCodeMessageLen { + self.l_zk + } + + pub const fn mask_proximity(&self) -> &MaskProximityConfig { + &self.mask_proximity + } + pub fn info(&self) -> MaskOracleInfo { MaskOracleInfo { c_zk_list_size: ListSize::new(self.c_zk.list_size()), @@ -229,8 +394,21 @@ impl MaskOracleConfig { } } -impl SoundnessBounded for MaskOracleConfig { - fn analytic_bits(&self) -> Bits { +/// Slim mask-oracle view (C_zk's list size + ℓ_zk). +/// +/// Reached only through `RoundMode::ZeroKnowledge`'s field, which is itself +/// accessible only via `RoundConfig::mode() -> &RoundMode`. The public +/// surface therefore stays read-only even though the variant fields are +/// nominally `pub`. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct MaskOracleInfo { + pub c_zk_list_size: ListSize, + pub l_zk: MaskCodeMessageLen, +} + +impl MaskOracleConfig { + /// Analytic soundness bits (excluding PoW) for this round's mask oracle. + pub fn analytic_bits(&self) -> Bits { self.mask_proximity.analytic_bits() } } diff --git a/src/protocols/params/regime.rs b/src/protocols/params/regime.rs new file mode 100644 index 00000000..f985f061 --- /dev/null +++ b/src/protocols/params/regime.rs @@ -0,0 +1,247 @@ +//! Reed–Solomon decoding regime — materialized per-round parameters and the +//! analytic helpers that depend on them. +//! +//! Spec-level policy lives in [`super::spec::DecodingRegime`] (rate-independent, +//! a user choice). The data-carrying [`DecodingRegimeParams`] is what gets +//! stored on per-round configs once a rate is known: [`Self::from_policy`] +//! is the single materialization point. + +use std::f64::consts::LOG2_10; + +use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; + +use crate::protocols::params::{ + bounds::{rate, usize_to_f64}, + spec::DecodingRegime, +}; + +/// Materialized decoding-regime parameters. +/// +/// `Unique` carries no data; `Johnson { slack }` carries `η`. The two variants +/// are statically distinct — there is no "Johnson with η = 0" representation, +/// so callers can pattern-match without a sentinel-comparison branch. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum DecodingRegimeParams { + Unique, + Johnson { slack: OrderedFloat }, +} + +impl DecodingRegimeParams { + /// Materialize spec policy at a known rate. The canonical Johnson slack + /// (`η = √ρ / 20`) is centralized here — any tuning of `η` lives at this + /// site and propagates to every per-round config. + // TODO: Optimize picking η. + pub fn from_policy(policy: DecodingRegime, rate: f64) -> Self { + match policy { + DecodingRegime::Unique => Self::Unique, + DecodingRegime::Johnson => Self::johnson_canonical(rate), + } + } + + /// Johnson regime with the canonical `η = √ρ / 20` slack. + pub fn johnson_canonical(rate: f64) -> Self { + Self::Johnson { + slack: OrderedFloat(rate.sqrt() / 20.0), + } + } + + pub const fn is_unique(self) -> bool { + matches!(self, Self::Unique) + } + + /// `log₂ |Λ(C, δ)|`. + pub fn list_size_log2(self, log_inv_rate: f64) -> f64 { + match self { + Self::Unique => 0.0, + // Johnson: |Λ| = 1 / (2 η √ρ). + Self::Johnson { slack } => -1.0 - slack.into_inner().log2() + 0.5 * log_inv_rate, + } + } + + /// `|Λ(C, δ)|`. + pub fn list_size(self, log_inv_rate: f64) -> f64 { + 2_f64.powf(self.list_size_log2(log_inv_rate)) + } + + /// `log₂(1 − δ)`. + pub fn one_minus_distance_log2(self, log_inv_rate: f64) -> f64 { + let one_minus_delta = match self { + Self::Unique => f64::midpoint(1.0, rate(log_inv_rate)), + Self::Johnson { slack } => rate(log_inv_rate).sqrt() + slack.into_inner(), + }; + one_minus_delta.log2() + } + + /// `log₂ ε_mca(C, δ)`. + pub fn eps_mca_log2(self, log_inv_rate: f64, message_length: usize, field_bits: f64) -> f64 { + let log_k = usize_to_f64(message_length).log2(); + let error = match self { + Self::Unique => log_k + log_inv_rate, + Self::Johnson { slack } => { + debug_assert!( + slack.into_inner().log2() >= -(0.5 * log_inv_rate + LOG2_10 + 1.0) - 1e-6 + ); + 7.0 * LOG2_10 + 3.5 * log_inv_rate + 2.0 * log_k + } + }; + error - field_bits + } +} + +/// Johnson list size at the canonical `η = √ρ / 20` slack, as a function of +/// `log_inv_rate` only. Used by planners that need a list-size estimate before +/// a target config exists. +pub fn johnson_list_size(log_inv_rate: f64) -> f64 { + DecodingRegimeParams::johnson_canonical(rate(log_inv_rate)).list_size(log_inv_rate) +} + +#[cfg(test)] +#[allow(clippy::float_cmp)] +mod tests { + use super::*; + use crate::protocols::params::test_utils::assert_close; + + /// Tighter tolerance for tests doing relative-error checks against an + /// alternative-derived expected value with the same operations. + const TIGHT_EPS: f64 = 1e-12; + + fn johnson(slack: f64) -> DecodingRegimeParams { + DecodingRegimeParams::Johnson { + slack: OrderedFloat(slack), + } + } + + /// Johnson list size: `|Λ| = 1 / (2η√ρ)`, log₂ form. Hand-evaluated at + /// `log_inv_rate = 2`, `η = 0.1`: `−1 − log₂(0.1) + 1 ≈ 3.3219`. + #[test] + fn list_size_log2_johnson_formula() { + let got = johnson(0.1).list_size_log2(2.0); + let expected = -1.0 - 0.1_f64.log2() + 0.5 * 2.0; + assert_close(got, expected); + } + + /// Unique-decoding regime gives `|Λ| = 1`, i.e. log = 0. + #[test] + fn list_size_log2_unique_decoding_is_zero() { + assert_eq!(DecodingRegimeParams::Unique.list_size_log2(2.0), 0.0); + } + + /// `η = √ρ / 20` substituted into `|Λ| = 1/(2η√ρ)` simplifies to `10/ρ`. + /// So `johnson_list_size(b) = 10 · 2^b`. + #[test] + fn johnson_list_size_closed_form() { + for b in [1.0, 2.0, 3.0, 5.0] { + let got = johnson_list_size(b); + let expected = 10.0 * 2_f64.powf(b); + assert!( + (got - expected).abs() / expected < TIGHT_EPS, + "log_inv_rate={b}: got {got} vs {expected}", + ); + } + } + + /// `johnson_list_size(b)` must match `Config::list_size` once a config is + /// built at the same rate. Keeps the rate-only helper in sync with + /// `irs_commit::Config::new`'s canonical-slack materialization. + #[test] + fn johnson_list_size_matches_config_list_size() { + use crate::{ + algebra::{embedding::Identity, fields::Field64}, + hash, + protocols::irs_commit::{Config, IrsMode}, + }; + const PLACEHOLDER_SECURITY_TARGET_BITS: f64 = 80.0; + const PLACEHOLDER_NUM_VECTORS: usize = 2; + const PLACEHOLDER_VECTOR_SIZE: usize = 8; + const PLACEHOLDER_INTERLEAVING_DEPTH: usize = 1; + const LOG_INV_RATE: u32 = 2; + + let config: Config> = Config::new( + PLACEHOLDER_SECURITY_TARGET_BITS, + DecodingRegime::Johnson, + hash::BLAKE3, + PLACEHOLDER_NUM_VECTORS, + PLACEHOLDER_VECTOR_SIZE, + PLACEHOLDER_INTERLEAVING_DEPTH, + 2_f64.powf(-f64::from(LOG_INV_RATE)), + IrsMode::Standard, + ); + let got = johnson_list_size(f64::from(LOG_INV_RATE)); + let expected = config.list_size(); + assert!( + (got - expected).abs() / expected < TIGHT_EPS, + "regime helper ({got}) vs Config::list_size ({expected})", + ); + } + + /// `1 − δ` in unique-decoding mode: midpoint of 1 and ρ. + #[test] + fn one_minus_distance_log2_unique() { + let log_inv_rate = 2.0; + let got = DecodingRegimeParams::Unique.one_minus_distance_log2(log_inv_rate); + let rho = 2_f64.powf(-log_inv_rate); + let expected = f64::midpoint(1.0, rho).log2(); + assert_close(got, expected); + } + + /// `1 − δ` in Johnson regime: `√ρ + η`. + #[test] + fn one_minus_distance_log2_johnson() { + let log_inv_rate = 2.0; + let eta = 0.1; + let got = johnson(eta).one_minus_distance_log2(log_inv_rate); + let rho = 2_f64.powf(-log_inv_rate); + let expected = (rho.sqrt() + eta).log2(); + assert_close(got, expected); + } + + /// MCA fixture — `message_length = 16 = 2^4` and `log_inv_rate = 2` give + /// exact `log2(k) = 4`. `field_bits = 64.0` for Field64. + const MCA_MESSAGE_LENGTH: usize = 16; + const MCA_LOG_INV_RATE: f64 = 2.0; + const MCA_FIELD_BITS: f64 = 64.0; + + /// MCA error, unique-decoding branch: `log k + log_inv_rate − field_bits`. + #[test] + fn eps_mca_log2_unique_decoding_formula() { + let got = DecodingRegimeParams::Unique.eps_mca_log2( + MCA_LOG_INV_RATE, + MCA_MESSAGE_LENGTH, + MCA_FIELD_BITS, + ); + let expected = (MCA_MESSAGE_LENGTH as f64).log2() + MCA_LOG_INV_RATE - MCA_FIELD_BITS; + assert_close(got, expected); + } + + /// MCA error, Johnson branch: `7·log₂10 + 3.5·log_inv_rate + 2·log k − field_bits`. + #[test] + fn eps_mca_log2_johnson_formula() { + // `η = 0.1` stays within the debug assertion's slack range. + const JOHNSON_SLACK: f64 = 0.1; + + let got = johnson(JOHNSON_SLACK).eps_mca_log2( + MCA_LOG_INV_RATE, + MCA_MESSAGE_LENGTH, + MCA_FIELD_BITS, + ); + let expected = + 7.0 * LOG2_10 + 3.5 * MCA_LOG_INV_RATE + 2.0 * (MCA_MESSAGE_LENGTH as f64).log2() + - MCA_FIELD_BITS; + assert_close(got, expected); + } + + /// `from_policy(Unique, _)` ignores rate; `from_policy(Johnson, rate)` + /// produces the same materialization as `johnson_canonical(rate)`. + #[test] + fn from_policy_matches_canonical() { + assert_eq!( + DecodingRegimeParams::from_policy(DecodingRegime::Unique, 0.25), + DecodingRegimeParams::Unique, + ); + assert_eq!( + DecodingRegimeParams::from_policy(DecodingRegime::Johnson, 0.25), + DecodingRegimeParams::johnson_canonical(0.25), + ); + } +} diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index a753f3f4..94f4fab3 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -1,6 +1,13 @@ -use core::{marker::PhantomData, num::NonZeroU32}; +use core::{ + fmt::{self, Display, Formatter}, + marker::PhantomData, + num::NonZeroU32, + ops::Deref, + str::FromStr, +}; use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; use crate::{bits::Bits, engines::EngineId}; @@ -147,6 +154,10 @@ pub enum Mode { /// Constructed only via [`ZkSpec::try_new`], which performs the mode check /// once at the boundary. ZK-only solvers accept `ZkSpec` to make /// "ZK mode required" a compile-time precondition instead of a runtime assert. +/// +/// `Deref` is implemented so fields and inherent +/// methods are reachable directly (`zk_spec.target_security_bits`). For sites +/// that need to pass `&SecuritySpec` explicitly, use [`Self::as_inner`]. #[derive(Debug, Clone, Copy)] pub struct ZkSpec<'a>(&'a SecuritySpec); @@ -156,7 +167,20 @@ impl<'a> ZkSpec<'a> { matches!(spec.mode, Mode::ZeroKnowledge).then_some(Self(spec)) } - pub const fn get(self) -> &'a SecuritySpec { + /// Explicit unwrap — `&SecuritySpec` with the wrapper's lifetime. + /// + /// Prefer field access through `Deref` for reads; reach for `as_inner` + /// when you specifically need to hand `&SecuritySpec` to a function whose + /// signature is not in deref-coercion position (e.g. trait method + /// dispatch). + pub const fn as_inner(self) -> &'a SecuritySpec { + self.0 + } +} + +impl Deref for ZkSpec<'_> { + type Target = SecuritySpec; + fn deref(&self) -> &SecuritySpec { self.0 } } @@ -174,17 +198,51 @@ impl<'a> ZkSpec<'a> { /// radius. At high security targets or deep folding, `Unique` may exceed /// the grind cap on per-round PoW and [`super::derive::ProtocolConfig::derive`] /// will return `PowUngrindable`. Pick `Johnson` for those cases. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum DecodingRegime { Unique, Johnson, } -impl DecodingRegime { - /// Bridge to [`super::super::irs_commit::Config::new`]'s `unique_decoding` - /// parameter. - pub const fn unique_decoding(self) -> bool { - matches!(self, Self::Unique) +impl Display for DecodingRegime { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Unique => f.write_str("Unique"), + Self::Johnson => f.write_str("Johnson"), + } + } +} + +impl FromStr for DecodingRegime { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "Unique" => Ok(Self::Unique), + "Johnson" => Ok(Self::Johnson), + _ => Err(format!( + "invalid decoding regime: {s}, options are: Unique, Johnson" + )), + } + } +} + +#[cfg(test)] +mod decoding_regime_tests { + use super::*; + + #[test] + fn from_str_round_trips_display() { + for r in [DecodingRegime::Unique, DecodingRegime::Johnson] { + assert_eq!(r.to_string().parse::().unwrap(), r); + } + } + + #[test] + fn from_str_rejects_unknown() { + assert!("johnson".parse::().is_err()); // case-sensitive + assert!("".parse::().is_err()); + assert!("Capacity".parse::().is_err()); } } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 9e563c3f..d0da57cf 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -8,7 +8,7 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::usize_to_f64, - error::{DeriveError, PowResultExt, PowSlot}, + error::{DeriveError, Pow, PowResultExt}, protocol_config::MaskOracleInfo, spec::{RoundContext, SecuritySpec}, }, @@ -18,13 +18,13 @@ use crate::{ }; /// `mask_oracle` is `Some` iff ZK; only C_zk's list size + ℓ_zk are read here. -/// `slot` is the [`PowSlot`] that labels grinding failures (basecase or per-round). +/// `pow` is the [`Pow`] that labels grinding failures (basecase or per-round). pub fn solve( spec: &SecuritySpec, ctx: &RoundContext, source_irs: &IrsConfig, mask_oracle: Option, - slot: PowSlot, + pow: Pow, ) -> Result, DeriveError> { let num_rounds = num_sumcheck_rounds(ctx); let round_pow = PowConfig::grind_to( @@ -32,7 +32,7 @@ pub fn solve( analytic_error_bits(source_irs, mask_oracle), spec.hash_id, ) - .at_slot(slot)?; + .at(pow)?; let mode = match mask_oracle { None => sumcheck::SumcheckMode::Standard, Some(_) => sumcheck::SumcheckMode::ZeroKnowledge { @@ -92,7 +92,6 @@ mod tests { use super::*; use crate::protocols::params::{ - error::RoundSlot, irs_commit as irs_solver, spec::{ListSize, MaskCodeMessageLen, Mode, OodSampleBudget}, test_utils::{ @@ -136,10 +135,7 @@ mod tests { &ctx, &source_irs, mask_oracle, - PowSlot::Round { - index: 0, - kind: RoundSlot::Sumcheck, - }, + Pow::RoundSumcheck { index: 0 }, ) .unwrap(); match config.mode { @@ -219,7 +215,7 @@ mod tests { ) { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle, PowSlot::Round { index: 0, kind: RoundSlot::Sumcheck }).unwrap(); + let config = solve(&spec, &ctx, &source_irs, mask_oracle, Pow::RoundSumcheck { index: 0 }).unwrap(); prop_assert!(matches!(config.mode, sumcheck::SumcheckMode::Standard)); } @@ -233,7 +229,7 @@ mod tests { ) { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle, PowSlot::Round { index: 0, kind: RoundSlot::Sumcheck }).unwrap(); + let config = solve(&spec, &ctx, &source_irs, mask_oracle, Pow::RoundSumcheck { index: 0 }).unwrap(); prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); } @@ -263,7 +259,7 @@ mod tests { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); let error = analytic_error_bits(&source_irs, mask_oracle); - let config = solve(&spec, &ctx, &source_irs, mask_oracle, PowSlot::Round { index: 0, kind: RoundSlot::Sumcheck }).unwrap(); + let config = solve(&spec, &ctx, &source_irs, mask_oracle, Pow::RoundSumcheck { index: 0 }).unwrap(); assert_pow_closes_gap(&spec, error, &config.round_pow); } } @@ -284,10 +280,7 @@ mod tests { &ctx, &source_irs, Some(info), - PowSlot::Round { - index: 0, - kind: RoundSlot::Sumcheck, - }, + Pow::RoundSumcheck { index: 0 }, ) .unwrap(); assert!(matches!( diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 23f2fa34..4ba803be 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -14,13 +14,13 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ - bounds::johnson_list_size, derive::compute_t_ood, irs_commit as irs_solver, protocol_config::MaskOracleInfo, - spec::{DecodingRegime, - ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, - RoundContext, SecuritySpec, ZkSpec, + regime::johnson_list_size, + spec::{ + DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + PowBudget, RoundContext, SecuritySpec, ZkSpec, }, }, proof_of_work::Config as PowConfig, @@ -51,7 +51,7 @@ pub const EPS: f64 = 1e-9; pub fn deterministic_spec(mode: Mode) -> SecuritySpec { SecuritySpec { mode, - decoding_regime: DecodingRegime::Johnson, + decoding_regime: DecodingRegime::Johnson, target_security_bits: FIXTURE_TARGET_BITS, pow_budget: PowBudget::Forbidden, hash_id: hash::BLAKE3, @@ -72,7 +72,7 @@ pub fn arb_spec( ]; (target_range, pow_strategy).prop_map(move |(target, pow_budget)| SecuritySpec { mode, - decoding_regime: DecodingRegime::Johnson, + decoding_regime: DecodingRegime::Johnson, target_security_bits: target, pow_budget, hash_id: hash::BLAKE3, diff --git a/src/protocols/whir/config.rs b/src/protocols/whir/config.rs index 45b138ff..2a2c2f24 100644 --- a/src/protocols/whir/config.rs +++ b/src/protocols/whir/config.rs @@ -55,7 +55,7 @@ impl Config { #[allow(clippy::cast_possible_wrap)] let initial_committer = irs_commit::Config::new( protocol_security_level, - whir_parameters.unique_decoding, + whir_parameters.decoding_regime, whir_parameters.hash_id, whir_parameters.batch_size, size, @@ -64,7 +64,7 @@ impl Config { IrsMode::Standard, ); let initial_out_domain_samples = num_ood_samples( - whir_parameters.unique_decoding, + whir_parameters.decoding_regime, protocol_security_level, field_size_bits, initial_committer.list_size(), @@ -103,7 +103,7 @@ impl Config { #[allow(clippy::cast_possible_wrap)] let irs_committer = irs_commit::Config::new( protocol_security_level, - whir_parameters.unique_decoding, + whir_parameters.decoding_regime, whir_parameters.hash_id, 1, 1 << num_variables, @@ -112,7 +112,7 @@ impl Config { IrsMode::Standard, ); let round_out_domain_samples = num_ood_samples( - whir_parameters.unique_decoding, + whir_parameters.decoding_regime, protocol_security_level, field_size_bits, irs_committer.list_size(), @@ -512,8 +512,6 @@ impl Display for RoundConfig { #[cfg(test)] mod tests { - use ordered_float::OrderedFloat; - use super::*; use crate::{ algebra::{ @@ -522,7 +520,7 @@ mod tests { }, bits::Bits, hash, - protocols::matrix_commit, + protocols::{matrix_commit, params::regime::DecodingRegimeParams}, type_info::Typed, utils::test_serde, }; @@ -534,7 +532,7 @@ mod tests { pow_bits: 20, initial_folding_factor: 4, folding_factor: 4, - unique_decoding: false, + decoding_regime: crate::protocols::params::DecodingRegime::Johnson, starting_log_inv_rate: 1, batch_size: 1, hash_id: hash::BLAKE3, @@ -584,7 +582,7 @@ mod tests { codeword_length: 1 << (10 + 3 - 2), interleaving_depth: 1 << 2, matrix_commit: matrix_commit::Config::::new(0, 0), - johnson_slack: OrderedFloat::default(), + regime: DecodingRegimeParams::Unique, in_domain_samples: 5, deduplicate_in_domain: true, }, @@ -606,7 +604,7 @@ mod tests { codeword_length: 1 << (10 + 4 - 2), interleaving_depth: 1 << 2, matrix_commit: matrix_commit::Config::::new(0, 0), - johnson_slack: OrderedFloat::default(), + regime: DecodingRegimeParams::Unique, in_domain_samples: 6, deduplicate_in_domain: true, }, diff --git a/src/protocols/whir/mod.rs b/src/protocols/whir/mod.rs index d5828018..3f338e92 100644 --- a/src/protocols/whir/mod.rs +++ b/src/protocols/whir/mod.rs @@ -180,6 +180,7 @@ mod tests { }, hash, parameters::ProtocolParameters, + protocols::params::DecodingRegime, transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, utils::test_serde, }; @@ -223,7 +224,7 @@ mod tests { initial_folding_factor: usize, folding_factor: usize, num_points: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, pow_bits: usize, ) { // Number of coefficients in the multilinear polynomial (2^num_variables) @@ -235,7 +236,7 @@ mod tests { pow_bits, initial_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: 1, batch_size: 1, hash_id: hash::SHA2, @@ -321,14 +322,14 @@ mod tests { let num_variables = folding_factor..=3 * folding_factor; for num_variable in num_variables { for num_points in [0, 1, 2] { - for unique_decoding in [true, false] { + for decoding_regime in [DecodingRegime::Unique, DecodingRegime::Johnson] { for pow_bits in [0, 5, 10] { eprintln!(); dbg!( folding_factor, num_variable, num_points, - unique_decoding, + decoding_regime, pow_bits ); @@ -337,7 +338,7 @@ mod tests { folding_factor, folding_factor, num_points, - unique_decoding, + decoding_regime, pow_bits, ); } @@ -349,7 +350,7 @@ mod tests { #[test] fn test_fail() { - make_whir_things(3, 2, 2, 0, false, 0); + make_whir_things(3, 2, 2, 0, DecodingRegime::Johnson, 0); } #[test] @@ -379,7 +380,7 @@ mod tests { initial_folding_factor, folding_factor, num_points, - false, + DecodingRegime::Johnson, 5, ); } @@ -398,7 +399,7 @@ mod tests { folding_factor: usize, num_points_per_poly: usize, num_vectors: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, pow_bits: usize, ) { let num_coeffs = 1 << num_variables; @@ -408,7 +409,7 @@ mod tests { pow_bits, initial_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: 1, batch_size: 1, hash_id: hash::SHA2, @@ -529,7 +530,7 @@ mod tests { folding_factor, num_points_per_poly, num_polys, - false, + DecodingRegime::Johnson, 0, // pow_bits ); } @@ -548,7 +549,8 @@ mod tests { 2, // folding_factor 2, // num_points_per_poly 1, // num_polynomials (single!) - false, 0, + DecodingRegime::Johnson, + 0, ); } @@ -576,7 +578,7 @@ mod tests { pow_bits: 0, initial_folding_factor, folding_factor, - unique_decoding: false, + decoding_regime: DecodingRegime::Johnson, starting_log_inv_rate: 1, batch_size: 1, hash_id: hash::SHA2, @@ -672,7 +674,7 @@ mod tests { num_points_per_poly: usize, num_witnesses: usize, batch_size: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, pow_bits: usize, ) { let num_coeffs = 1 << num_variables; @@ -682,7 +684,7 @@ mod tests { pow_bits, initial_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: 1, batch_size, // KEY: batch_size > 1 hash_id: hash::SHA2, @@ -788,7 +790,7 @@ mod tests { 1, // num_points_per_poly num_witness, batch_size, - false, + DecodingRegime::Johnson, 0, // pow_bits ); } @@ -803,7 +805,7 @@ mod tests { initial_folding_factor: usize, folding_factor: usize, num_points: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, pow_bits: usize, ) { eprintln!("\n---------------------"); @@ -813,7 +815,7 @@ mod tests { eprintln!(" initial_folding : {initial_folding_factor}"); eprintln!(" folding_factor : {folding_factor}"); eprintln!(" num_points : {num_points:?}"); - eprintln!(" unique_decoding : {unique_decoding:?}"); + eprintln!(" decoding_regime : {decoding_regime:?}"); eprintln!(" pow_bits : {pow_bits}"); // Number of coefficients in the multilinear polynomial (2^num_variables) @@ -825,7 +827,7 @@ mod tests { pow_bits, initial_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: 1, batch_size, hash_id: hash::SHA2, @@ -910,7 +912,7 @@ mod tests { #[test] fn test_batched_whir() { let folding_factors = [1, 4]; - let unique_decoding_options = [false, true]; + let decoding_regime_options = [DecodingRegime::Johnson, DecodingRegime::Unique]; let num_points = [0, 2]; let pow_bits = [0, 10]; @@ -918,7 +920,7 @@ mod tests { let num_variables = (2 * folding_factor)..=3 * folding_factor; for num_variable in num_variables { for num_points in num_points { - for unique_decoding in unique_decoding_options { + for decoding_regime in decoding_regime_options { for pow_bits in pow_bits { for batch_size in 1..=4 { make_batched_whir_things( @@ -927,7 +929,7 @@ mod tests { folding_factor, folding_factor, num_points, - unique_decoding, + decoding_regime, pow_bits, ); } diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs index ce49b13f..18dacc42 100644 --- a/src/protocols/whir_zk/mod.rs +++ b/src/protocols/whir_zk/mod.rs @@ -43,14 +43,14 @@ impl BlindingSizePolicy { .saturating_sub(main_whir_params.pow_bits); #[allow(clippy::cast_possible_wrap)] let q_delta_1 = irs_commit::num_in_domain_queries( - main_whir_params.unique_decoding, + main_whir_params.decoding_regime, protocol_security_level_main as f64, 0.5_f64.powi(main_whir_params.starting_log_inv_rate as i32), ) .get(); #[allow(clippy::cast_possible_wrap)] let q_delta_2 = irs_commit::num_in_domain_queries( - main_whir_params.unique_decoding, + main_whir_params.decoding_regime, main_whir_params.security_level as f64, 0.5_f64.powi(main_whir_params.starting_log_inv_rate as i32), ) @@ -257,6 +257,7 @@ mod tests { }, hash, parameters::ProtocolParameters, + protocols::params::DecodingRegime, transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, }; @@ -277,7 +278,7 @@ mod tests { fn make_test_config(num_polynomials: usize) -> Config { let whir_params = ProtocolParameters { - unique_decoding: false, + decoding_regime: DecodingRegime::Johnson, security_level: 16, pow_bits: 0, initial_folding_factor: 2, From 4ccb1e37d11191aa17ff4e42d1c2637ae44b4af1 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 23 May 2026 04:23:36 +0530 Subject: [PATCH 21/31] refactor : clippy --- .../protocols/params/basecase.txt | 10 + .../protocols/params/code_switch.txt | 2 + src/protocols/params/basecase.rs | 22 +-- src/protocols/params/bounds.rs | 6 +- src/protocols/params/code_switch.rs | 84 ++++---- src/protocols/params/derive.rs | 187 ++++++++---------- src/protocols/params/error.rs | 36 +++- src/protocols/params/mask_proximity.rs | 14 +- src/protocols/params/protocol_config.rs | 22 +-- src/protocols/params/regime.rs | 9 +- src/protocols/params/spec.rs | 1 - src/protocols/params/sumcheck.rs | 81 +++++--- src/protocols/params/test_utils.rs | 39 ++-- 13 files changed, 276 insertions(+), 237 deletions(-) create mode 100644 proptest-regressions/protocols/params/basecase.txt diff --git a/proptest-regressions/protocols/params/basecase.txt b/proptest-regressions/protocols/params/basecase.txt new file mode 100644 index 00000000..c3b26074 --- /dev/null +++ b/proptest-regressions/protocols/params/basecase.txt @@ -0,0 +1,10 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc b6e8ae0b3e6a9769901e0e0e489da34965bf0a8df7dd049aef66e0541bf10baf # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (1, 1) +cc a2f771fc5031440200810b95ea2d347da895f8eb2e1a87f53fd69ad224287e84 # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (2, 2) +cc f66c89bc700c79bca5f4b7234f1345129962c78e5a2036a6430564f615f19b30 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (1, 1) +cc 312fec8a96f6f55f5d3c0346bdb85690f23150aadf88888d453041e99b05d414 # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 39, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (1, 1) diff --git a/proptest-regressions/protocols/params/code_switch.txt b/proptest-regressions/protocols/params/code_switch.txt index 13a64288..de6c40cf 100644 --- a/proptest-regressions/protocols/params/code_switch.txt +++ b/proptest-regressions/protocols/params/code_switch.txt @@ -9,3 +9,5 @@ cc b42c982074a04c7110df07cf00f45156607be547e176b1ddd5f9d994ad491ddb # shrinks to cc eaf09a2b6bdffa86026264679f008326498ca800260dd2f17d4370df9fb3f801 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 cc 3887a5fa698c99109e8262e843dbd24ea94b9c9d420791e4520b5c9211a3eca0 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 100, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 7) cc b3e128084f721e6f43e263e05acf2e2de6fcd05dccf3811f063eeb0b63d78f8e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 47, max_pow_bits: Some(15), hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 4) +cc b71da9002ceac9e4a74af097a7b087557a5b916fe8da47e39c4682375d749f88 # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 50, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (1, 2, 4) +cc 1981509d857e56772dd4a79f8692619e968891aa3d84576ea1857f6d9a484a2d # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 50, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (1, 2, 4) diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 97a65f5d..b53a5b17 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -9,7 +9,7 @@ use crate::{ basecase::{self, Config as BasecaseConfig}, irs_commit::Config as IrsConfig, params::{ - error::{DeriveError, Pow, PowResultExt}, + error::{grind_to_at, DeriveError, Pow}, irs_commit as irs_solver, spec::{Mode as SpecMode, OodSampleBudget, RoundContext, SecuritySpec}, sumcheck as sumcheck_solver, @@ -35,13 +35,11 @@ pub fn solve( }; let commit = irs_solver::solve(spec, &ctx, OodSampleBudget::ZERO); - let target_bits = Bits::new(f64::from(spec.target_security_bits)); - let sumcheck_pow = PowConfig::grind_to( - target_bits, + let sumcheck_pow = grind_to_at( + spec, sumcheck_solver::analytic_error_bits(&commit, None), - spec.hash_id, - ) - .at(Pow::BasecaseSumcheck)?; + Pow::BasecaseSumcheck, + )?; let sumcheck = SumcheckConfig::new( vector_size, sumcheck_pow, @@ -56,10 +54,11 @@ pub fn solve( let pow = match mode { basecase::BasecaseMode::Standard => PowConfig::none(), - basecase::BasecaseMode::ZeroKnowledge => { - PowConfig::grind_to(target_bits, analytic_error_bits(&commit), spec.hash_id) - .at(Pow::BasecaseGammaCombination)? - } + basecase::BasecaseMode::ZeroKnowledge => grind_to_at( + spec, + analytic_error_bits(&commit), + Pow::BasecaseGammaCombination, + )?, }; Ok(BasecaseConfig::new(commit, sumcheck, mode, pow)) @@ -92,7 +91,6 @@ impl BasecaseConfig { } #[cfg(test)] -#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index 524b0955..b01bf1ba 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -8,9 +8,8 @@ pub(super) fn rate(log_inv_rate: f64) -> f64 { 2_f64.powf(-log_inv_rate) } -/// Lossy `usize → f64` for analytic-error formulas. Single allow-site for -/// `clippy::cast_precision_loss` so individual call sites can stay terse. -#[allow(clippy::cast_precision_loss)] +/// Lossy `usize → f64` for analytic-error formulas. Named so individual call +/// sites can stay terse and intent-tagged. pub(super) const fn usize_to_f64(x: usize) -> f64 { x as f64 } @@ -21,7 +20,6 @@ pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { } #[cfg(test)] -#[allow(clippy::float_cmp)] mod tests { use super::*; use crate::protocols::params::test_utils::assert_close; diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 7730518a..f5f657d9 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -14,44 +14,57 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::usize_to_f64, - error::{DeriveError, Pow, PowResultExt}, + error::{grind_to_at, DeriveError, Pow}, protocol_config::MaskOracleInfo, spec::SecuritySpec, }, - proof_of_work::Config as PowConfig, }, }; -/// `mask_oracle.l_zk` must have been used to size C_zk (planner's job). -/// -/// PoW closes the Lemma 9.9 OOD gap to `spec.target_security_bits`. `t_ood ≥ 1` -/// is required: enforced by [`analytic_error_bits`] and -/// [`code_switch::Config::new`] (Construction 9.7 needs OOD queries). -pub fn solve( +/// Standard-mode code-switch builder. PoW closes the Lemma 9.9 OOD gap to +/// `spec.target_security_bits`; `t_ood ≥ 1` is required by Construction 9.7. +pub fn solve_standard( spec: &SecuritySpec, source: IrsConfig, target: IrsConfig>, t_ood: usize, - mask_oracle: Option, round_index: usize, ) -> Result, DeriveError> { - let mode = mask_oracle.map_or(code_switch::CodeSwitchMode::Standard, |info| { - let l_zk = info.l_zk.get(); - assert!( - l_zk >= source.mask_length() + t_ood, - "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Theorem 9.6 witness sizing", - source.mask_length(), - t_ood, - ); - code_switch::CodeSwitchMode::ZeroKnowledge { - message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), - } - }); + let analytic = analytic_error_bits(&source, &target, t_ood, None); + let pow = grind_to_at(spec, analytic, Pow::RoundCodeSwitch { index: round_index })?; + Ok(CodeSwitchConfig::new( + source, + target, + t_ood, + code_switch::CodeSwitchMode::Standard, + pow, + )) +} - let target_bits = Bits::new(f64::from(spec.target_security_bits)); - let analytic = analytic_error_bits(&source, &target, t_ood, mask_oracle); - let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id) - .at(Pow::RoundCodeSwitch { index: round_index })?; +/// ZK code-switch builder. `mask_oracle.l_zk` must have been used to size C_zk +/// (planner's job). PoW closes the Lemma 9.9 OOD gap; `t_ood ≥ 1` and +/// `ℓ_zk ≥ r + t_ood` (Theorem 9.6) are asserted here. +pub fn solve_zk( + spec: &SecuritySpec, + source: IrsConfig, + target: IrsConfig>, + t_ood: usize, + mask_oracle: MaskOracleInfo, + round_index: usize, +) -> Result, DeriveError> { + let l_zk = mask_oracle.l_zk.get(); + assert!( + l_zk >= source.mask_length() + t_ood, + "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Theorem 9.6 witness sizing", + source.mask_length(), + t_ood, + ); + let mode = code_switch::CodeSwitchMode::ZeroKnowledge { + message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), + }; + + let analytic = analytic_error_bits(&source, &target, t_ood, Some(mask_oracle)); + let pow = grind_to_at(spec, analytic, Pow::RoundCodeSwitch { index: round_index })?; Ok(CodeSwitchConfig::new(source, target, t_ood, mode, pow)) } @@ -102,7 +115,6 @@ pub const fn masks_required() -> usize { } #[cfg(test)] -#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; @@ -273,7 +285,7 @@ mod tests { ) { let (source, target, t_ood) = build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); - let config = solve(&spec, source, target, t_ood, None, 0).unwrap(); + let config = solve_standard(&spec, source, target, t_ood, 0).unwrap(); prop_assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); prop_assert!(config.out_domain_samples >= 1); } @@ -323,14 +335,13 @@ mod tests { let target_log_inv_rate = f64::from(log_inv_rate + folding_factor - 1); let target_list_size = johnson_list_size(target_log_inv_rate); let recomputed_t_ood = - compute_t_ood(&spec, &source, target_list_size, Some(c_zk.list_size()), 0) - .unwrap(); + compute_t_ood(&spec, &source, target_list_size, Some(c_zk.list_size()), t_ood); prop_assert_eq!(t_ood, recomputed_t_ood, "placeholder ⇒ final C_zk fixed-point"); let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, }; - let config = solve(&spec, source, target, t_ood, Some(mask_oracle), 0).unwrap(); + let config = solve_zk(&spec, source, target, t_ood, mask_oracle, 0).unwrap(); prop_assert_eq!(config.message_mask_length(), (r + t_ood).next_power_of_two()); } @@ -343,7 +354,7 @@ mod tests { let (source, target, t_ood) = build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); let error = analytic_error_bits(&source, &target, t_ood, None); - let config = solve(&spec, source, target, t_ood, None, 0).unwrap(); + let config = solve_standard(&spec, source, target, t_ood, 0).unwrap(); assert_pow_closes_gap(&spec, error, &config.pow); } } @@ -385,9 +396,9 @@ mod tests { &target_ctx, OodSampleBudget::ZERO, ); - let t_ood = compute_t_ood(&spec, &source, target.list_size(), None, 0).unwrap(); + let t_ood = compute_t_ood(&spec, &source, target.list_size(), None, 0); - let config = solve(&spec, source, target, t_ood, None, 0).unwrap(); + let config = solve_standard(&spec, source, target, t_ood, 0).unwrap(); assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); } @@ -421,9 +432,8 @@ mod tests { &source, target.list_size(), Some(SMOKE_C_ZK_LIST_SIZE), - 0, - ) - .unwrap(); + t_ood, + ); if new_t_ood == t_ood { break; } @@ -436,7 +446,7 @@ mod tests { c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()), }; - let config = solve(&spec, source, target, t_ood, Some(mask_oracle), 0).unwrap(); + let config = solve_zk(&spec, source, target, t_ood, mask_oracle, 0).unwrap(); assert!(matches!( config.mode, code_switch::CodeSwitchMode::ZeroKnowledge { .. } diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 5785e37b..cb1e6af2 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -26,6 +26,9 @@ use crate::{ }, }; +/// Paranoia guard on `solve_t_ood` — convergence proof on the function itself. +const T_OOD_MAX_ITER: usize = 32; + impl ProtocolConfig { /// In ZK each round owns its mask oracle; the `ℓ_zk ↔ c_zk ↔ t_ood` /// fixed-point runs independently per round. @@ -81,12 +84,6 @@ struct RoundLayout { basecase_log_inv_rate: u32, } -struct RoundData { - source: IrsConfig, - target: IrsConfig>, - t_ood: usize, -} - /// Stops when there's no room for both a valid source and a valid target IRS. fn round_layout(tuning: &TuningSpec) -> RoundLayout { assert!(tuning.vector_size.is_power_of_two()); @@ -103,7 +100,6 @@ fn round_layout(tuning: &TuningSpec) -> RoundLayout { if num_vars < source_folding + target_folding { break; } - #[allow(clippy::cast_possible_truncation)] shapes.push(RoundShape { round_index: round, source_vector_size: 1usize << num_vars, @@ -112,10 +108,7 @@ fn round_layout(tuning: &TuningSpec) -> RoundLayout { target_folding_factor: target_folding as u32, }); num_vars -= source_folding; - #[allow(clippy::cast_possible_truncation)] - { - log_inv_rate += (source_folding as u32).saturating_sub(1); - } + log_inv_rate += (source_folding as u32).saturating_sub(1); } RoundLayout { @@ -157,11 +150,23 @@ fn build_zk_round_config( // C_zk.list_size depends only on rate — no IRS build needed for it. let c_zk_list_size = johnson_list_size(f64::from(c_zk_log_inv_rate.get())); - let RoundData { - source, - target, - t_ood, - } = build_zk_round_data::(spec, shape, c_zk_list_size)?; + let src_ctx = round_context(shape); + let target_log_inv_rate = + f64::from(shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1)); + let target_list_size = johnson_list_size(target_log_inv_rate); + + let (source, t_ood) = solve_t_ood::( + spec, + &src_ctx, + target_list_size, + Some(c_zk_list_size), + shape.round_index, + )?; + let target: IrsConfig> = irs_solver::solve( + spec, + &target_context(shape, &source), + OodSampleBudget::new(t_ood), + ); let l_zk = compute_l_zk(&source, t_ood); let c_zk: IrsConfig> = irs_solver::solve_mask_code( @@ -171,101 +176,64 @@ fn build_zk_round_config( c_zk_log_inv_rate, 2 * num_masks, ); + debug_assert!( + (c_zk.list_size() - c_zk_list_size).abs() < 1e-9 * c_zk_list_size.max(1.0), + "c_zk.list_size() {} drifted from rate-only planner estimate {} — \ + see `johnson_list_size` for the invariant", + c_zk.list_size(), + c_zk_list_size, + ); let mask_proximity = mask_proximity_solver::solve(spec, c_zk.clone(), num_masks, shape.round_index)?; let mask_oracle = MaskOracleConfig::new(c_zk, l_zk, mask_proximity); let info = mask_oracle.info(); - let sumcheck = sumcheck_solver::solve( + let sumcheck = sumcheck_solver::solve_zk( spec, &ctx, &source, - Some(info), + info, Pow::RoundSumcheck { index: shape.round_index, }, )?; let code_switch = - code_switch_solver::solve(spec, source, target, t_ood, Some(info), shape.round_index)?; + code_switch_solver::solve_zk(spec, source, target, t_ood, info, shape.round_index)?; Ok(RoundConfig::new( shape.round_index, sumcheck, code_switch, RoundMode::ZeroKnowledge { t_ood: OodSampleBudget::new(t_ood), - mask_oracle, + mask_oracle: Box::new(mask_oracle), }, )) } -/// Local `t_ood ↔ r` fixed-point. `r = source.mask_length()` is a step function -/// of `t_ood` (`next_pow2(ℓ + q + t_ood) − ℓ`); the loop re-iterates only when -/// `t_ood` pushes `r` into the next pow-of-2 bucket. -fn build_zk_round_data( +fn build_round_config( spec: &SecuritySpec, shape: &RoundShape, - c_zk_list_size: f64, -) -> Result, DeriveError> { - const LOCAL_MAX_ITER: usize = 16; - +) -> Result, DeriveError> { let src_ctx = round_context(shape); let target_log_inv_rate = f64::from(shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1)); let target_list_size = johnson_list_size(target_log_inv_rate); - let mut t_ood = 0; - let mut source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::ZERO); - for _ in 0..LOCAL_MAX_ITER { - let new_t_ood = compute_t_ood( - spec, - &source, - target_list_size, - Some(c_zk_list_size), - shape.round_index, - )?; - if new_t_ood == t_ood { - let target: IrsConfig> = irs_solver::solve( - spec, - &target_context(shape, &source), - OodSampleBudget::new(t_ood), - ); - return Ok(RoundData { - source, - target, - t_ood, - }); - } - t_ood = new_t_ood; - source = irs_solver::solve(spec, &src_ctx, OodSampleBudget::new(t_ood)); - } - - Err(DeriveError::FixedPointDidNotConverge { - round_index: shape.round_index, - loop_kind: FixedPointLoop::ZkRound, - }) -} - -fn build_round_config( - spec: &SecuritySpec, - shape: &RoundShape, -) -> Result, DeriveError> { - let src_ctx = round_context(shape); - let source: IrsConfig = irs_solver::solve(spec, &src_ctx, OodSampleBudget::ZERO); + let (source, t_ood) = + solve_t_ood::(spec, &src_ctx, target_list_size, None, shape.round_index)?; let target: IrsConfig> = irs_solver::solve(spec, &target_context(shape, &source), OodSampleBudget::ZERO); - let t_ood = compute_t_ood(spec, &source, target.list_size(), None, shape.round_index)?; - let sumcheck = sumcheck_solver::solve( + let sumcheck = sumcheck_solver::solve_standard( spec, &src_ctx, &source, - None, Pow::RoundSumcheck { index: shape.round_index, }, )?; let code_switch = - code_switch_solver::solve(spec, source, target, t_ood, None, shape.round_index)?; + code_switch_solver::solve_standard(spec, source, target, t_ood, shape.round_index)?; Ok(RoundConfig::new( shape.round_index, sumcheck, @@ -283,52 +251,70 @@ pub(super) const fn compute_l_zk( MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()) } -/// Solves Lemma 9.9 term 1 for `t_ood`. In ZK, `degree = ℓ + r + t_ood` -/// couples back to `t_ood`, so iterate. +/// One application of the Lemma 9.9 OOD step. ZK: `degree = ℓ + ℓ_zk(t_ood)` +/// with `ℓ_zk = next_pow2(source.mask_length() + t_ood)`. Standard: +/// `degree = ℓ`. +/// +/// `Johnson` is forced even under `DecodingRegime::Unique` — Construction 9.7 +/// requires `t_ood ≥ 1` regardless of the decoding regime in `spec`. pub(super) fn compute_t_ood( spec: &SecuritySpec, source: &IrsConfig, target_list_size: f64, c_zk_list_size: Option, - round_index: usize, -) -> Result { - const MAX_ITER: usize = 32; - + t_ood: usize, +) -> usize { let security_target = f64::from(spec.protocol_security_target_bits()); let field_bits = M::Target::field_size_bits(); let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); let message_length = source.message_length(); - // `Johnson` force-computes OOD samples even when `spec.decoding_regime` - // is `Unique`. `compute_t_ood` is called from the ZK fixed-point loop, - // which needs the count for sizing the mask; bypassing the early-return - // in `num_ood_samples` is intentional. - let solve_for_degree = |degree: usize| { - irs_commit::num_ood_samples( - DecodingRegime::Johnson, - security_target, - field_bits, - combined_list_size, - degree, - ) + let degree = if c_zk_list_size.is_some() { + let l_zk = (source.mask_length() + t_ood).next_power_of_two(); + message_length + l_zk + } else { + message_length }; - let mut t_ood = solve_for_degree(message_length); - if matches!(spec.mode, Mode::Standard) { - return Ok(t_ood); + irs_commit::num_ood_samples( + DecodingRegime::Johnson, + security_target, + field_bits, + combined_list_size, + degree, + ) +} + +/// Solves the per-round `t_ood` fixed-point and the source IRS together. +/// +/// Convergence: `Φ(t) = num_ood_samples(ℓ + next_pow2(in_domain + 2·t))` is +/// monotone non-decreasing on ℕ (`in_domain` depends only on the requested +/// rate; `next_pow2` and `num_ood_samples` are monotone) and bounded above, +/// so Kleene iteration from `t = 0` converges to the least fixed point in +/// finitely many steps. Standard mode (`c_zk_list_size = None`) has `Φ` +/// constant in `t`, so one application suffices. +pub(super) fn solve_t_ood( + spec: &SecuritySpec, + src_ctx: &RoundContext, + target_list_size: f64, + c_zk_list_size: Option, + round_index: usize, +) -> Result<(IrsConfig, usize), DeriveError> { + let mut source: IrsConfig = irs_solver::solve(spec, src_ctx, OodSampleBudget::ZERO); + + if c_zk_list_size.is_none() { + let t_ood = compute_t_ood(spec, &source, target_list_size, None, 0); + return Ok((source, t_ood)); } - let r = source.mask_length(); - for _ in 0..MAX_ITER { - // Polynomial degree = `ℓ + ℓ_zk` where `ℓ_zk = next_pow2(r + t_ood)` - // (Theorem 9.6 / Lemma 9.3). Using `r + t_ood` would under-count when - // not pow2. - let l_zk = (r + t_ood).next_power_of_two(); - let new_t_ood = solve_for_degree(message_length + l_zk); + let mut t_ood = 0; + for _ in 0..T_OOD_MAX_ITER { + let new_t_ood = compute_t_ood(spec, &source, target_list_size, c_zk_list_size, t_ood); if new_t_ood == t_ood { - return Ok(t_ood); + return Ok((source, t_ood)); } t_ood = new_t_ood; + source = irs_solver::solve(spec, src_ctx, OodSampleBudget::new(t_ood)); } Err(DeriveError::FixedPointDidNotConverge { round_index, @@ -337,7 +323,6 @@ pub(super) fn compute_t_ood( } #[cfg(test)] -#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; @@ -642,7 +627,7 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - assert_eq!( + assert_close( f64::from(plan.privacy_error_bits()), f64::from(PLAN_FIXTURE_TARGET_BITS), ); diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index ac68bc5a..c217036f 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -10,7 +10,13 @@ use std::fmt::{self, Display, Formatter}; use thiserror::Error; -use crate::{bits::Bits, protocols::proof_of_work::PowError}; +use crate::{ + bits::Bits, + protocols::{ + params::spec::SecuritySpec, + proof_of_work::{Config as PowConfig, PowError}, + }, +}; /// Identifies a single PoW grind in the derived protocol — basecase /// sub-protocol or a per-round sub-protocol at a specific round index. Used @@ -80,17 +86,14 @@ impl Display for ChainTarget { /// Which fixed-point loop failed to converge. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FixedPointLoop { - /// `compute_t_ood`'s scalar iteration. + /// `derive::solve_t_ood` — combined `t_ood ↔ source` Kleene iteration. TOod, - /// `build_zk_round_data`'s outer `t_ood ↔ source.mask_length()` iteration. - ZkRound, } impl Display for FixedPointLoop { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { Self::TOod => f.write_str("t_ood"), - Self::ZkRound => f.write_str("ZK per-round"), } } } @@ -151,3 +154,26 @@ impl PowResultExt for Result { self.map_err(|source| DeriveError::PowUngrindable { pow, source }) } } + +/// Grind `analytic → spec.target_security_bits`, then check the result against +/// `spec.pow_budget` — both failures attributed to `pow_kind` at the same site. +/// `ProtocolConfig::validate_pow_budget` remains as a defense-in-depth check +/// for hand-mutated plans. +pub(crate) fn grind_to_at( + spec: &SecuritySpec, + analytic: Bits, + pow_kind: Pow, +) -> Result { + let target = Bits::new(f64::from(spec.target_security_bits)); + let pow = PowConfig::grind_to(target, analytic, spec.hash_id).at(pow_kind)?; + let required = pow.difficulty(); + let max = Bits::new(f64::from(spec.pow_budget.bits())); + if required > max { + return Err(DeriveError::PowBudgetExceeded { + pow: pow_kind, + required, + max, + }); + } + Ok(pow) +} diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index b6481e5e..115121c1 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -11,10 +11,9 @@ use crate::{ mask_proximity::Config as MaskProximityConfig, params::{ bounds::usize_to_f64, - error::{DeriveError, Pow, PowResultExt}, + error::{grind_to_at, DeriveError, Pow}, spec::SecuritySpec, }, - proof_of_work::Config as PowConfig, }, }; @@ -26,10 +25,12 @@ pub fn solve( num_masks: usize, round_index: usize, ) -> Result, DeriveError> { - let target_bits = Bits::new(f64::from(spec.target_security_bits)); let analytic = analytic_error_bits(&c_zk, num_masks); - let pow = PowConfig::grind_to(target_bits, analytic, spec.hash_id) - .at(Pow::RoundMaskProximity { index: round_index })?; + let pow = grind_to_at( + spec, + analytic, + Pow::RoundMaskProximity { index: round_index }, + )?; Ok(MaskProximityConfig::new(c_zk, num_masks, pow)) } @@ -53,7 +54,6 @@ impl MaskProximityConfig { } #[cfg(test)] -#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; @@ -106,7 +106,7 @@ mod tests { let c_zk = build_test_c_zk(&spec, 2, 1, 1); let bits = f64::from(analytic_error_bits(&c_zk, 0)); let field_bits = ::field_size_bits(); - assert_eq!(bits, field_bits.max(0.0)); + assert_close(bits, field_bits.max(0.0)); } proptest! { diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 9d299fad..7022e311 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -275,10 +275,10 @@ impl RoundConfig { /// Convenience: borrow the round's mask oracle if this is a ZK round. /// Equivalent to pattern-matching on `mode()`. - pub const fn mask_oracle(&self) -> Option<&MaskOracleConfig> { + pub fn mask_oracle(&self) -> Option<&MaskOracleConfig> { match &self.mode { RoundMode::Standard => None, - RoundMode::ZeroKnowledge { mask_oracle, .. } => Some(mask_oracle), + RoundMode::ZeroKnowledge { mask_oracle, .. } => Some(mask_oracle.as_ref()), } } @@ -289,19 +289,13 @@ impl RoundConfig { } } -/// Standard vs. ZK round. ZK variant carries the full per-round mask oracle -/// — there is no longer a separate `MaskOracleInfo` slim view stored -/// alongside it (which would duplicate `mask_oracle.info()`). +/// Standard vs. ZK round. /// -/// Not `Copy`: `MaskOracleConfig` owns a `MaskProximityConfig` and an -/// `IrsConfig`, neither of which is `Copy`. -/// -/// `large_enum_variant` allowed: the ZK variant carries `MaskOracleConfig` -/// (~330B) while `Standard` is 0B, but a proof holds O(rounds) RoundModes -/// (single-digit count) so the absolute overhead is a few KB. Boxing the -/// payload would add per-access indirection without measurable savings. +/// The ZK payload is boxed so the enum stays small: `MaskOracleConfig` is +/// ~330 B while the `Standard` variant is 0 B, and proofs hold O(rounds) +/// `RoundMode`s. Accessors expose `&MaskOracleConfig` so call sites are +/// unaffected by the indirection. #[derive(Clone, Debug)] -#[allow(clippy::large_enum_variant)] pub enum RoundMode { Standard, ZeroKnowledge { @@ -309,7 +303,7 @@ pub enum RoundMode { t_ood: OodSampleBudget, /// Per-round mask oracle: C_zk codeword (sized for `2·(k+1)` /// columns) + ℓ_zk + mask-proximity check for `k+1` masks. - mask_oracle: MaskOracleConfig, + mask_oracle: Box>, }, } diff --git a/src/protocols/params/regime.rs b/src/protocols/params/regime.rs index f985f061..11409bee 100644 --- a/src/protocols/params/regime.rs +++ b/src/protocols/params/regime.rs @@ -92,12 +92,17 @@ impl DecodingRegimeParams { /// Johnson list size at the canonical `η = √ρ / 20` slack, as a function of /// `log_inv_rate` only. Used by planners that need a list-size estimate before /// a target config exists. +/// +/// Equals `IrsConfig::list_size()` exactly when the IRS has pow2 `vector_size`, +/// `interleaving_depth = 1`, integer `log_inv_rate`, and lives on a 2-adic NTT +/// field — the conditions `solve_mask_code` enforces for C_zk. Outside that +/// regime, `ntt::next_order` may shift the effective rate and the helper +/// underestimates. pub fn johnson_list_size(log_inv_rate: f64) -> f64 { DecodingRegimeParams::johnson_canonical(rate(log_inv_rate)).list_size(log_inv_rate) } #[cfg(test)] -#[allow(clippy::float_cmp)] mod tests { use super::*; use crate::protocols::params::test_utils::assert_close; @@ -124,7 +129,7 @@ mod tests { /// Unique-decoding regime gives `|Λ| = 1`, i.e. log = 0. #[test] fn list_size_log2_unique_decoding_is_zero() { - assert_eq!(DecodingRegimeParams::Unique.list_size_log2(2.0), 0.0); + assert_close(DecodingRegimeParams::Unique.list_size_log2(2.0), 0.0); } /// `η = √ρ / 20` substituted into `|Λ| = 1/(2η√ρ)` simplifies to `10/ρ`. diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 94f4fab3..67fc8f51 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -285,7 +285,6 @@ impl ListSize { } #[cfg(test)] -#[allow(clippy::float_cmp)] mod tests { use super::*; use crate::hash; diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index d0da57cf..81b5b821 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -8,42 +8,52 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::usize_to_f64, - error::{DeriveError, Pow, PowResultExt}, + error::{grind_to_at, DeriveError, Pow}, protocol_config::MaskOracleInfo, spec::{RoundContext, SecuritySpec}, }, - proof_of_work::Config as PowConfig, sumcheck::{self, Config as SumcheckConfig, SumcheckMaskLen}, }, }; -/// `mask_oracle` is `Some` iff ZK; only C_zk's list size + ℓ_zk are read here. -/// `pow` is the [`Pow`] that labels grinding failures (basecase or per-round). -pub fn solve( +/// Standard-mode sumcheck builder. `pow` labels grinding failures +/// (basecase or per-round). +pub fn solve_standard( spec: &SecuritySpec, ctx: &RoundContext, source_irs: &IrsConfig, - mask_oracle: Option, pow: Pow, ) -> Result, DeriveError> { - let num_rounds = num_sumcheck_rounds(ctx); - let round_pow = PowConfig::grind_to( - Bits::new(f64::from(spec.target_security_bits)), - analytic_error_bits(source_irs, mask_oracle), - spec.hash_id, - ) - .at(pow)?; - let mode = match mask_oracle { - None => sumcheck::SumcheckMode::Standard, - Some(_) => sumcheck::SumcheckMode::ZeroKnowledge { - mask_length: zk_mask_length(), - }, - }; + let round_pow = grind_to_at(spec, analytic_error_bits(source_irs, None), pow)?; Ok(SumcheckConfig::new( ctx.vector_size, round_pow, - num_rounds, - mode, + num_sumcheck_rounds(ctx), + sumcheck::SumcheckMode::Standard, + )) +} + +/// ZK sumcheck builder. `mask_oracle` carries C_zk's list size + ℓ_zk; only +/// those two values are read here. `pow` labels grinding failures. +pub fn solve_zk( + spec: &SecuritySpec, + ctx: &RoundContext, + source_irs: &IrsConfig, + mask_oracle: MaskOracleInfo, + pow: Pow, +) -> Result, DeriveError> { + let round_pow = grind_to_at( + spec, + analytic_error_bits(source_irs, Some(mask_oracle)), + pow, + )?; + Ok(SumcheckConfig::new( + ctx.vector_size, + round_pow, + num_sumcheck_rounds(ctx), + sumcheck::SumcheckMode::ZeroKnowledge { + mask_length: zk_mask_length(), + }, )) } @@ -86,7 +96,6 @@ const fn zk_mask_length() -> SumcheckMaskLen { } #[cfg(test)] -#[allow(clippy::float_cmp)] mod tests { use proptest::prelude::*; @@ -129,8 +138,9 @@ mod tests { let spec = deterministic_spec(Mode::ZeroKnowledge); let ctx = fixture_ctx(); let source_irs = build_source_irs(&spec, &ctx); - let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve( + let mask_oracle = + build_minimal_mask_oracle(&spec).expect("ZK spec must produce a mask oracle"); + let config = solve_zk( &spec, &ctx, &source_irs, @@ -204,7 +214,7 @@ mod tests { l_zk: MaskCodeMessageLen::new(1 << OVERSIZED_LOG_L_ZK), }; let bits = f64::from(analytic_error_bits::(&irs, Some(huge))); - assert_eq!(bits, 0.0); + assert_close(bits, 0.0); } proptest! { @@ -214,8 +224,8 @@ mod tests { ctx in arb_round_ctx(), ) { let source_irs = build_source_irs(&spec, &ctx); - let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle, Pow::RoundSumcheck { index: 0 }).unwrap(); + let pow = Pow::RoundSumcheck { index: 0 }; + let config = solve_standard(&spec, &ctx, &source_irs, pow).unwrap(); prop_assert!(matches!(config.mode, sumcheck::SumcheckMode::Standard)); } @@ -228,8 +238,11 @@ mod tests { ctx in arb_round_ctx(), ) { let source_irs = build_source_irs(&spec, &ctx); - let mask_oracle = build_minimal_mask_oracle(&spec); - let config = solve(&spec, &ctx, &source_irs, mask_oracle, Pow::RoundSumcheck { index: 0 }).unwrap(); + let pow = Pow::RoundSumcheck { index: 0 }; + let config = build_minimal_mask_oracle(&spec).map_or_else( + || solve_standard(&spec, &ctx, &source_irs, pow).unwrap(), + |info| solve_zk(&spec, &ctx, &source_irs, info, pow).unwrap(), + ); prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); } @@ -259,7 +272,11 @@ mod tests { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec); let error = analytic_error_bits(&source_irs, mask_oracle); - let config = solve(&spec, &ctx, &source_irs, mask_oracle, Pow::RoundSumcheck { index: 0 }).unwrap(); + let pow = Pow::RoundSumcheck { index: 0 }; + let config = mask_oracle.map_or_else( + || solve_standard(&spec, &ctx, &source_irs, pow).unwrap(), + |info| solve_zk(&spec, &ctx, &source_irs, info, pow).unwrap(), + ); assert_pow_closes_gap(&spec, error, &config.round_pow); } } @@ -275,11 +292,11 @@ mod tests { c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; - let config = solve( + let config = solve_zk( &spec, &ctx, &source_irs, - Some(info), + info, Pow::RoundSumcheck { index: 0 }, ) .unwrap(); diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 4ba803be..55f1b944 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -14,7 +14,7 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ - derive::compute_t_ood, + derive::solve_t_ood, irs_commit as irs_solver, protocol_config::MaskOracleInfo, regime::johnson_list_size, @@ -48,33 +48,30 @@ pub const FIXTURE_TARGET_BITS: u32 = 80; /// used in the analytic-error formulas. pub const EPS: f64 = 1e-9; +/// Matches `proof_of_work::MAX_DIFFICULTY` so per-slot budget checks in +/// `grind_to_at` never bite. Tests exercising budget enforcement build their +/// own specs. +pub const FIXTURE_POW_BUDGET_BITS: u32 = 60; + pub fn deterministic_spec(mode: Mode) -> SecuritySpec { SecuritySpec { mode, decoding_regime: DecodingRegime::Johnson, target_security_bits: FIXTURE_TARGET_BITS, - pow_budget: PowBudget::Forbidden, + pow_budget: PowBudget::per_slot(FIXTURE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, } } -/// `pow_budget` ∈ `{Forbidden, PerSlot{1..=16}}`; bounded so the analytic -/// floor stays positive for the lowest test targets and the PoW gap stays -/// under the 60-bit cap. `PerSlot { bits: 0 }` is unrepresentable, so we -/// generate `Forbidden` for the "no grinding" case directly. pub fn arb_spec( mode: Mode, target_range: RangeInclusive, ) -> impl Strategy { - let pow_strategy = prop_oneof![ - Just(PowBudget::Forbidden), - (1u32..=16).prop_map(PowBudget::per_slot), - ]; - (target_range, pow_strategy).prop_map(move |(target, pow_budget)| SecuritySpec { + target_range.prop_map(move |target| SecuritySpec { mode, decoding_regime: DecodingRegime::Johnson, target_security_bits: target, - pow_budget, + pow_budget: PowBudget::per_slot(FIXTURE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, }) } @@ -157,10 +154,10 @@ pub fn build_test_c_zk( /// per-round shape that `code_switch::solve` expects. /// /// `t_ood` is solved against the rate-only `johnson_list_size(target_log_inv_rate)`, -/// mirroring `derive::build_zk_round_data`. Using `target.list_size()` here -/// instead would couple `t_ood` to the target's effective rate (which itself -/// depends on `t_ood` via the mask), producing a non-monotone oscillation -/// once the mask is tight (Lemma 9.5 part ii) rather than pow2-padded. +/// mirroring `derive::solve_t_ood`. Using `target.list_size()` here instead +/// would couple `t_ood` to the target's effective rate (which itself depends +/// on `t_ood` via the mask), producing a non-monotone oscillation once the +/// mask is tight (Lemma 9.5 part ii) rather than pow2-padded. pub fn build_round_io( spec: &SecuritySpec, log_inv_rate: u32, @@ -173,18 +170,16 @@ pub fn build_round_io( log_inv_rate, folding_factor, }; - let source = irs_solver::solve(spec, &source_ctx, OodSampleBudget::ZERO); - let target_log_inv_rate = log_inv_rate + folding_factor - 1; + let target_list_size = johnson_list_size(f64::from(target_log_inv_rate)); + let (source, t_ood) = solve_t_ood::(spec, &source_ctx, target_list_size, c_zk_list_size, 0) + .expect("solve_t_ood diverged in test fixture"); + let target_ctx = RoundContext { vector_size: source.message_length(), log_inv_rate: target_log_inv_rate, folding_factor, }; - - let target_list_size = johnson_list_size(f64::from(target_log_inv_rate)); - let t_ood = compute_t_ood(spec, &source, target_list_size, c_zk_list_size, 0) - .expect("compute_t_ood diverged in test fixture"); let target = irs_solver::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); (source, target, t_ood) } From 998571242bacadcdd78672bf457a6d6c1b58676e Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 23 May 2026 05:55:01 +0530 Subject: [PATCH 22/31] feat : added capacity bound --- src/protocols/irs_commit.rs | 3 +- src/protocols/params/basecase.rs | 8 +- src/protocols/params/code_switch.rs | 51 +++----- src/protocols/params/derive.rs | 176 ++++++++++++++++++++++----- src/protocols/params/regime.rs | 179 ++++++++++++++++++++++------ src/protocols/params/spec.rs | 25 ++-- src/protocols/params/sumcheck.rs | 12 +- src/protocols/params/test_utils.rs | 37 ++++-- 8 files changed, 360 insertions(+), 131 deletions(-) diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index 4f1d2540..2c137fde 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -222,7 +222,8 @@ impl Config { /// Compute a list size bound. pub fn list_size(&self) -> f64 { - self.regime.list_size(self.log_inv_rate()) + let log_degree = (self.masked_message_length() as f64).log2(); + self.regime.list_size(log_degree, self.log_inv_rate()) } /// Round-by-round soundness of the in-domain queries in bits. diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index b53a5b17..3875de04 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -96,8 +96,8 @@ mod tests { use super::*; use crate::protocols::params::test_utils::{ - arb_standard_johnson_spec, arb_zk_spec, assert_close, assert_pow_closes_gap, - deterministic_spec, TestField, TEST_TARGET_RANGE, + arb_standard_spec, arb_zk_spec, assert_close, assert_pow_closes_gap, deterministic_spec, + TestField, TEST_TARGET_RANGE, }; /// `vector_size = 16` (2^4) and `log_inv_rate = 2` give a small but @@ -172,7 +172,7 @@ mod tests { proptest! { #[test] fn solve_standard_assembles( - spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), + spec in arb_standard_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); @@ -203,7 +203,7 @@ mod tests { #[test] fn standard_mode_has_no_pow( - spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), + spec in arb_standard_spec(TEST_TARGET_RANGE), (log_size, log_inv_rate) in arb_dims(), ) { let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index f5f657d9..f49607bc 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -122,14 +122,14 @@ mod tests { use crate::protocols::params::{ derive::{compute_l_zk, compute_t_ood}, irs_commit as irs_solver, - regime::johnson_list_size, + regime::list_size_estimate, spec::{ DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, RoundContext, SecuritySpec, ZkSpec, }, test_utils::{ - arb_standard_johnson_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, - assert_close, assert_pow_closes_gap, build_round_io, deterministic_spec, TestEmbedding, + arb_standard_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, assert_close, + assert_pow_closes_gap, build_round_io, deterministic_spec, TestEmbedding, TestExtensionField, TestField, TestNonIdentityEmbedding, TEST_TARGET_RANGE, }, }; @@ -140,7 +140,7 @@ mod tests { utils_zk_spec(TEST_TARGET_RANGE) } - fn arb_standard_johnson_spec() -> impl Strategy { + fn arb_standard_spec() -> impl Strategy { utils_standard_spec(TEST_TARGET_RANGE) } @@ -210,7 +210,7 @@ mod tests { FORMULA_LOG_INV_RATE, FORMULA_FOLDING_FACTOR, FORMULA_NUM_VARS, - Some(C_ZK_LIST_SIZE), + Some(FORMULA_LOG_INV_RATE), ); let got = f64::from(analytic_error_bits( &source, @@ -280,7 +280,7 @@ mod tests { proptest! { #[test] fn solve_standard_assembles( - spec in arb_standard_johnson_spec(), + spec in arb_standard_spec(), (log_inv_rate, folding_factor, num_vars) in arb_dims(), ) { let (source, target, t_ood) = @@ -296,30 +296,12 @@ mod tests { spec in arb_zk_spec(), (log_inv_rate, folding_factor, num_vars) in arb_dims(), ) { - // Break the t_ood ↔ c_zk.list_size cycle with a placeholder C_zk. - let placeholder_source_ctx = RoundContext { - vector_size: 1usize << num_vars, - log_inv_rate, - folding_factor, - }; - let placeholder_source = irs_solver::solve::( - &spec, - &placeholder_source_ctx, - OodSampleBudget::ZERO, - ); - let zk_spec = ZkSpec::try_new(&spec).expect("arb_zk_spec"); - let c_zk_placeholder = irs_solver::solve_mask_code::( - zk_spec, - compute_l_zk(&placeholder_source, 1), - placeholder_source.mask_length(), - LogInvRate::new(log_inv_rate), - 2, - ); let (source, target, t_ood) = build_round_io::( - &spec, log_inv_rate, folding_factor, num_vars, Some(c_zk_placeholder.list_size()), + &spec, log_inv_rate, folding_factor, num_vars, Some(log_inv_rate), ); let r = source.mask_length(); let l_zk = compute_l_zk(&source, t_ood); + let zk_spec = ZkSpec::try_new(&spec).expect("arb_zk_spec"); let c_zk = irs_solver::solve_mask_code::( zk_spec, l_zk, @@ -327,16 +309,17 @@ mod tests { LogInvRate::new(log_inv_rate), 2, ); - // Use the same rate-only Johnson list as the planner / `build_round_io`. - // `target.list_size()` here would read the *effective* rate after - // `next_order` rounding, which (post Lemma-9.5 tight masking) differs - // from the requested rate and would spuriously shift `t_ood`. The - // assertion isolates the c_zk fixed-point, not the rate-drift artifact. + // Same rate-only list estimate as the planner — `target.list_size()` + // would read the effective rate after `next_order` rounding and + // spuriously shift t_ood, masking the c_zk fixed-point under test. let target_log_inv_rate = f64::from(log_inv_rate + folding_factor - 1); - let target_list_size = johnson_list_size(target_log_inv_rate); + let target_log_degree = f64::from(num_vars - folding_factor); + let target_list_size = list_size_estimate( + spec.decoding_regime, target_log_degree, target_log_inv_rate, + ); let recomputed_t_ood = compute_t_ood(&spec, &source, target_list_size, Some(c_zk.list_size()), t_ood); - prop_assert_eq!(t_ood, recomputed_t_ood, "placeholder ⇒ final C_zk fixed-point"); + prop_assert_eq!(t_ood, recomputed_t_ood, "solve_t_ood ⇒ converged C_zk fixed-point"); let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, @@ -348,7 +331,7 @@ mod tests { /// `analytic_error + pow ≥ target` (Lemma 9.9 OOD term). #[test] fn pow_closes_gap_to_target_standard( - spec in arb_standard_johnson_spec(), + spec in arb_standard_spec(), (log_inv_rate, folding_factor, num_vars) in arb_dims(), ) { let (source, target, t_ood) = diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index cb1e6af2..c16c982d 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -16,10 +16,10 @@ use crate::{ error::{DeriveError, FixedPointLoop, Pow}, irs_commit as irs_solver, mask_proximity as mask_proximity_solver, protocol_config::{MaskOracleConfig, ProtocolConfig, RoundConfig, RoundMode}, - regime::johnson_list_size, + regime::list_size_estimate, spec::{ - DecodingRegime, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, - RoundContext, SecuritySpec, TuningSpec, ZkSpec, + LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + TuningSpec, ZkSpec, }, sumcheck as sumcheck_solver, }, @@ -147,19 +147,23 @@ fn build_zk_round_config( let spec = zk_spec.as_inner(); let ctx = round_context(shape); let num_masks = sumcheck_solver::masks_required(&ctx) + code_switch_solver::masks_required(); - // C_zk.list_size depends only on rate — no IRS build needed for it. - let c_zk_list_size = johnson_list_size(f64::from(c_zk_log_inv_rate.get())); + let c_zk_log_inv_rate_f = f64::from(c_zk_log_inv_rate.get()); let src_ctx = round_context(shape); let target_log_inv_rate = f64::from(shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1)); - let target_list_size = johnson_list_size(target_log_inv_rate); + // Target encodes one polynomial of length `source.message_length()` = + // `source_vector_size / 2^source_folding_factor`. + let target_log_degree = + f64::from(shape.source_vector_size.trailing_zeros() - shape.source_folding_factor); + let target_list_size = + list_size_estimate(spec.decoding_regime, target_log_degree, target_log_inv_rate); let (source, t_ood) = solve_t_ood::( spec, &src_ctx, target_list_size, - Some(c_zk_list_size), + Some(c_zk_log_inv_rate_f), shape.round_index, )?; let target: IrsConfig> = irs_solver::solve( @@ -176,12 +180,18 @@ fn build_zk_round_config( c_zk_log_inv_rate, 2 * num_masks, ); + let c_zk_list_size_estimate = list_size_estimate( + spec.decoding_regime, + (l_zk.get() as f64).log2(), + c_zk_log_inv_rate_f, + ); debug_assert!( - (c_zk.list_size() - c_zk_list_size).abs() < 1e-9 * c_zk_list_size.max(1.0), - "c_zk.list_size() {} drifted from rate-only planner estimate {} — \ - see `johnson_list_size` for the invariant", + (c_zk.list_size() - c_zk_list_size_estimate).abs() + < 1e-9 * c_zk_list_size_estimate.max(1.0), + "c_zk.list_size() {} drifted from planner estimate {} — \ + see `list_size_estimate` for the invariant", c_zk.list_size(), - c_zk_list_size, + c_zk_list_size_estimate, ); let mask_proximity = mask_proximity_solver::solve(spec, c_zk.clone(), num_masks, shape.round_index)?; @@ -217,7 +227,10 @@ fn build_round_config( let src_ctx = round_context(shape); let target_log_inv_rate = f64::from(shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1)); - let target_list_size = johnson_list_size(target_log_inv_rate); + let target_log_degree = + f64::from(shape.source_vector_size.trailing_zeros() - shape.source_folding_factor); + let target_list_size = + list_size_estimate(spec.decoding_regime, target_log_degree, target_log_inv_rate); let (source, t_ood) = solve_t_ood::(spec, &src_ctx, target_list_size, None, shape.round_index)?; @@ -255,8 +268,9 @@ pub(super) const fn compute_l_zk( /// with `ℓ_zk = next_pow2(source.mask_length() + t_ood)`. Standard: /// `degree = ℓ`. /// -/// `Johnson` is forced even under `DecodingRegime::Unique` — Construction 9.7 -/// requires `t_ood ≥ 1` regardless of the decoding regime in `spec`. +/// Floored at `1`: Lemma 9.9's OOD term vanishes under `Unique` decoding +/// (`|Λ| = 1`), but Construction 9.7's code-switch still needs at least one +/// OOD point to bind the witness polynomial. pub(super) fn compute_t_ood( spec: &SecuritySpec, source: &IrsConfig, @@ -276,40 +290,51 @@ pub(super) fn compute_t_ood( message_length }; - irs_commit::num_ood_samples( - DecodingRegime::Johnson, + let soundness_t_ood = irs_commit::num_ood_samples( + spec.decoding_regime, security_target, field_bits, combined_list_size, degree, - ) + ); + soundness_t_ood.max(1) } /// Solves the per-round `t_ood` fixed-point and the source IRS together. /// /// Convergence: `Φ(t) = num_ood_samples(ℓ + next_pow2(in_domain + 2·t))` is /// monotone non-decreasing on ℕ (`in_domain` depends only on the requested -/// rate; `next_pow2` and `num_ood_samples` are monotone) and bounded above, -/// so Kleene iteration from `t = 0` converges to the least fixed point in -/// finitely many steps. Standard mode (`c_zk_list_size = None`) has `Φ` -/// constant in `t`, so one application suffices. +/// rate; `next_pow2` and `num_ood_samples` are monotone; under `Capacity` the +/// `c_zk_list_size(t)` factor is monotone too) and bounded above, so Kleene +/// iteration from `t = 0` converges to the least fixed point in finitely many +/// steps. Standard mode (`c_zk_log_inv_rate = None`) has `Φ` constant in +/// `t`, so one application suffices. pub(super) fn solve_t_ood( spec: &SecuritySpec, src_ctx: &RoundContext, target_list_size: f64, - c_zk_list_size: Option, + c_zk_log_inv_rate: Option, round_index: usize, ) -> Result<(IrsConfig, usize), DeriveError> { let mut source: IrsConfig = irs_solver::solve(spec, src_ctx, OodSampleBudget::ZERO); - if c_zk_list_size.is_none() { + let Some(c_zk_log_inv_rate) = c_zk_log_inv_rate else { let t_ood = compute_t_ood(spec, &source, target_list_size, None, 0); return Ok((source, t_ood)); - } + }; let mut t_ood = 0; for _ in 0..T_OOD_MAX_ITER { - let new_t_ood = compute_t_ood(spec, &source, target_list_size, c_zk_list_size, t_ood); + // Under `Capacity`, c_zk's list size depends on its message length + // ℓ_zk(t), so recompute per iteration. Under `Johnson`/`Unique` the + // result is t-independent — the recomputation is a no-op. + let l_zk = (source.mask_length() + t_ood).next_power_of_two(); + let c_zk_list_size = list_size_estimate( + spec.decoding_regime, + (l_zk as f64).log2(), + c_zk_log_inv_rate, + ); + let new_t_ood = compute_t_ood(spec, &source, target_list_size, Some(c_zk_list_size), t_ood); if new_t_ood == t_ood { return Ok((source, t_ood)); } @@ -780,8 +805,7 @@ mod tests { assert!(plan.basecase().commit.unique_decoding()); } - /// Same threading check under ZK mode. Basecase-only avoids the per-round - /// code-switch (which requires `t_ood ≥ 1`). + /// Same threading check under ZK mode (basecase-only fixture). #[test] fn derive_threads_unique_decoding_zk() { let spec = SecuritySpec { @@ -800,6 +824,104 @@ mod tests { assert!(plan.basecase().commit.unique_decoding()); } + /// Multi-round derivation under Unique: every round's IRS carries the + /// Unique regime and every code-switch slot satisfies the Construction + /// 9.7 `t_ood ≥ 1` floor. + #[test] + fn derive_multi_round_unique_decoding_succeeds() { + let spec = SecuritySpec { + mode: Mode::Standard, + decoding_regime: DecodingRegime::Unique, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "expected multi-round plan"); + for r in plan.rounds() { + let cs = r.code_switch(); + assert!(cs.source.unique_decoding()); + assert!(cs.target.unique_decoding()); + assert!(cs.out_domain_samples >= 1, "Construction 9.7 floor"); + } + assert!(plan.basecase().commit.unique_decoding()); + } + + /// ZK + Unique multi-round: per-round mask oracle still assembled, C_zk + /// built under Unique, code-switch carries `t_ood ≥ 1` per floor. + #[test] + fn derive_multi_round_unique_decoding_zk_succeeds() { + let spec = SecuritySpec { + mode: Mode::ZeroKnowledge, + decoding_regime: DecodingRegime::Unique, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "expected multi-round plan"); + for r in plan.rounds() { + let mo = r.mask_oracle().expect("ZK round must own a mask oracle"); + assert!(mo.c_zk().unique_decoding()); + assert!(r.code_switch().source.unique_decoding()); + assert!(r.code_switch().out_domain_samples >= 1); + } + assert!(plan.basecase().commit.unique_decoding()); + } + + /// Multi-round Capacity (Standard): IRS configs carry the Capacity regime + /// and the `c_zk_list_size(t)` fixed-point resolves inside `solve_t_ood`. + #[test] + fn derive_multi_round_capacity_decoding_succeeds() { + let spec = SecuritySpec { + mode: Mode::Standard, + decoding_regime: DecodingRegime::Capacity, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "expected multi-round plan"); + for r in plan.rounds() { + assert!(r.code_switch().out_domain_samples >= 1); + } + } + + /// ZK + Capacity multi-round: exercises the degree-dependent c_zk list + /// size inside the t_ood fixed-point. + #[test] + fn derive_multi_round_capacity_decoding_zk_succeeds() { + let spec = SecuritySpec { + mode: Mode::ZeroKnowledge, + decoding_regime: DecodingRegime::Capacity, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "expected multi-round plan"); + for r in plan.rounds() { + r.mask_oracle().expect("ZK round must own a mask oracle"); + assert!(r.code_switch().out_domain_samples >= 1); + } + } + /// `analytic_error + pow ≥ target` for every PoW slot in the plan. fn assert_plan_meets_target_per_slot( spec: &SecuritySpec, diff --git a/src/protocols/params/regime.rs b/src/protocols/params/regime.rs index 11409bee..e3c6aefc 100644 --- a/src/protocols/params/regime.rs +++ b/src/protocols/params/regime.rs @@ -5,6 +5,14 @@ //! a user choice). The data-carrying [`DecodingRegimeParams`] is what gets //! stored on per-round configs once a rate is known: [`Self::from_policy`] //! is the single materialization point. +//! +//! # References +//! +//! - Johnson proximity-gap error follows the BCSS25 improvement +//! (`O(n/η^5)`, m=10 at canonical slack) over BCIKS '20. +//! - Capacity bound follows STIR Conjecture 5.6: `(1 − ρ − η, d/(ρ·η))`-list +//! decodability for RS codes. +//! - Aligned with the Plonky3 `SecurityAssumption` parametrization. use std::f64::consts::LOG2_10; @@ -16,26 +24,27 @@ use crate::protocols::params::{ spec::DecodingRegime, }; -/// Materialized decoding-regime parameters. +/// Materialized decoding-regime parameters at a known rate. /// -/// `Unique` carries no data; `Johnson { slack }` carries `η`. The two variants -/// are statically distinct — there is no "Johnson with η = 0" representation, -/// so callers can pattern-match without a sentinel-comparison branch. +/// `Unique` carries no data; `Johnson` and `Capacity` each carry the slack `η` +/// from their respective proximity boundary. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum DecodingRegimeParams { Unique, Johnson { slack: OrderedFloat }, + Capacity { slack: OrderedFloat }, } impl DecodingRegimeParams { - /// Materialize spec policy at a known rate. The canonical Johnson slack - /// (`η = √ρ / 20`) is centralized here — any tuning of `η` lives at this - /// site and propagates to every per-round config. + /// Materialize spec policy at a known rate. Canonical slacks (`√ρ/20` for + /// Johnson, `ρ/20` for Capacity) live here — any tuning of `η` propagates + /// to every per-round config through this single site. // TODO: Optimize picking η. pub fn from_policy(policy: DecodingRegime, rate: f64) -> Self { match policy { DecodingRegime::Unique => Self::Unique, DecodingRegime::Johnson => Self::johnson_canonical(rate), + DecodingRegime::Capacity => Self::capacity_canonical(rate), } } @@ -46,22 +55,35 @@ impl DecodingRegimeParams { } } + /// Capacity regime with the canonical `η = ρ / 20` slack. + pub fn capacity_canonical(rate: f64) -> Self { + Self::Capacity { + slack: OrderedFloat(rate / 20.0), + } + } + pub const fn is_unique(self) -> bool { matches!(self, Self::Unique) } /// `log₂ |Λ(C, δ)|`. - pub fn list_size_log2(self, log_inv_rate: f64) -> f64 { + /// + /// `log_degree` is `log₂` of the code's message length; it's only read in + /// the `Capacity` branch (STIR Conj 5.6 gives `|Λ| = d/(ρ·η)`). Johnson + /// and Unique ignore it. + pub fn list_size_log2(self, log_degree: f64, log_inv_rate: f64) -> f64 { match self { Self::Unique => 0.0, // Johnson: |Λ| = 1 / (2 η √ρ). Self::Johnson { slack } => -1.0 - slack.into_inner().log2() + 0.5 * log_inv_rate, + // Capacity (STIR Conj 5.6): |Λ| = d / (ρ · η). + Self::Capacity { slack } => log_degree + log_inv_rate - slack.into_inner().log2(), } } /// `|Λ(C, δ)|`. - pub fn list_size(self, log_inv_rate: f64) -> f64 { - 2_f64.powf(self.list_size_log2(log_inv_rate)) + pub fn list_size(self, log_degree: f64, log_inv_rate: f64) -> f64 { + 2_f64.powf(self.list_size_log2(log_degree, log_inv_rate)) } /// `log₂(1 − δ)`. @@ -69,11 +91,22 @@ impl DecodingRegimeParams { let one_minus_delta = match self { Self::Unique => f64::midpoint(1.0, rate(log_inv_rate)), Self::Johnson { slack } => rate(log_inv_rate).sqrt() + slack.into_inner(), + Self::Capacity { slack } => rate(log_inv_rate) + slack.into_inner(), }; one_minus_delta.log2() } - /// `log₂ ε_mca(C, δ)`. + /// `log₂ ε_mca(C, δ)` for the per-step proximity-gaps error (bare, no + /// arity factor — callers apply their own). + /// + /// - Unique: `(k − 1) / |F|`, log = `log k − |F|` (with `+ log ρ⁻¹` to + /// pick up the `n/|F|` factor). + /// - Johnson: BCSS25 Theorem 1.5 at canonical `η = √ρ/20`, `m = 10`: + /// `ε ≈ (2·10.5⁵/3) · n · ρ^{−3/2} / |F|`. + /// - Capacity: STIR Conj 5.6, `ε ≈ d / (η · ρ²) / |F|`. + /// + /// The formula expressions hardcode the canonical slack; debug-asserts + /// catch a non-canonical `slack` that would invalidate the constants. pub fn eps_mca_log2(self, log_inv_rate: f64, message_length: usize, field_bits: f64) -> f64 { let log_k = usize_to_f64(message_length).log2(); let error = match self { @@ -82,24 +115,31 @@ impl DecodingRegimeParams { debug_assert!( slack.into_inner().log2() >= -(0.5 * log_inv_rate + LOG2_10 + 1.0) - 1e-6 ); - 7.0 * LOG2_10 + 3.5 * log_inv_rate + 2.0 * log_k + // BCSS25 with m = 10: log_2(2·10.5⁵/3) + log n + 1.5·log ρ⁻¹. + let bcss25_const = (2.0 * 10.5_f64.powi(5) / 3.0).log2(); + bcss25_const + log_k + 2.5 * log_inv_rate + } + Self::Capacity { slack } => { + debug_assert!(slack.into_inner().log2() >= -(log_inv_rate + LOG2_10 + 1.0) - 1e-6); + // d / (η · ρ²) at canonical η = ρ/20: log d + log 20 + 3·log ρ⁻¹. + log_k + 3.0 * log_inv_rate + LOG2_10 + 1.0 } }; error - field_bits } } -/// Johnson list size at the canonical `η = √ρ / 20` slack, as a function of -/// `log_inv_rate` only. Used by planners that need a list-size estimate before -/// a target config exists. +/// `|Λ|` at the given degree + rate under `regime`. Used before an IRS config +/// exists. /// -/// Equals `IrsConfig::list_size()` exactly when the IRS has pow2 `vector_size`, -/// `interleaving_depth = 1`, integer `log_inv_rate`, and lives on a 2-adic NTT -/// field — the conditions `solve_mask_code` enforces for C_zk. Outside that -/// regime, `ntt::next_order` may shift the effective rate and the helper -/// underestimates. -pub fn johnson_list_size(log_inv_rate: f64) -> f64 { - DecodingRegimeParams::johnson_canonical(rate(log_inv_rate)).list_size(log_inv_rate) +/// Matches `IrsConfig::list_size()` when the IRS is built under the same +/// regime, with the same `masked_message_length`, and `ntt::next_order` +/// doesn't pad the codeword (pow2 `vector_size`, `interleaving_depth = 1`, +/// integer `log_inv_rate`, 2-adic field — the conditions `solve_mask_code` +/// enforces for C_zk). +pub fn list_size_estimate(regime: DecodingRegime, log_degree: f64, log_inv_rate: f64) -> f64 { + DecodingRegimeParams::from_policy(regime, rate(log_inv_rate)) + .list_size(log_degree, log_inv_rate) } #[cfg(test)] @@ -117,27 +157,43 @@ mod tests { } } + fn capacity(slack: f64) -> DecodingRegimeParams { + DecodingRegimeParams::Capacity { + slack: OrderedFloat(slack), + } + } + /// Johnson list size: `|Λ| = 1 / (2η√ρ)`, log₂ form. Hand-evaluated at /// `log_inv_rate = 2`, `η = 0.1`: `−1 − log₂(0.1) + 1 ≈ 3.3219`. + /// `log_degree` is ignored by the Johnson branch. #[test] fn list_size_log2_johnson_formula() { - let got = johnson(0.1).list_size_log2(2.0); + let got = johnson(0.1).list_size_log2(/* log_degree */ 4.0, 2.0); let expected = -1.0 - 0.1_f64.log2() + 0.5 * 2.0; assert_close(got, expected); } + /// Capacity list size: `|Λ| = d / (ρ · η)`, log₂ form. At `log_degree = 4`, + /// `log_inv_rate = 2`, `η = 1/8`: `4 + 2 − log₂(1/8) = 4 + 2 + 3 = 9`. + #[test] + fn list_size_log2_capacity_formula() { + let got = capacity(0.125).list_size_log2(4.0, 2.0); + let expected = 4.0 + 2.0 - 0.125_f64.log2(); + assert_close(got, expected); + } + /// Unique-decoding regime gives `|Λ| = 1`, i.e. log = 0. #[test] fn list_size_log2_unique_decoding_is_zero() { - assert_close(DecodingRegimeParams::Unique.list_size_log2(2.0), 0.0); + assert_close(DecodingRegimeParams::Unique.list_size_log2(4.0, 2.0), 0.0); } /// `η = √ρ / 20` substituted into `|Λ| = 1/(2η√ρ)` simplifies to `10/ρ`. - /// So `johnson_list_size(b) = 10 · 2^b`. + /// So `list_size_estimate(Johnson, _, b) = 10 · 2^b`. #[test] fn johnson_list_size_closed_form() { for b in [1.0, 2.0, 3.0, 5.0] { - let got = johnson_list_size(b); + let got = list_size_estimate(DecodingRegime::Johnson, /* log_degree */ 4.0, b); let expected = 10.0 * 2_f64.powf(b); assert!( (got - expected).abs() / expected < TIGHT_EPS, @@ -146,9 +202,22 @@ mod tests { } } - /// `johnson_list_size(b)` must match `Config::list_size` once a config is - /// built at the same rate. Keeps the rate-only helper in sync with - /// `irs_commit::Config::new`'s canonical-slack materialization. + /// `η = ρ / 20` substituted into `|Λ| = d/(ρ · η)` simplifies to `20 · d / ρ²`. + #[test] + fn capacity_list_size_closed_form() { + for (log_d, b) in [(4.0, 1.0), (6.0, 2.0), (8.0, 3.0)] { + let got = list_size_estimate(DecodingRegime::Capacity, log_d, b); + let expected = 20.0 * 2_f64.powf(log_d) * 2_f64.powf(2.0 * b); + assert!( + (got - expected).abs() / expected < TIGHT_EPS, + "log_d={log_d}, log_inv_rate={b}: got {got} vs {expected}", + ); + } + } + + /// `list_size_estimate(Johnson, _, b)` must match `Config::list_size` once + /// a config is built at the same rate. Keeps the rate-only helper in sync + /// with `irs_commit::Config::new`'s canonical-slack materialization. #[test] fn johnson_list_size_matches_config_list_size() { use crate::{ @@ -172,7 +241,8 @@ mod tests { 2_f64.powf(-f64::from(LOG_INV_RATE)), IrsMode::Standard, ); - let got = johnson_list_size(f64::from(LOG_INV_RATE)); + let log_degree = (config.masked_message_length() as f64).log2(); + let got = list_size_estimate(DecodingRegime::Johnson, log_degree, f64::from(LOG_INV_RATE)); let expected = config.list_size(); assert!( (got - expected).abs() / expected < TIGHT_EPS, @@ -201,6 +271,17 @@ mod tests { assert_close(got, expected); } + /// `1 − δ` in Capacity regime: `ρ + η`. + #[test] + fn one_minus_distance_log2_capacity() { + let log_inv_rate = 2.0; + let eta = 0.05; + let got = capacity(eta).one_minus_distance_log2(log_inv_rate); + let rho = 2_f64.powf(-log_inv_rate); + let expected = (rho + eta).log2(); + assert_close(got, expected); + } + /// MCA fixture — `message_length = 16 = 2^4` and `log_inv_rate = 2` give /// exact `log2(k) = 4`. `field_bits = 64.0` for Field64. const MCA_MESSAGE_LENGTH: usize = 16; @@ -219,25 +300,41 @@ mod tests { assert_close(got, expected); } - /// MCA error, Johnson branch: `7·log₂10 + 3.5·log_inv_rate + 2·log k − field_bits`. + /// MCA error, Johnson (BCSS25): `log₂(2·10.5⁵/3) + log k + 2.5·log_inv_rate − field_bits`. #[test] fn eps_mca_log2_johnson_formula() { - // `η = 0.1` stays within the debug assertion's slack range. - const JOHNSON_SLACK: f64 = 0.1; + // `η = √ρ/20 ≈ 0.025` at `log_inv_rate = 2`. Use the canonical slack + // so the debug-assert in the formula is satisfied. + let canonical_slack = 2_f64.powf(-MCA_LOG_INV_RATE).sqrt() / 20.0; - let got = johnson(JOHNSON_SLACK).eps_mca_log2( + let got = johnson(canonical_slack).eps_mca_log2( MCA_LOG_INV_RATE, MCA_MESSAGE_LENGTH, MCA_FIELD_BITS, ); - let expected = - 7.0 * LOG2_10 + 3.5 * MCA_LOG_INV_RATE + 2.0 * (MCA_MESSAGE_LENGTH as f64).log2() - - MCA_FIELD_BITS; + let bcss25_const = (2.0 * 10.5_f64.powi(5) / 3.0).log2(); + let expected = bcss25_const + (MCA_MESSAGE_LENGTH as f64).log2() + 2.5 * MCA_LOG_INV_RATE + - MCA_FIELD_BITS; assert_close(got, expected); } - /// `from_policy(Unique, _)` ignores rate; `from_policy(Johnson, rate)` - /// produces the same materialization as `johnson_canonical(rate)`. + /// MCA error, Capacity (STIR Conj 5.6 at canonical η): + /// `log k + 3·log_inv_rate + log₂10 + 1 − field_bits`. + #[test] + fn eps_mca_log2_capacity_formula() { + let canonical_slack = 2_f64.powf(-MCA_LOG_INV_RATE) / 20.0; + + let got = capacity(canonical_slack).eps_mca_log2( + MCA_LOG_INV_RATE, + MCA_MESSAGE_LENGTH, + MCA_FIELD_BITS, + ); + let expected = (MCA_MESSAGE_LENGTH as f64).log2() + 3.0 * MCA_LOG_INV_RATE + LOG2_10 + 1.0 + - MCA_FIELD_BITS; + assert_close(got, expected); + } + + /// `from_policy` dispatches to the canonical constructor for each regime. #[test] fn from_policy_matches_canonical() { assert_eq!( @@ -248,5 +345,9 @@ mod tests { DecodingRegimeParams::from_policy(DecodingRegime::Johnson, 0.25), DecodingRegimeParams::johnson_canonical(0.25), ); + assert_eq!( + DecodingRegimeParams::from_policy(DecodingRegime::Capacity, 0.25), + DecodingRegimeParams::capacity_canonical(0.25), + ); } } diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 67fc8f51..e4fd1d37 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -187,21 +187,22 @@ impl Deref for ZkSpec<'_> { /// Reed–Solomon decoding regime selection. /// -/// Picks the proximity radius `δ` and slack policy used by the IRS and -/// downstream sub-protocols. `Johnson` uses the codebase's slack policy -/// `η = √ρ / 20`; the list-decoding ball can hold `~10/ρ` codewords. -/// `Unique` operates strictly inside the unique-decoding radius `(1 − ρ)/2`; -/// the ball holds at most one. +/// - `Unique`: `δ < (1 − ρ)/2`, list size 1, no conjectures. +/// - `Johnson`: `δ < 1 − √ρ − η`, canonical `η = √ρ/20`. Proximity-gap error +/// per the BCSS25 improvement to BCIKS '20. +/// - `Capacity`: `δ < 1 − ρ − η`, canonical `η = ρ/20`. Conjectured list size +/// `d/(ρ·η)` and proximity-gap error per STIR Conjecture 5.6. /// /// WHIR's rate stepping (each round bumps `log_inv_rate` by /// `folding_factor − 1`) pushes ρ → 1, shrinking the unique-decoding /// radius. At high security targets or deep folding, `Unique` may exceed /// the grind cap on per-round PoW and [`super::derive::ProtocolConfig::derive`] -/// will return `PowUngrindable`. Pick `Johnson` for those cases. +/// will return `PowUngrindable` — pick `Johnson` or `Capacity` for those. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum DecodingRegime { Unique, Johnson, + Capacity, } impl Display for DecodingRegime { @@ -209,6 +210,7 @@ impl Display for DecodingRegime { match self { Self::Unique => f.write_str("Unique"), Self::Johnson => f.write_str("Johnson"), + Self::Capacity => f.write_str("Capacity"), } } } @@ -220,8 +222,9 @@ impl FromStr for DecodingRegime { match s { "Unique" => Ok(Self::Unique), "Johnson" => Ok(Self::Johnson), + "Capacity" => Ok(Self::Capacity), _ => Err(format!( - "invalid decoding regime: {s}, options are: Unique, Johnson" + "invalid decoding regime: {s}, options are: Unique, Johnson, Capacity" )), } } @@ -233,7 +236,11 @@ mod decoding_regime_tests { #[test] fn from_str_round_trips_display() { - for r in [DecodingRegime::Unique, DecodingRegime::Johnson] { + for r in [ + DecodingRegime::Unique, + DecodingRegime::Johnson, + DecodingRegime::Capacity, + ] { assert_eq!(r.to_string().parse::().unwrap(), r); } } @@ -242,7 +249,7 @@ mod decoding_regime_tests { fn from_str_rejects_unknown() { assert!("johnson".parse::().is_err()); // case-sensitive assert!("".parse::().is_err()); - assert!("Capacity".parse::().is_err()); + assert!("capacity".parse::().is_err()); } } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 81b5b821..3e418957 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -104,9 +104,9 @@ mod tests { irs_commit as irs_solver, spec::{ListSize, MaskCodeMessageLen, Mode, OodSampleBudget}, test_utils::{ - arb_round_ctx, arb_standard_johnson_spec, arb_zk_spec, assert_close, - assert_pow_closes_gap, build_minimal_mask_oracle, deterministic_spec, TestEmbedding, - TestField, TestNonIdentityEmbedding, EPS, TEST_TARGET_RANGE, + arb_round_ctx, arb_standard_spec, arb_zk_spec, assert_close, assert_pow_closes_gap, + build_minimal_mask_oracle, deterministic_spec, TestEmbedding, TestField, + TestNonIdentityEmbedding, EPS, TEST_TARGET_RANGE, }, }; @@ -220,7 +220,7 @@ mod tests { proptest! { #[test] fn standard_mode_propagates( - spec in arb_standard_johnson_spec(TEST_TARGET_RANGE), + spec in arb_standard_spec(TEST_TARGET_RANGE), ctx in arb_round_ctx(), ) { let source_irs = build_source_irs(&spec, &ctx); @@ -232,7 +232,7 @@ mod tests { #[test] fn num_rounds_matches_folding_factor( spec in prop_oneof![ - arb_standard_johnson_spec(TEST_TARGET_RANGE), + arb_standard_spec(TEST_TARGET_RANGE), arb_zk_spec(TEST_TARGET_RANGE), ], ctx in arb_round_ctx(), @@ -264,7 +264,7 @@ mod tests { #[test] fn round_pow_closes_gap_to_target( spec in prop_oneof![ - arb_standard_johnson_spec(TEST_TARGET_RANGE), + arb_standard_spec(TEST_TARGET_RANGE), arb_zk_spec(TEST_TARGET_RANGE), ], ctx in arb_round_ctx(), diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 55f1b944..cebbe080 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -17,7 +17,7 @@ use crate::{ derive::solve_t_ood, irs_commit as irs_solver, protocol_config::MaskOracleInfo, - regime::johnson_list_size, + regime::list_size_estimate, spec::{ DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, RoundContext, SecuritySpec, ZkSpec, @@ -63,13 +63,23 @@ pub fn deterministic_spec(mode: Mode) -> SecuritySpec { } } +/// Both decoding regimes, equally weighted. Used by `arb_spec` so proptests +/// sweep all three regimes. +fn arb_decoding_regime() -> impl Strategy { + prop_oneof![ + Just(DecodingRegime::Johnson), + Just(DecodingRegime::Unique), + Just(DecodingRegime::Capacity), + ] +} + pub fn arb_spec( mode: Mode, target_range: RangeInclusive, ) -> impl Strategy { - target_range.prop_map(move |target| SecuritySpec { + (target_range, arb_decoding_regime()).prop_map(move |(target, decoding_regime)| SecuritySpec { mode, - decoding_regime: DecodingRegime::Johnson, + decoding_regime, target_security_bits: target, pow_budget: PowBudget::per_slot(FIXTURE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, @@ -80,9 +90,7 @@ pub fn arb_zk_spec(target_range: RangeInclusive) -> impl Strategy, -) -> impl Strategy { +pub fn arb_standard_spec(target_range: RangeInclusive) -> impl Strategy { arb_spec(Mode::Standard, target_range) } @@ -153,7 +161,7 @@ pub fn build_test_c_zk( /// Builds a self-consistent `(source, target, t_ood)` triplet matching the /// per-round shape that `code_switch::solve` expects. /// -/// `t_ood` is solved against the rate-only `johnson_list_size(target_log_inv_rate)`, +/// `t_ood` is solved against the rate-only `list_size_estimate(...)`, /// mirroring `derive::solve_t_ood`. Using `target.list_size()` here instead /// would couple `t_ood` to the target's effective rate (which itself depends /// on `t_ood` via the mask), producing a non-monotone oscillation once the @@ -163,7 +171,7 @@ pub fn build_round_io( log_inv_rate: u32, folding_factor: u32, num_vars: u32, - c_zk_list_size: Option, + c_zk_log_inv_rate: Option, ) -> (IrsConfig, IrsConfig>, usize) { let source_ctx = RoundContext { vector_size: 1usize << num_vars, @@ -171,9 +179,16 @@ pub fn build_round_io( folding_factor, }; let target_log_inv_rate = log_inv_rate + folding_factor - 1; - let target_list_size = johnson_list_size(f64::from(target_log_inv_rate)); - let (source, t_ood) = solve_t_ood::(spec, &source_ctx, target_list_size, c_zk_list_size, 0) - .expect("solve_t_ood diverged in test fixture"); + let target_log_degree = f64::from(num_vars - folding_factor); + let target_list_size = list_size_estimate( + spec.decoding_regime, + target_log_degree, + f64::from(target_log_inv_rate), + ); + let c_zk_log_inv_rate = c_zk_log_inv_rate.map(f64::from); + let (source, t_ood) = + solve_t_ood::(spec, &source_ctx, target_list_size, c_zk_log_inv_rate, 0) + .expect("solve_t_ood diverged in test fixture"); let target_ctx = RoundContext { vector_size: source.message_length(), From 4c8515b220c60fbaed4c053343038cfd21f68734 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 23 May 2026 17:18:12 +0530 Subject: [PATCH 23/31] feat : added capacity bound and ood point calc refactor --- src/protocols/params/code_switch.rs | 72 +++++++---------- src/protocols/params/derive.rs | 121 ++++++++++++---------------- src/protocols/params/regime.rs | 49 +++++++++++ 3 files changed, 129 insertions(+), 113 deletions(-) diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index f49607bc..6886426b 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -120,7 +120,7 @@ mod tests { use super::*; use crate::protocols::params::{ - derive::{compute_l_zk, compute_t_ood}, + derive::{compute_l_zk, solve_t_ood}, irs_commit as irs_solver, regime::list_size_estimate, spec::{ @@ -309,17 +309,6 @@ mod tests { LogInvRate::new(log_inv_rate), 2, ); - // Same rate-only list estimate as the planner — `target.list_size()` - // would read the effective rate after `next_order` rounding and - // spuriously shift t_ood, masking the c_zk fixed-point under test. - let target_log_inv_rate = f64::from(log_inv_rate + folding_factor - 1); - let target_log_degree = f64::from(num_vars - folding_factor); - let target_list_size = list_size_estimate( - spec.decoding_regime, target_log_degree, target_log_inv_rate, - ); - let recomputed_t_ood = - compute_t_ood(&spec, &source, target_list_size, Some(c_zk.list_size()), t_ood); - prop_assert_eq!(t_ood, recomputed_t_ood, "solve_t_ood ⇒ converged C_zk fixed-point"); let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, @@ -367,63 +356,56 @@ mod tests { fn solve_works_with_basefield_embedding_standard() { let spec: SecuritySpec = deterministic_spec(Mode::Standard); let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); - - let source = irs_solver::solve::( - &spec, - &source_ctx, - OodSampleBudget::ZERO, + let target_log_degree = + f64::from((source_ctx.vector_size / (1 << source_ctx.folding_factor)).trailing_zeros()); + let target_list_size = list_size_estimate( + spec.decoding_regime, + target_log_degree, + f64::from(target_ctx.log_inv_rate), ); + let (source, t_ood) = + solve_t_ood::(&spec, &source_ctx, target_list_size, None, 0) + .unwrap(); // Standard target: codeword_length is t_ood-independent (mask = 0). let target = irs_solver::solve::>( &spec, &target_ctx, OodSampleBudget::ZERO, ); - let t_ood = compute_t_ood(&spec, &source, target.list_size(), None, 0); let config = solve_standard(&spec, source, target, t_ood, 0).unwrap(); assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); } - /// Placeholder mask-oracle list size for the smoke test. Pow2 keeps - /// `log2` exact and matches `analytic_error_zk_formula`'s fixture. + /// Placeholder mask-oracle list size for the smoke test — pow2 so `log2` + /// is exact and matches `analytic_error_zk_formula`'s fixture. const SMOKE_C_ZK_LIST_SIZE: f64 = 4.0; - /// Cap on the smoke-test `t_ood ↔ (source, target)` fixed-point. Matches the - /// loop bound used in `build_round_io`; in practice converges in 1–3 iters. - const SMOKE_FIXED_POINT_MAX_ITER: usize = 8; /// Smoke test: `M::Source ≠ M::Target`, ZK mode. #[test] fn solve_works_with_basefield_embedding_zk() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); - - let mut t_ood = 0; - let mut source = irs_solver::solve::( + let target_log_degree = + f64::from((source_ctx.vector_size / (1 << source_ctx.folding_factor)).trailing_zeros()); + let target_list_size = list_size_estimate( + spec.decoding_regime, + target_log_degree, + f64::from(target_ctx.log_inv_rate), + ); + let (source, t_ood) = solve_t_ood::( &spec, &source_ctx, - OodSampleBudget::ZERO, - ); - let mut target = irs_solver::solve::>( + target_list_size, + Some(f64::from(source_ctx.log_inv_rate)), + 0, + ) + .unwrap(); + let target = irs_solver::solve::>( &spec, &target_ctx, - OodSampleBudget::ZERO, + OodSampleBudget::new(t_ood), ); - for _ in 0..SMOKE_FIXED_POINT_MAX_ITER { - let new_t_ood = compute_t_ood( - &spec, - &source, - target.list_size(), - Some(SMOKE_C_ZK_LIST_SIZE), - t_ood, - ); - if new_t_ood == t_ood { - break; - } - t_ood = new_t_ood; - source = irs_solver::solve(&spec, &source_ctx, OodSampleBudget::new(t_ood)); - target = irs_solver::solve(&spec, &target_ctx, OodSampleBudget::new(t_ood)); - } let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index c16c982d..66c10187 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -10,16 +10,18 @@ use crate::{ fields::FieldWithSize, }, protocols::{ - irs_commit::{self, Config as IrsConfig}, + irs_commit::Config as IrsConfig, params::{ - basecase as basecase_solver, code_switch as code_switch_solver, + basecase as basecase_solver, + bounds::usize_to_f64, + code_switch as code_switch_solver, error::{DeriveError, FixedPointLoop, Pow}, irs_commit as irs_solver, mask_proximity as mask_proximity_solver, protocol_config::{MaskOracleConfig, ProtocolConfig, RoundConfig, RoundMode}, regime::list_size_estimate, spec::{ - LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, - TuningSpec, ZkSpec, + DecodingRegime, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + RoundContext, SecuritySpec, TuningSpec, ZkSpec, }, sumcheck as sumcheck_solver, }, @@ -264,51 +266,20 @@ pub(super) const fn compute_l_zk( MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()) } -/// One application of the Lemma 9.9 OOD step. ZK: `degree = ℓ + ℓ_zk(t_ood)` -/// with `ℓ_zk = next_pow2(source.mask_length() + t_ood)`. Standard: -/// `degree = ℓ`. +/// Per-round `(source, t_ood)` from a linear search over `t_ood`. /// -/// Floored at `1`: Lemma 9.9's OOD term vanishes under `Unique` decoding -/// (`|Λ| = 1`), but Construction 9.7's code-switch still needs at least one -/// OOD point to bind the witness polynomial. -pub(super) fn compute_t_ood( - spec: &SecuritySpec, - source: &IrsConfig, - target_list_size: f64, - c_zk_list_size: Option, - t_ood: usize, -) -> usize { - let security_target = f64::from(spec.protocol_security_target_bits()); - let field_bits = M::Target::field_size_bits(); - let combined_list_size = target_list_size * c_zk_list_size.unwrap_or(1.0); - let message_length = source.message_length(); - - let degree = if c_zk_list_size.is_some() { - let l_zk = (source.mask_length() + t_ood).next_power_of_two(); - message_length + l_zk - } else { - message_length - }; - - let soundness_t_ood = irs_commit::num_ood_samples( - spec.decoding_regime, - security_target, - field_bits, - combined_list_size, - degree, - ); - soundness_t_ood.max(1) -} - -/// Solves the per-round `t_ood` fixed-point and the source IRS together. +/// Mirrors Plonky3's `determine_ood_samples`: under `Unique`, OOD contributes +/// no soundness (`|Λ| = 1`), so the search is skipped and `t_ood = 1` is +/// returned for the Construction 9.7 binding floor. Under `Johnson`/`Capacity`, +/// linearly searches `t_ood = 1..=T_OOD_MAX_ITER` and returns the smallest +/// value where the OOD security bound `t · (|F| − log d) − 2·log|Λ_combined| +/// + 1` meets `protocol_security_target_bits` (STIR Lemma 4.5 / Plonky3 +/// `ood_error`). /// -/// Convergence: `Φ(t) = num_ood_samples(ℓ + next_pow2(in_domain + 2·t))` is -/// monotone non-decreasing on ℕ (`in_domain` depends only on the requested -/// rate; `next_pow2` and `num_ood_samples` are monotone; under `Capacity` the -/// `c_zk_list_size(t)` factor is monotone too) and bounded above, so Kleene -/// iteration from `t = 0` converges to the least fixed point in finitely many -/// steps. Standard mode (`c_zk_log_inv_rate = None`) has `Φ` constant in -/// `t`, so one application suffices. +/// The bound is monotone-increasing in `t` for `|F| ≫ log d` (always the case +/// here), so the first match is the minimum. Source is rebuilt per iteration +/// because `source.mask_length()` depends on `t_ood` in ZK (Lemma 9.5 ii); +/// the rebuild is cheap (struct fields only). pub(super) fn solve_t_ood( spec: &SecuritySpec, src_ctx: &RoundContext, @@ -316,30 +287,44 @@ pub(super) fn solve_t_ood( c_zk_log_inv_rate: Option, round_index: usize, ) -> Result<(IrsConfig, usize), DeriveError> { - let mut source: IrsConfig = irs_solver::solve(spec, src_ctx, OodSampleBudget::ZERO); + if matches!(spec.decoding_regime, DecodingRegime::Unique) { + let source = irs_solver::solve(spec, src_ctx, OodSampleBudget::new(1)); + return Ok((source, 1)); + } - let Some(c_zk_log_inv_rate) = c_zk_log_inv_rate else { - let t_ood = compute_t_ood(spec, &source, target_list_size, None, 0); - return Ok((source, t_ood)); - }; + let security_target = f64::from(spec.protocol_security_target_bits()); + let field_bits = M::Target::field_size_bits(); - let mut t_ood = 0; - for _ in 0..T_OOD_MAX_ITER { - // Under `Capacity`, c_zk's list size depends on its message length - // ℓ_zk(t), so recompute per iteration. Under `Johnson`/`Unique` the - // result is t-independent — the recomputation is a no-op. - let l_zk = (source.mask_length() + t_ood).next_power_of_two(); - let c_zk_list_size = list_size_estimate( - spec.decoding_regime, - (l_zk as f64).log2(), - c_zk_log_inv_rate, + for t_ood in 1..=T_OOD_MAX_ITER { + let source: IrsConfig = irs_solver::solve(spec, src_ctx, OodSampleBudget::new(t_ood)); + + // `degree` and `log_combined_list` depend on `t_ood` in ZK via ℓ_zk; + // Standard collapses to `degree = ℓ` and `combined = target.list_size`. + let (log_degree, log_combined_list) = c_zk_log_inv_rate.map_or_else( + || { + ( + usize_to_f64(source.message_length()).log2(), + target_list_size.log2(), + ) + }, + |c_zk_rate| { + let l_zk = (source.mask_length() + t_ood).next_power_of_two(); + let c_zk_list = + list_size_estimate(spec.decoding_regime, usize_to_f64(l_zk).log2(), c_zk_rate); + ( + usize_to_f64(source.message_length() + l_zk).log2(), + (target_list_size * c_zk_list).log2(), + ) + }, ); - let new_t_ood = compute_t_ood(spec, &source, target_list_size, Some(c_zk_list_size), t_ood); - if new_t_ood == t_ood { + + // OOD security at MCA arity 2 (STIR Lemma 4.5 / Plonky3 `ood_error`): + // bits = t · (|F| − log d) − 2·log|Λ_combined| + 1 + let ood = usize_to_f64(t_ood); + let bits = ood * (field_bits - log_degree) - 2.0 * log_combined_list + 1.0; + if bits >= security_target { return Ok((source, t_ood)); } - t_ood = new_t_ood; - source = irs_solver::solve(spec, src_ctx, OodSampleBudget::new(t_ood)); } Err(DeriveError::FixedPointDidNotConverge { round_index, @@ -529,9 +514,9 @@ mod tests { )); } - /// Lemma 9.9 fixed-point: every ZK round needs at least one OOD challenge. + /// Construction 9.7: every ZK round needs at least one OOD challenge. #[test] - fn compute_t_ood_nonzero_in_zk() { + fn t_ood_nonzero_in_zk() { let spec = test_spec(Mode::ZeroKnowledge); let plan = ProtocolConfig::::derive( spec, diff --git a/src/protocols/params/regime.rs b/src/protocols/params/regime.rs index e3c6aefc..09b9d3d9 100644 --- a/src/protocols/params/regime.rs +++ b/src/protocols/params/regime.rs @@ -96,6 +96,28 @@ impl DecodingRegimeParams { one_minus_delta.log2() } + /// Bits of security delivered by `ood_samples` OOD challenges on a code + /// of given `log_degree` and `log_inv_rate` at MCA arity 2. + /// + /// Mirrors Plonky3's `ood_error` / STIR Lemma 4.5: the error is + /// `(L choose 2) · ((d − 1)/|F|)^{ood_samples}`, giving security + /// `ood · (|F| − log d) − 2·log|Λ| + 1` bits. Returns `0` under + /// `Unique` — OOD contributes no soundness when `|Λ| = 1`. + pub fn ood_security_bits( + self, + log_degree: f64, + log_inv_rate: f64, + field_bits: f64, + ood_samples: usize, + ) -> f64 { + if self.is_unique() { + return 0.0; + } + let log_list = self.list_size_log2(log_degree, log_inv_rate); + let ood = usize_to_f64(ood_samples); + ood * (field_bits - log_degree) - 2.0 * log_list + 1.0 + } + /// `log₂ ε_mca(C, δ)` for the per-step proximity-gaps error (bare, no /// arity factor — callers apply their own). /// @@ -282,6 +304,33 @@ mod tests { assert_close(got, expected); } + /// `ood_security_bits` mirrors Plonky3 `ood_error`: + /// `t · (|F| − log d) − 2·log|Λ| + 1`. Returns 0 under Unique. + #[test] + fn ood_security_bits_formula() { + const LOG_DEGREE: f64 = 6.0; + const LOG_INV_RATE: f64 = 2.0; + const FIELD_BITS: f64 = 64.0; + const OOD: usize = 3; + + // Unique → 0 (no soundness from OOD). + let unique = DecodingRegimeParams::Unique.ood_security_bits( + LOG_DEGREE, + LOG_INV_RATE, + FIELD_BITS, + OOD, + ); + assert_close(unique, 0.0); + + // Johnson at canonical slack: list_size matches the formula in + // `list_size_log2_johnson_formula`. + let slack = 2_f64.powf(-LOG_INV_RATE).sqrt() / 20.0; + let got = johnson(slack).ood_security_bits(LOG_DEGREE, LOG_INV_RATE, FIELD_BITS, OOD); + let log_list = johnson(slack).list_size_log2(LOG_DEGREE, LOG_INV_RATE); + let expected = (OOD as f64) * (FIELD_BITS - LOG_DEGREE) - 2.0 * log_list + 1.0; + assert_close(got, expected); + } + /// MCA fixture — `message_length = 16 = 2^4` and `log_inv_rate = 2` give /// exact `log2(k) = 4`. `field_bits = 64.0` for Field64. const MCA_MESSAGE_LENGTH: usize = 16; From bd7ea191831bb93338e132fc27225cb534ebafe9 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 23 May 2026 21:26:24 +0530 Subject: [PATCH 24/31] feat : t_ood = 1 for unique decoding clarification, tradeoff slight proof size increase but code simple --- src/protocols/params/code_switch.rs | 35 ++++- src/protocols/params/derive.rs | 188 ++++++++++++++++-------- src/protocols/params/protocol_config.rs | 48 +++--- src/protocols/params/regime.rs | 14 +- src/protocols/params/test_utils.rs | 10 +- 5 files changed, 200 insertions(+), 95 deletions(-) diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 6886426b..4831b224 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -109,7 +109,7 @@ pub fn analytic_error_bits( } /// Number of `(r ‖ s)` mask polynomials code-switch contributes to C_zk per -/// round. Mirrors [`super::sumcheck::masks_required`]. +/// round. pub const fn masks_required() -> usize { 1 } @@ -331,8 +331,8 @@ mod tests { } } - /// Shared shape for the `M::Source ≠ M::Target` smoke tests. - /// `target_ctx` mirrors the planner's per-round chaining. + /// Shared shape for the `M::Source ≠ M::Target` smoke tests. `target_ctx` + /// uses the same per-round chaining the planner does. fn non_identity_smoke_ctxs() -> (RoundContext, RoundContext) { const SOURCE_VECTOR_SIZE: usize = 64; const SOURCE_LOG_INV_RATE: u32 = 1; @@ -351,6 +351,33 @@ mod tests { (source_ctx, target_ctx) } + /// `solve_zk` asserts `ℓ_zk ≥ source.mask_length() + t_ood` (Theorem 9.6 + /// witness sizing). Build a self-consistent `(source, target, t_ood)` + /// and pass a deliberately-too-small `l_zk = 1` to trip the precondition. + #[test] + #[should_panic(expected = "violates Theorem 9.6")] + fn solve_zk_rejects_l_zk_below_r_plus_t_ood() { + const TOO_SMALL_L_ZK: usize = 1; + + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let (source, target, t_ood) = build_round_io::( + &spec, + FORMULA_LOG_INV_RATE, + FORMULA_FOLDING_FACTOR, + FORMULA_NUM_VARS, + Some(FORMULA_LOG_INV_RATE), + ); + // `source.mask_length() + t_ood ≥ 1 + 1 > TOO_SMALL_L_ZK` in ZK, + // so the assert in solve_zk fires. + assert!(source.mask_length() + t_ood > TOO_SMALL_L_ZK); + + let mask_oracle = MaskOracleInfo { + c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), + l_zk: MaskCodeMessageLen::new(TOO_SMALL_L_ZK), + }; + let _ = solve_zk(&spec, source, target, t_ood, mask_oracle, 0); + } + /// Smoke test: `M::Source ≠ M::Target`, Standard mode. #[test] fn solve_works_with_basefield_embedding_standard() { @@ -378,7 +405,7 @@ mod tests { } /// Placeholder mask-oracle list size for the smoke test — pow2 so `log2` - /// is exact and matches `analytic_error_zk_formula`'s fixture. + /// is exact. const SMOKE_C_ZK_LIST_SIZE: f64 = 4.0; /// Smoke test: `M::Source ≠ M::Target`, ZK mode. diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 66c10187..ac13ea16 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -268,18 +268,19 @@ pub(super) const fn compute_l_zk( /// Per-round `(source, t_ood)` from a linear search over `t_ood`. /// -/// Mirrors Plonky3's `determine_ood_samples`: under `Unique`, OOD contributes -/// no soundness (`|Λ| = 1`), so the search is skipped and `t_ood = 1` is -/// returned for the Construction 9.7 binding floor. Under `Johnson`/`Capacity`, -/// linearly searches `t_ood = 1..=T_OOD_MAX_ITER` and returns the smallest -/// value where the OOD security bound `t · (|F| − log d) − 2·log|Λ_combined| -/// + 1` meets `protocol_security_target_bits` (STIR Lemma 4.5 / Plonky3 -/// `ood_error`). +/// Under `Unique`, OOD contributes no soundness (`|Λ| = 1` ⇒ `(L choose 2) = 0`). +/// The short-circuit returns `t_ood = 1` — the Construction 9.7 protocol-layer +/// minimum, since [`crate::protocols::code_switch::Config::new`] asserts +/// `out_domain_samples ≥ 1` (Steps 2-3 always execute). /// -/// The bound is monotone-increasing in `t` for `|F| ≫ log d` (always the case -/// here), so the first match is the minimum. Source is rebuilt per iteration -/// because `source.mask_length()` depends on `t_ood` in ZK (Lemma 9.5 ii); -/// the rebuild is cheap (struct fields only). +/// Under `Johnson`, searches `t_ood = 1..=T_OOD_MAX_ITER` for the smallest +/// value where the OOD security bound (STIR Lemma 4.5) +/// `t · (|F| − log d) − 2·log|Λ_combined| + 1` meets +/// `protocol_security_target_bits`. The bound is monotone-increasing in `t` +/// for `|F| ≫ log d` (always the case here), so the first match is the +/// minimum. Source is rebuilt per iteration because `source.mask_length()` +/// depends on `t_ood` in ZK (Lemma 9.5 ii); the rebuild is cheap (struct +/// fields only). pub(super) fn solve_t_ood( spec: &SecuritySpec, src_ctx: &RoundContext, @@ -288,6 +289,12 @@ pub(super) fn solve_t_ood( round_index: usize, ) -> Result<(IrsConfig, usize), DeriveError> { if matches!(spec.decoding_regime, DecodingRegime::Unique) { + // The short-circuit is not an optimization — the linear-search formula + // uses `log(L·(L−1)/2) ≈ 2·log L − 1`, which is exact for `L ≥ 2` but + // *underestimates* security by `+∞` when `L = 1` (the true `(L choose 2)` + // is 0, not L²/2). Letting the loop run would falsely demand `t_ood > 1` + // at high security targets even though OOD provides infinite soundness + // headroom under Unique. Pin `t_ood = 1` directly. let source = irs_solver::solve(spec, src_ctx, OodSampleBudget::new(1)); return Ok((source, 1)); } @@ -318,8 +325,10 @@ pub(super) fn solve_t_ood( }, ); - // OOD security at MCA arity 2 (STIR Lemma 4.5 / Plonky3 `ood_error`): - // bits = t · (|F| − log d) − 2·log|Λ_combined| + 1 + // STIR Lemma 4.5 (single-MCA OOD): + // bits = t · (|F| − log d) − log(L · (L − 1) / 2) + // Approximate `log(L · (L − 1) / 2) ≈ 2·log L − 1` (exact-ish for L ≥ 2; + // the L = 1 case is handled by the Unique short-circuit above). let ood = usize_to_f64(t_ood); let bits = ood * (field_bits - log_degree) - 2.0 * log_combined_list + 1.0; if bits >= security_target { @@ -514,10 +523,15 @@ mod tests { )); } - /// Construction 9.7: every ZK round needs at least one OOD challenge. + /// Johnson + ZK: each round runs a non-trivial OOD challenge to amplify + /// the list-decoding soundness gap (Lemma 9.9). `solve_t_ood`'s linear + /// search lands at the smallest `t_ood` clearing the security target. #[test] - fn t_ood_nonzero_in_zk() { - let spec = test_spec(Mode::ZeroKnowledge); + fn t_ood_nonzero_in_johnson_zk() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Johnson, + ..test_spec(Mode::ZeroKnowledge) + }; let plan = ProtocolConfig::::derive( spec, tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), @@ -531,6 +545,55 @@ mod tests { } } + /// Unique + ZK: OOD contributes no soundness (`|Λ| = 1`), but + /// `protocols::code_switch::Config::new` requires `out_domain_samples ≥ 1` + /// to run Construction 9.7 Steps 2-3. `solve_t_ood` pins `t_ood = 1` + /// exactly — sharper than the Johnson-side `≥ 1` invariant. + #[test] + fn t_ood_pinned_to_one_in_unique_zk() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Unique, + ..test_spec(Mode::ZeroKnowledge) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + for r in plan.rounds() { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + panic!("expected ZK round") + }; + assert_eq!(t_ood.get(), 1); + } + } + + /// Under Unique decoding, the C_zk mask oracle still carries the full + /// `2 · (k + 1)` columns — `k` sumcheck masks (Lemma 6.4) plus the + /// `(r ‖ s)` code-switch mask (Construction 9.7). With `t_ood = 1`, the + /// `s`-tail has length `ℓ_zk − r ≥ 1` and supports the + /// Vandermonde-surjectivity ZK argument (bounds doc §5.3 / Bound 3). + /// Pins the shape so an accidental "drop code-switch mask under Unique" + /// optimization can't slip in unnoticed. + #[test] + fn c_zk_keeps_code_switch_mask_under_unique() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Unique, + ..test_spec(Mode::ZeroKnowledge) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + for r in plan.rounds() { + let mask_oracle = r.mask_oracle().expect("ZK round has a mask oracle"); + let k = r.code_switch().source.interleaving_depth.trailing_zeros() as usize; + let expected_num_masks = k + 1; // k sumcheck + 1 code-switch + assert_eq!(mask_oracle.c_zk().num_vectors, 2 * expected_num_masks); + } + } + #[test] fn analytic_bits_finite_and_positive_standard() { let spec = test_spec(Mode::Standard); @@ -646,15 +709,8 @@ mod tests { /// Derived plans must satisfy their own `pow_budget`. #[test] fn check_pow_bits_passes_on_derived_plan() { - let spec = SecuritySpec { - mode: Mode::ZeroKnowledge, - decoding_regime: DecodingRegime::Johnson, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, - pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, - }; let plan = ProtocolConfig::::derive( - spec, + test_spec(Mode::ZeroKnowledge), tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); @@ -672,11 +728,8 @@ mod tests { use crate::{bits::Bits, protocols::proof_of_work::Config as PowConfig}; const MODERATE_POW_BUDGET_BITS: u32 = 30; let spec = SecuritySpec { - mode: Mode::ZeroKnowledge, - decoding_regime: DecodingRegime::Johnson, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, pow_budget: PowBudget::per_slot(MODERATE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, + ..test_spec(Mode::ZeroKnowledge) }; let mut plan = ProtocolConfig::::derive( spec, @@ -689,6 +742,45 @@ mod tests { assert!(!plan.check_pow_bits()); } + /// `validate_round_chaining` trips when round `i`'s target `vector_size` + /// no longer matches round `i+1`'s source. Covers the adjacent-rounds + /// `windows(2)` branch — distinct from the basecase branch, which is + /// covered by `validate_round_chaining_detects_basecase_mismatch`. + #[test] + fn validate_round_chaining_detects_adjacent_round_mismatch() { + let spec = test_spec(Mode::ZeroKnowledge); + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + let n = plan.rounds().len(); + assert!(n >= 2, "need ≥ 2 rounds to break a mid-chain link"); + assert!(plan.check_all_invariants(), "fresh plan must validate"); + + // Round 0's natural target.vector_size is some power of 2; bumping + // it to a value the next round's source can't match (the source + // still carries the originally-derived size) breaks the chain. + let bad_size = plan.rounds()[0].code_switch().target.vector_size + 1; + plan.corrupt_round_target_vector_size_for_test(0, bad_size); + + let err = plan + .validate_round_chaining() + .expect_err("adjacent-round mismatch must trip the chain check"); + assert!( + matches!( + err, + DeriveError::RoundChainBroken { + from: crate::protocols::params::error::ChainSource::Round(0), + to: crate::protocols::params::error::ChainTarget::NextRound(1), + .. + } + ), + "got {err:?}", + ); + assert!(!plan.check_all_invariants()); + } + /// `validate_round_chaining` trips when the basecase no longer chains /// to the (new) last round after the tail is dropped. Multi-round plan /// is required so dropping the last leaves at least one round behind. @@ -728,11 +820,8 @@ mod tests { fn derive_reports_pow_ungrindable() { const UNREACHABLE_TARGET_BITS: u32 = 200; let spec = SecuritySpec { - mode: Mode::Standard, - decoding_regime: DecodingRegime::Johnson, target_security_bits: UNREACHABLE_TARGET_BITS, - pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, + ..test_spec(Mode::Standard) }; let err = ProtocolConfig::::derive( spec, @@ -752,11 +841,8 @@ mod tests { fn derive_reports_pow_budget_exceeded() { const TIGHT_MAX_POW: u32 = 5; let spec = SecuritySpec { - mode: Mode::ZeroKnowledge, - decoding_regime: DecodingRegime::Johnson, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, pow_budget: PowBudget::per_slot(TIGHT_MAX_POW), - hash_id: hash::BLAKE3, + ..test_spec(Mode::ZeroKnowledge) }; let err = ProtocolConfig::::derive( spec, @@ -775,11 +861,8 @@ mod tests { #[test] fn derive_threads_unique_decoding_standard() { let spec = SecuritySpec { - mode: Mode::Standard, decoding_regime: DecodingRegime::Unique, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, - pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, + ..test_spec(Mode::Standard) }; let plan = ProtocolConfig::::derive( spec, @@ -794,11 +877,8 @@ mod tests { #[test] fn derive_threads_unique_decoding_zk() { let spec = SecuritySpec { - mode: Mode::ZeroKnowledge, decoding_regime: DecodingRegime::Unique, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, - pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, + ..test_spec(Mode::ZeroKnowledge) }; let plan = ProtocolConfig::::derive( spec, @@ -815,11 +895,8 @@ mod tests { #[test] fn derive_multi_round_unique_decoding_succeeds() { let spec = SecuritySpec { - mode: Mode::Standard, decoding_regime: DecodingRegime::Unique, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, - pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, + ..test_spec(Mode::Standard) }; let plan = ProtocolConfig::::derive( spec, @@ -841,11 +918,8 @@ mod tests { #[test] fn derive_multi_round_unique_decoding_zk_succeeds() { let spec = SecuritySpec { - mode: Mode::ZeroKnowledge, decoding_regime: DecodingRegime::Unique, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, - pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, + ..test_spec(Mode::ZeroKnowledge) }; let plan = ProtocolConfig::::derive( spec, @@ -867,11 +941,8 @@ mod tests { #[test] fn derive_multi_round_capacity_decoding_succeeds() { let spec = SecuritySpec { - mode: Mode::Standard, decoding_regime: DecodingRegime::Capacity, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, - pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, + ..test_spec(Mode::Standard) }; let plan = ProtocolConfig::::derive( spec, @@ -889,11 +960,8 @@ mod tests { #[test] fn derive_multi_round_capacity_decoding_zk_succeeds() { let spec = SecuritySpec { - mode: Mode::ZeroKnowledge, decoding_regime: DecodingRegime::Capacity, - target_security_bits: PLAN_FIXTURE_TARGET_BITS, - pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), - hash_id: hash::BLAKE3, + ..test_spec(Mode::ZeroKnowledge) }; let plan = ProtocolConfig::::derive( spec, diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 7022e311..6dfa9273 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -74,23 +74,6 @@ impl ProtocolConfig { &self.basecase } - /// `#[cfg(test)]` escape hatch: lets the negative test in - /// `derive::tests` inject an over-budget basecase PoW slot so that - /// `validate_pow_budget` can be exercised on a corrupted plan. - /// Not for production use — there is no equivalent on the public API. - #[cfg(test)] - pub(crate) const fn override_basecase_pow_for_test(&mut self, pow: PowConfig) { - self.basecase.pow = pow; - } - - /// `#[cfg(test)]` escape hatch: lets chain-broken tests in - /// `derive::tests` drop the tail of `rounds` so the basecase's chained - /// `vector_size` no longer matches the new last round. - #[cfg(test)] - pub(crate) fn truncate_rounds_for_test(&mut self, len: usize) { - self.rounds.truncate(len); - } - /// `true` if every PoW slot's difficulty fits within `security.pow_budget`. /// Boolean form of [`Self::validate_pow_budget`]. pub fn check_pow_bits(&self) -> bool { @@ -230,6 +213,37 @@ impl ProtocolConfig { } } +/// Test-only mutators. Grouped here so the production `impl` block above +/// reads as the public API surface and these escape hatches aren't easily +/// mistaken for it. Each one supports a specific negative test in +/// `derive::tests`; there is no equivalent on the public API. +#[cfg(test)] +impl ProtocolConfig { + /// Inject an over-budget basecase PoW slot so `validate_pow_budget` can + /// be exercised on a corrupted plan. + pub(crate) const fn override_basecase_pow_for_test(&mut self, pow: PowConfig) { + self.basecase.pow = pow; + } + + /// Drop the tail of `rounds` so the basecase's chained `vector_size` no + /// longer matches the (new) last round — trips the basecase branch of + /// `validate_round_chaining`. + pub(crate) fn truncate_rounds_for_test(&mut self, len: usize) { + self.rounds.truncate(len); + } + + /// Overwrite a round's code-switch target `vector_size` so the next + /// round's source no longer chains — trips the adjacent `windows(2)` + /// branch of `validate_round_chaining`, which truncation cannot reach. + pub(crate) fn corrupt_round_target_vector_size_for_test( + &mut self, + round_idx: usize, + new_size: usize, + ) { + self.rounds[round_idx].code_switch.target.vector_size = new_size; + } +} + #[derive(Clone, Debug)] pub struct RoundConfig { round_index: usize, diff --git a/src/protocols/params/regime.rs b/src/protocols/params/regime.rs index 09b9d3d9..d1bc0a30 100644 --- a/src/protocols/params/regime.rs +++ b/src/protocols/params/regime.rs @@ -12,7 +12,6 @@ //! (`O(n/η^5)`, m=10 at canonical slack) over BCIKS '20. //! - Capacity bound follows STIR Conjecture 5.6: `(1 − ρ − η, d/(ρ·η))`-list //! decodability for RS codes. -//! - Aligned with the Plonky3 `SecurityAssumption` parametrization. use std::f64::consts::LOG2_10; @@ -99,10 +98,9 @@ impl DecodingRegimeParams { /// Bits of security delivered by `ood_samples` OOD challenges on a code /// of given `log_degree` and `log_inv_rate` at MCA arity 2. /// - /// Mirrors Plonky3's `ood_error` / STIR Lemma 4.5: the error is - /// `(L choose 2) · ((d − 1)/|F|)^{ood_samples}`, giving security - /// `ood · (|F| − log d) − 2·log|Λ| + 1` bits. Returns `0` under - /// `Unique` — OOD contributes no soundness when `|Λ| = 1`. + /// STIR Lemma 4.5: the error is `(L choose 2) · ((d − 1)/|F|)^{ood_samples}`, + /// giving security `ood · (|F| − log d) − 2·log|Λ| + 1` bits. Returns `0` + /// under `Unique` — OOD contributes no soundness when `|Λ| = 1`. pub fn ood_security_bits( self, log_degree: f64, @@ -304,8 +302,8 @@ mod tests { assert_close(got, expected); } - /// `ood_security_bits` mirrors Plonky3 `ood_error`: - /// `t · (|F| − log d) − 2·log|Λ| + 1`. Returns 0 under Unique. + /// `ood_security_bits = t · (|F| − log d) − 2·log|Λ| + 1`. Returns 0 + /// under Unique. #[test] fn ood_security_bits_formula() { const LOG_DEGREE: f64 = 6.0; @@ -322,8 +320,6 @@ mod tests { ); assert_close(unique, 0.0); - // Johnson at canonical slack: list_size matches the formula in - // `list_size_log2_johnson_formula`. let slack = 2_f64.powf(-LOG_INV_RATE).sqrt() / 20.0; let got = johnson(slack).ood_security_bits(LOG_DEGREE, LOG_INV_RATE, FIELD_BITS, OOD); let log_list = johnson(slack).list_size_log2(LOG_DEGREE, LOG_INV_RATE); diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index cebbe080..cf6d8ea0 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -161,11 +161,11 @@ pub fn build_test_c_zk( /// Builds a self-consistent `(source, target, t_ood)` triplet matching the /// per-round shape that `code_switch::solve` expects. /// -/// `t_ood` is solved against the rate-only `list_size_estimate(...)`, -/// mirroring `derive::solve_t_ood`. Using `target.list_size()` here instead -/// would couple `t_ood` to the target's effective rate (which itself depends -/// on `t_ood` via the mask), producing a non-monotone oscillation once the -/// mask is tight (Lemma 9.5 part ii) rather than pow2-padded. +/// `t_ood` is solved against the rate-only `list_size_estimate(...)` rather +/// than `target.list_size()`: the latter reads the target's effective rate +/// (which itself depends on `t_ood` via the mask), producing a non-monotone +/// oscillation once the mask is tight (Lemma 9.5 part ii) rather than +/// pow2-padded. pub fn build_round_io( spec: &SecuritySpec, log_inv_rate: u32, From 5c40c06ad81c478e24f7c020afc7c450d3ded263 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Mon, 25 May 2026 00:35:38 +0530 Subject: [PATCH 25/31] feat : updated branching structure --- src/protocols/code_switch.rs | 3 +- src/protocols/params/basecase.rs | 18 +- src/protocols/params/bounds.rs | 2 +- src/protocols/params/code_switch.rs | 153 ++++++---- src/protocols/params/derive.rs | 387 ++++++++++++++---------- src/protocols/params/error.rs | 28 +- src/protocols/params/irs_commit.rs | 27 +- src/protocols/params/mask_proximity.rs | 2 +- src/protocols/params/mod.rs | 18 +- src/protocols/params/protocol_config.rs | 12 +- src/protocols/params/regime.rs | 35 ++- src/protocols/params/spec.rs | 3 +- src/protocols/params/sumcheck.rs | 79 +++-- src/protocols/params/test_utils.rs | 38 ++- 14 files changed, 446 insertions(+), 359 deletions(-) diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index c30fe9a2..99921ef4 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -115,7 +115,8 @@ impl Config { "sampled randomness (s) length must cover all out-of-domain sample requests" ); // t' = target in-domain queries + OOD queries (Construction 9.7 step 4). - // Lemma 9.5 perfect-ZK: t' ≤ r' = target.mask_length. + // Definition 3.16: a t'-query ZK encoding requires r' ≥ t'; here + // r' = target.mask_length. assert!( target_config.mask_length() >= target_config.in_domain_samples + out_domain_samples, "target encoder violates t' ≤ r': queries must be covered by target mask" diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 3875de04..9b875079 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -10,9 +10,9 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ error::{grind_to_at, DeriveError, Pow}, - irs_commit as irs_solver, + irs_commit as irs_params, spec::{Mode as SpecMode, OodSampleBudget, RoundContext, SecuritySpec}, - sumcheck as sumcheck_solver, + sumcheck as sumcheck_params, }, proof_of_work::Config as PowConfig, sumcheck::{self, Config as SumcheckConfig}, @@ -33,11 +33,11 @@ pub fn solve( log_inv_rate, folding_factor: 0, }; - let commit = irs_solver::solve(spec, &ctx, OodSampleBudget::ZERO); + let commit = irs_params::solve(spec, &ctx, OodSampleBudget::ZERO); let sumcheck_pow = grind_to_at( spec, - sumcheck_solver::analytic_error_bits(&commit, None), + sumcheck_params::analytic_error_bits(&commit, None), Pow::BasecaseSumcheck, )?; let sumcheck = SumcheckConfig::new( @@ -79,7 +79,7 @@ impl BasecaseConfig { /// The γ-slot only contributes in ZK mode; Standard collapses to the /// sumcheck term. pub fn analytic_bits(&self) -> Bits { - let sumcheck_term = f64::from(sumcheck_solver::analytic_error_bits(&self.commit, None)); + let sumcheck_term = f64::from(sumcheck_params::analytic_error_bits(&self.commit, None)); let min_bits = match self.mode { basecase::BasecaseMode::Standard => sumcheck_term, basecase::BasecaseMode::ZeroKnowledge => { @@ -115,7 +115,7 @@ mod tests { #[test] fn analytic_error_formula() { use crate::protocols::params::{ - irs_commit as irs_solver, + irs_commit as irs_params, spec::{Mode, OodSampleBudget, RoundContext}, }; @@ -126,7 +126,7 @@ mod tests { folding_factor: 0, }; let commit: IrsConfig> = - irs_solver::solve(&spec, &ctx, OodSampleBudget::ZERO); + irs_params::solve(&spec, &ctx, OodSampleBudget::ZERO); let got = f64::from(analytic_error_bits(&commit)); let field_bits = TestField::field_size_bits(); @@ -143,7 +143,7 @@ mod tests { #[test] fn analytic_error_uses_eps_mca_when_limiting() { use crate::protocols::params::{ - irs_commit as irs_solver, + irs_commit as irs_params, spec::{Mode, OodSampleBudget, RoundContext}, }; @@ -154,7 +154,7 @@ mod tests { folding_factor: 0, }; let commit: IrsConfig> = - irs_solver::solve(&spec, &ctx, OodSampleBudget::ZERO); + irs_params::solve(&spec, &ctx, OodSampleBudget::ZERO); let field_bits = TestField::field_size_bits(); let log_list = commit.list_size().log2(); diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs index b01bf1ba..c8cad942 100644 --- a/src/protocols/params/bounds.rs +++ b/src/protocols/params/bounds.rs @@ -1,6 +1,6 @@ //! Regime-agnostic analytic primitives shared across the params solvers. //! -//! Regime-specific math (Johnson / Unique branches) lives on +//! Regime-specific math (Unique / Johnson / Capacity branches) lives on //! [`super::regime::DecodingRegimeParams`]. /// `ρ = 2^-log_inv_rate`. Centralized so the rate formula lives in one place. diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 4831b224..00faff2f 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -17,58 +17,54 @@ use crate::{ error::{grind_to_at, DeriveError, Pow}, protocol_config::MaskOracleInfo, spec::SecuritySpec, + SolveMode, }, }, }; -/// Standard-mode code-switch builder. PoW closes the Lemma 9.9 OOD gap to +/// Per-round code-switch builder. PoW closes the Lemma 9.9 OOD gap to /// `spec.target_security_bits`; `t_ood ≥ 1` is required by Construction 9.7. -pub fn solve_standard( +/// In ZK mode, `ℓ_zk ≥ r + t_ood` (Theorem 9.6 witness sizing) is asserted +/// against the mask oracle carried by [`super::SolveMode::ZeroKnowledge`]. +pub fn solve( spec: &SecuritySpec, source: IrsConfig, target: IrsConfig>, t_ood: usize, + mode: SolveMode, round_index: usize, ) -> Result, DeriveError> { - let analytic = analytic_error_bits(&source, &target, t_ood, None); + let (mask_oracle, output_mode) = match mode { + SolveMode::Standard => (None, code_switch::CodeSwitchMode::Standard), + SolveMode::ZeroKnowledge { mask_oracle } => { + let l_zk = mask_oracle.l_zk.get(); + assert!( + l_zk >= source.mask_length().saturating_add(t_ood), + "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Theorem 9.6 witness sizing", + source.mask_length(), + t_ood, + ); + ( + Some(mask_oracle), + code_switch::CodeSwitchMode::ZeroKnowledge { + message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), + }, + ) + } + }; + + let analytic = analytic_error_bits(&source, &target, t_ood, mask_oracle); let pow = grind_to_at(spec, analytic, Pow::RoundCodeSwitch { index: round_index })?; + Ok(CodeSwitchConfig::new( source, target, t_ood, - code_switch::CodeSwitchMode::Standard, + output_mode, pow, )) } -/// ZK code-switch builder. `mask_oracle.l_zk` must have been used to size C_zk -/// (planner's job). PoW closes the Lemma 9.9 OOD gap; `t_ood ≥ 1` and -/// `ℓ_zk ≥ r + t_ood` (Theorem 9.6) are asserted here. -pub fn solve_zk( - spec: &SecuritySpec, - source: IrsConfig, - target: IrsConfig>, - t_ood: usize, - mask_oracle: MaskOracleInfo, - round_index: usize, -) -> Result, DeriveError> { - let l_zk = mask_oracle.l_zk.get(); - assert!( - l_zk >= source.mask_length() + t_ood, - "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Theorem 9.6 witness sizing", - source.mask_length(), - t_ood, - ); - let mode = code_switch::CodeSwitchMode::ZeroKnowledge { - message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), - }; - - let analytic = analytic_error_bits(&source, &target, t_ood, Some(mask_oracle)); - let pow = grind_to_at(spec, analytic, Pow::RoundCodeSwitch { index: round_index })?; - - Ok(CodeSwitchConfig::new(source, target, t_ood, mode, pow)) -} - /// Per-round code-switch soundness in bits: `min` over Lemma 9.9's three RBR /// error slots (OOD, in-domain, combination). `t_ood ≥ 1` per /// [`code_switch::Config::new`]. @@ -88,12 +84,12 @@ pub fn analytic_error_bits( // just `t_ood`), so degree must use the realized `ℓ_zk`, not `r + t_ood`. let degree = mask_oracle.map_or_else( || source.message_length(), - |info| source.message_length() + info.l_zk.get(), + |info| source.message_length().saturating_add(info.l_zk.get()), ); let t_ood_f = usize_to_f64(t_ood); // OOD term — Lemma 9.9, term 1. - let log_degree_minus_1 = usize_to_f64(degree - 1).log2(); + let log_degree_minus_1 = usize_to_f64(degree.saturating_sub(1)).log2(); let log_l_choose_2 = (combined_list * (combined_list - 1.0) / 2.0).log2(); let ood_term = t_ood_f * (field_bits - log_degree_minus_1) - log_l_choose_2; @@ -102,7 +98,8 @@ pub fn analytic_error_bits( // Combination term — Lemma 9.9, term 3 (γ-RLC, bounds doc §5.1). let log_count = - usize_to_f64(t_ood + source.in_domain_samples * source.interleaving_depth).log2(); + usize_to_f64(t_ood.saturating_add(source.in_domain_samples * source.interleaving_depth)) + .log2(); let combination_term = field_bits - log_count - combined_list.log2(); Bits::new(ood_term.min(in_domain_term).min(combination_term).max(0.0)) @@ -120,9 +117,8 @@ mod tests { use super::*; use crate::protocols::params::{ - derive::{compute_l_zk, solve_t_ood}, - irs_commit as irs_solver, - regime::list_size_estimate, + derive::{compute_l_zk, solve_t_ood, OodMode}, + irs_commit as irs_params, spec::{ DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, RoundContext, SecuritySpec, ZkSpec, @@ -285,7 +281,7 @@ mod tests { ) { let (source, target, t_ood) = build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); - let config = solve_standard(&spec, source, target, t_ood, 0).unwrap(); + let config = solve(&spec, source, target, t_ood, SolveMode::Standard, 0).unwrap(); prop_assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); prop_assert!(config.out_domain_samples >= 1); } @@ -302,7 +298,7 @@ mod tests { let r = source.mask_length(); let l_zk = compute_l_zk(&source, t_ood); let zk_spec = ZkSpec::try_new(&spec).expect("arb_zk_spec"); - let c_zk = irs_solver::solve_mask_code::( + let c_zk = irs_params::solve_mask_code::( zk_spec, l_zk, r, @@ -313,7 +309,15 @@ mod tests { c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, }; - let config = solve_zk(&spec, source, target, t_ood, mask_oracle, 0).unwrap(); + let config = solve( + &spec, + source, + target, + t_ood, + SolveMode::ZeroKnowledge { mask_oracle }, + 0, + ) + .unwrap(); prop_assert_eq!(config.message_mask_length(), (r + t_ood).next_power_of_two()); } @@ -326,7 +330,7 @@ mod tests { let (source, target, t_ood) = build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); let error = analytic_error_bits(&source, &target, t_ood, None); - let config = solve_standard(&spec, source, target, t_ood, 0).unwrap(); + let config = solve(&spec, source, target, t_ood, SolveMode::Standard, 0).unwrap(); assert_pow_closes_gap(&spec, error, &config.pow); } } @@ -351,9 +355,10 @@ mod tests { (source_ctx, target_ctx) } - /// `solve_zk` asserts `ℓ_zk ≥ source.mask_length() + t_ood` (Theorem 9.6 - /// witness sizing). Build a self-consistent `(source, target, t_ood)` - /// and pass a deliberately-too-small `l_zk = 1` to trip the precondition. + /// `solve` asserts `ℓ_zk ≥ source.mask_length() + t_ood` (Theorem 9.6 + /// witness sizing) under [`SolveMode::ZeroKnowledge`]. Build a + /// self-consistent `(source, target, t_ood)` and pass a too-small + /// `l_zk = 1` to trip the precondition. #[test] #[should_panic(expected = "violates Theorem 9.6")] fn solve_zk_rejects_l_zk_below_r_plus_t_ood() { @@ -368,14 +373,21 @@ mod tests { Some(FORMULA_LOG_INV_RATE), ); // `source.mask_length() + t_ood ≥ 1 + 1 > TOO_SMALL_L_ZK` in ZK, - // so the assert in solve_zk fires. + // so the assert in `solve` fires. assert!(source.mask_length() + t_ood > TOO_SMALL_L_ZK); let mask_oracle = MaskOracleInfo { c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new(TOO_SMALL_L_ZK), }; - let _ = solve_zk(&spec, source, target, t_ood, mask_oracle, 0); + let _ = solve( + &spec, + source, + target, + t_ood, + SolveMode::ZeroKnowledge { mask_oracle }, + 0, + ); } /// Smoke test: `M::Source ≠ M::Target`, Standard mode. @@ -385,22 +397,25 @@ mod tests { let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); let target_log_degree = f64::from((source_ctx.vector_size / (1 << source_ctx.folding_factor)).trailing_zeros()); - let target_list_size = list_size_estimate( - spec.decoding_regime, - target_log_degree, - f64::from(target_ctx.log_inv_rate), - ); - let (source, t_ood) = - solve_t_ood::(&spec, &source_ctx, target_list_size, None, 0) - .unwrap(); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, f64::from(target_ctx.log_inv_rate)); + let (source, t_ood) = solve_t_ood::( + &spec, + &source_ctx, + target_list_size, + OodMode::Standard, + 0, + ) + .unwrap(); // Standard target: codeword_length is t_ood-independent (mask = 0). - let target = irs_solver::solve::>( + let target = irs_params::solve::>( &spec, &target_ctx, OodSampleBudget::ZERO, ); - let config = solve_standard(&spec, source, target, t_ood, 0).unwrap(); + let config = solve(&spec, source, target, t_ood, SolveMode::Standard, 0).unwrap(); assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); } @@ -415,20 +430,20 @@ mod tests { let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); let target_log_degree = f64::from((source_ctx.vector_size / (1 << source_ctx.folding_factor)).trailing_zeros()); - let target_list_size = list_size_estimate( - spec.decoding_regime, - target_log_degree, - f64::from(target_ctx.log_inv_rate), - ); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, f64::from(target_ctx.log_inv_rate)); let (source, t_ood) = solve_t_ood::( &spec, &source_ctx, target_list_size, - Some(f64::from(source_ctx.log_inv_rate)), + OodMode::ZeroKnowledge { + c_zk_log_inv_rate: f64::from(source_ctx.log_inv_rate), + }, 0, ) .unwrap(); - let target = irs_solver::solve::>( + let target = irs_params::solve::>( &spec, &target_ctx, OodSampleBudget::new(t_ood), @@ -438,7 +453,15 @@ mod tests { c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()), }; - let config = solve_zk(&spec, source, target, t_ood, mask_oracle, 0).unwrap(); + let config = solve( + &spec, + source, + target, + t_ood, + SolveMode::ZeroKnowledge { mask_oracle }, + 0, + ) + .unwrap(); assert!(matches!( config.mode, code_switch::CodeSwitchMode::ZeroKnowledge { .. } diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index ac13ea16..e74d1056 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -12,18 +12,17 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ - basecase as basecase_solver, + basecase as basecase_params, bounds::usize_to_f64, - code_switch as code_switch_solver, - error::{DeriveError, FixedPointLoop, Pow}, - irs_commit as irs_solver, mask_proximity as mask_proximity_solver, + code_switch as code_switch_params, + error::{DeriveError, Pow}, + irs_commit as irs_params, mask_proximity as mask_proximity_params, protocol_config::{MaskOracleConfig, ProtocolConfig, RoundConfig, RoundMode}, - regime::list_size_estimate, spec::{ DecodingRegime, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, TuningSpec, ZkSpec, }, - sumcheck as sumcheck_solver, + sumcheck as sumcheck_params, SolveMode, }, }, }; @@ -31,6 +30,21 @@ use crate::{ /// Paranoia guard on `solve_t_ood` — convergence proof on the function itself. const T_OOD_MAX_ITER: usize = 32; +/// Mode flag for the OOD security bound in [`solve_t_ood`] / +/// [`ood_security_bits_at`]. +/// +/// In `Standard`, the bound uses `L = target.list_size` and +/// `d = source.message_length()`. +/// +/// In `ZeroKnowledge`, the carried `c_zk_log_inv_rate` lets the bound include +/// the combined list size `target.list_size · c_zk.list_size` and the masked +/// degree `source.message_length() + ℓ_zk` (Lemma 9.9 witness layout). +#[derive(Clone, Copy)] +pub(super) enum OodMode { + Standard, + ZeroKnowledge { c_zk_log_inv_rate: f64 }, +} + impl ProtocolConfig { /// In ZK each round owns its mask oracle; the `ℓ_zk ↔ c_zk ↔ t_ood` /// fixed-point runs independently per round. @@ -45,22 +59,20 @@ impl ProtocolConfig { basecase_log_inv_rate, } = round_layout(&tuning); - let rounds: Vec> = match spec.mode { - Mode::Standard => shapes - .iter() - .map(|shape| build_round_config::(&spec, shape)) - .collect::>()?, - Mode::ZeroKnowledge => { - let zk_spec = ZkSpec::try_new(&spec).expect("matched Mode::ZeroKnowledge above"); - let c_zk_log_inv_rate = LogInvRate::new(tuning.starting_log_inv_rate); - shapes - .iter() - .map(|shape| build_zk_round_config::(zk_spec, shape, c_zk_log_inv_rate)) - .collect::>()? - } + let mode = match spec.mode { + Mode::Standard => RoundBuildMode::Standard, + Mode::ZeroKnowledge => RoundBuildMode::ZeroKnowledge { + zk_spec: ZkSpec::try_new(&spec).expect("matched Mode::ZeroKnowledge above"), + c_zk_log_inv_rate: LogInvRate::new(tuning.starting_log_inv_rate), + }, }; - let basecase = basecase_solver::solve(&spec, basecase_vector_size, basecase_log_inv_rate)?; + let rounds: Vec> = shapes + .iter() + .map(|shape| build_round_config::(&spec, shape, mode)) + .collect::>()?; + + let basecase = basecase_params::solve(&spec, basecase_vector_size, basecase_log_inv_rate)?; let plan = Self::new(spec, tuning, rounds, basecase); plan.validate()?; @@ -68,6 +80,33 @@ impl ProtocolConfig { } } +/// Mode-dispatch input for [`build_round_config`]. The ZK variant borrows the +/// spec via `ZkSpec` and carries the planner-level `c_zk_log_inv_rate`; the +/// rest of the round's behavior (target OOD budget, sub-protocol `SolveMode`, +/// output `RoundMode` payload) is derived from this single discriminator. +#[derive(Clone, Copy)] +enum RoundBuildMode<'a> { + Standard, + ZeroKnowledge { + zk_spec: ZkSpec<'a>, + c_zk_log_inv_rate: LogInvRate, + }, +} + +impl RoundBuildMode<'_> { + /// Project to the OOD-search mode consumed by [`solve_round_source`]. + fn to_ood_mode(self) -> OodMode { + match self { + Self::Standard => OodMode::Standard, + Self::ZeroKnowledge { + c_zk_log_inv_rate, .. + } => OodMode::ZeroKnowledge { + c_zk_log_inv_rate: f64::from(c_zk_log_inv_rate.get()), + }, + } + } +} + /// `target_folding_factor` is the next round's source folding — uniform /// `tuning.folding_factor` — so `target_r → source_{r+1}` has matching /// interleaving. @@ -86,7 +125,6 @@ struct RoundLayout { basecase_log_inv_rate: u32, } -/// Stops when there's no room for both a valid source and a valid target IRS. fn round_layout(tuning: &TuningSpec) -> RoundLayout { assert!(tuning.vector_size.is_power_of_two()); assert!(tuning.folding_factor.min() >= 1); @@ -98,8 +136,8 @@ fn round_layout(tuning: &TuningSpec) -> RoundLayout { loop { let round = shapes.len(); let source_folding = tuning.folding_factor.at_round(round); - let target_folding = tuning.folding_factor.at_round(round + 1); - if num_vars < source_folding + target_folding { + let target_folding = tuning.folding_factor.at_round(round.saturating_add(1)); + if num_vars < source_folding.saturating_add(target_folding) { break; } shapes.push(RoundShape { @@ -109,8 +147,8 @@ fn round_layout(tuning: &TuningSpec) -> RoundLayout { source_folding_factor: source_folding as u32, target_folding_factor: target_folding as u32, }); - num_vars -= source_folding; - log_inv_rate += (source_folding as u32).saturating_sub(1); + num_vars = num_vars.saturating_sub(source_folding); + log_inv_rate = log_inv_rate.saturating_add((source_folding as u32).saturating_sub(1)); } RoundLayout { @@ -131,129 +169,148 @@ const fn round_context(shape: &RoundShape) -> RoundContext { fn target_context(shape: &RoundShape, source: &IrsConfig) -> RoundContext { RoundContext { vector_size: source.message_length(), - log_inv_rate: shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1), + log_inv_rate: shape + .source_log_inv_rate + .saturating_add(shape.source_folding_factor.saturating_sub(1)), folding_factor: shape.target_folding_factor, } } -/// Per-round ZK builder. C_zk holds `2 · (k + 1)` columns (Construction 7.2 -/// originals + fresh): `k` sumcheck masks (Lemma 6.4) + one `(r ‖ s)` -/// code-switch mask (Construction 9.7). `ℓ_zk = next_pow2(r + t_ood)` from -/// Theorem 9.6's witness layout + Lemma 9.3's `r ≥ t` privacy precondition; -/// `t_ood` solves Lemma 9.9 term 1. -fn build_zk_round_config( - zk_spec: ZkSpec<'_>, +/// Per-round `(source, t_ood)` for either mode. The target dimensions +/// (`log_inv_rate` after the round's rate step, post-fold `log_degree`) feed +/// the regime's canonical-slack list-size estimate, which `solve_t_ood` uses +/// inside its OOD bound. Shared by Standard and ZK builders. +fn solve_round_source( + spec: &SecuritySpec, shape: &RoundShape, - c_zk_log_inv_rate: LogInvRate, -) -> Result, DeriveError> { - let spec = zk_spec.as_inner(); - let ctx = round_context(shape); - let num_masks = sumcheck_solver::masks_required(&ctx) + code_switch_solver::masks_required(); - let c_zk_log_inv_rate_f = f64::from(c_zk_log_inv_rate.get()); - + ood_mode: OodMode, +) -> Result<(IrsConfig, usize), DeriveError> { let src_ctx = round_context(shape); - let target_log_inv_rate = - f64::from(shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1)); + let target_log_inv_rate = f64::from( + shape + .source_log_inv_rate + .saturating_add(shape.source_folding_factor.saturating_sub(1)), + ); // Target encodes one polynomial of length `source.message_length()` = // `source_vector_size / 2^source_folding_factor`. - let target_log_degree = - f64::from(shape.source_vector_size.trailing_zeros() - shape.source_folding_factor); - let target_list_size = - list_size_estimate(spec.decoding_regime, target_log_degree, target_log_inv_rate); - - let (source, t_ood) = solve_t_ood::( + let target_log_degree = f64::from( + shape + .source_vector_size + .trailing_zeros() + .saturating_sub(shape.source_folding_factor), + ); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, target_log_inv_rate); + solve_t_ood::( spec, &src_ctx, target_list_size, - Some(c_zk_log_inv_rate_f), + ood_mode, shape.round_index, - )?; - let target: IrsConfig> = irs_solver::solve( - spec, - &target_context(shape, &source), - OodSampleBudget::new(t_ood), - ); + ) +} - let l_zk = compute_l_zk(&source, t_ood); - let c_zk: IrsConfig> = irs_solver::solve_mask_code( +/// ZK-only: assemble the per-round mask oracle (C_zk codeword + mask-proximity +/// check). `ℓ_zk = next_pow2(r + t_ood)` from Theorem 9.6's witness layout + +/// Lemma 9.3's `r ≥ t` privacy precondition; C_zk holds `2 · num_masks` +/// columns (Construction 7.2: originals + fresh). +fn build_mask_oracle( + zk_spec: ZkSpec<'_>, + source: &IrsConfig, + t_ood: usize, + num_masks: usize, + c_zk_log_inv_rate: LogInvRate, + round_index: usize, +) -> Result, DeriveError> { + let spec = zk_spec.as_inner(); + let l_zk = compute_l_zk(source, t_ood); + let c_zk: IrsConfig> = irs_params::solve_mask_code( zk_spec, l_zk, source.mask_length(), c_zk_log_inv_rate, 2 * num_masks, ); - let c_zk_list_size_estimate = list_size_estimate( - spec.decoding_regime, + let c_zk_list_size_estimate = spec.decoding_regime.list_size_estimate( (l_zk.get() as f64).log2(), - c_zk_log_inv_rate_f, + f64::from(c_zk_log_inv_rate.get()), ); debug_assert!( (c_zk.list_size() - c_zk_list_size_estimate).abs() < 1e-9 * c_zk_list_size_estimate.max(1.0), "c_zk.list_size() {} drifted from planner estimate {} — \ - see `list_size_estimate` for the invariant", + see `DecodingRegime::list_size_estimate` for the invariant", c_zk.list_size(), c_zk_list_size_estimate, ); - let mask_proximity = - mask_proximity_solver::solve(spec, c_zk.clone(), num_masks, shape.round_index)?; - let mask_oracle = MaskOracleConfig::new(c_zk, l_zk, mask_proximity); - let info = mask_oracle.info(); - - let sumcheck = sumcheck_solver::solve_zk( - spec, - &ctx, - &source, - info, - Pow::RoundSumcheck { - index: shape.round_index, - }, - )?; - let code_switch = - code_switch_solver::solve_zk(spec, source, target, t_ood, info, shape.round_index)?; - Ok(RoundConfig::new( - shape.round_index, - sumcheck, - code_switch, - RoundMode::ZeroKnowledge { - t_ood: OodSampleBudget::new(t_ood), - mask_oracle: Box::new(mask_oracle), - }, - )) + let mask_proximity = mask_proximity_params::solve(spec, c_zk.clone(), num_masks, round_index)?; + Ok(MaskOracleConfig::new(c_zk, l_zk, mask_proximity)) } +/// Per-round builder. Under [`RoundBuildMode::ZeroKnowledge`], C_zk holds +/// `2 · (k + 1)` columns: `k` sumcheck masks (Lemma 6.4) + one `(r ‖ s)` +/// code-switch mask (Construction 9.7). `t_ood` solves Lemma 9.9 term 1 in +/// both modes. fn build_round_config( spec: &SecuritySpec, shape: &RoundShape, + mode: RoundBuildMode<'_>, ) -> Result, DeriveError> { - let src_ctx = round_context(shape); - let target_log_inv_rate = - f64::from(shape.source_log_inv_rate + shape.source_folding_factor.saturating_sub(1)); - let target_log_degree = - f64::from(shape.source_vector_size.trailing_zeros() - shape.source_folding_factor); - let target_list_size = - list_size_estimate(spec.decoding_regime, target_log_degree, target_log_inv_rate); - - let (source, t_ood) = - solve_t_ood::(spec, &src_ctx, target_list_size, None, shape.round_index)?; - let target: IrsConfig> = - irs_solver::solve(spec, &target_context(shape, &source), OodSampleBudget::ZERO); + let ctx = round_context(shape); + let (source, t_ood) = solve_round_source::(spec, shape, mode.to_ood_mode())?; + + // Single mode-dispatch site; ZK additionally builds the mask oracle. + let (target_budget, solve_mode, round_mode) = match mode { + RoundBuildMode::Standard => ( + OodSampleBudget::ZERO, + SolveMode::Standard, + RoundMode::Standard, + ), + RoundBuildMode::ZeroKnowledge { + zk_spec, + c_zk_log_inv_rate, + } => { + let num_masks = + sumcheck_params::masks_required(&ctx) + code_switch_params::masks_required(); + let mask_oracle = build_mask_oracle::( + zk_spec, + &source, + t_ood, + num_masks, + c_zk_log_inv_rate, + shape.round_index, + )?; + let solve_mode = SolveMode::ZeroKnowledge { + mask_oracle: mask_oracle.info(), + }; + let round_mode = RoundMode::ZeroKnowledge { + t_ood: OodSampleBudget::new(t_ood), + mask_oracle: Box::new(mask_oracle), + }; + (OodSampleBudget::new(t_ood), solve_mode, round_mode) + } + }; - let sumcheck = sumcheck_solver::solve_standard( + let target: IrsConfig> = + irs_params::solve(spec, &target_context(shape, &source), target_budget); + let sumcheck = sumcheck_params::solve( spec, - &src_ctx, + &ctx, &source, + solve_mode, Pow::RoundSumcheck { index: shape.round_index, }, )?; let code_switch = - code_switch_solver::solve_standard(spec, source, target, t_ood, shape.round_index)?; + code_switch_params::solve(spec, source, target, t_ood, solve_mode, shape.round_index)?; + Ok(RoundConfig::new( shape.round_index, sumcheck, code_switch, - RoundMode::Standard, + round_mode, )) } @@ -263,39 +320,39 @@ pub(super) const fn compute_l_zk( source: &IrsConfig, t_ood: usize, ) -> MaskCodeMessageLen { - MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()) + MaskCodeMessageLen::new( + source + .mask_length() + .saturating_add(t_ood) + .next_power_of_two(), + ) } /// Per-round `(source, t_ood)` from a linear search over `t_ood`. /// -/// Under `Unique`, OOD contributes no soundness (`|Λ| = 1` ⇒ `(L choose 2) = 0`). -/// The short-circuit returns `t_ood = 1` — the Construction 9.7 protocol-layer +/// Under `Unique`, OOD contributes no soundness (`|Λ| = 1` ⇒ `(L choose 2) = 0`) +/// and the short-circuit pins `t_ood = 1` — the Construction 9.7 protocol-layer /// minimum, since [`crate::protocols::code_switch::Config::new`] asserts -/// `out_domain_samples ≥ 1` (Steps 2-3 always execute). +/// `out_domain_samples ≥ 1`. Letting the loop run would mis-evaluate the +/// `log(L·(L−1)/2) ≈ 2·log L − 1` approximation, which is `+∞` off the true +/// value when `L = 1`. /// -/// Under `Johnson`, searches `t_ood = 1..=T_OOD_MAX_ITER` for the smallest -/// value where the OOD security bound (STIR Lemma 4.5) -/// `t · (|F| − log d) − 2·log|Λ_combined| + 1` meets -/// `protocol_security_target_bits`. The bound is monotone-increasing in `t` -/// for `|F| ≫ log d` (always the case here), so the first match is the -/// minimum. Source is rebuilt per iteration because `source.mask_length()` -/// depends on `t_ood` in ZK (Lemma 9.5 ii); the rebuild is cheap (struct -/// fields only). +/// Under `Johnson`/`Capacity`, searches `t_ood = 1..=T_OOD_MAX_ITER` for the +/// smallest value where [`ood_security_bits_at`] meets +/// `protocol_security_target_bits`. The bound is monotone-increasing in `t` for +/// `|F| ≫ log d` (always the case here), so the first match is the minimum. +/// Source is rebuilt per iteration because `source.mask_length()` depends on +/// `t_ood` in ZK (`mask_length = in_domain + t_ood` per Construction 9.7 / +/// Theorem 9.6); the rebuild is cheap (struct fields only). pub(super) fn solve_t_ood( spec: &SecuritySpec, src_ctx: &RoundContext, target_list_size: f64, - c_zk_log_inv_rate: Option, + ood_mode: OodMode, round_index: usize, ) -> Result<(IrsConfig, usize), DeriveError> { if matches!(spec.decoding_regime, DecodingRegime::Unique) { - // The short-circuit is not an optimization — the linear-search formula - // uses `log(L·(L−1)/2) ≈ 2·log L − 1`, which is exact for `L ≥ 2` but - // *underestimates* security by `+∞` when `L = 1` (the true `(L choose 2)` - // is 0, not L²/2). Letting the loop run would falsely demand `t_ood > 1` - // at high security targets even though OOD provides infinite soundness - // headroom under Unique. Pin `t_ood = 1` directly. - let source = irs_solver::solve(spec, src_ctx, OodSampleBudget::new(1)); + let source = irs_params::solve(spec, src_ctx, OodSampleBudget::new(1)); return Ok((source, 1)); } @@ -303,42 +360,54 @@ pub(super) fn solve_t_ood( let field_bits = M::Target::field_size_bits(); for t_ood in 1..=T_OOD_MAX_ITER { - let source: IrsConfig = irs_solver::solve(spec, src_ctx, OodSampleBudget::new(t_ood)); - - // `degree` and `log_combined_list` depend on `t_ood` in ZK via ℓ_zk; - // Standard collapses to `degree = ℓ` and `combined = target.list_size`. - let (log_degree, log_combined_list) = c_zk_log_inv_rate.map_or_else( - || { - ( - usize_to_f64(source.message_length()).log2(), - target_list_size.log2(), - ) - }, - |c_zk_rate| { - let l_zk = (source.mask_length() + t_ood).next_power_of_two(); - let c_zk_list = - list_size_estimate(spec.decoding_regime, usize_to_f64(l_zk).log2(), c_zk_rate); - ( - usize_to_f64(source.message_length() + l_zk).log2(), - (target_list_size * c_zk_list).log2(), - ) - }, - ); - - // STIR Lemma 4.5 (single-MCA OOD): - // bits = t · (|F| − log d) − log(L · (L − 1) / 2) - // Approximate `log(L · (L − 1) / 2) ≈ 2·log L − 1` (exact-ish for L ≥ 2; - // the L = 1 case is handled by the Unique short-circuit above). - let ood = usize_to_f64(t_ood); - let bits = ood * (field_bits - log_degree) - 2.0 * log_combined_list + 1.0; + let source: IrsConfig = irs_params::solve(spec, src_ctx, OodSampleBudget::new(t_ood)); + let bits = + ood_security_bits_at(spec, &source, t_ood, target_list_size, ood_mode, field_bits); if bits >= security_target { return Ok((source, t_ood)); } } - Err(DeriveError::FixedPointDidNotConverge { - round_index, - loop_kind: FixedPointLoop::TOod, - }) + Err(DeriveError::FixedPointDidNotConverge { round_index }) +} + +/// OOD security bits at candidate `t_ood`, per STIR Lemma 4.5: +/// `bits = t · (|F| − log d) − log(L · (L − 1) / 2) ≈ t·(|F| − log d) − 2·log L + 1`. +/// +/// In [`OodMode::Standard`], `L = target.list_size` and `d = source.message_length()`. +/// In [`OodMode::ZeroKnowledge`], `L = target.list_size · c_zk.list_size` and +/// `d = source.message_length() + ℓ_zk` (Lemma 9.9 witness layout). +/// +/// The approximation `log(L·(L−1)/2) ≈ 2·log L − 1` is exact-ish for `L ≥ 2`; +/// the `L = 1` case is handled by [`solve_t_ood`]'s `Unique` short-circuit. +fn ood_security_bits_at( + spec: &SecuritySpec, + source: &IrsConfig, + t_ood: usize, + target_list_size: f64, + ood_mode: OodMode, + field_bits: f64, +) -> f64 { + let (log_degree, log_combined_list) = match ood_mode { + OodMode::Standard => ( + usize_to_f64(source.message_length()).log2(), + target_list_size.log2(), + ), + OodMode::ZeroKnowledge { c_zk_log_inv_rate } => { + let l_zk = source + .mask_length() + .saturating_add(t_ood) + .next_power_of_two(); + let c_zk_list = spec + .decoding_regime + .list_size_estimate(usize_to_f64(l_zk).log2(), c_zk_log_inv_rate); + ( + usize_to_f64(source.message_length().saturating_add(l_zk)).log2(), + (target_list_size * c_zk_list).log2(), + ) + } + }; + let ood = usize_to_f64(t_ood); + ood * (field_bits - log_degree) - 2.0 * log_combined_list + 1.0 } #[cfg(test)] @@ -985,12 +1054,12 @@ mod tests { let cs = r.code_switch(); assert_pow_closes_gap( spec, - sumcheck_solver::analytic_error_bits(&cs.source, mask_info), + sumcheck_params::analytic_error_bits(&cs.source, mask_info), &r.sumcheck().round_pow, ); assert_pow_closes_gap( spec, - code_switch_solver::analytic_error_bits( + code_switch_params::analytic_error_bits( &cs.source, &cs.target, cs.out_domain_samples, @@ -1002,14 +1071,14 @@ mod tests { let mp = mo.mask_proximity(); assert_pow_closes_gap( spec, - mask_proximity_solver::analytic_error_bits(&mp.c_zk_commit, mp.num_masks), + mask_proximity_params::analytic_error_bits(&mp.c_zk_commit, mp.num_masks), &mp.pow, ); } } assert_pow_closes_gap( spec, - sumcheck_solver::analytic_error_bits(&plan.basecase().commit, None), + sumcheck_params::analytic_error_bits(&plan.basecase().commit, None), &plan.basecase().sumcheck.round_pow, ); // γ-slot is ZK-only. @@ -1019,7 +1088,7 @@ mod tests { ) { assert_pow_closes_gap( spec, - basecase_solver::analytic_error_bits(&plan.basecase().commit), + basecase_params::analytic_error_bits(&plan.basecase().commit), &plan.basecase().pow, ); } diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index c217036f..d054cc5b 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -83,33 +83,15 @@ impl Display for ChainTarget { } } -/// Which fixed-point loop failed to converge. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum FixedPointLoop { - /// `derive::solve_t_ood` — combined `t_ood ↔ source` Kleene iteration. - TOod, -} - -impl Display for FixedPointLoop { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Self::TOod => f.write_str("t_ood"), - } - } -} - /// Failure modes for [`super::derive::ProtocolConfig::derive`] and the /// sub-protocol solvers it calls. #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum DeriveError { - /// A fixed-point loop ran out of iterations. Indicates a pathological - /// spec/tuning combo; should not happen under realistic security targets - /// on supported fields. - #[error("{loop_kind} fixed-point did not converge for round {round_index}")] - FixedPointDidNotConverge { - round_index: usize, - loop_kind: FixedPointLoop, - }, + /// The `t_ood` fixed-point in [`super::derive::solve_t_ood`] ran out of + /// iterations. Indicates a pathological spec/tuning combo; should not + /// happen under realistic security targets on supported fields. + #[error("t_ood fixed-point did not converge for round {round_index}")] + FixedPointDidNotConverge { round_index: usize }, /// A PoW grind cannot close the analytic-to-target gap — the spec is too /// tight for any single grind to reach `target_security_bits`. diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 83756957..430dea57 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -1,8 +1,13 @@ //! IRS-commit parameter selection. //! -//! ZK mask is sized per Lemma 9.5 (paper p.53) at the tight bound -//! `in-domain + OOD` queries. Codeword NTT-smoothness is enforced inside -//! [`IrsConfig::new`] on `codeword_length`, not by inflating the mask. +//! ZK mask is sized at the tight bound `in-domain + OOD` queries. Construction +//! 9.7 / Theorem 9.6 (paper p.54-55) reveals `in_domain + t_ood` source +//! positions per round (in-domain queries + OOD linear combinations via +//! ze_ood), so the source encoding must be `(in_domain + t_ood)`-query ZK +//! (Definition 3.16, p.29). For the Reed–Solomon code that means +//! `mask_length = in_domain + t_ood` (Proposition 3.19, p.30). Codeword +//! NTT-smoothness is enforced inside [`IrsConfig::new`] on `codeword_length`, +//! not by inflating the mask. use crate::{ algebra::embedding::Embedding, @@ -30,11 +35,14 @@ pub fn solve( let mode = match spec.mode { Mode::Standard => IrsMode::Standard, Mode::ZeroKnowledge => { - // Lemma 9.5 (part ii): r-query perfect-ZK encoding requires - // `r ≥ in-domain + OOD`. Use the tight bound; do not pow2-pad here. + // Construction 9.7 / Theorem 9.6: the verifier reveals + // `in_domain + t_ood` source positions (in-domain queries + + // ze_ood linear combinations), so the source encoding must be + // (in_domain + t_ood)-query ZK (Definition 3.16). Use the tight + // RS bound `mask_length = t` from Proposition 3.19; do not + // pow2-pad here. let mask_length = num_in_domain_queries(spec.decoding_regime, security_target, rate) - .checked_add(out_domain_samples.get()) - .expect("usize overflow"); + .saturating_add(out_domain_samples.get()); IrsMode::ZeroKnowledge { mask_length } } }; @@ -164,9 +172,10 @@ mod tests { } proptest! { - /// Lemma 9.5 (part ii): mask covers all revealed evaluations. + /// Construction 9.7 / Theorem 9.6: mask covers all revealed source + /// positions (in-domain queries + OOD linear combinations). #[test] - fn zk_mask_covers_lemma_9_5( + fn zk_mask_covers_in_domain_plus_ood( spec in arb_zk_spec_default(), ctx in arb_round_ctx(), out_domain in 0usize..16, diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index 115121c1..0f917d2b 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -42,7 +42,7 @@ pub fn analytic_error_bits(c_zk: &IrsConfig>, num_masks: u if deg <= 1 || num_masks == 0 { return Bits::new(field_bits.max(0.0)); } - let log_combined = usize_to_f64(num_masks * (deg - 1)).log2(); + let log_combined = usize_to_f64(num_masks * deg.saturating_sub(1)).log2(); Bits::new((field_bits - log_combined).max(0.0)) } diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index d257df72..d3229b3d 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -23,7 +23,7 @@ pub(crate) mod sumcheck; #[cfg(test)] pub(crate) mod test_utils; -pub use error::{ChainSource, ChainTarget, DeriveError, FixedPointLoop, Pow}; +pub use error::{ChainSource, ChainTarget, DeriveError, Pow}; pub use protocol_config::{ MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, }; @@ -31,3 +31,19 @@ pub use spec::{ DecodingRegime, FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, RoundContext, SecuritySpec, TuningSpec, ZkSpec, }; + +/// Solver-input mode for the per-round sumcheck and code-switch builders. +/// +/// Both sub-protocols branch on the same Standard vs. ZK distinction with the +/// same `MaskOracleInfo` payload, so a shared vocabulary keeps call sites +/// uniform. +/// +/// Distinct from [`Mode`] (the spec-level policy enum, which carries no +/// payload) and from sub-protocol *output* modes +/// (`sumcheck::SumcheckMode`, `code_switch::CodeSwitchMode`) whose payloads +/// describe the configured round rather than its solver input. +#[derive(Clone, Copy)] +pub enum SolveMode { + Standard, + ZeroKnowledge { mask_oracle: MaskOracleInfo }, +} diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 6dfa9273..0d60d210 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -25,10 +25,10 @@ use crate::{ mask_proximity::Config as MaskProximityConfig, params::{ bounds::usize_to_f64, - code_switch as code_switch_solver, + code_switch as code_switch_params, error::{ChainSource, ChainTarget, DeriveError, Pow}, spec::{ListSize, MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, - sumcheck as sumcheck_solver, + sumcheck as sumcheck_params, }, proof_of_work::Config as PowConfig, sumcheck::Config as SumcheckConfig, @@ -214,9 +214,7 @@ impl ProtocolConfig { } /// Test-only mutators. Grouped here so the production `impl` block above -/// reads as the public API surface and these escape hatches aren't easily -/// mistaken for it. Each one supports a specific negative test in -/// `derive::tests`; there is no equivalent on the public API. +/// reads as the public API surface; no equivalent on the public API. #[cfg(test)] impl ProtocolConfig { /// Inject an over-budget basecase PoW slot so `validate_pow_budget` can @@ -337,8 +335,8 @@ impl RoundConfig { let target = &self.code_switch.target; let mask_info = self.mask_oracle_info(); - let sumcheck_term = f64::from(sumcheck_solver::analytic_error_bits(source, mask_info)); - let code_switch_term = f64::from(code_switch_solver::analytic_error_bits( + let sumcheck_term = f64::from(sumcheck_params::analytic_error_bits(source, mask_info)); + let code_switch_term = f64::from(code_switch_params::analytic_error_bits( source, target, self.code_switch.out_domain_samples, diff --git a/src/protocols/params/regime.rs b/src/protocols/params/regime.rs index d1bc0a30..fcb52c08 100644 --- a/src/protocols/params/regime.rs +++ b/src/protocols/params/regime.rs @@ -149,17 +149,20 @@ impl DecodingRegimeParams { } } -/// `|Λ|` at the given degree + rate under `regime`. Used before an IRS config -/// exists. -/// -/// Matches `IrsConfig::list_size()` when the IRS is built under the same -/// regime, with the same `masked_message_length`, and `ntt::next_order` -/// doesn't pad the codeword (pow2 `vector_size`, `interleaving_depth = 1`, -/// integer `log_inv_rate`, 2-adic field — the conditions `solve_mask_code` -/// enforces for C_zk). -pub fn list_size_estimate(regime: DecodingRegime, log_degree: f64, log_inv_rate: f64) -> f64 { - DecodingRegimeParams::from_policy(regime, rate(log_inv_rate)) - .list_size(log_degree, log_inv_rate) +impl DecodingRegime { + /// `|Λ|` at canonical slack, before an IRS config exists. Use the + /// `DecodingRegimeParams::list_size` method when a non-canonical slack + /// has already been materialized. + /// + /// Matches `IrsConfig::list_size()` when the IRS is built under the same + /// regime, with the same `masked_message_length`, and `ntt::next_order` + /// doesn't pad the codeword (pow2 `vector_size`, `interleaving_depth = 1`, + /// integer `log_inv_rate`, 2-adic field — the conditions `solve_mask_code` + /// enforces for C_zk). + pub fn list_size_estimate(self, log_degree: f64, log_inv_rate: f64) -> f64 { + DecodingRegimeParams::from_policy(self, rate(log_inv_rate)) + .list_size(log_degree, log_inv_rate) + } } #[cfg(test)] @@ -209,11 +212,11 @@ mod tests { } /// `η = √ρ / 20` substituted into `|Λ| = 1/(2η√ρ)` simplifies to `10/ρ`. - /// So `list_size_estimate(Johnson, _, b) = 10 · 2^b`. + /// So `DecodingRegime::Johnson.list_size_estimate(_, b) = 10 · 2^b`. #[test] fn johnson_list_size_closed_form() { for b in [1.0, 2.0, 3.0, 5.0] { - let got = list_size_estimate(DecodingRegime::Johnson, /* log_degree */ 4.0, b); + let got = DecodingRegime::Johnson.list_size_estimate(/* log_degree */ 4.0, b); let expected = 10.0 * 2_f64.powf(b); assert!( (got - expected).abs() / expected < TIGHT_EPS, @@ -226,7 +229,7 @@ mod tests { #[test] fn capacity_list_size_closed_form() { for (log_d, b) in [(4.0, 1.0), (6.0, 2.0), (8.0, 3.0)] { - let got = list_size_estimate(DecodingRegime::Capacity, log_d, b); + let got = DecodingRegime::Capacity.list_size_estimate(log_d, b); let expected = 20.0 * 2_f64.powf(log_d) * 2_f64.powf(2.0 * b); assert!( (got - expected).abs() / expected < TIGHT_EPS, @@ -235,7 +238,7 @@ mod tests { } } - /// `list_size_estimate(Johnson, _, b)` must match `Config::list_size` once + /// `DecodingRegime::Johnson.list_size_estimate(_, b)` must match `Config::list_size` once /// a config is built at the same rate. Keeps the rate-only helper in sync /// with `irs_commit::Config::new`'s canonical-slack materialization. #[test] @@ -262,7 +265,7 @@ mod tests { IrsMode::Standard, ); let log_degree = (config.masked_message_length() as f64).log2(); - let got = list_size_estimate(DecodingRegime::Johnson, log_degree, f64::from(LOG_INV_RATE)); + let got = DecodingRegime::Johnson.list_size_estimate(log_degree, f64::from(LOG_INV_RATE)); let expected = config.list_size(); assert!( (got - expected).abs() / expected < TIGHT_EPS, diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index e4fd1d37..dbc4ee3c 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -264,8 +264,7 @@ pub enum LogInvRateTag {} pub type OodSampleBudget = Tagged; impl Tagged { - /// Sentinel for "no OOD samples". Used by sub-protocols that don't - /// require an OOD challenge round (e.g. Standard mode, basecase). + /// Sentinel for "no OOD samples". pub const ZERO: Self = Self::new(0); } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 3e418957..c5f41287 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -11,49 +11,37 @@ use crate::{ error::{grind_to_at, DeriveError, Pow}, protocol_config::MaskOracleInfo, spec::{RoundContext, SecuritySpec}, + SolveMode, }, sumcheck::{self, Config as SumcheckConfig, SumcheckMaskLen}, }, }; -/// Standard-mode sumcheck builder. `pow` labels grinding failures -/// (basecase or per-round). -pub fn solve_standard( +/// Per-round sumcheck builder. `mode` carries the optional mask oracle (see +/// [`super::SolveMode`]); `pow` labels grinding failures (basecase or +/// per-round). +pub fn solve( spec: &SecuritySpec, ctx: &RoundContext, source_irs: &IrsConfig, + mode: SolveMode, pow: Pow, ) -> Result, DeriveError> { - let round_pow = grind_to_at(spec, analytic_error_bits(source_irs, None), pow)?; - Ok(SumcheckConfig::new( - ctx.vector_size, - round_pow, - num_sumcheck_rounds(ctx), - sumcheck::SumcheckMode::Standard, - )) -} - -/// ZK sumcheck builder. `mask_oracle` carries C_zk's list size + ℓ_zk; only -/// those two values are read here. `pow` labels grinding failures. -pub fn solve_zk( - spec: &SecuritySpec, - ctx: &RoundContext, - source_irs: &IrsConfig, - mask_oracle: MaskOracleInfo, - pow: Pow, -) -> Result, DeriveError> { - let round_pow = grind_to_at( - spec, - analytic_error_bits(source_irs, Some(mask_oracle)), - pow, - )?; + let (mask_oracle, output_mode) = match mode { + SolveMode::Standard => (None, sumcheck::SumcheckMode::Standard), + SolveMode::ZeroKnowledge { mask_oracle } => ( + Some(mask_oracle), + sumcheck::SumcheckMode::ZeroKnowledge { + mask_length: zk_mask_length(), + }, + ), + }; + let round_pow = grind_to_at(spec, analytic_error_bits(source_irs, mask_oracle), pow)?; Ok(SumcheckConfig::new( ctx.vector_size, round_pow, num_sumcheck_rounds(ctx), - sumcheck::SumcheckMode::ZeroKnowledge { - mask_length: zk_mask_length(), - }, + output_mode, )) } @@ -101,7 +89,7 @@ mod tests { use super::*; use crate::protocols::params::{ - irs_commit as irs_solver, + irs_commit as irs_params, spec::{ListSize, MaskCodeMessageLen, Mode, OodSampleBudget}, test_utils::{ arb_round_ctx, arb_standard_spec, arb_zk_spec, assert_close, assert_pow_closes_gap, @@ -116,7 +104,7 @@ mod tests { const FIXTURE_L_ZK: usize = 8; fn build_source_irs(spec: &SecuritySpec, ctx: &RoundContext) -> IrsConfig { - irs_solver::solve(spec, ctx, OodSampleBudget::ZERO) + irs_params::solve(spec, ctx, OodSampleBudget::ZERO) } /// Smallest pow2 shape that still produces a non-degenerate IRS. @@ -140,11 +128,11 @@ mod tests { let source_irs = build_source_irs(&spec, &ctx); let mask_oracle = build_minimal_mask_oracle(&spec).expect("ZK spec must produce a mask oracle"); - let config = solve_zk( + let config = solve( &spec, &ctx, &source_irs, - mask_oracle, + SolveMode::ZeroKnowledge { mask_oracle }, Pow::RoundSumcheck { index: 0 }, ) .unwrap(); @@ -225,7 +213,7 @@ mod tests { ) { let source_irs = build_source_irs(&spec, &ctx); let pow = Pow::RoundSumcheck { index: 0 }; - let config = solve_standard(&spec, &ctx, &source_irs, pow).unwrap(); + let config = solve(&spec, &ctx, &source_irs, SolveMode::Standard, pow).unwrap(); prop_assert!(matches!(config.mode, sumcheck::SumcheckMode::Standard)); } @@ -239,10 +227,11 @@ mod tests { ) { let source_irs = build_source_irs(&spec, &ctx); let pow = Pow::RoundSumcheck { index: 0 }; - let config = build_minimal_mask_oracle(&spec).map_or_else( - || solve_standard(&spec, &ctx, &source_irs, pow).unwrap(), - |info| solve_zk(&spec, &ctx, &source_irs, info, pow).unwrap(), - ); + let mode = build_minimal_mask_oracle(&spec) + .map_or(SolveMode::Standard, |mask_oracle| { + SolveMode::ZeroKnowledge { mask_oracle } + }); + let config = solve(&spec, &ctx, &source_irs, mode, pow).unwrap(); prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); } @@ -273,10 +262,10 @@ mod tests { let mask_oracle = build_minimal_mask_oracle(&spec); let error = analytic_error_bits(&source_irs, mask_oracle); let pow = Pow::RoundSumcheck { index: 0 }; - let config = mask_oracle.map_or_else( - || solve_standard(&spec, &ctx, &source_irs, pow).unwrap(), - |info| solve_zk(&spec, &ctx, &source_irs, info, pow).unwrap(), - ); + let mode = mask_oracle.map_or(SolveMode::Standard, |mask_oracle| { + SolveMode::ZeroKnowledge { mask_oracle } + }); + let config = solve(&spec, &ctx, &source_irs, mode, pow).unwrap(); assert_pow_closes_gap(&spec, error, &config.round_pow); } } @@ -287,16 +276,16 @@ mod tests { let spec = deterministic_spec(Mode::ZeroKnowledge); let ctx = fixture_ctx(); let source_irs: IrsConfig = - irs_solver::solve(&spec, &ctx, OodSampleBudget::ZERO); + irs_params::solve(&spec, &ctx, OodSampleBudget::ZERO); let info = MaskOracleInfo { c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), }; - let config = solve_zk( + let config = solve( &spec, &ctx, &source_irs, - info, + SolveMode::ZeroKnowledge { mask_oracle: info }, Pow::RoundSumcheck { index: 0 }, ) .unwrap(); diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index cf6d8ea0..807b0e98 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -14,10 +14,9 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ - derive::solve_t_ood, - irs_commit as irs_solver, + derive::{solve_t_ood, OodMode}, + irs_commit as irs_params, protocol_config::MaskOracleInfo, - regime::list_size_estimate, spec::{ DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, RoundContext, SecuritySpec, ZkSpec, @@ -111,7 +110,7 @@ pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option let zk_spec = ZkSpec::try_new(spec)?; let l_zk = MaskCodeMessageLen::new(2); let c_zk: IrsConfig = - irs_solver::solve_mask_code(zk_spec, l_zk, 0, LogInvRate::new(1), 2); + irs_params::solve_mask_code(zk_spec, l_zk, 0, LogInvRate::new(1), 2); Some(MaskOracleInfo { c_zk_list_size: ListSize::new(c_zk.list_size()), l_zk, @@ -149,7 +148,7 @@ pub fn build_test_c_zk( num_masks: usize, ) -> IrsConfig { let zk_spec = ZkSpec::try_new(spec).expect("build_test_c_zk requires a ZK spec"); - irs_solver::solve_mask_code( + irs_params::solve_mask_code( zk_spec, MaskCodeMessageLen::new(l_zk), 0, @@ -161,11 +160,11 @@ pub fn build_test_c_zk( /// Builds a self-consistent `(source, target, t_ood)` triplet matching the /// per-round shape that `code_switch::solve` expects. /// -/// `t_ood` is solved against the rate-only `list_size_estimate(...)` rather -/// than `target.list_size()`: the latter reads the target's effective rate -/// (which itself depends on `t_ood` via the mask), producing a non-monotone -/// oscillation once the mask is tight (Lemma 9.5 part ii) rather than -/// pow2-padded. +/// `t_ood` is solved against the rate-only `DecodingRegime::list_size_estimate` +/// rather than `target.list_size()`: the latter reads the target's effective +/// rate (which itself depends on `t_ood` via the mask), producing a +/// non-monotone oscillation once the mask is tight (`mask_length = in_domain +/// + t_ood` per Construction 9.7 / Theorem 9.6) rather than pow2-padded. pub fn build_round_io( spec: &SecuritySpec, log_inv_rate: u32, @@ -180,21 +179,20 @@ pub fn build_round_io( }; let target_log_inv_rate = log_inv_rate + folding_factor - 1; let target_log_degree = f64::from(num_vars - folding_factor); - let target_list_size = list_size_estimate( - spec.decoding_regime, - target_log_degree, - f64::from(target_log_inv_rate), - ); - let c_zk_log_inv_rate = c_zk_log_inv_rate.map(f64::from); - let (source, t_ood) = - solve_t_ood::(spec, &source_ctx, target_list_size, c_zk_log_inv_rate, 0) - .expect("solve_t_ood diverged in test fixture"); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, f64::from(target_log_inv_rate)); + let ood_mode = c_zk_log_inv_rate.map_or(OodMode::Standard, |rate| OodMode::ZeroKnowledge { + c_zk_log_inv_rate: f64::from(rate), + }); + let (source, t_ood) = solve_t_ood::(spec, &source_ctx, target_list_size, ood_mode, 0) + .expect("solve_t_ood diverged in test fixture"); let target_ctx = RoundContext { vector_size: source.message_length(), log_inv_rate: target_log_inv_rate, folding_factor, }; - let target = irs_solver::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); + let target = irs_params::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); (source, target, t_ood) } From 1c66bb5afca61bff4b9aef2dcbdca2c16c9bb22d Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Mon, 25 May 2026 11:15:43 +0530 Subject: [PATCH 26/31] lint: comments --- src/protocols/params/basecase.rs | 12 -- src/protocols/params/code_switch.rs | 42 +----- src/protocols/params/derive.rs | 183 ++---------------------- src/protocols/params/error.rs | 21 +-- src/protocols/params/irs_commit.rs | 30 +--- src/protocols/params/mask_proximity.rs | 15 -- src/protocols/params/mod.rs | 13 -- src/protocols/params/protocol_config.rs | 76 ++-------- src/protocols/params/regime.rs | 65 ++------- src/protocols/params/spec.rs | 53 +------ src/protocols/params/sumcheck.rs | 20 +-- src/protocols/params/test_utils.rs | 32 +---- 12 files changed, 58 insertions(+), 504 deletions(-) diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index 9b875079..dfa221d1 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -19,8 +19,6 @@ use crate::{ }, }; -/// PoW closes the Theorem 7.1 γ-slot gap to `spec.target_security_bits`; no -/// γ challenge in Standard mode ⇒ `Config::none()`. pub fn solve( spec: &SecuritySpec, vector_size: usize, @@ -65,7 +63,6 @@ pub fn solve( } /// γ-combination soundness (Lemma 7.4 combination-randomness slot, paper p.45). -/// At `n = 0` the `C_zk` factors vanish; `ε_mca(C, δ)` does not. pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { let field_bits = F::field_size_bits(); let log_list = commit.list_size().log2(); @@ -76,8 +73,6 @@ pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { impl BasecaseConfig { /// Analytic soundness bits (excluding PoW): `min(sumcheck round error, γ-slot error)`. - /// The γ-slot only contributes in ZK mode; Standard collapses to the - /// sumcheck term. pub fn analytic_bits(&self) -> Bits { let sumcheck_term = f64::from(sumcheck_params::analytic_error_bits(&self.commit, None)); let min_bits = match self.mode { @@ -100,9 +95,6 @@ mod tests { TestField, TEST_TARGET_RANGE, }; - /// `vector_size = 16` (2^4) and `log_inv_rate = 2` give a small but - /// non-degenerate basecase IRS. `folding_factor = 0` is the basecase - /// invariant (no folding, message_length = vector_size). const FIXTURE_VECTOR_SIZE: usize = 16; const FIXTURE_LOG_INV_RATE: u32 = 2; @@ -110,8 +102,6 @@ mod tests { (1u32..=4, 1u32..=3) } - /// Builds the commit directly via the IRS solver to bypass `solve`'s PoW - /// grind (which would assert against the cap for default test targets). #[test] fn analytic_error_formula() { use crate::protocols::params::{ @@ -138,8 +128,6 @@ mod tests { assert_close(got, expected); } - /// At `log_inv_rate = 1` on `Field64`, `ε_mca` is below the poly-identity - /// term — pins the `min` to the prox-gaps arm rather than `poly_id`. #[test] fn analytic_error_uses_eps_mca_when_limiting() { use crate::protocols::params::{ diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 00faff2f..f539469d 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -1,5 +1,4 @@ //! Code-switching IOR (Construction 9.7, p.55) builder + Lemma 9.9 OOD bound. -//! The `t_ood` / `ℓ_zk` fixed-points live in the planner. use std::num::NonZeroUsize; @@ -22,10 +21,7 @@ use crate::{ }, }; -/// Per-round code-switch builder. PoW closes the Lemma 9.9 OOD gap to -/// `spec.target_security_bits`; `t_ood ≥ 1` is required by Construction 9.7. -/// In ZK mode, `ℓ_zk ≥ r + t_ood` (Theorem 9.6 witness sizing) is asserted -/// against the mask oracle carried by [`super::SolveMode::ZeroKnowledge`]. +/// Per-round code-switch builder. pub fn solve( spec: &SecuritySpec, source: IrsConfig, @@ -66,8 +62,7 @@ pub fn solve( } /// Per-round code-switch soundness in bits: `min` over Lemma 9.9's three RBR -/// error slots (OOD, in-domain, combination). `t_ood ≥ 1` per -/// [`code_switch::Config::new`]. +/// error slots (OOD, in-domain, combination). pub fn analytic_error_bits( source: &IrsConfig, target: &IrsConfig>, @@ -80,8 +75,7 @@ pub fn analytic_error_bits( let combined_list = target.list_size() * mask_oracle.map_or(1.0, |info| info.c_zk_list_size.get()); // OOD polynomial is over witness `[f; r_C; s]` of length `ℓ + ℓ_zk` (ZK) or - // `ℓ` (Standard). The `s`-tail is sampled at full length `ℓ_zk − r` (not - // just `t_ood`), so degree must use the realized `ℓ_zk`, not `r + t_ood`. + // `ℓ` (Standard). let degree = mask_oracle.map_or_else( || source.message_length(), |info| source.message_length().saturating_add(info.l_zk.get()), @@ -142,8 +136,6 @@ mod tests { const NUM_VARS_HEADROOM: u32 = 4; - /// `(log_inv_rate, folding_factor, num_vars)`. `num_vars ≥ 2 · folding_factor` - /// keeps target IRS valid. fn arb_dims() -> impl Strategy { (1u32..=3, 1u32..=2).prop_flat_map(|(log_inv_rate, folding_factor)| { let min_num_vars = 2 * folding_factor; @@ -159,8 +151,6 @@ mod tests { const FORMULA_FOLDING_FACTOR: u32 = 2; const FORMULA_NUM_VARS: u32 = 6; - /// Standard `min(ood, in_domain, comb)` from Lemma 9.9's three RBR error - /// slots; `L = target.list_size()`. #[test] fn analytic_error_standard_formula() { let spec: SecuritySpec = deterministic_spec(Mode::Standard); @@ -187,14 +177,10 @@ mod tests { assert_close(got, expected); } - /// ZK bound: combined list `L = target × c_zk`, masked degree `ℓ + ℓ_zk`, - /// combination term also subtracts `log|Λ(C_zk)|`. #[test] fn analytic_error_zk_formula() { - // Both mask-oracle values are pow2 so `log2` is exact (avoids - // floating-point drift in the expected-vs-got comparison). - const C_ZK_LIST_SIZE: f64 = 4.0; // log2 = 2 - const L_ZK_USIZE: usize = 8; // log2 = 3 + const C_ZK_LIST_SIZE: f64 = 4.0; + const L_ZK_USIZE: usize = 8; let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); let mask_oracle = MaskOracleInfo { @@ -230,9 +216,6 @@ mod tests { assert_close(got, expected); } - /// Low security target (16 bits) pins `source.rbr_queries()` below the - /// natural OOD and combination floors on `Field64`, forcing the `min` to - /// the arm #[test] fn analytic_error_uses_in_domain_when_limiting() { const LIMITING_TARGET_BITS: u32 = 16; @@ -286,7 +269,6 @@ mod tests { prop_assert!(config.out_domain_samples >= 1); } - /// ZK: `ℓ_zk = next_power_of_two(r + t_ood)`. #[test] fn solve_zk_mask_equals_padded_r_plus_t_ood( spec in arb_zk_spec(), @@ -321,7 +303,6 @@ mod tests { prop_assert_eq!(config.message_mask_length(), (r + t_ood).next_power_of_two()); } - /// `analytic_error + pow ≥ target` (Lemma 9.9 OOD term). #[test] fn pow_closes_gap_to_target_standard( spec in arb_standard_spec(), @@ -335,8 +316,6 @@ mod tests { } } - /// Shared shape for the `M::Source ≠ M::Target` smoke tests. `target_ctx` - /// uses the same per-round chaining the planner does. fn non_identity_smoke_ctxs() -> (RoundContext, RoundContext) { const SOURCE_VECTOR_SIZE: usize = 64; const SOURCE_LOG_INV_RATE: u32 = 1; @@ -355,10 +334,6 @@ mod tests { (source_ctx, target_ctx) } - /// `solve` asserts `ℓ_zk ≥ source.mask_length() + t_ood` (Theorem 9.6 - /// witness sizing) under [`SolveMode::ZeroKnowledge`]. Build a - /// self-consistent `(source, target, t_ood)` and pass a too-small - /// `l_zk = 1` to trip the precondition. #[test] #[should_panic(expected = "violates Theorem 9.6")] fn solve_zk_rejects_l_zk_below_r_plus_t_ood() { @@ -372,8 +347,6 @@ mod tests { FORMULA_NUM_VARS, Some(FORMULA_LOG_INV_RATE), ); - // `source.mask_length() + t_ood ≥ 1 + 1 > TOO_SMALL_L_ZK` in ZK, - // so the assert in `solve` fires. assert!(source.mask_length() + t_ood > TOO_SMALL_L_ZK); let mask_oracle = MaskOracleInfo { @@ -390,7 +363,6 @@ mod tests { ); } - /// Smoke test: `M::Source ≠ M::Target`, Standard mode. #[test] fn solve_works_with_basefield_embedding_standard() { let spec: SecuritySpec = deterministic_spec(Mode::Standard); @@ -408,7 +380,6 @@ mod tests { 0, ) .unwrap(); - // Standard target: codeword_length is t_ood-independent (mask = 0). let target = irs_params::solve::>( &spec, &target_ctx, @@ -419,11 +390,8 @@ mod tests { assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); } - /// Placeholder mask-oracle list size for the smoke test — pow2 so `log2` - /// is exact. const SMOKE_C_ZK_LIST_SIZE: f64 = 4.0; - /// Smoke test: `M::Source ≠ M::Target`, ZK mode. #[test] fn solve_works_with_basefield_embedding_zk() { let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index e74d1056..922da696 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -1,8 +1,4 @@ //! Derives a [`ProtocolConfig`] from a spec + tuning. -//! -//! All cross-protocol coordination lives here: per-round `t_ood ↔ r` and -//! `ℓ_zk ↔ c_zk` fixed-points, plus the per-round mask oracle (C_zk + -//! mask-proximity sized for `k + 1` masks). use crate::{ algebra::{ @@ -27,18 +23,10 @@ use crate::{ }, }; -/// Paranoia guard on `solve_t_ood` — convergence proof on the function itself. const T_OOD_MAX_ITER: usize = 32; /// Mode flag for the OOD security bound in [`solve_t_ood`] / /// [`ood_security_bits_at`]. -/// -/// In `Standard`, the bound uses `L = target.list_size` and -/// `d = source.message_length()`. -/// -/// In `ZeroKnowledge`, the carried `c_zk_log_inv_rate` lets the bound include -/// the combined list size `target.list_size · c_zk.list_size` and the masked -/// degree `source.message_length() + ℓ_zk` (Lemma 9.9 witness layout). #[derive(Clone, Copy)] pub(super) enum OodMode { Standard, @@ -46,12 +34,8 @@ pub(super) enum OodMode { } impl ProtocolConfig { - /// In ZK each round owns its mask oracle; the `ℓ_zk ↔ c_zk ↔ t_ood` - /// fixed-point runs independently per round. - /// /// Fails with [`DeriveError`] when the spec/tuning combination is - /// infeasible: a PoW slot exceeds the grind cap, a fixed point diverges, - /// or any slot exceeds `spec.pow_budget` (post-derivation validation). + /// infeasible. pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Result { let RoundLayout { shapes, @@ -80,10 +64,7 @@ impl ProtocolConfig { } } -/// Mode-dispatch input for [`build_round_config`]. The ZK variant borrows the -/// spec via `ZkSpec` and carries the planner-level `c_zk_log_inv_rate`; the -/// rest of the round's behavior (target OOD budget, sub-protocol `SolveMode`, -/// output `RoundMode` payload) is derived from this single discriminator. +/// Mode-dispatch input for [`build_round_config`]. #[derive(Clone, Copy)] enum RoundBuildMode<'a> { Standard, @@ -94,7 +75,6 @@ enum RoundBuildMode<'a> { } impl RoundBuildMode<'_> { - /// Project to the OOD-search mode consumed by [`solve_round_source`]. fn to_ood_mode(self) -> OodMode { match self { Self::Standard => OodMode::Standard, @@ -107,9 +87,6 @@ impl RoundBuildMode<'_> { } } -/// `target_folding_factor` is the next round's source folding — uniform -/// `tuning.folding_factor` — so `target_r → source_{r+1}` has matching -/// interleaving. #[derive(Debug, Clone, Copy)] struct RoundShape { round_index: usize, @@ -176,10 +153,6 @@ fn target_context(shape: &RoundShape, source: &IrsConfig) -> Ro } } -/// Per-round `(source, t_ood)` for either mode. The target dimensions -/// (`log_inv_rate` after the round's rate step, post-fold `log_degree`) feed -/// the regime's canonical-slack list-size estimate, which `solve_t_ood` uses -/// inside its OOD bound. Shared by Standard and ZK builders. fn solve_round_source( spec: &SecuritySpec, shape: &RoundShape, @@ -191,8 +164,6 @@ fn solve_round_source( .source_log_inv_rate .saturating_add(shape.source_folding_factor.saturating_sub(1)), ); - // Target encodes one polynomial of length `source.message_length()` = - // `source_vector_size / 2^source_folding_factor`. let target_log_degree = f64::from( shape .source_vector_size @@ -212,9 +183,7 @@ fn solve_round_source( } /// ZK-only: assemble the per-round mask oracle (C_zk codeword + mask-proximity -/// check). `ℓ_zk = next_pow2(r + t_ood)` from Theorem 9.6's witness layout + -/// Lemma 9.3's `r ≥ t` privacy precondition; C_zk holds `2 · num_masks` -/// columns (Construction 7.2: originals + fresh). +/// check). fn build_mask_oracle( zk_spec: ZkSpec<'_>, source: &IrsConfig, @@ -239,8 +208,7 @@ fn build_mask_oracle( debug_assert!( (c_zk.list_size() - c_zk_list_size_estimate).abs() < 1e-9 * c_zk_list_size_estimate.max(1.0), - "c_zk.list_size() {} drifted from planner estimate {} — \ - see `DecodingRegime::list_size_estimate` for the invariant", + "c_zk.list_size() {} drifted from planner estimate {}", c_zk.list_size(), c_zk_list_size_estimate, ); @@ -248,10 +216,6 @@ fn build_mask_oracle( Ok(MaskOracleConfig::new(c_zk, l_zk, mask_proximity)) } -/// Per-round builder. Under [`RoundBuildMode::ZeroKnowledge`], C_zk holds -/// `2 · (k + 1)` columns: `k` sumcheck masks (Lemma 6.4) + one `(r ‖ s)` -/// code-switch mask (Construction 9.7). `t_ood` solves Lemma 9.9 term 1 in -/// both modes. fn build_round_config( spec: &SecuritySpec, shape: &RoundShape, @@ -260,7 +224,6 @@ fn build_round_config( let ctx = round_context(shape); let (source, t_ood) = solve_round_source::(spec, shape, mode.to_ood_mode())?; - // Single mode-dispatch site; ZK additionally builds the mask oracle. let (target_budget, solve_mode, round_mode) = match mode { RoundBuildMode::Standard => ( OodSampleBudget::ZERO, @@ -314,8 +277,7 @@ fn build_round_config( )) } -/// `ℓ_zk = next_pow2(r + t_ood)`: Theorem 9.6 witness layout `0^{ℓ_zk − r}` -/// combined with Lemma 9.3's `r ≥ t` privacy precondition. +/// `ℓ_zk = next_pow2(r + t_ood)` (Theorem 9.6 + Lemma 9.3). pub(super) const fn compute_l_zk( source: &IrsConfig, t_ood: usize, @@ -328,22 +290,12 @@ pub(super) const fn compute_l_zk( ) } -/// Per-round `(source, t_ood)` from a linear search over `t_ood`. -/// -/// Under `Unique`, OOD contributes no soundness (`|Λ| = 1` ⇒ `(L choose 2) = 0`) -/// and the short-circuit pins `t_ood = 1` — the Construction 9.7 protocol-layer -/// minimum, since [`crate::protocols::code_switch::Config::new`] asserts -/// `out_domain_samples ≥ 1`. Letting the loop run would mis-evaluate the -/// `log(L·(L−1)/2) ≈ 2·log L − 1` approximation, which is `+∞` off the true -/// value when `L = 1`. +/// Per-round `(source, t_ood)`. /// -/// Under `Johnson`/`Capacity`, searches `t_ood = 1..=T_OOD_MAX_ITER` for the -/// smallest value where [`ood_security_bits_at`] meets -/// `protocol_security_target_bits`. The bound is monotone-increasing in `t` for -/// `|F| ≫ log d` (always the case here), so the first match is the minimum. -/// Source is rebuilt per iteration because `source.mask_length()` depends on -/// `t_ood` in ZK (`mask_length = in_domain + t_ood` per Construction 9.7 / -/// Theorem 9.6); the rebuild is cheap (struct fields only). +/// Under `Unique`, `t_ood = 1` is pinned (the `log(L·(L−1)/2)` term degenerates +/// when `L = 1`, and Construction 9.7 requires `out_domain_samples ≥ 1`). +/// Otherwise linear search over `t_ood = 1..=T_OOD_MAX_ITER` for the smallest +/// value where [`ood_security_bits_at`] meets `protocol_security_target_bits`. pub(super) fn solve_t_ood( spec: &SecuritySpec, src_ctx: &RoundContext, @@ -372,13 +324,6 @@ pub(super) fn solve_t_ood( /// OOD security bits at candidate `t_ood`, per STIR Lemma 4.5: /// `bits = t · (|F| − log d) − log(L · (L − 1) / 2) ≈ t·(|F| − log d) − 2·log L + 1`. -/// -/// In [`OodMode::Standard`], `L = target.list_size` and `d = source.message_length()`. -/// In [`OodMode::ZeroKnowledge`], `L = target.list_size · c_zk.list_size` and -/// `d = source.message_length() + ℓ_zk` (Lemma 9.9 witness layout). -/// -/// The approximation `log(L·(L−1)/2) ≈ 2·log L − 1` is exact-ish for `L ≥ 2`; -/// the `L = 1` case is handled by [`solve_t_ood`]'s `Unique` short-circuit. fn ood_security_bits_at( spec: &SecuritySpec, source: &IrsConfig, @@ -423,9 +368,6 @@ mod tests { }, }; - /// Varied tuning space for proptests. Exercises both `FoldingFactor` - /// variants. Bounds keep PoW under the 60-bit cap and the IRS solver - /// inside Field64's reachable range. fn arb_tuning() -> impl Strategy { let folding = prop_oneof![ (1usize..=3).prop_map(FoldingFactor::Constant), @@ -442,22 +384,12 @@ mod tests { }) } - /// `tuning_with` uses `FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR)` so - /// each round folds by 2. With `target_folding == source_folding == 2`, - /// `round_layout` keeps a round only while `num_vars ≥ 4`. const FIXTURE_FOLDING_FACTOR: usize = 2; const FIXTURE_LOG_INV_RATE: u32 = 1; - /// `log_vector_size` chosen to be below `2 · FIXTURE_FOLDING_FACTOR`, so - /// `round_layout` exits before adding any round → basecase-only plan. const LOG_VECTOR_SIZE_NO_ROUNDS: u32 = 3; - /// Large enough to produce multiple rounds under - /// `FIXTURE_FOLDING_FACTOR`-uniform folding; used by every multi-round test. const LOG_VECTOR_SIZE_MULTI_ROUND: u32 = 8; - /// Folding pair used by tests that need round-to-round folding variation - /// (rate stepping, target→source chaining). The two values must differ - /// from each other so the variation across rounds is observable. const VARIED_INITIAL_FOLDING: usize = 3; const VARIED_STEADY_FOLDING: usize = 2; @@ -469,10 +401,6 @@ mod tests { } } - /// Planner-level tests build full `ProtocolConfig`s, so we use a lower target - /// than `test_utils::FIXTURE_TARGET_BITS` (= 80). Keeps PoW below the 60-bit - /// cap when every sub-protocol grinds individually. 40 leaves - /// `target − analytic_error ≤ 60` on `Field64`. const PLAN_FIXTURE_TARGET_BITS: u32 = 40; fn test_spec(mode: Mode) -> SecuritySpec { @@ -480,21 +408,14 @@ mod tests { mode, decoding_regime: DecodingRegime::Johnson, target_security_bits: PLAN_FIXTURE_TARGET_BITS, - // Allow up to the grind cap; derive() auto-validates the budget - // and would reject configs that need any PoW under `Forbidden`. pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), hash_id: hash::BLAKE3, } } - /// `> 1` so the first round's rate is distinct from the boundary. const RATE_STEPPING_STARTING_LOG_INV_RATE: u32 = 2; - /// Pairwise `windows(2)` chaining check needs ≥ 2 rounds. const MIN_ROUNDS_FOR_CHAINING_TEST: usize = 2; - /// Each round's source rate steps up by `source_folding - 1`. The basecase - /// inherits the rate after the final round. Uses varied folding so the - /// per-round step is non-uniform (initial step = 2, steady step = 1). #[test] fn round_layout_rate_steps_up_by_folding_minus_one() { let tuning = TuningSpec { @@ -515,9 +436,6 @@ mod tests { assert_eq!(layout.basecase_log_inv_rate, expected_log_inv_rate); } - /// Cross-round chaining: round `i`'s target folding factor must match - /// round `i+1`'s source folding factor (the doc-comment on `RoundShape` - /// codifies this). Varied folding makes the check non-vacuous. #[test] fn round_layout_chains_target_to_next_source_folding() { let tuning = TuningSpec { @@ -541,8 +459,6 @@ mod tests { } } - /// Basecase consumes whatever `num_vars` the round loop left behind: - /// `basecase_vector_size = 2^(initial_num_vars - sum(source_folding_factor))`. #[test] fn round_layout_basecase_size_consumes_remaining_num_vars() { let tuning = tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND); @@ -553,9 +469,6 @@ mod tests { assert_eq!(layout.basecase_vector_size, 1usize << remaining); } - /// Loop exits when `num_vars < source_folding + target_folding`. Below the - /// `2 · FIXTURE_FOLDING_FACTOR` threshold, no round is admitted and the - /// basecase carries the whole vector at the starting rate. #[test] fn round_layout_stops_when_no_room_for_source_plus_target() { let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; @@ -575,8 +488,6 @@ mod tests { assert_eq!(plan.basecase().commit.vector_size, vector_size); } - /// ZK with zero WHIR rounds = ZK basecase only. Per-round mask oracles are - /// absent (there are no rounds); the basecase γ-slot PoW carries soundness. #[test] fn derive_zk_with_no_rounds_uses_zk_basecase_only() { let spec = test_spec(Mode::ZeroKnowledge); @@ -592,9 +503,6 @@ mod tests { )); } - /// Johnson + ZK: each round runs a non-trivial OOD challenge to amplify - /// the list-decoding soundness gap (Lemma 9.9). `solve_t_ood`'s linear - /// search lands at the smallest `t_ood` clearing the security target. #[test] fn t_ood_nonzero_in_johnson_zk() { let spec = SecuritySpec { @@ -614,10 +522,6 @@ mod tests { } } - /// Unique + ZK: OOD contributes no soundness (`|Λ| = 1`), but - /// `protocols::code_switch::Config::new` requires `out_domain_samples ≥ 1` - /// to run Construction 9.7 Steps 2-3. `solve_t_ood` pins `t_ood = 1` - /// exactly — sharper than the Johnson-side `≥ 1` invariant. #[test] fn t_ood_pinned_to_one_in_unique_zk() { let spec = SecuritySpec { @@ -637,13 +541,6 @@ mod tests { } } - /// Under Unique decoding, the C_zk mask oracle still carries the full - /// `2 · (k + 1)` columns — `k` sumcheck masks (Lemma 6.4) plus the - /// `(r ‖ s)` code-switch mask (Construction 9.7). With `t_ood = 1`, the - /// `s`-tail has length `ℓ_zk − r ≥ 1` and supports the - /// Vandermonde-surjectivity ZK argument (bounds doc §5.3 / Bound 3). - /// Pins the shape so an accidental "drop code-switch mask under Unique" - /// optimization can't slip in unnoticed. #[test] fn c_zk_keeps_code_switch_mask_under_unique() { let spec = SecuritySpec { @@ -658,7 +555,7 @@ mod tests { for r in plan.rounds() { let mask_oracle = r.mask_oracle().expect("ZK round has a mask oracle"); let k = r.code_switch().source.interleaving_depth.trailing_zeros() as usize; - let expected_num_masks = k + 1; // k sumcheck + 1 code-switch + let expected_num_masks = k + 1; assert_eq!(mask_oracle.c_zk().num_vectors, 2 * expected_num_masks); } } @@ -724,15 +621,10 @@ mod tests { crate::protocols::basecase::BasecaseMode::ZeroKnowledge )); assert_eq!(plan.basecase().commit.interleaving_depth, 1); - // Sumcheck folds basecase to size 1. assert_eq!(plan.basecase().sumcheck.final_size(), 1); } - /// Matches `proof_of_work::threshold`'s 60-bit cap. const LOOSE_POW_BUDGET_BITS: u32 = 60; - /// Sits between a moderate budget (30) and the grind cap (60) — used by - /// `check_pow_bits_detects_over_budget_slot` to inject a slot that fits - /// the cap but exceeds the test's `pow_budget`. const OVER_BUDGET_INJECTED_BITS: f64 = 50.0; /// Bounds doc §5.3 + §5.7: HVZK privacy error in bits matches the closed @@ -759,8 +651,6 @@ mod tests { assert_close(got, expected_bits); } - /// Standard-mode plans have no HVZK claim — `privacy_error_bits` returns - /// the spec's `target_security_bits` as a sentinel. #[test] fn privacy_error_bits_standard_returns_target_sentinel() { let spec = test_spec(Mode::Standard); @@ -775,7 +665,6 @@ mod tests { ); } - /// Derived plans must satisfy their own `pow_budget`. #[test] fn check_pow_bits_passes_on_derived_plan() { let plan = ProtocolConfig::::derive( @@ -786,12 +675,6 @@ mod tests { assert!(plan.check_pow_bits()); } - /// Hand-injected over-budget PoW slot fails `check_pow_bits()`. - /// - /// Derive with a moderately tight budget (passes auto-validation because - /// the natural slot pow stays well below it), then mutate the basecase - /// pow to a value above that budget but still within the grind cap, and - /// verify the boolean check trips. #[test] fn check_pow_bits_detects_over_budget_slot() { use crate::{bits::Bits, protocols::proof_of_work::Config as PowConfig}; @@ -811,10 +694,6 @@ mod tests { assert!(!plan.check_pow_bits()); } - /// `validate_round_chaining` trips when round `i`'s target `vector_size` - /// no longer matches round `i+1`'s source. Covers the adjacent-rounds - /// `windows(2)` branch — distinct from the basecase branch, which is - /// covered by `validate_round_chaining_detects_basecase_mismatch`. #[test] fn validate_round_chaining_detects_adjacent_round_mismatch() { let spec = test_spec(Mode::ZeroKnowledge); @@ -827,9 +706,6 @@ mod tests { assert!(n >= 2, "need ≥ 2 rounds to break a mid-chain link"); assert!(plan.check_all_invariants(), "fresh plan must validate"); - // Round 0's natural target.vector_size is some power of 2; bumping - // it to a value the next round's source can't match (the source - // still carries the originally-derived size) breaks the chain. let bad_size = plan.rounds()[0].code_switch().target.vector_size + 1; plan.corrupt_round_target_vector_size_for_test(0, bad_size); @@ -850,9 +726,6 @@ mod tests { assert!(!plan.check_all_invariants()); } - /// `validate_round_chaining` trips when the basecase no longer chains - /// to the (new) last round after the tail is dropped. Multi-round plan - /// is required so dropping the last leaves at least one round behind. #[test] fn validate_round_chaining_detects_basecase_mismatch() { let spec = test_spec(Mode::ZeroKnowledge); @@ -882,9 +755,6 @@ mod tests { assert!(!plan.check_all_invariants()); } - /// `derive()` reports `PowUngrindable` when the spec demands a per-slot - /// difficulty above the grind cap. `target_security_bits = 200` against - /// `analytic ≈ 64` on `Field64` gives `required ≈ 136` ≫ 60. #[test] fn derive_reports_pow_ungrindable() { const UNREACHABLE_TARGET_BITS: u32 = 200; @@ -903,9 +773,6 @@ mod tests { ); } - /// `derive()` reports `PowBudgetExceeded` when a slot's required PoW - /// fits the grind cap but exceeds `spec.pow_budget`. `target = 40` - /// with `pow_budget = PerSlot { bits: 5 }` forces this on `Field64`. #[test] fn derive_reports_pow_budget_exceeded() { const TIGHT_MAX_POW: u32 = 5; @@ -924,9 +791,6 @@ mod tests { ); } - /// Unique decoding threads through to the basecase IRS in Standard mode. - /// Uses a basecase-only tuning so the regime is unambiguous (no rate - /// stepping across rounds). #[test] fn derive_threads_unique_decoding_standard() { let spec = SecuritySpec { @@ -942,7 +806,6 @@ mod tests { assert!(plan.basecase().commit.unique_decoding()); } - /// Same threading check under ZK mode (basecase-only fixture). #[test] fn derive_threads_unique_decoding_zk() { let spec = SecuritySpec { @@ -958,9 +821,6 @@ mod tests { assert!(plan.basecase().commit.unique_decoding()); } - /// Multi-round derivation under Unique: every round's IRS carries the - /// Unique regime and every code-switch slot satisfies the Construction - /// 9.7 `t_ood ≥ 1` floor. #[test] fn derive_multi_round_unique_decoding_succeeds() { let spec = SecuritySpec { @@ -977,13 +837,11 @@ mod tests { let cs = r.code_switch(); assert!(cs.source.unique_decoding()); assert!(cs.target.unique_decoding()); - assert!(cs.out_domain_samples >= 1, "Construction 9.7 floor"); + assert!(cs.out_domain_samples >= 1); } assert!(plan.basecase().commit.unique_decoding()); } - /// ZK + Unique multi-round: per-round mask oracle still assembled, C_zk - /// built under Unique, code-switch carries `t_ood ≥ 1` per floor. #[test] fn derive_multi_round_unique_decoding_zk_succeeds() { let spec = SecuritySpec { @@ -1005,8 +863,6 @@ mod tests { assert!(plan.basecase().commit.unique_decoding()); } - /// Multi-round Capacity (Standard): IRS configs carry the Capacity regime - /// and the `c_zk_list_size(t)` fixed-point resolves inside `solve_t_ood`. #[test] fn derive_multi_round_capacity_decoding_succeeds() { let spec = SecuritySpec { @@ -1024,8 +880,6 @@ mod tests { } } - /// ZK + Capacity multi-round: exercises the degree-dependent c_zk list - /// size inside the t_ood fixed-point. #[test] fn derive_multi_round_capacity_decoding_zk_succeeds() { let spec = SecuritySpec { @@ -1044,7 +898,6 @@ mod tests { } } - /// `analytic_error + pow ≥ target` for every PoW slot in the plan. fn assert_plan_meets_target_per_slot( spec: &SecuritySpec, plan: &ProtocolConfig, @@ -1081,7 +934,6 @@ mod tests { sumcheck_params::analytic_error_bits(&plan.basecase().commit, None), &plan.basecase().sumcheck.round_pow, ); - // γ-slot is ZK-only. if matches!( plan.basecase().mode, crate::protocols::basecase::BasecaseMode::ZeroKnowledge @@ -1095,8 +947,6 @@ mod tests { } proptest! { - /// End-to-end soundness (Standard): every PoW slot in the derived plan - /// closes the gap `analytic + pow ≥ target` against the spec target. #[test] fn derived_plan_meets_target_per_slot_standard(tuning in arb_tuning()) { let spec = test_spec(Mode::Standard); @@ -1104,8 +954,6 @@ mod tests { assert_plan_meets_target_per_slot(&spec, &plan); } - /// End-to-end soundness (ZK): same as above, plus the per-round - /// mask-proximity slot and the basecase γ-slot. #[test] fn derived_plan_meets_target_per_slot_zk(tuning in arb_tuning()) { let log_threshold = @@ -1116,8 +964,6 @@ mod tests { assert_plan_meets_target_per_slot(&spec, &plan); } - /// Standard mode: derive succeeds for any tuning shape, no per-round - /// mask oracle, and basecase covers the post-fold tail. #[test] fn derive_standard_succeeds_over_tunings(tuning in arb_tuning()) { let spec = test_spec(Mode::Standard); @@ -1133,8 +979,6 @@ mod tests { prop_assert_eq!(plan.basecase().commit.interleaving_depth, 1); } - /// ZK mode: each round has its own mask oracle sized for `k + 1` - /// masks; basecase is ZK-flagged when shapes are non-empty. #[test] fn derive_zk_succeeds_over_tunings(tuning in arb_tuning()) { let log_threshold = @@ -1155,7 +999,6 @@ mod tests { let num_masks = k + 1; prop_assert_eq!(mask_oracle.c_zk().num_vectors, 2 * num_masks); prop_assert_eq!(mask_oracle.mask_proximity().num_masks, num_masks); - // Theorem 9.6 / Lemma 9.3: ℓ_zk ≥ r + t_ood for this round. let source_mask = cs.source.mask_length(); prop_assert!(mask_oracle.l_zk().get() >= source_mask + t_ood.get()); } @@ -1165,8 +1008,6 @@ mod tests { )); } - /// `analytic_bits` is finite and non-negative for any tuning the - /// planner accepts in Standard mode. #[test] fn analytic_bits_finite_and_non_negative_standard(tuning in arb_tuning()) { let spec = test_spec(Mode::Standard); diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index d054cc5b..f753ffda 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -1,10 +1,5 @@ //! Errors raised by [`super::derive::ProtocolConfig::derive`] and the //! sub-protocol solvers. -//! -//! Two layers: [`super::super::proof_of_work::PowError`] for grinding-cap -//! failures, [`DeriveError`] for everything `derive()` can surface. The latter -//! wraps the former via [`DeriveError::PowUngrindable::source`] so callers can -//! walk the `std::error::Error::source()` chain. use std::fmt::{self, Display, Formatter}; @@ -18,11 +13,7 @@ use crate::{ }, }; -/// Identifies a single PoW grind in the derived protocol — basecase -/// sub-protocol or a per-round sub-protocol at a specific round index. Used -/// to label grinding-cap and budget failures. -/// -/// Flat by design: each variant is one valid (where, sub-protocol) pair. +/// Identifies a single PoW grind in the derived protocol. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Pow { /// Basecase γ-RLC grind (Lemma 7.4) — ZK mode only. @@ -88,8 +79,7 @@ impl Display for ChainTarget { #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum DeriveError { /// The `t_ood` fixed-point in [`super::derive::solve_t_ood`] ran out of - /// iterations. Indicates a pathological spec/tuning combo; should not - /// happen under realistic security targets on supported fields. + /// iterations. #[error("t_ood fixed-point did not converge for round {round_index}")] FixedPointDidNotConverge { round_index: usize }, @@ -125,8 +115,7 @@ pub enum DeriveError { } /// Lift `Result` into `Result` by attaching a -/// [`Pow`] label. Lets call sites stay single-line — no manual -/// `.map_err(|e| DeriveError::PowUngrindable { pow, source: e })` boilerplate. +/// [`Pow`] label. pub(crate) trait PowResultExt { fn at(self, pow: Pow) -> Result; } @@ -138,9 +127,7 @@ impl PowResultExt for Result { } /// Grind `analytic → spec.target_security_bits`, then check the result against -/// `spec.pow_budget` — both failures attributed to `pow_kind` at the same site. -/// `ProtocolConfig::validate_pow_budget` remains as a defense-in-depth check -/// for hand-mutated plans. +/// `spec.pow_budget`. pub(crate) fn grind_to_at( spec: &SecuritySpec, analytic: Bits, diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs index 430dea57..9e0c8ea6 100644 --- a/src/protocols/params/irs_commit.rs +++ b/src/protocols/params/irs_commit.rs @@ -1,13 +1,7 @@ //! IRS-commit parameter selection. //! -//! ZK mask is sized at the tight bound `in-domain + OOD` queries. Construction -//! 9.7 / Theorem 9.6 (paper p.54-55) reveals `in_domain + t_ood` source -//! positions per round (in-domain queries + OOD linear combinations via -//! ze_ood), so the source encoding must be `(in_domain + t_ood)`-query ZK -//! (Definition 3.16, p.29). For the Reed–Solomon code that means -//! `mask_length = in_domain + t_ood` (Proposition 3.19, p.30). Codeword -//! NTT-smoothness is enforced inside [`IrsConfig::new`] on `codeword_length`, -//! not by inflating the mask. +//! ZK mask sizing follows Construction 9.7 / Theorem 9.6: +//! `mask_length = in_domain + t_ood` (Proposition 3.19). use crate::{ algebra::embedding::Embedding, @@ -35,12 +29,6 @@ pub fn solve( let mode = match spec.mode { Mode::Standard => IrsMode::Standard, Mode::ZeroKnowledge => { - // Construction 9.7 / Theorem 9.6: the verifier reveals - // `in_domain + t_ood` source positions (in-domain queries + - // ze_ood linear combinations), so the source encoding must be - // (in_domain + t_ood)-query ZK (Definition 3.16). Use the tight - // RS bound `mask_length = t` from Proposition 3.19; do not - // pow2-pad here. let mask_length = num_in_domain_queries(spec.decoding_regime, security_target, rate) .saturating_add(out_domain_samples.get()); IrsMode::ZeroKnowledge { mask_length } @@ -51,7 +39,7 @@ pub fn solve( security_target, spec.decoding_regime, spec.hash_id, - 1, // one vector committed per round + 1, ctx.vector_size, interleaving_depth, rate, @@ -157,10 +145,6 @@ mod tests { ); } - /// `irs_commit::solve` doesn't grind PoW, so this range can sit higher than - /// the shared `TEST_TARGET_RANGE` (which is capped at 50 to keep the PoW - /// gap below the 60-bit threshold). 80..=128 covers production-realistic - /// target sizes. const IRS_TARGET_RANGE: std::ops::RangeInclusive = 80..=128; fn arb_zk_spec_default() -> impl Strategy { @@ -172,8 +156,6 @@ mod tests { } proptest! { - /// Construction 9.7 / Theorem 9.6: mask covers all revealed source - /// positions (in-domain queries + OOD linear combinations). #[test] fn zk_mask_covers_in_domain_plus_ood( spec in arb_zk_spec_default(), @@ -199,17 +181,11 @@ mod tests { } } - /// Smoke-test fixture: 64-element vector folded by 2 at rate 1/2 — small - /// but produces a non-degenerate IRS for the non-identity embedding. const SMOKE_VECTOR_SIZE: usize = 64; const SMOKE_LOG_INV_RATE: u32 = 1; const SMOKE_FOLDING_FACTOR: u32 = 2; - /// Arbitrary > 0 so the ZK mask sizing exercises the OOD path. const SMOKE_OOD_BUDGET: usize = 2; - /// Smoke test: `M::Source ≠ M::Target`, ZK path. Mask sizing depends only - /// on the target field (via `field_size_bits`), but the generic embedding - /// still flows through the Config and must compile + execute. #[test] fn solve_works_with_basefield_embedding_zk() { let spec = deterministic_spec(Mode::ZeroKnowledge); diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index 0f917d2b..f60b0476 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -18,7 +18,6 @@ use crate::{ }; /// `c_zk.num_vectors` must equal `2 * num_masks` (originals + fresh). -/// PoW closes the Lemma 7.4 γ-combination gap to `spec.target_security_bits`. pub fn solve( spec: &SecuritySpec, c_zk: IrsConfig>, @@ -73,13 +72,6 @@ mod tests { }, }; - /// γ-combination (Lemma 7.4): `log|F| − log(num_masks · (deg − 1))`, - /// `deg = c_zk.masked_message_length()`. With `num_masks = 0` or `deg ≤ 1` - /// the bound saturates to `field_bits`. - /// Pow2 `l_zk = 8` gives exact `log2(deg − 1) = log2(7) ≈ 2.81`. - /// `num_masks = 3` is the smallest count > 1 (so `num_masks · (deg − 1) > 1` - /// and the formula doesn't saturate). `log_inv_rate = 1` is the minimum - /// rate the C_zk solver accepts. const FIXTURE_L_ZK: usize = 8; const FIXTURE_NUM_MASKS: usize = 3; const FIXTURE_LOG_INV_RATE: u32 = 1; @@ -99,7 +91,6 @@ mod tests { assert_close(got, expected); } - /// Degenerate inputs (`num_masks == 0` or `deg ≤ 1`) saturate to `field_bits`. #[test] fn analytic_error_saturates_when_no_masks() { let spec = deterministic_spec(Mode::ZeroKnowledge); @@ -124,7 +115,6 @@ mod tests { prop_assert_eq!(config.c_zk_commit.interleaving_depth, 1); } - /// `analytic_error + pow ≥ target` (Lemma 7.4 γ-combination). #[test] fn pow_closes_gap_to_target( spec in arb_zk_spec(TEST_TARGET_RANGE), @@ -139,9 +129,6 @@ mod tests { } } - /// `mask_proximity::solve` requires `c_zk.num_vectors == 2 · num_masks`. - /// Builds C_zk for `num_masks = 2` (so `num_vectors = 4`), then calls - /// `solve` with `num_masks = 3` to trip the assertion. #[test] #[should_panic(expected = "c_zk.num_vectors must be 2 * num_masks")] fn solve_rejects_mismatched_num_vectors() { @@ -153,8 +140,6 @@ mod tests { #[test] #[should_panic(expected = "interleaving_depth = 1")] fn solve_rejects_non_unit_interleaving() { - // All values except `NON_UNIT_INTERLEAVING_DEPTH` are chosen to satisfy - // `Config::new`'s divisibility/pow2 constraints. const SECURITY_TARGET_BITS: f64 = 80.0; const NUM_VECTORS: usize = 2; const VECTOR_SIZE: usize = 8; diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index d3229b3d..3f447c46 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -3,10 +3,6 @@ //! Soundness and ZK bound derivations (referred to in submodule comments as //! "the bounds doc, §N") live at //! . -//! -//! `derive` is the public entry point; the sub-protocol solvers (`basecase`, -//! `code_switch`, `irs_commit`, `mask_proximity`, `sumcheck`) are crate-local -//! and reached only via `derive`. Output and spec types are re-exported below. pub(crate) mod basecase; pub(crate) mod bounds; @@ -33,15 +29,6 @@ pub use spec::{ }; /// Solver-input mode for the per-round sumcheck and code-switch builders. -/// -/// Both sub-protocols branch on the same Standard vs. ZK distinction with the -/// same `MaskOracleInfo` payload, so a shared vocabulary keeps call sites -/// uniform. -/// -/// Distinct from [`Mode`] (the spec-level policy enum, which carries no -/// payload) and from sub-protocol *output* modes -/// (`sumcheck::SumcheckMode`, `code_switch::CodeSwitchMode`) whose payloads -/// describe the configured round rather than its solver input. #[derive(Clone, Copy)] pub enum SolveMode { Standard, diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 0d60d210..8d4c0bc9 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -1,14 +1,4 @@ //! Output of [`super::derive`]: the assembled per-round and basecase configs. -//! -//! Each ZK round owns its mask oracle: a per-round C_zk codeword (sized for -//! `2·(k+1)` columns — `k` sumcheck masks + 1 code-switch `(r ‖ s)` mask, all -//! doubled by Construction 7.2's originals + fresh pairs) plus a per-round -//! mask-proximity check. Standard rounds carry no mask oracle. -//! -//! The post-construction structures (`ProtocolConfig`, `RoundConfig`, -//! `MaskOracleConfig`) expose only read accessors externally — invariants -//! validated by [`ProtocolConfig::validate`] survive past the call site -//! because there is no public mutation surface. use ark_ff::Field; @@ -75,32 +65,23 @@ impl ProtocolConfig { } /// `true` if every PoW slot's difficulty fits within `security.pow_budget`. - /// Boolean form of [`Self::validate_pow_budget`]. pub fn check_pow_bits(&self) -> bool { self.validate_pow_budget().is_ok() } - /// Returns `true` if every post-construction invariant holds: PoW - /// budget, mask-oracle coherence, and cross-round shape chaining. + /// Returns `true` if every post-construction invariant holds. pub fn check_all_invariants(&self) -> bool { self.validate().is_ok() } - /// Run every post-construction invariant check. Auto-invoked by - /// `derive()`; callers only need this after manual inspection (and only - /// then through the `pub(crate)` test shim, since fields are private). - /// - /// Mask-oracle coherence is *not* a separate check: the per-round - /// `mask_oracle` lives inside `RoundMode::ZeroKnowledge`, so its - /// presence ↔ ZK-ness equivalence is enforced by the type system. + /// Run every post-construction invariant check. pub fn validate(&self) -> Result<(), DeriveError> { self.validate_pow_budget()?; self.validate_round_chaining()?; Ok(()) } - /// PoW slot difficulty ≤ `security.pow_budget` for every slot. Auto-invoked - /// by `derive()` via [`Self::validate`]. + /// PoW slot difficulty ≤ `security.pow_budget` for every slot. pub fn validate_pow_budget(&self) -> Result<(), DeriveError> { let max = Bits::new(f64::from(self.security.pow_budget.bits())); let check = |pow: Pow, cfg: &PowConfig| -> Result<(), DeriveError> { @@ -180,16 +161,12 @@ impl ProtocolConfig { /// HVZK privacy error in bits, summed across ZK rounds: /// `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` (bounds doc, §5.3 + §5.7). - /// Standard-mode plans return `target_security_bits` as a sentinel — - /// HVZK isn't claimed when there are no ZK rounds. pub fn privacy_error_bits(&self) -> Bits { let field_bits = ::field_size_bits(); let mut total_error = 0.0_f64; for r in &self.rounds { if let RoundMode::ZeroKnowledge { t_ood, .. } = &r.mode { let t = usize_to_f64(t_ood.get()); - // ζ_ze ≤ (t_ood² + t_ood) / (2|F|). Compute in log space to - // stay numerically stable for large field_bits. let log_err = f64::midpoint(t * t, t).log2() - field_bits; total_error += 2_f64.powf(log_err); } @@ -202,8 +179,7 @@ impl ProtocolConfig { } impl ProtocolConfig { - /// Analytic soundness bits (excluding PoW): minimum over basecase and - /// every round. + /// Analytic soundness bits (excluding PoW). pub fn analytic_bits(&self) -> Bits { let mut min_bits = f64::from(self.basecase.analytic_bits()); for round in &self.rounds { @@ -213,26 +189,16 @@ impl ProtocolConfig { } } -/// Test-only mutators. Grouped here so the production `impl` block above -/// reads as the public API surface; no equivalent on the public API. #[cfg(test)] impl ProtocolConfig { - /// Inject an over-budget basecase PoW slot so `validate_pow_budget` can - /// be exercised on a corrupted plan. pub(crate) const fn override_basecase_pow_for_test(&mut self, pow: PowConfig) { self.basecase.pow = pow; } - /// Drop the tail of `rounds` so the basecase's chained `vector_size` no - /// longer matches the (new) last round — trips the basecase branch of - /// `validate_round_chaining`. pub(crate) fn truncate_rounds_for_test(&mut self, len: usize) { self.rounds.truncate(len); } - /// Overwrite a round's code-switch target `vector_size` so the next - /// round's source no longer chains — trips the adjacent `windows(2)` - /// branch of `validate_round_chaining`, which truncation cannot reach. pub(crate) fn corrupt_round_target_vector_size_for_test( &mut self, round_idx: usize, @@ -247,10 +213,6 @@ pub struct RoundConfig { round_index: usize, sumcheck: SumcheckConfig, code_switch: CodeSwitchConfig, - /// Standard vs. ZK — and in ZK mode, owns the round's full mask oracle - /// directly. No separate `mask_oracle` field on `RoundConfig`: the - /// variant tag is the single source of truth for both ZK-ness and the - /// oracle's presence/contents. mode: RoundMode, } @@ -285,8 +247,7 @@ impl RoundConfig { &self.mode } - /// Convenience: borrow the round's mask oracle if this is a ZK round. - /// Equivalent to pattern-matching on `mode()`. + /// Borrow the round's mask oracle if this is a ZK round. pub fn mask_oracle(&self) -> Option<&MaskOracleConfig> { match &self.mode { RoundMode::Standard => None, @@ -294,8 +255,7 @@ impl RoundConfig { } } - /// Slim mask-oracle view derived from `mask_oracle()`. Produced on - /// demand — there is no stored copy. + /// Slim mask-oracle view derived from `mask_oracle()`. pub fn mask_oracle_info(&self) -> Option { self.mask_oracle().map(MaskOracleConfig::info) } @@ -303,18 +263,14 @@ impl RoundConfig { /// Standard vs. ZK round. /// -/// The ZK payload is boxed so the enum stays small: `MaskOracleConfig` is -/// ~330 B while the `Standard` variant is 0 B, and proofs hold O(rounds) -/// `RoundMode`s. Accessors expose `&MaskOracleConfig` so call sites are -/// unaffected by the indirection. +/// The ZK payload is boxed so the enum stays small. #[derive(Clone, Debug)] pub enum RoundMode { Standard, ZeroKnowledge { /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). t_ood: OodSampleBudget, - /// Per-round mask oracle: C_zk codeword (sized for `2·(k+1)` - /// columns) + ℓ_zk + mask-proximity check for `k+1` masks. + /// Per-round mask oracle. mask_oracle: Box>, }, } @@ -327,9 +283,7 @@ impl RoundMode { impl RoundConfig { /// Round-level analytic floor: the smallest of `sumcheck`, `code_switch`, - /// and (when present) the per-round mask-oracle proximity check. Folding - /// the mask-oracle term in here keeps `ProtocolConfig::analytic_bits` - /// a pure `min` over rounds + basecase. + /// and (when present) the per-round mask-oracle proximity check. pub fn analytic_bits(&self) -> Bits { let source = &self.code_switch.source; let target = &self.code_switch.target; @@ -355,14 +309,11 @@ impl RoundConfig { } } -/// One round's mask oracle: a C_zk codeword + ℓ_zk + mask-proximity check -/// covering `k + 1` masks (sumcheck + code-switch) for this round. +/// One round's mask oracle: a C_zk codeword + ℓ_zk + mask-proximity check. #[derive(Clone, Debug)] pub struct MaskOracleConfig { - /// `num_vectors = 2 · (k + 1)` (Construction 7.2: originals + fresh). c_zk: IrsConfig>, - /// `next_pow2(r + t_ood)` for this round: Theorem 9.6 witness layout - /// (`0^{ℓ_zk − r}` padding) + Lemma 9.3 `(ℓ_zk − r, 0)`-privacy precondition. + /// `next_pow2(r + t_ood)` (Theorem 9.6 + Lemma 9.3). l_zk: MaskCodeMessageLen, mask_proximity: MaskProximityConfig, } @@ -401,11 +352,6 @@ impl MaskOracleConfig { } /// Slim mask-oracle view (C_zk's list size + ℓ_zk). -/// -/// Reached only through `RoundMode::ZeroKnowledge`'s field, which is itself -/// accessible only via `RoundConfig::mode() -> &RoundMode`. The public -/// surface therefore stays read-only even though the variant fields are -/// nominally `pub`. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct MaskOracleInfo { pub c_zk_list_size: ListSize, diff --git a/src/protocols/params/regime.rs b/src/protocols/params/regime.rs index fcb52c08..478a71bd 100644 --- a/src/protocols/params/regime.rs +++ b/src/protocols/params/regime.rs @@ -1,11 +1,6 @@ //! Reed–Solomon decoding regime — materialized per-round parameters and the //! analytic helpers that depend on them. //! -//! Spec-level policy lives in [`super::spec::DecodingRegime`] (rate-independent, -//! a user choice). The data-carrying [`DecodingRegimeParams`] is what gets -//! stored on per-round configs once a rate is known: [`Self::from_policy`] -//! is the single materialization point. -//! //! # References //! //! - Johnson proximity-gap error follows the BCSS25 improvement @@ -24,9 +19,6 @@ use crate::protocols::params::{ }; /// Materialized decoding-regime parameters at a known rate. -/// -/// `Unique` carries no data; `Johnson` and `Capacity` each carry the slack `η` -/// from their respective proximity boundary. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum DecodingRegimeParams { Unique, @@ -35,9 +27,8 @@ pub enum DecodingRegimeParams { } impl DecodingRegimeParams { - /// Materialize spec policy at a known rate. Canonical slacks (`√ρ/20` for - /// Johnson, `ρ/20` for Capacity) live here — any tuning of `η` propagates - /// to every per-round config through this single site. + /// Materialize spec policy at a known rate. Canonical slacks: `√ρ/20` for + /// Johnson, `ρ/20` for Capacity. // TODO: Optimize picking η. pub fn from_policy(policy: DecodingRegime, rate: f64) -> Self { match policy { @@ -66,10 +57,6 @@ impl DecodingRegimeParams { } /// `log₂ |Λ(C, δ)|`. - /// - /// `log_degree` is `log₂` of the code's message length; it's only read in - /// the `Capacity` branch (STIR Conj 5.6 gives `|Λ| = d/(ρ·η)`). Johnson - /// and Unique ignore it. pub fn list_size_log2(self, log_degree: f64, log_inv_rate: f64) -> f64 { match self { Self::Unique => 0.0, @@ -95,12 +82,8 @@ impl DecodingRegimeParams { one_minus_delta.log2() } - /// Bits of security delivered by `ood_samples` OOD challenges on a code - /// of given `log_degree` and `log_inv_rate` at MCA arity 2. - /// - /// STIR Lemma 4.5: the error is `(L choose 2) · ((d − 1)/|F|)^{ood_samples}`, - /// giving security `ood · (|F| − log d) − 2·log|Λ| + 1` bits. Returns `0` - /// under `Unique` — OOD contributes no soundness when `|Λ| = 1`. + /// Bits of security delivered by `ood_samples` OOD challenges (STIR Lemma 4.5): + /// `ood · (|F| − log d) − 2·log|Λ| + 1`. Returns `0` under `Unique`. pub fn ood_security_bits( self, log_degree: f64, @@ -116,17 +99,12 @@ impl DecodingRegimeParams { ood * (field_bits - log_degree) - 2.0 * log_list + 1.0 } - /// `log₂ ε_mca(C, δ)` for the per-step proximity-gaps error (bare, no - /// arity factor — callers apply their own). + /// `log₂ ε_mca(C, δ)` for the per-step proximity-gaps error. /// - /// - Unique: `(k − 1) / |F|`, log = `log k − |F|` (with `+ log ρ⁻¹` to - /// pick up the `n/|F|` factor). + /// - Unique: `(k − 1) / |F|`, log = `log k − |F|` (with `+ log ρ⁻¹`). /// - Johnson: BCSS25 Theorem 1.5 at canonical `η = √ρ/20`, `m = 10`: /// `ε ≈ (2·10.5⁵/3) · n · ρ^{−3/2} / |F|`. /// - Capacity: STIR Conj 5.6, `ε ≈ d / (η · ρ²) / |F|`. - /// - /// The formula expressions hardcode the canonical slack; debug-asserts - /// catch a non-canonical `slack` that would invalidate the constants. pub fn eps_mca_log2(self, log_inv_rate: f64, message_length: usize, field_bits: f64) -> f64 { let log_k = usize_to_f64(message_length).log2(); let error = match self { @@ -150,15 +128,7 @@ impl DecodingRegimeParams { } impl DecodingRegime { - /// `|Λ|` at canonical slack, before an IRS config exists. Use the - /// `DecodingRegimeParams::list_size` method when a non-canonical slack - /// has already been materialized. - /// - /// Matches `IrsConfig::list_size()` when the IRS is built under the same - /// regime, with the same `masked_message_length`, and `ntt::next_order` - /// doesn't pad the codeword (pow2 `vector_size`, `interleaving_depth = 1`, - /// integer `log_inv_rate`, 2-adic field — the conditions `solve_mask_code` - /// enforces for C_zk). + /// `|Λ|` at canonical slack, before an IRS config exists. pub fn list_size_estimate(self, log_degree: f64, log_inv_rate: f64) -> f64 { DecodingRegimeParams::from_policy(self, rate(log_inv_rate)) .list_size(log_degree, log_inv_rate) @@ -170,8 +140,6 @@ mod tests { use super::*; use crate::protocols::params::test_utils::assert_close; - /// Tighter tolerance for tests doing relative-error checks against an - /// alternative-derived expected value with the same operations. const TIGHT_EPS: f64 = 1e-12; fn johnson(slack: f64) -> DecodingRegimeParams { @@ -186,9 +154,7 @@ mod tests { } } - /// Johnson list size: `|Λ| = 1 / (2η√ρ)`, log₂ form. Hand-evaluated at - /// `log_inv_rate = 2`, `η = 0.1`: `−1 − log₂(0.1) + 1 ≈ 3.3219`. - /// `log_degree` is ignored by the Johnson branch. + /// Johnson list size: `|Λ| = 1 / (2η√ρ)`, log₂ form. #[test] fn list_size_log2_johnson_formula() { let got = johnson(0.1).list_size_log2(/* log_degree */ 4.0, 2.0); @@ -196,8 +162,7 @@ mod tests { assert_close(got, expected); } - /// Capacity list size: `|Λ| = d / (ρ · η)`, log₂ form. At `log_degree = 4`, - /// `log_inv_rate = 2`, `η = 1/8`: `4 + 2 − log₂(1/8) = 4 + 2 + 3 = 9`. + /// Capacity list size: `|Λ| = d / (ρ · η)`, log₂ form. #[test] fn list_size_log2_capacity_formula() { let got = capacity(0.125).list_size_log2(4.0, 2.0); @@ -211,8 +176,7 @@ mod tests { assert_close(DecodingRegimeParams::Unique.list_size_log2(4.0, 2.0), 0.0); } - /// `η = √ρ / 20` substituted into `|Λ| = 1/(2η√ρ)` simplifies to `10/ρ`. - /// So `DecodingRegime::Johnson.list_size_estimate(_, b) = 10 · 2^b`. + /// `η = √ρ/20` ⇒ `|Λ| = 10/ρ` ⇒ `list_size_estimate(_, b) = 10 · 2^b`. #[test] fn johnson_list_size_closed_form() { for b in [1.0, 2.0, 3.0, 5.0] { @@ -225,7 +189,7 @@ mod tests { } } - /// `η = ρ / 20` substituted into `|Λ| = d/(ρ · η)` simplifies to `20 · d / ρ²`. + /// `η = ρ/20` ⇒ `|Λ| = 20 · d / ρ²`. #[test] fn capacity_list_size_closed_form() { for (log_d, b) in [(4.0, 1.0), (6.0, 2.0), (8.0, 3.0)] { @@ -238,9 +202,6 @@ mod tests { } } - /// `DecodingRegime::Johnson.list_size_estimate(_, b)` must match `Config::list_size` once - /// a config is built at the same rate. Keeps the rate-only helper in sync - /// with `irs_commit::Config::new`'s canonical-slack materialization. #[test] fn johnson_list_size_matches_config_list_size() { use crate::{ @@ -330,8 +291,6 @@ mod tests { assert_close(got, expected); } - /// MCA fixture — `message_length = 16 = 2^4` and `log_inv_rate = 2` give - /// exact `log2(k) = 4`. `field_bits = 64.0` for Field64. const MCA_MESSAGE_LENGTH: usize = 16; const MCA_LOG_INV_RATE: f64 = 2.0; const MCA_FIELD_BITS: f64 = 64.0; @@ -351,8 +310,6 @@ mod tests { /// MCA error, Johnson (BCSS25): `log₂(2·10.5⁵/3) + log k + 2.5·log_inv_rate − field_bits`. #[test] fn eps_mca_log2_johnson_formula() { - // `η = √ρ/20 ≈ 0.025` at `log_inv_rate = 2`. Use the canonical slack - // so the debug-assert in the formula is satisfied. let canonical_slack = 2_f64.powf(-MCA_LOG_INV_RATE).sqrt() / 20.0; let got = johnson(canonical_slack).eps_mca_log2( diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index dbc4ee3c..2b4e203b 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -13,24 +13,13 @@ use crate::{bits::Bits, engines::EngineId}; /// Per-slot proof-of-work policy. /// -/// The same `bits` value plays two roles, deliberately coupled: -/// - **Planning credit**: [`SecuritySpec::protocol_security_target_bits`] -/// subtracts `bits` from `target_security_bits` so solvers know the -/// analytic floor they must reach. -/// - **Validation cap**: [`super::protocol_config::ProtocolConfig::validate_pow_budget`] -/// rejects any per-slot PoW that exceeds `bits`. -/// -/// `Forbidden` is *not* `PerSlot { bits: 0 }`: the latter is unrepresentable -/// (the variant takes a [`NonZeroU32`]). Use [`PowBudget::per_slot`] when -/// converting from an arbitrary `u32` — it collapses `0` to `Forbidden`. +/// `bits` plays two roles: +/// - **Planning credit**: subtracted from `target_security_bits` so solvers +/// know the analytic floor they must reach. +/// - **Validation cap**: rejects any per-slot PoW that exceeds `bits`. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PowBudget { - /// Per-slot grinding forbidden. Solvers still plan against the full - /// `target_security_bits`; any nonzero per-slot PoW the planner emits - /// is rejected by validation. Forbidden, - /// Per-slot grinding allowed up to `bits`. Planning relaxes the - /// analytic target by `bits`; validation caps every slot at `bits`. PerSlot { bits: NonZeroU32 }, } @@ -52,7 +41,7 @@ impl PowBudget { } } -/// Phantom-typed newtype — `Tagged` and `Tagged` are distinct types. +/// Phantom-typed newtype. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Tagged(T, PhantomData); @@ -69,14 +58,8 @@ impl Tagged { #[derive(Debug, Clone)] pub struct SecuritySpec { pub mode: Mode, - /// Reed–Solomon decoding regime — selects the proximity radius `δ` and - /// slack policy. See [`DecodingRegime`]. pub decoding_regime: DecodingRegime, pub target_security_bits: u32, - /// Per-slot PoW policy — both the planning credit subtracted from - /// `target_security_bits` and the per-slot cap enforced by - /// [`super::protocol_config::ProtocolConfig::validate_pow_budget`]. - /// See [`PowBudget`] for the dual role. pub pow_budget: PowBudget, pub hash_id: EngineId, } @@ -150,29 +133,14 @@ pub enum Mode { } /// A `SecuritySpec` borrow proven to be in [`Mode::ZeroKnowledge`]. -/// -/// Constructed only via [`ZkSpec::try_new`], which performs the mode check -/// once at the boundary. ZK-only solvers accept `ZkSpec` to make -/// "ZK mode required" a compile-time precondition instead of a runtime assert. -/// -/// `Deref` is implemented so fields and inherent -/// methods are reachable directly (`zk_spec.target_security_bits`). For sites -/// that need to pass `&SecuritySpec` explicitly, use [`Self::as_inner`]. #[derive(Debug, Clone, Copy)] pub struct ZkSpec<'a>(&'a SecuritySpec); impl<'a> ZkSpec<'a> { - /// Returns `Some` iff `spec.mode == Mode::ZeroKnowledge`. pub fn try_new(spec: &'a SecuritySpec) -> Option { matches!(spec.mode, Mode::ZeroKnowledge).then_some(Self(spec)) } - /// Explicit unwrap — `&SecuritySpec` with the wrapper's lifetime. - /// - /// Prefer field access through `Deref` for reads; reach for `as_inner` - /// when you specifically need to hand `&SecuritySpec` to a function whose - /// signature is not in deref-coercion position (e.g. trait method - /// dispatch). pub const fn as_inner(self) -> &'a SecuritySpec { self.0 } @@ -274,9 +242,7 @@ pub type MaskCodeMessageLen = Tagged; /// `rate = 2^-log_inv_rate`. pub type LogInvRate = Tagged; -/// Reed–Solomon list-decoding ball size `|Λ(C, δ)|`. Wraps `OrderedFloat` -/// so it can be stored alongside the `Tagged` integer newtypes without losing -/// `Eq`/`Hash`. +/// Reed–Solomon list-decoding ball size `|Λ(C, δ)|`. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ListSize(OrderedFloat); @@ -295,8 +261,6 @@ mod tests { use super::*; use crate::hash; - /// Fixture target. 100 is chosen so the expected `target − pow` values in - /// the tests below are round numbers (80, 40, 0) for readability. const TARGET_BITS: u32 = 100; fn spec(pow_budget: PowBudget) -> SecuritySpec { @@ -319,9 +283,6 @@ mod tests { #[test] fn per_slot_zero_collapses_to_forbidden() { - // `per_slot(0)` is the only documented way to ask for "no grinding" - // from a `u32`; it must produce the `Forbidden` variant, not a - // `PerSlot { bits: 0 }` (which is unrepresentable). assert_eq!(PowBudget::per_slot(0), PowBudget::Forbidden); } @@ -333,7 +294,6 @@ mod tests { #[test] fn pow_credit_shifts_analytic_floor() { - // Two below-target PoW budgets: `target − pow` shifts down 1:1. assert_eq!( spec(PowBudget::per_slot(20)).protocol_security_target_bits(), Bits::new(80.0), @@ -346,7 +306,6 @@ mod tests { #[test] fn pow_exceeding_target_saturates_to_zero() { - // `pow > target` saturates rather than going negative. let pow_over_target = TARGET_BITS + 100; assert_eq!( spec(PowBudget::per_slot(pow_over_target)).protocol_security_target_bits(), diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index c5f41287..43ca202a 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -1,5 +1,5 @@ //! Sumcheck parameter selection. ZK mode adds a degree-2 mask per round -//! (Lemma 6.4, p.38). PoW closes the gap between target and analytic error. +//! (Lemma 6.4, p.38). use crate::{ algebra::{embedding::Embedding, fields::FieldWithSize}, @@ -17,9 +17,7 @@ use crate::{ }, }; -/// Per-round sumcheck builder. `mode` carries the optional mask oracle (see -/// [`super::SolveMode`]); `pow` labels grinding failures (basecase or -/// per-round). +/// Per-round sumcheck builder. pub fn solve( spec: &SecuritySpec, ctx: &RoundContext, @@ -67,7 +65,7 @@ pub fn analytic_error_bits( } /// Number of degree-2 round-polynomial masks sumcheck contributes to C_zk -/// per round (Lemma 6.4): one per sumcheck round. +/// per round (Lemma 6.4). pub const fn masks_required(ctx: &RoundContext) -> usize { num_sumcheck_rounds(ctx) } @@ -98,8 +96,6 @@ mod tests { }, }; - /// Mask-oracle fixture used by the formula tests + the ZK smoke test. - /// Both values are pow2 so `log2` is exact (no f64 drift in expected-vs-got). const FIXTURE_C_ZK_LIST_SIZE: f64 = 4.0; const FIXTURE_L_ZK: usize = 8; @@ -107,7 +103,6 @@ mod tests { irs_params::solve(spec, ctx, OodSampleBudget::ZERO) } - /// Smallest pow2 shape that still produces a non-degenerate IRS. const FIXTURE_LOG_VECTOR_SIZE: u32 = 4; const FIXTURE_LOG_INV_RATE: u32 = 1; const FIXTURE_FOLDING_FACTOR: u32 = 2; @@ -120,7 +115,6 @@ mod tests { } } - /// Lemma 6.4: ZK round polynomial has 3 coefficients. #[test] fn zk_mode_has_three_mask_coefficients() { let spec = deterministic_spec(Mode::ZeroKnowledge); @@ -144,7 +138,6 @@ mod tests { } } - /// Standard branch: `min(prox_gaps, log|F| − log|Λ(C)| − 1).max(0)`. #[test] fn analytic_error_standard_formula() { let spec = deterministic_spec(Mode::Standard); @@ -161,7 +154,6 @@ mod tests { assert_close(got, expected); } - /// ZK branch (Lemma 6.5): `min(prox_gaps, log|F| − log|Λ(C)| − log|Λ(C_zk)| − log ℓ_zk).max(0)`. #[test] fn analytic_error_zk_formula() { let log_c_zk_list = FIXTURE_C_ZK_LIST_SIZE.log2(); @@ -187,10 +179,8 @@ mod tests { assert_close(got, expected); } - /// Oracle large enough to drive `poly_id` strongly negative → clamped to 0. #[test] fn analytic_error_clamps_to_zero() { - // `log2(c_zk_list_size) + log2(l_zk) > field_bits` on `Field64`. const OVERSIZED_LOG_C_ZK_LIST: i32 = 60; const OVERSIZED_LOG_L_ZK: u32 = 30; @@ -235,8 +225,6 @@ mod tests { prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); } - /// ZK subtracts two non-negative log terms beyond Standard, so the ZK - /// error term cannot exceed the Standard one for any source IRS. #[test] fn zk_error_le_standard_error( spec in arb_zk_spec(TEST_TARGET_RANGE), @@ -249,7 +237,6 @@ mod tests { prop_assert!(zk <= standard + EPS, "zk {} > standard {}", zk, standard); } - /// `analytic_error + pow ≥ target`. #[test] fn round_pow_closes_gap_to_target( spec in prop_oneof![ @@ -270,7 +257,6 @@ mod tests { } } - /// Smoke test: `M::Source ≠ M::Target`, ZK mode. #[test] fn solve_works_with_basefield_embedding_zk() { let spec = deterministic_spec(Mode::ZeroKnowledge); diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index 807b0e98..ef55a9d4 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -32,24 +32,12 @@ pub type TestExtensionField = Field64_2; /// `Source = Field64, Target = Field64_2`. pub type TestNonIdentityEmbedding = Basefield; -/// `target_security_bits` range used by every solver-level proptest. -/// Upper bound keeps `target − analytic_error ≤ 60`, matching the cap in -/// `proof_of_work::threshold`. Lower bound keeps the analytic floor away from 0. pub const TEST_TARGET_RANGE: RangeInclusive = 30..=50; -/// Default `target_security_bits` for `deterministic_spec` fixtures. -/// 80 leaves enough analytic headroom on `Field64` (~64-bit) that every -/// sub-protocol solver has a closable gap to target. pub const FIXTURE_TARGET_BITS: u32 = 80; -/// Tolerance for `(got - expected).abs() < EPS` checks on formula-reconstruction -/// tests. `1e-9` is well above the `f64` rounding noise on log/sum expressions -/// used in the analytic-error formulas. pub const EPS: f64 = 1e-9; -/// Matches `proof_of_work::MAX_DIFFICULTY` so per-slot budget checks in -/// `grind_to_at` never bite. Tests exercising budget enforcement build their -/// own specs. pub const FIXTURE_POW_BUDGET_BITS: u32 = 60; pub fn deterministic_spec(mode: Mode) -> SecuritySpec { @@ -62,8 +50,6 @@ pub fn deterministic_spec(mode: Mode) -> SecuritySpec { } } -/// Both decoding regimes, equally weighted. Used by `arb_spec` so proptests -/// sweep all three regimes. fn arb_decoding_regime() -> impl Strategy { prop_oneof![ Just(DecodingRegime::Johnson), @@ -93,8 +79,6 @@ pub fn arb_standard_spec(target_range: RangeInclusive) -> impl Strategy impl Strategy { (4u32..=8, 1u32..=4, 1u32..=3).prop_map(|(log_size, log_inv_rate, folding_factor)| { RoundContext { @@ -117,9 +101,7 @@ pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option }) } -/// Shared check used by every sub-protocol's `pow_closes_gap_to_target*` test: -/// `analytic_error_bits + pow.difficulty() ≥ target_security_bits` (the `1e-3` -/// tolerance absorbs `proof_of_work::threshold`'s ceil quantization). +/// `analytic_error_bits + pow.difficulty() ≥ target_security_bits`. pub fn assert_pow_closes_gap(spec: &SecuritySpec, analytic: Bits, pow: &PowConfig) { let error = f64::from(analytic); let pow_bits = f64::from(pow.difficulty()); @@ -130,8 +112,7 @@ pub fn assert_pow_closes_gap(spec: &SecuritySpec, analytic: Bits, pow: &PowConfi ); } -/// `|got − expected| < EPS` with a uniform error message. Shared by every -/// `analytic_error_*_formula` test. +/// `|got − expected| < EPS`. pub fn assert_close(got: f64, expected: f64) { assert!( (got - expected).abs() < EPS, @@ -139,8 +120,7 @@ pub fn assert_close(got: f64, expected: f64) { ); } -/// C_zk fixture used by every `mask_proximity` test: source mask length 0, -/// `num_vectors = 2 · num_masks` (Construction 7.2 originals + fresh pairs). +/// C_zk fixture for `mask_proximity` tests. pub fn build_test_c_zk( spec: &SecuritySpec, l_zk: usize, @@ -159,12 +139,6 @@ pub fn build_test_c_zk( /// Builds a self-consistent `(source, target, t_ood)` triplet matching the /// per-round shape that `code_switch::solve` expects. -/// -/// `t_ood` is solved against the rate-only `DecodingRegime::list_size_estimate` -/// rather than `target.list_size()`: the latter reads the target's effective -/// rate (which itself depends on `t_ood` via the mask), producing a -/// non-monotone oscillation once the mask is tight (`mask_length = in_domain -/// + t_ood` per Construction 9.7 / Theorem 9.6) rather than pow2-padded. pub fn build_round_io( spec: &SecuritySpec, log_inv_rate: u32, From ff18333980e32ef740d4fb5b9a594e0b2d10f74f Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 30 May 2026 07:41:48 +0530 Subject: [PATCH 27/31] refactor : destruct derive.rs file --- src/protocols/params/build_round.rs | 264 ++++++++++++++++++ src/protocols/params/code_switch.rs | 2 +- src/protocols/params/derive.rs | 406 +--------------------------- src/protocols/params/error.rs | 2 +- src/protocols/params/layout.rs | 172 ++++++++++++ src/protocols/params/mod.rs | 2 + src/protocols/params/test_utils.rs | 2 +- 7 files changed, 456 insertions(+), 394 deletions(-) create mode 100644 src/protocols/params/build_round.rs create mode 100644 src/protocols/params/layout.rs diff --git a/src/protocols/params/build_round.rs b/src/protocols/params/build_round.rs new file mode 100644 index 00000000..311cb6ea --- /dev/null +++ b/src/protocols/params/build_round.rs @@ -0,0 +1,264 @@ +//! Per-round build: turns a [`RoundShape`] into a [`RoundConfig`]. +//! +//! Solves the `t_ood` fix-point, builds source/target IRS configs, and +//! (in ZK) assembles the per-round mask oracle. Consumed by +//! [`super::derive`], which drives the per-round loop. + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, + protocols::{ + irs_commit::Config as IrsConfig, + params::{ + bounds::usize_to_f64, + code_switch as code_switch_params, + error::{DeriveError, Pow}, + irs_commit as irs_params, + layout::{round_context, target_context, RoundShape}, + mask_proximity as mask_proximity_params, + protocol_config::{MaskOracleConfig, RoundConfig, RoundMode}, + spec::{ + DecodingRegime, LogInvRate, MaskCodeMessageLen, OodSampleBudget, RoundContext, + SecuritySpec, ZkSpec, + }, + sumcheck as sumcheck_params, SolveMode, + }, + }, +}; + +const T_OOD_MAX_ITER: usize = 32; + +/// Mode flag for the OOD security bound in [`solve_t_ood`] / +/// [`ood_security_bits_at`]. +#[derive(Clone, Copy)] +pub(super) enum OodMode { + Standard, + ZeroKnowledge { c_zk_log_inv_rate: f64 }, +} + +/// Mode-dispatch input for [`build_round_config`]. +#[derive(Clone, Copy)] +pub(super) enum RoundBuildMode<'a> { + Standard, + ZeroKnowledge { + zk_spec: ZkSpec<'a>, + c_zk_log_inv_rate: LogInvRate, + }, +} + +impl RoundBuildMode<'_> { + fn to_ood_mode(self) -> OodMode { + match self { + Self::Standard => OodMode::Standard, + Self::ZeroKnowledge { + c_zk_log_inv_rate, .. + } => OodMode::ZeroKnowledge { + c_zk_log_inv_rate: f64::from(c_zk_log_inv_rate.get()), + }, + } + } +} + +pub(super) fn build_round_config( + spec: &SecuritySpec, + shape: &RoundShape, + mode: RoundBuildMode<'_>, +) -> Result, DeriveError> { + let ctx = round_context(shape); + let (source, t_ood) = solve_round_source::(spec, shape, mode.to_ood_mode())?; + + let (target_budget, solve_mode, round_mode) = match mode { + RoundBuildMode::Standard => ( + OodSampleBudget::ZERO, + SolveMode::Standard, + RoundMode::Standard, + ), + RoundBuildMode::ZeroKnowledge { + zk_spec, + c_zk_log_inv_rate, + } => { + let num_masks = + sumcheck_params::masks_required(&ctx) + code_switch_params::masks_required(); + let mask_oracle = build_mask_oracle::( + zk_spec, + &source, + t_ood, + num_masks, + c_zk_log_inv_rate, + shape.round_index, + )?; + let solve_mode = SolveMode::ZeroKnowledge { + mask_oracle: mask_oracle.info(), + }; + let round_mode = RoundMode::ZeroKnowledge { + t_ood: OodSampleBudget::new(t_ood), + mask_oracle: Box::new(mask_oracle), + }; + (OodSampleBudget::new(t_ood), solve_mode, round_mode) + } + }; + + let target: IrsConfig> = + irs_params::solve(spec, &target_context(shape, &source), target_budget); + let sumcheck = sumcheck_params::solve( + spec, + &ctx, + &source, + solve_mode, + Pow::RoundSumcheck { + index: shape.round_index, + }, + )?; + let code_switch = + code_switch_params::solve(spec, source, target, t_ood, solve_mode, shape.round_index)?; + + Ok(RoundConfig::new( + shape.round_index, + sumcheck, + code_switch, + round_mode, + )) +} + +fn solve_round_source( + spec: &SecuritySpec, + shape: &RoundShape, + ood_mode: OodMode, +) -> Result<(IrsConfig, usize), DeriveError> { + let src_ctx = round_context(shape); + let target_log_inv_rate = f64::from( + shape + .source_log_inv_rate + .saturating_add(shape.source_folding_factor.saturating_sub(1)), + ); + let target_log_degree = f64::from( + shape + .source_vector_size + .trailing_zeros() + .saturating_sub(shape.source_folding_factor), + ); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, target_log_inv_rate); + solve_t_ood::( + spec, + &src_ctx, + target_list_size, + ood_mode, + shape.round_index, + ) +} + +/// ZK-only: assemble the per-round mask oracle (C_zk codeword + mask-proximity +/// check). +fn build_mask_oracle( + zk_spec: ZkSpec<'_>, + source: &IrsConfig, + t_ood: usize, + num_masks: usize, + c_zk_log_inv_rate: LogInvRate, + round_index: usize, +) -> Result, DeriveError> { + let spec = zk_spec.as_inner(); + let l_zk = compute_l_zk(source, t_ood); + let c_zk: IrsConfig> = irs_params::solve_mask_code( + zk_spec, + l_zk, + source.mask_length(), + c_zk_log_inv_rate, + 2 * num_masks, + ); + let c_zk_list_size_estimate = spec.decoding_regime.list_size_estimate( + (l_zk.get() as f64).log2(), + f64::from(c_zk_log_inv_rate.get()), + ); + debug_assert!( + (c_zk.list_size() - c_zk_list_size_estimate).abs() + < 1e-9 * c_zk_list_size_estimate.max(1.0), + "c_zk.list_size() {} drifted from planner estimate {}", + c_zk.list_size(), + c_zk_list_size_estimate, + ); + let mask_proximity = mask_proximity_params::solve(spec, c_zk.clone(), num_masks, round_index)?; + Ok(MaskOracleConfig::new(c_zk, l_zk, mask_proximity)) +} + +/// `ℓ_zk = next_pow2(r + t_ood)` (Theorem 9.6 + Lemma 9.3). +pub(super) const fn compute_l_zk( + source: &IrsConfig, + t_ood: usize, +) -> MaskCodeMessageLen { + MaskCodeMessageLen::new( + source + .mask_length() + .saturating_add(t_ood) + .next_power_of_two(), + ) +} + +/// Per-round `(source, t_ood)`. +/// +/// Under `Unique`, `t_ood = 1` is pinned (the `log(L·(L−1)/2)` term degenerates +/// when `L = 1`, and Construction 9.7 requires `out_domain_samples ≥ 1`). +/// Otherwise linear search over `t_ood = 1..=T_OOD_MAX_ITER` for the smallest +/// value where [`ood_security_bits_at`] meets `protocol_security_target_bits`. +pub(super) fn solve_t_ood( + spec: &SecuritySpec, + src_ctx: &RoundContext, + target_list_size: f64, + ood_mode: OodMode, + round_index: usize, +) -> Result<(IrsConfig, usize), DeriveError> { + if matches!(spec.decoding_regime, DecodingRegime::Unique) { + let source = irs_params::solve(spec, src_ctx, OodSampleBudget::new(1)); + return Ok((source, 1)); + } + + let security_target = f64::from(spec.protocol_security_target_bits()); + let field_bits = M::Target::field_size_bits(); + + for t_ood in 1..=T_OOD_MAX_ITER { + let source: IrsConfig = irs_params::solve(spec, src_ctx, OodSampleBudget::new(t_ood)); + let bits = + ood_security_bits_at(spec, &source, t_ood, target_list_size, ood_mode, field_bits); + if bits >= security_target { + return Ok((source, t_ood)); + } + } + Err(DeriveError::FixedPointDidNotConverge { round_index }) +} + +/// OOD security bits at candidate `t_ood`, per STIR Lemma 4.5: +/// `bits = t · (|F| − log d) − log(L · (L − 1) / 2) ≈ t·(|F| − log d) − 2·log L + 1`. +fn ood_security_bits_at( + spec: &SecuritySpec, + source: &IrsConfig, + t_ood: usize, + target_list_size: f64, + ood_mode: OodMode, + field_bits: f64, +) -> f64 { + let (log_degree, log_combined_list) = match ood_mode { + OodMode::Standard => ( + usize_to_f64(source.message_length()).log2(), + target_list_size.log2(), + ), + OodMode::ZeroKnowledge { c_zk_log_inv_rate } => { + let l_zk = source + .mask_length() + .saturating_add(t_ood) + .next_power_of_two(); + let c_zk_list = spec + .decoding_regime + .list_size_estimate(usize_to_f64(l_zk).log2(), c_zk_log_inv_rate); + ( + usize_to_f64(source.message_length().saturating_add(l_zk)).log2(), + (target_list_size * c_zk_list).log2(), + ) + } + }; + let ood = usize_to_f64(t_ood); + ood * (field_bits - log_degree) - 2.0 * log_combined_list + 1.0 +} diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index f539469d..530d6296 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -111,7 +111,7 @@ mod tests { use super::*; use crate::protocols::params::{ - derive::{compute_l_zk, solve_t_ood, OodMode}, + build_round::{compute_l_zk, solve_t_ood, OodMode}, irs_commit as irs_params, spec::{ DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 922da696..23d05a02 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -1,38 +1,17 @@ //! Derives a [`ProtocolConfig`] from a spec + tuning. use crate::{ - algebra::{ - embedding::{Embedding, Identity}, - fields::FieldWithSize, - }, - protocols::{ - irs_commit::Config as IrsConfig, - params::{ - basecase as basecase_params, - bounds::usize_to_f64, - code_switch as code_switch_params, - error::{DeriveError, Pow}, - irs_commit as irs_params, mask_proximity as mask_proximity_params, - protocol_config::{MaskOracleConfig, ProtocolConfig, RoundConfig, RoundMode}, - spec::{ - DecodingRegime, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, - RoundContext, SecuritySpec, TuningSpec, ZkSpec, - }, - sumcheck as sumcheck_params, SolveMode, - }, + algebra::embedding::Embedding, + protocols::params::{ + basecase as basecase_params, + build_round::{build_round_config, RoundBuildMode}, + error::DeriveError, + layout::{round_layout, RoundLayout}, + protocol_config::{ProtocolConfig, RoundConfig}, + spec::{LogInvRate, Mode, SecuritySpec, TuningSpec, ZkSpec}, }, }; -const T_OOD_MAX_ITER: usize = 32; - -/// Mode flag for the OOD security bound in [`solve_t_ood`] / -/// [`ood_security_bits_at`]. -#[derive(Clone, Copy)] -pub(super) enum OodMode { - Standard, - ZeroKnowledge { c_zk_log_inv_rate: f64 }, -} - impl ProtocolConfig { /// Fails with [`DeriveError`] when the spec/tuning combination is /// infeasible. @@ -64,306 +43,20 @@ impl ProtocolConfig { } } -/// Mode-dispatch input for [`build_round_config`]. -#[derive(Clone, Copy)] -enum RoundBuildMode<'a> { - Standard, - ZeroKnowledge { - zk_spec: ZkSpec<'a>, - c_zk_log_inv_rate: LogInvRate, - }, -} - -impl RoundBuildMode<'_> { - fn to_ood_mode(self) -> OodMode { - match self { - Self::Standard => OodMode::Standard, - Self::ZeroKnowledge { - c_zk_log_inv_rate, .. - } => OodMode::ZeroKnowledge { - c_zk_log_inv_rate: f64::from(c_zk_log_inv_rate.get()), - }, - } - } -} - -#[derive(Debug, Clone, Copy)] -struct RoundShape { - round_index: usize, - source_vector_size: usize, - source_log_inv_rate: u32, - source_folding_factor: u32, - target_folding_factor: u32, -} - -struct RoundLayout { - shapes: Vec, - basecase_vector_size: usize, - basecase_log_inv_rate: u32, -} - -fn round_layout(tuning: &TuningSpec) -> RoundLayout { - assert!(tuning.vector_size.is_power_of_two()); - assert!(tuning.folding_factor.min() >= 1); - - let mut num_vars = tuning.vector_size.trailing_zeros() as usize; - let mut log_inv_rate = tuning.starting_log_inv_rate; - let mut shapes = Vec::new(); - - loop { - let round = shapes.len(); - let source_folding = tuning.folding_factor.at_round(round); - let target_folding = tuning.folding_factor.at_round(round.saturating_add(1)); - if num_vars < source_folding.saturating_add(target_folding) { - break; - } - shapes.push(RoundShape { - round_index: round, - source_vector_size: 1usize << num_vars, - source_log_inv_rate: log_inv_rate, - source_folding_factor: source_folding as u32, - target_folding_factor: target_folding as u32, - }); - num_vars = num_vars.saturating_sub(source_folding); - log_inv_rate = log_inv_rate.saturating_add((source_folding as u32).saturating_sub(1)); - } - - RoundLayout { - shapes, - basecase_vector_size: 1usize << num_vars, - basecase_log_inv_rate: log_inv_rate, - } -} - -const fn round_context(shape: &RoundShape) -> RoundContext { - RoundContext { - vector_size: shape.source_vector_size, - log_inv_rate: shape.source_log_inv_rate, - folding_factor: shape.source_folding_factor, - } -} - -fn target_context(shape: &RoundShape, source: &IrsConfig) -> RoundContext { - RoundContext { - vector_size: source.message_length(), - log_inv_rate: shape - .source_log_inv_rate - .saturating_add(shape.source_folding_factor.saturating_sub(1)), - folding_factor: shape.target_folding_factor, - } -} - -fn solve_round_source( - spec: &SecuritySpec, - shape: &RoundShape, - ood_mode: OodMode, -) -> Result<(IrsConfig, usize), DeriveError> { - let src_ctx = round_context(shape); - let target_log_inv_rate = f64::from( - shape - .source_log_inv_rate - .saturating_add(shape.source_folding_factor.saturating_sub(1)), - ); - let target_log_degree = f64::from( - shape - .source_vector_size - .trailing_zeros() - .saturating_sub(shape.source_folding_factor), - ); - let target_list_size = spec - .decoding_regime - .list_size_estimate(target_log_degree, target_log_inv_rate); - solve_t_ood::( - spec, - &src_ctx, - target_list_size, - ood_mode, - shape.round_index, - ) -} - -/// ZK-only: assemble the per-round mask oracle (C_zk codeword + mask-proximity -/// check). -fn build_mask_oracle( - zk_spec: ZkSpec<'_>, - source: &IrsConfig, - t_ood: usize, - num_masks: usize, - c_zk_log_inv_rate: LogInvRate, - round_index: usize, -) -> Result, DeriveError> { - let spec = zk_spec.as_inner(); - let l_zk = compute_l_zk(source, t_ood); - let c_zk: IrsConfig> = irs_params::solve_mask_code( - zk_spec, - l_zk, - source.mask_length(), - c_zk_log_inv_rate, - 2 * num_masks, - ); - let c_zk_list_size_estimate = spec.decoding_regime.list_size_estimate( - (l_zk.get() as f64).log2(), - f64::from(c_zk_log_inv_rate.get()), - ); - debug_assert!( - (c_zk.list_size() - c_zk_list_size_estimate).abs() - < 1e-9 * c_zk_list_size_estimate.max(1.0), - "c_zk.list_size() {} drifted from planner estimate {}", - c_zk.list_size(), - c_zk_list_size_estimate, - ); - let mask_proximity = mask_proximity_params::solve(spec, c_zk.clone(), num_masks, round_index)?; - Ok(MaskOracleConfig::new(c_zk, l_zk, mask_proximity)) -} - -fn build_round_config( - spec: &SecuritySpec, - shape: &RoundShape, - mode: RoundBuildMode<'_>, -) -> Result, DeriveError> { - let ctx = round_context(shape); - let (source, t_ood) = solve_round_source::(spec, shape, mode.to_ood_mode())?; - - let (target_budget, solve_mode, round_mode) = match mode { - RoundBuildMode::Standard => ( - OodSampleBudget::ZERO, - SolveMode::Standard, - RoundMode::Standard, - ), - RoundBuildMode::ZeroKnowledge { - zk_spec, - c_zk_log_inv_rate, - } => { - let num_masks = - sumcheck_params::masks_required(&ctx) + code_switch_params::masks_required(); - let mask_oracle = build_mask_oracle::( - zk_spec, - &source, - t_ood, - num_masks, - c_zk_log_inv_rate, - shape.round_index, - )?; - let solve_mode = SolveMode::ZeroKnowledge { - mask_oracle: mask_oracle.info(), - }; - let round_mode = RoundMode::ZeroKnowledge { - t_ood: OodSampleBudget::new(t_ood), - mask_oracle: Box::new(mask_oracle), - }; - (OodSampleBudget::new(t_ood), solve_mode, round_mode) - } - }; - - let target: IrsConfig> = - irs_params::solve(spec, &target_context(shape, &source), target_budget); - let sumcheck = sumcheck_params::solve( - spec, - &ctx, - &source, - solve_mode, - Pow::RoundSumcheck { - index: shape.round_index, - }, - )?; - let code_switch = - code_switch_params::solve(spec, source, target, t_ood, solve_mode, shape.round_index)?; - - Ok(RoundConfig::new( - shape.round_index, - sumcheck, - code_switch, - round_mode, - )) -} - -/// `ℓ_zk = next_pow2(r + t_ood)` (Theorem 9.6 + Lemma 9.3). -pub(super) const fn compute_l_zk( - source: &IrsConfig, - t_ood: usize, -) -> MaskCodeMessageLen { - MaskCodeMessageLen::new( - source - .mask_length() - .saturating_add(t_ood) - .next_power_of_two(), - ) -} - -/// Per-round `(source, t_ood)`. -/// -/// Under `Unique`, `t_ood = 1` is pinned (the `log(L·(L−1)/2)` term degenerates -/// when `L = 1`, and Construction 9.7 requires `out_domain_samples ≥ 1`). -/// Otherwise linear search over `t_ood = 1..=T_OOD_MAX_ITER` for the smallest -/// value where [`ood_security_bits_at`] meets `protocol_security_target_bits`. -pub(super) fn solve_t_ood( - spec: &SecuritySpec, - src_ctx: &RoundContext, - target_list_size: f64, - ood_mode: OodMode, - round_index: usize, -) -> Result<(IrsConfig, usize), DeriveError> { - if matches!(spec.decoding_regime, DecodingRegime::Unique) { - let source = irs_params::solve(spec, src_ctx, OodSampleBudget::new(1)); - return Ok((source, 1)); - } - - let security_target = f64::from(spec.protocol_security_target_bits()); - let field_bits = M::Target::field_size_bits(); - - for t_ood in 1..=T_OOD_MAX_ITER { - let source: IrsConfig = irs_params::solve(spec, src_ctx, OodSampleBudget::new(t_ood)); - let bits = - ood_security_bits_at(spec, &source, t_ood, target_list_size, ood_mode, field_bits); - if bits >= security_target { - return Ok((source, t_ood)); - } - } - Err(DeriveError::FixedPointDidNotConverge { round_index }) -} - -/// OOD security bits at candidate `t_ood`, per STIR Lemma 4.5: -/// `bits = t · (|F| − log d) − log(L · (L − 1) / 2) ≈ t·(|F| − log d) − 2·log L + 1`. -fn ood_security_bits_at( - spec: &SecuritySpec, - source: &IrsConfig, - t_ood: usize, - target_list_size: f64, - ood_mode: OodMode, - field_bits: f64, -) -> f64 { - let (log_degree, log_combined_list) = match ood_mode { - OodMode::Standard => ( - usize_to_f64(source.message_length()).log2(), - target_list_size.log2(), - ), - OodMode::ZeroKnowledge { c_zk_log_inv_rate } => { - let l_zk = source - .mask_length() - .saturating_add(t_ood) - .next_power_of_two(); - let c_zk_list = spec - .decoding_regime - .list_size_estimate(usize_to_f64(l_zk).log2(), c_zk_log_inv_rate); - ( - usize_to_f64(source.message_length().saturating_add(l_zk)).log2(), - (target_list_size * c_zk_list).log2(), - ) - } - }; - let ood = usize_to_f64(t_ood); - ood * (field_bits - log_degree) - 2.0 * log_combined_list + 1.0 -} - #[cfg(test)] mod tests { use proptest::prelude::*; - use super::*; use crate::{ + algebra::{embedding::Embedding, fields::FieldWithSize}, hash, protocols::params::{ - spec::{DecodingRegime, FoldingFactor, PowBudget}, + basecase as basecase_params, code_switch as code_switch_params, + error::DeriveError, + mask_proximity as mask_proximity_params, + protocol_config::{ProtocolConfig, RoundMode}, + spec::{DecodingRegime, FoldingFactor, Mode, PowBudget, SecuritySpec, TuningSpec}, + sumcheck as sumcheck_params, test_utils::{assert_close, assert_pow_closes_gap, TestEmbedding}, }, }; @@ -390,9 +83,6 @@ mod tests { const LOG_VECTOR_SIZE_NO_ROUNDS: u32 = 3; const LOG_VECTOR_SIZE_MULTI_ROUND: u32 = 8; - const VARIED_INITIAL_FOLDING: usize = 3; - const VARIED_STEADY_FOLDING: usize = 2; - fn tuning_with(vector_size: usize) -> TuningSpec { TuningSpec { vector_size, @@ -413,72 +103,6 @@ mod tests { } } - const RATE_STEPPING_STARTING_LOG_INV_RATE: u32 = 2; - const MIN_ROUNDS_FOR_CHAINING_TEST: usize = 2; - - #[test] - fn round_layout_rate_steps_up_by_folding_minus_one() { - let tuning = TuningSpec { - vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, - starting_log_inv_rate: RATE_STEPPING_STARTING_LOG_INV_RATE, - folding_factor: FoldingFactor::ConstantFromSecondRound { - initial: VARIED_INITIAL_FOLDING, - rest: VARIED_STEADY_FOLDING, - }, - }; - let layout = round_layout(&tuning); - - let mut expected_log_inv_rate = RATE_STEPPING_STARTING_LOG_INV_RATE; - for shape in &layout.shapes { - assert_eq!(shape.source_log_inv_rate, expected_log_inv_rate); - expected_log_inv_rate += shape.source_folding_factor.saturating_sub(1); - } - assert_eq!(layout.basecase_log_inv_rate, expected_log_inv_rate); - } - - #[test] - fn round_layout_chains_target_to_next_source_folding() { - let tuning = TuningSpec { - vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, - starting_log_inv_rate: FIXTURE_LOG_INV_RATE, - folding_factor: FoldingFactor::ConstantFromSecondRound { - initial: VARIED_INITIAL_FOLDING, - rest: VARIED_STEADY_FOLDING, - }, - }; - let layout = round_layout(&tuning); - assert!( - layout.shapes.len() >= MIN_ROUNDS_FOR_CHAINING_TEST, - "need ≥ {MIN_ROUNDS_FOR_CHAINING_TEST} rounds to test chaining", - ); - for window in layout.shapes.windows(2) { - assert_eq!( - window[0].target_folding_factor, - window[1].source_folding_factor - ); - } - } - - #[test] - fn round_layout_basecase_size_consumes_remaining_num_vars() { - let tuning = tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND); - let layout = round_layout(&tuning); - let consumed: u32 = layout.shapes.iter().map(|s| s.source_folding_factor).sum(); - let initial_num_vars = tuning.vector_size.trailing_zeros(); - let remaining = initial_num_vars - consumed; - assert_eq!(layout.basecase_vector_size, 1usize << remaining); - } - - #[test] - fn round_layout_stops_when_no_room_for_source_plus_target() { - let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; - let tuning = tuning_with(vector_size); - let layout = round_layout(&tuning); - assert!(layout.shapes.is_empty()); - assert_eq!(layout.basecase_vector_size, vector_size); - assert_eq!(layout.basecase_log_inv_rate, FIXTURE_LOG_INV_RATE); - } - #[test] fn derive_standard_with_no_rounds_uses_basecase_only() { let spec = test_spec(Mode::Standard); diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index f753ffda..d1211e1d 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -78,7 +78,7 @@ impl Display for ChainTarget { /// sub-protocol solvers it calls. #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum DeriveError { - /// The `t_ood` fixed-point in [`super::derive::solve_t_ood`] ran out of + /// The `t_ood` fixed-point in [`super::build_round::solve_t_ood`] ran out of /// iterations. #[error("t_ood fixed-point did not converge for round {round_index}")] FixedPointDidNotConverge { round_index: usize }, diff --git a/src/protocols/params/layout.rs b/src/protocols/params/layout.rs new file mode 100644 index 00000000..122e8938 --- /dev/null +++ b/src/protocols/params/layout.rs @@ -0,0 +1,172 @@ +//! Round-skeleton layout: pure-data walk over the witness shape. +//! +//! Produces per-round shapes (vector size, log_inv_rate, folding factors) +//! independent of [`SecuritySpec`] and IRS solving. Consumed by +//! [`super::build_round`] to instantiate per-round configs and by +//! [`super::derive`] to drive the round/basecase split. + +use crate::{ + algebra::embedding::Embedding, + protocols::{ + irs_commit::Config as IrsConfig, + params::spec::{RoundContext, TuningSpec}, + }, +}; + +#[derive(Debug, Clone, Copy)] +pub(super) struct RoundShape { + pub(super) round_index: usize, + pub(super) source_vector_size: usize, + pub(super) source_log_inv_rate: u32, + pub(super) source_folding_factor: u32, + pub(super) target_folding_factor: u32, +} + +pub(super) struct RoundLayout { + pub(super) shapes: Vec, + pub(super) basecase_vector_size: usize, + pub(super) basecase_log_inv_rate: u32, +} + +pub(super) fn round_layout(tuning: &TuningSpec) -> RoundLayout { + assert!(tuning.vector_size.is_power_of_two()); + assert!(tuning.folding_factor.min() >= 1); + + let mut num_vars = tuning.vector_size.trailing_zeros() as usize; + let mut log_inv_rate = tuning.starting_log_inv_rate; + let mut shapes = Vec::new(); + + loop { + let round = shapes.len(); + let source_folding = tuning.folding_factor.at_round(round); + let target_folding = tuning.folding_factor.at_round(round.saturating_add(1)); + if num_vars < source_folding.saturating_add(target_folding) { + break; + } + shapes.push(RoundShape { + round_index: round, + source_vector_size: 1usize << num_vars, + source_log_inv_rate: log_inv_rate, + source_folding_factor: source_folding as u32, + target_folding_factor: target_folding as u32, + }); + num_vars = num_vars.saturating_sub(source_folding); + log_inv_rate = log_inv_rate.saturating_add((source_folding as u32).saturating_sub(1)); + } + + RoundLayout { + shapes, + basecase_vector_size: 1usize << num_vars, + basecase_log_inv_rate: log_inv_rate, + } +} + +pub(super) const fn round_context(shape: &RoundShape) -> RoundContext { + RoundContext { + vector_size: shape.source_vector_size, + log_inv_rate: shape.source_log_inv_rate, + folding_factor: shape.source_folding_factor, + } +} + +pub(super) fn target_context( + shape: &RoundShape, + source: &IrsConfig, +) -> RoundContext { + RoundContext { + vector_size: source.message_length(), + log_inv_rate: shape + .source_log_inv_rate + .saturating_add(shape.source_folding_factor.saturating_sub(1)), + folding_factor: shape.target_folding_factor, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::params::spec::FoldingFactor; + + const FIXTURE_FOLDING_FACTOR: usize = 2; + const FIXTURE_LOG_INV_RATE: u32 = 1; + + const LOG_VECTOR_SIZE_NO_ROUNDS: u32 = 3; + const LOG_VECTOR_SIZE_MULTI_ROUND: u32 = 8; + + const VARIED_INITIAL_FOLDING: usize = 3; + const VARIED_STEADY_FOLDING: usize = 2; + + const RATE_STEPPING_STARTING_LOG_INV_RATE: u32 = 2; + const MIN_ROUNDS_FOR_CHAINING_TEST: usize = 2; + + fn tuning_with(vector_size: usize) -> TuningSpec { + TuningSpec { + vector_size, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + } + } + + #[test] + fn round_layout_rate_steps_up_by_folding_minus_one() { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: RATE_STEPPING_STARTING_LOG_INV_RATE, + folding_factor: FoldingFactor::ConstantFromSecondRound { + initial: VARIED_INITIAL_FOLDING, + rest: VARIED_STEADY_FOLDING, + }, + }; + let layout = round_layout(&tuning); + + let mut expected_log_inv_rate = RATE_STEPPING_STARTING_LOG_INV_RATE; + for shape in &layout.shapes { + assert_eq!(shape.source_log_inv_rate, expected_log_inv_rate); + expected_log_inv_rate += shape.source_folding_factor.saturating_sub(1); + } + assert_eq!(layout.basecase_log_inv_rate, expected_log_inv_rate); + } + + #[test] + fn round_layout_chains_target_to_next_source_folding() { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::ConstantFromSecondRound { + initial: VARIED_INITIAL_FOLDING, + rest: VARIED_STEADY_FOLDING, + }, + }; + let layout = round_layout(&tuning); + assert!( + layout.shapes.len() >= MIN_ROUNDS_FOR_CHAINING_TEST, + "need ≥ {MIN_ROUNDS_FOR_CHAINING_TEST} rounds to test chaining", + ); + for window in layout.shapes.windows(2) { + assert_eq!( + window[0].target_folding_factor, + window[1].source_folding_factor + ); + } + } + + #[test] + fn round_layout_basecase_size_consumes_remaining_num_vars() { + let tuning = tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND); + let layout = round_layout(&tuning); + let consumed: u32 = layout.shapes.iter().map(|s| s.source_folding_factor).sum(); + let initial_num_vars = tuning.vector_size.trailing_zeros(); + let remaining = initial_num_vars - consumed; + assert_eq!(layout.basecase_vector_size, 1usize << remaining); + } + + #[test] + fn round_layout_stops_when_no_room_for_source_plus_target() { + let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; + let tuning = tuning_with(vector_size); + let layout = round_layout(&tuning); + assert!(layout.shapes.is_empty()); + assert_eq!(layout.basecase_vector_size, vector_size); + assert_eq!(layout.basecase_log_inv_rate, FIXTURE_LOG_INV_RATE); + } +} diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index 3f447c46..c5a2bcf8 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -6,10 +6,12 @@ pub(crate) mod basecase; pub(crate) mod bounds; +pub(crate) mod build_round; pub(crate) mod code_switch; pub mod derive; pub mod error; pub(crate) mod irs_commit; +pub(crate) mod layout; pub(crate) mod mask_proximity; pub mod protocol_config; pub(crate) mod regime; diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index ef55a9d4..af835293 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -14,7 +14,7 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ - derive::{solve_t_ood, OodMode}, + build_round::{solve_t_ood, OodMode}, irs_commit as irs_params, protocol_config::MaskOracleInfo, spec::{ From c5e223d2777347bf568fac28d8e90853544fe28a Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 30 May 2026 07:56:27 +0530 Subject: [PATCH 28/31] refactor : shrink enums in params --- src/protocols/params/branch.rs | 65 +++++++++++++++++++++++++ src/protocols/params/build_round.rs | 61 +++++++---------------- src/protocols/params/code_switch.rs | 15 +++--- src/protocols/params/derive.rs | 11 +++-- src/protocols/params/mod.rs | 9 +--- src/protocols/params/protocol_config.rs | 28 +++++------ src/protocols/params/sumcheck.rs | 14 ++---- src/protocols/params/test_utils.rs | 7 +-- 8 files changed, 121 insertions(+), 89 deletions(-) create mode 100644 src/protocols/params/branch.rs diff --git a/src/protocols/params/branch.rs b/src/protocols/params/branch.rs new file mode 100644 index 00000000..0797952f --- /dev/null +++ b/src/protocols/params/branch.rs @@ -0,0 +1,65 @@ +//! Standard-vs-ZK branching used to thread mode through the build pipeline. +//! +//! [`Branch`] is the shared shape: a transient choice between the standard +//! path (no payload) and the zero-knowledge path (payload `T`). Concrete +//! pipeline stages alias it with the payload they carry: +//! +//! - [`RoundBuildMode`] — input to [`super::build_round::build_round_config`]. +//! - [`OodMode`] — input to OOD-bound helpers. +//! - [`SolveMode`] — input to the per-sub-protocol solvers (`sumcheck`, `code_switch`). +//! +//! Sharing one enum gives us a free [`Branch::map`] for stage-to-stage +//! payload conversions, replacing one-off `to_ood_mode`-style helpers. + +use crate::protocols::params::{ + protocol_config::MaskOracleInfo, + spec::{LogInvRate, ZkSpec}, +}; + +/// Standard (no payload) vs. zero-knowledge (payload `T`). +#[derive(Clone, Copy, Debug)] +pub enum Branch { + Standard, + ZeroKnowledge(T), +} + +impl Branch { + pub const fn is_zk(&self) -> bool { + matches!(self, Self::ZeroKnowledge(_)) + } + + /// Transform the ZK payload, leaving `Standard` unchanged. Replaces + /// per-stage `to_*` conversion helpers. + pub fn map(self, f: impl FnOnce(T) -> U) -> Branch { + match self { + Self::Standard => Branch::Standard, + Self::ZeroKnowledge(t) => Branch::ZeroKnowledge(f(t)), + } + } + + pub const fn as_ref(&self) -> Branch<&T> { + match self { + Self::Standard => Branch::Standard, + Self::ZeroKnowledge(t) => Branch::ZeroKnowledge(t), + } + } +} + +/// Payload carried by [`RoundBuildMode::ZeroKnowledge`] — references the +/// `SecuritySpec` (so its lifetime threads through) plus the planner-chosen +/// `C_zk` rate. +#[derive(Clone, Copy, Debug)] +pub struct RoundBuildPayload<'a> { + pub zk_spec: ZkSpec<'a>, + pub c_zk_log_inv_rate: LogInvRate, +} + +/// Mode-dispatch input for [`super::build_round::build_round_config`]. +pub type RoundBuildMode<'a> = Branch>; + +/// Mode flag for the OOD security bound. Payload is the `C_zk` log-inverse +/// rate as `f64` (already coerced for the analytic formula). +pub type OodMode = Branch; + +/// Solver-input mode for the per-round sumcheck and code-switch builders. +pub type SolveMode = Branch; diff --git a/src/protocols/params/build_round.rs b/src/protocols/params/build_round.rs index 311cb6ea..e3a26383 100644 --- a/src/protocols/params/build_round.rs +++ b/src/protocols/params/build_round.rs @@ -13,6 +13,7 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::usize_to_f64, + branch::{Branch, OodMode, RoundBuildMode, RoundBuildPayload}, code_switch as code_switch_params, error::{DeriveError, Pow}, irs_commit as irs_params, @@ -30,55 +31,26 @@ use crate::{ const T_OOD_MAX_ITER: usize = 32; -/// Mode flag for the OOD security bound in [`solve_t_ood`] / -/// [`ood_security_bits_at`]. -#[derive(Clone, Copy)] -pub(super) enum OodMode { - Standard, - ZeroKnowledge { c_zk_log_inv_rate: f64 }, -} - -/// Mode-dispatch input for [`build_round_config`]. -#[derive(Clone, Copy)] -pub(super) enum RoundBuildMode<'a> { - Standard, - ZeroKnowledge { - zk_spec: ZkSpec<'a>, - c_zk_log_inv_rate: LogInvRate, - }, -} - -impl RoundBuildMode<'_> { - fn to_ood_mode(self) -> OodMode { - match self { - Self::Standard => OodMode::Standard, - Self::ZeroKnowledge { - c_zk_log_inv_rate, .. - } => OodMode::ZeroKnowledge { - c_zk_log_inv_rate: f64::from(c_zk_log_inv_rate.get()), - }, - } - } -} - pub(super) fn build_round_config( spec: &SecuritySpec, shape: &RoundShape, mode: RoundBuildMode<'_>, ) -> Result, DeriveError> { let ctx = round_context(shape); - let (source, t_ood) = solve_round_source::(spec, shape, mode.to_ood_mode())?; + let ood_mode = mode.map(|p| f64::from(p.c_zk_log_inv_rate.get())); + let (source, t_ood) = solve_round_source::(spec, shape, ood_mode)?; - let (target_budget, solve_mode, round_mode) = match mode { - RoundBuildMode::Standard => ( + let (target_budget, solve_mode, round_mode, mask_oracle) = match mode { + Branch::Standard => ( OodSampleBudget::ZERO, SolveMode::Standard, RoundMode::Standard, + None, ), - RoundBuildMode::ZeroKnowledge { + Branch::ZeroKnowledge(RoundBuildPayload { zk_spec, c_zk_log_inv_rate, - } => { + }) => { let num_masks = sumcheck_params::masks_required(&ctx) + code_switch_params::masks_required(); let mask_oracle = build_mask_oracle::( @@ -89,14 +61,16 @@ pub(super) fn build_round_config( c_zk_log_inv_rate, shape.round_index, )?; - let solve_mode = SolveMode::ZeroKnowledge { - mask_oracle: mask_oracle.info(), - }; + let solve_mode = SolveMode::ZeroKnowledge(mask_oracle.info()); let round_mode = RoundMode::ZeroKnowledge { t_ood: OodSampleBudget::new(t_ood), - mask_oracle: Box::new(mask_oracle), }; - (OodSampleBudget::new(t_ood), solve_mode, round_mode) + ( + OodSampleBudget::new(t_ood), + solve_mode, + round_mode, + Some(mask_oracle), + ) } }; @@ -119,6 +93,7 @@ pub(super) fn build_round_config( sumcheck, code_switch, round_mode, + mask_oracle, )) } @@ -241,11 +216,11 @@ fn ood_security_bits_at( field_bits: f64, ) -> f64 { let (log_degree, log_combined_list) = match ood_mode { - OodMode::Standard => ( + Branch::Standard => ( usize_to_f64(source.message_length()).log2(), target_list_size.log2(), ), - OodMode::ZeroKnowledge { c_zk_log_inv_rate } => { + Branch::ZeroKnowledge(c_zk_log_inv_rate) => { let l_zk = source .mask_length() .saturating_add(t_ood) diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 530d6296..2aaf58f3 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -32,7 +32,7 @@ pub fn solve( ) -> Result, DeriveError> { let (mask_oracle, output_mode) = match mode { SolveMode::Standard => (None, code_switch::CodeSwitchMode::Standard), - SolveMode::ZeroKnowledge { mask_oracle } => { + SolveMode::ZeroKnowledge(mask_oracle) => { let l_zk = mask_oracle.l_zk.get(); assert!( l_zk >= source.mask_length().saturating_add(t_ood), @@ -111,7 +111,8 @@ mod tests { use super::*; use crate::protocols::params::{ - build_round::{compute_l_zk, solve_t_ood, OodMode}, + branch::OodMode, + build_round::{compute_l_zk, solve_t_ood}, irs_commit as irs_params, spec::{ DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, @@ -296,7 +297,7 @@ mod tests { source, target, t_ood, - SolveMode::ZeroKnowledge { mask_oracle }, + SolveMode::ZeroKnowledge(mask_oracle), 0, ) .unwrap(); @@ -358,7 +359,7 @@ mod tests { source, target, t_ood, - SolveMode::ZeroKnowledge { mask_oracle }, + SolveMode::ZeroKnowledge(mask_oracle), 0, ); } @@ -405,9 +406,7 @@ mod tests { &spec, &source_ctx, target_list_size, - OodMode::ZeroKnowledge { - c_zk_log_inv_rate: f64::from(source_ctx.log_inv_rate), - }, + OodMode::ZeroKnowledge(f64::from(source_ctx.log_inv_rate)), 0, ) .unwrap(); @@ -426,7 +425,7 @@ mod tests { source, target, t_ood, - SolveMode::ZeroKnowledge { mask_oracle }, + SolveMode::ZeroKnowledge(mask_oracle), 0, ) .unwrap(); diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 23d05a02..ba9c9e82 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -4,7 +4,8 @@ use crate::{ algebra::embedding::Embedding, protocols::params::{ basecase as basecase_params, - build_round::{build_round_config, RoundBuildMode}, + branch::{Branch, RoundBuildMode, RoundBuildPayload}, + build_round::build_round_config, error::DeriveError, layout::{round_layout, RoundLayout}, protocol_config::{ProtocolConfig, RoundConfig}, @@ -22,12 +23,12 @@ impl ProtocolConfig { basecase_log_inv_rate, } = round_layout(&tuning); - let mode = match spec.mode { - Mode::Standard => RoundBuildMode::Standard, - Mode::ZeroKnowledge => RoundBuildMode::ZeroKnowledge { + let mode: RoundBuildMode<'_> = match spec.mode { + Mode::Standard => Branch::Standard, + Mode::ZeroKnowledge => Branch::ZeroKnowledge(RoundBuildPayload { zk_spec: ZkSpec::try_new(&spec).expect("matched Mode::ZeroKnowledge above"), c_zk_log_inv_rate: LogInvRate::new(tuning.starting_log_inv_rate), - }, + }), }; let rounds: Vec> = shapes diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index c5a2bcf8..b1b7f0aa 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -6,6 +6,7 @@ pub(crate) mod basecase; pub(crate) mod bounds; +pub(crate) mod branch; pub(crate) mod build_round; pub(crate) mod code_switch; pub mod derive; @@ -21,6 +22,7 @@ pub(crate) mod sumcheck; #[cfg(test)] pub(crate) mod test_utils; +pub use branch::{Branch, SolveMode}; pub use error::{ChainSource, ChainTarget, DeriveError, Pow}; pub use protocol_config::{ MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, @@ -29,10 +31,3 @@ pub use spec::{ DecodingRegime, FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, PowBudget, RoundContext, SecuritySpec, TuningSpec, ZkSpec, }; - -/// Solver-input mode for the per-round sumcheck and code-switch builders. -#[derive(Clone, Copy)] -pub enum SolveMode { - Standard, - ZeroKnowledge { mask_oracle: MaskOracleInfo }, -} diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 8d4c0bc9..8b4e0182 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -213,7 +213,9 @@ pub struct RoundConfig { round_index: usize, sumcheck: SumcheckConfig, code_switch: CodeSwitchConfig, - mode: RoundMode, + mode: RoundMode, + /// `Some` iff `mode.is_zk()`. Sized for this round's `k + 1` masks. + mask_oracle: Option>, } impl RoundConfig { @@ -221,13 +223,15 @@ impl RoundConfig { round_index: usize, sumcheck: SumcheckConfig, code_switch: CodeSwitchConfig, - mode: RoundMode, + mode: RoundMode, + mask_oracle: Option>, ) -> Self { Self { round_index, sumcheck, code_switch, mode, + mask_oracle, } } @@ -243,16 +247,13 @@ impl RoundConfig { &self.code_switch } - pub const fn mode(&self) -> &RoundMode { + pub const fn mode(&self) -> &RoundMode { &self.mode } /// Borrow the round's mask oracle if this is a ZK round. - pub fn mask_oracle(&self) -> Option<&MaskOracleConfig> { - match &self.mode { - RoundMode::Standard => None, - RoundMode::ZeroKnowledge { mask_oracle, .. } => Some(mask_oracle.as_ref()), - } + pub const fn mask_oracle(&self) -> Option<&MaskOracleConfig> { + self.mask_oracle.as_ref() } /// Slim mask-oracle view derived from `mask_oracle()`. @@ -263,19 +264,18 @@ impl RoundConfig { /// Standard vs. ZK round. /// -/// The ZK payload is boxed so the enum stays small. -#[derive(Clone, Debug)] -pub enum RoundMode { +/// Non-generic — the per-round `MaskOracleConfig` lives on +/// [`RoundConfig`] as a sibling field. +#[derive(Clone, Copy, Debug)] +pub enum RoundMode { Standard, ZeroKnowledge { /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). t_ood: OodSampleBudget, - /// Per-round mask oracle. - mask_oracle: Box>, }, } -impl RoundMode { +impl RoundMode { pub const fn is_zk(&self) -> bool { matches!(self, Self::ZeroKnowledge { .. }) } diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 43ca202a..3c86d8bb 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -27,7 +27,7 @@ pub fn solve( ) -> Result, DeriveError> { let (mask_oracle, output_mode) = match mode { SolveMode::Standard => (None, sumcheck::SumcheckMode::Standard), - SolveMode::ZeroKnowledge { mask_oracle } => ( + SolveMode::ZeroKnowledge(mask_oracle) => ( Some(mask_oracle), sumcheck::SumcheckMode::ZeroKnowledge { mask_length: zk_mask_length(), @@ -126,7 +126,7 @@ mod tests { &spec, &ctx, &source_irs, - SolveMode::ZeroKnowledge { mask_oracle }, + SolveMode::ZeroKnowledge(mask_oracle), Pow::RoundSumcheck { index: 0 }, ) .unwrap(); @@ -218,9 +218,7 @@ mod tests { let source_irs = build_source_irs(&spec, &ctx); let pow = Pow::RoundSumcheck { index: 0 }; let mode = build_minimal_mask_oracle(&spec) - .map_or(SolveMode::Standard, |mask_oracle| { - SolveMode::ZeroKnowledge { mask_oracle } - }); + .map_or(SolveMode::Standard, SolveMode::ZeroKnowledge); let config = solve(&spec, &ctx, &source_irs, mode, pow).unwrap(); prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); } @@ -249,9 +247,7 @@ mod tests { let mask_oracle = build_minimal_mask_oracle(&spec); let error = analytic_error_bits(&source_irs, mask_oracle); let pow = Pow::RoundSumcheck { index: 0 }; - let mode = mask_oracle.map_or(SolveMode::Standard, |mask_oracle| { - SolveMode::ZeroKnowledge { mask_oracle } - }); + let mode = mask_oracle.map_or(SolveMode::Standard, SolveMode::ZeroKnowledge); let config = solve(&spec, &ctx, &source_irs, mode, pow).unwrap(); assert_pow_closes_gap(&spec, error, &config.round_pow); } @@ -271,7 +267,7 @@ mod tests { &spec, &ctx, &source_irs, - SolveMode::ZeroKnowledge { mask_oracle: info }, + SolveMode::ZeroKnowledge(info), Pow::RoundSumcheck { index: 0 }, ) .unwrap(); diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index af835293..cef71ce6 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -14,7 +14,8 @@ use crate::{ protocols::{ irs_commit::Config as IrsConfig, params::{ - build_round::{solve_t_ood, OodMode}, + branch::OodMode, + build_round::solve_t_ood, irs_commit as irs_params, protocol_config::MaskOracleInfo, spec::{ @@ -156,8 +157,8 @@ pub fn build_round_io( let target_list_size = spec .decoding_regime .list_size_estimate(target_log_degree, f64::from(target_log_inv_rate)); - let ood_mode = c_zk_log_inv_rate.map_or(OodMode::Standard, |rate| OodMode::ZeroKnowledge { - c_zk_log_inv_rate: f64::from(rate), + let ood_mode = c_zk_log_inv_rate.map_or(OodMode::Standard, |rate| { + OodMode::ZeroKnowledge(f64::from(rate)) }); let (source, t_ood) = solve_t_ood::(spec, &source_ctx, target_list_size, ood_mode, 0) .expect("solve_t_ood diverged in test fixture"); From b56a847877645936a90f3e9412d96e67ccd22b77 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 30 May 2026 08:11:31 +0530 Subject: [PATCH 29/31] feat : added with_recorded_analytic helper for each sub protocol config --- src/protocols/basecase.rs | 10 +++ src/protocols/code_switch.rs | 10 +++ src/protocols/mask_proximity.rs | 9 ++ src/protocols/params/basecase.rs | 32 +++---- src/protocols/params/code_switch.rs | 11 +-- src/protocols/params/derive.rs | 47 +++++++++++ src/protocols/params/error.rs | 25 ++++++ src/protocols/params/mask_proximity.rs | 2 +- src/protocols/params/protocol_config.rs | 106 ++++++++++++++++++++++++ src/protocols/params/sumcheck.rs | 6 +- src/protocols/sumcheck.rs | 10 +++ 11 files changed, 244 insertions(+), 24 deletions(-) diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 8c8ef5f5..61e98766 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -41,6 +41,9 @@ pub struct Config { pub sumcheck: sumcheck::Config, pub mode: BasecaseMode, pub pow: proof_of_work::Config, + /// γ-combination analytic floor recorded by `params::basecase::solve` + /// (ZK only). `None` for Standard mode or ad-hoc construction. + pub recorded_analytic: Option, } impl Config { @@ -64,9 +67,15 @@ impl Config { sumcheck, mode, pow, + recorded_analytic: None, } } + pub const fn with_recorded_analytic(mut self, analytic: crate::bits::Bits) -> Self { + self.recorded_analytic = Some(analytic); + self + } + pub const fn size(&self) -> usize { self.sumcheck.initial_size } @@ -296,6 +305,7 @@ mod tests { size.next_power_of_two().trailing_zeros() as usize, sumcheck::SumcheckMode::Standard, ), + recorded_analytic: None, mode: if is_zk { BasecaseMode::ZeroKnowledge } else { diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 99921ef4..84cd668f 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -47,6 +47,10 @@ pub struct Config { pub mode: CodeSwitchMode, pub out_domain_samples: usize, pub pow: proof_of_work::Config, + /// Analytic-error floor recorded by `params::code_switch::solve`. `None` + /// for configs built via ad-hoc paths. Drift checks compare against a + /// recompute. + pub recorded_analytic: Option, } /// Prover output from the code-switch. @@ -135,9 +139,15 @@ impl Config { mode, out_domain_samples, pow, + recorded_analytic: None, } } + pub const fn with_recorded_analytic(mut self, analytic: crate::bits::Bits) -> Self { + self.recorded_analytic = Some(analytic); + self + } + /// Mask oracle length `ℓ_zk`. Returns 0 in Standard mode. pub const fn message_mask_length(&self) -> usize { match &self.mode { diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index c5d90b67..9b63ef0f 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -69,6 +69,9 @@ pub struct Config { pub c_zk_commit: IrsConfig>, pub num_masks: usize, pub pow: proof_of_work::Config, + /// γ-combination analytic floor recorded by `params::mask_proximity::solve`. + /// `None` for ad-hoc construction. + pub recorded_analytic: Option, } /// Prover output from the commit phase. @@ -100,9 +103,15 @@ impl Config { c_zk_commit, num_masks, pow, + recorded_analytic: None, } } + pub const fn with_recorded_analytic(mut self, analytic: crate::bits::Bits) -> Self { + self.recorded_analytic = Some(analytic); + self + } + /// Commit all masks and their mask-of-masks in a single shared tree. /// /// Samples n fresh mask-of-mask polynomials, combines them with the diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs index dfa221d1..fabcbb6c 100644 --- a/src/protocols/params/basecase.rs +++ b/src/protocols/params/basecase.rs @@ -33,33 +33,37 @@ pub fn solve( }; let commit = irs_params::solve(spec, &ctx, OodSampleBudget::ZERO); - let sumcheck_pow = grind_to_at( - spec, - sumcheck_params::analytic_error_bits(&commit, None), - Pow::BasecaseSumcheck, - )?; + let sumcheck_analytic = sumcheck_params::analytic_error_bits(&commit, None); + let sumcheck_pow = grind_to_at(spec, sumcheck_analytic, Pow::BasecaseSumcheck)?; let sumcheck = SumcheckConfig::new( vector_size, sumcheck_pow, vector_size.next_power_of_two().trailing_zeros() as usize, sumcheck::SumcheckMode::Standard, - ); + ) + .with_recorded_analytic(sumcheck_analytic); let mode = match spec.mode { SpecMode::Standard => basecase::BasecaseMode::Standard, SpecMode::ZeroKnowledge => basecase::BasecaseMode::ZeroKnowledge, }; - let pow = match mode { - basecase::BasecaseMode::Standard => PowConfig::none(), - basecase::BasecaseMode::ZeroKnowledge => grind_to_at( - spec, - analytic_error_bits(&commit), - Pow::BasecaseGammaCombination, - )?, + let (pow, gamma_analytic) = match mode { + basecase::BasecaseMode::Standard => (PowConfig::none(), None), + basecase::BasecaseMode::ZeroKnowledge => { + let a = analytic_error_bits(&commit); + ( + grind_to_at(spec, a, Pow::BasecaseGammaCombination)?, + Some(a), + ) + } }; - Ok(BasecaseConfig::new(commit, sumcheck, mode, pow)) + let mut cfg = BasecaseConfig::new(commit, sumcheck, mode, pow); + if let Some(a) = gamma_analytic { + cfg = cfg.with_recorded_analytic(a); + } + Ok(cfg) } /// γ-combination soundness (Lemma 7.4 combination-randomness slot, paper p.45). diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index 2aaf58f3..bedf8036 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -52,13 +52,10 @@ pub fn solve( let analytic = analytic_error_bits(&source, &target, t_ood, mask_oracle); let pow = grind_to_at(spec, analytic, Pow::RoundCodeSwitch { index: round_index })?; - Ok(CodeSwitchConfig::new( - source, - target, - t_ood, - output_mode, - pow, - )) + Ok( + CodeSwitchConfig::new(source, target, t_ood, output_mode, pow) + .with_recorded_analytic(analytic), + ) } /// Per-round code-switch soundness in bits: `min` over Lemma 9.9's three RBR diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index ba9c9e82..9b859c12 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -380,6 +380,53 @@ mod tests { assert!(!plan.check_all_invariants()); } + #[test] + fn validate_security_target_met_passes_on_fresh_plan() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + plan.validate_security_target_met() + .expect("fresh plan must satisfy per-slot target check"); + } + + #[test] + fn validate_security_target_met_catches_recorded_analytic_drift() { + use crate::bits::Bits; + let spec = test_spec(Mode::ZeroKnowledge); + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "need a round to corrupt"); + let recorded = plan + .rounds() + .first() + .and_then(|r| r.sumcheck().recorded_analytic) + .expect("params solver records sumcheck analytic"); + // Bump the recorded value far from the recompute → triggers drift. + plan.corrupt_round_sumcheck_recorded_analytic_for_test( + 0, + Bits::new(f64::from(recorded) + 10.0), + ); + let err = plan + .validate_security_target_met() + .expect_err("recorded vs recompute mismatch must trip drift check"); + assert!( + matches!( + err, + DeriveError::AnalyticDrift { + pow: crate::protocols::params::error::Pow::RoundSumcheck { index: 0 }, + .. + } + ), + "got {err:?}", + ); + } + #[test] fn derive_reports_pow_ungrindable() { const UNREACHABLE_TARGET_BITS: u32 = 200; diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index d1211e1d..c2d522f8 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -112,6 +112,31 @@ pub enum DeriveError { expected: usize, found: usize, }, + + /// A PoW slot's `analytic + pow.difficulty()` is below `target_security_bits`. + /// `grind_to_at` guarantees this at construction; this error fires only if + /// the analytic-error formulas applied at validate-time disagree with what + /// the per-protocol `solve` functions consumed (e.g., a planner regression + /// drifts away from the actual configured IRS rate). Catches the case where + /// the rate-schedule plumbing under-reports a per-slot rate. + #[error("{pow} soundness gap: analytic {analytic} + pow {pow_bits} < target {target}")] + SecurityTargetNotMet { + pow: Pow, + analytic: Bits, + pow_bits: Bits, + target: Bits, + }, + + /// The analytic floor recorded at solve time disagrees with a fresh + /// recompute from the same config's state. Indicates that the inputs to + /// the `analytic_error_bits` formula drifted between solve and validate + /// (e.g., an IRS field was overwritten after construction). + #[error("{pow} analytic drift: recorded {recorded} vs recompute {recompute}")] + AnalyticDrift { + pow: Pow, + recorded: Bits, + recompute: Bits, + }, } /// Lift `Result` into `Result` by attaching a diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs index f60b0476..4293796e 100644 --- a/src/protocols/params/mask_proximity.rs +++ b/src/protocols/params/mask_proximity.rs @@ -30,7 +30,7 @@ pub fn solve( analytic, Pow::RoundMaskProximity { index: round_index }, )?; - Ok(MaskProximityConfig::new(c_zk, num_masks, pow)) + Ok(MaskProximityConfig::new(c_zk, num_masks, pow).with_recorded_analytic(analytic)) } /// γ-combination soundness (Lemma 7.4): diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs index 8b4e0182..41c8fb4b 100644 --- a/src/protocols/params/protocol_config.rs +++ b/src/protocols/params/protocol_config.rs @@ -14,9 +14,11 @@ use crate::{ irs_commit::Config as IrsConfig, mask_proximity::Config as MaskProximityConfig, params::{ + basecase as basecase_params, bounds::usize_to_f64, code_switch as code_switch_params, error::{ChainSource, ChainTarget, DeriveError, Pow}, + mask_proximity as mask_proximity_params, spec::{ListSize, MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, sumcheck as sumcheck_params, }, @@ -78,6 +80,102 @@ impl ProtocolConfig { pub fn validate(&self) -> Result<(), DeriveError> { self.validate_pow_budget()?; self.validate_round_chaining()?; + self.validate_security_target_met()?; + Ok(()) + } + + /// For each PoW slot: verify (a) the analytic-bits floor recorded at + /// solve time still matches a fresh recompute from the config's current + /// state, and (b) `recorded_analytic + pow.difficulty() ≥ target_security_bits`. + /// + /// `grind_to_at` guarantees (b) at solve time. If (a) holds, (b) holds + /// trivially. If (a) drifts, (b) may fail — most often because a planner + /// regression overwrote an IRS field after the solver consumed it. + /// + /// `EPS` matches the `assert_pow_closes_gap` slack used by the per-slot + /// proptest helper, so validation stays consistent with test-time + /// assertions. + pub fn validate_security_target_met(&self) -> Result<(), DeriveError> { + const EPS: f64 = 1e-3; + let target = Bits::new(f64::from(self.security.target_security_bits)); + let check = |pow_kind: Pow, + recorded: Option, + recompute: Bits, + pow_cfg: &PowConfig| + -> Result<(), DeriveError> { + if let Some(recorded) = recorded { + if (f64::from(recorded) - f64::from(recompute)).abs() > EPS { + return Err(DeriveError::AnalyticDrift { + pow: pow_kind, + recorded, + recompute, + }); + } + } + let analytic = recorded.unwrap_or(recompute); + let pow_bits = pow_cfg.difficulty(); + let sum = f64::from(analytic) + f64::from(pow_bits); + if sum + EPS < f64::from(target) { + return Err(DeriveError::SecurityTargetNotMet { + pow: pow_kind, + analytic, + pow_bits, + target, + }); + } + Ok(()) + }; + for r in &self.rounds { + let mask_info = r.mask_oracle_info(); + check( + Pow::RoundSumcheck { + index: r.round_index, + }, + r.sumcheck.recorded_analytic, + sumcheck_params::analytic_error_bits(&r.code_switch.source, mask_info), + &r.sumcheck.round_pow, + )?; + check( + Pow::RoundCodeSwitch { + index: r.round_index, + }, + r.code_switch.recorded_analytic, + code_switch_params::analytic_error_bits( + &r.code_switch.source, + &r.code_switch.target, + r.code_switch.out_domain_samples, + mask_info, + ), + &r.code_switch.pow, + )?; + if let Some(mo) = r.mask_oracle() { + check( + Pow::RoundMaskProximity { + index: r.round_index, + }, + mo.mask_proximity.recorded_analytic, + mask_proximity_params::analytic_error_bits( + &mo.mask_proximity.c_zk_commit, + mo.mask_proximity.num_masks, + ), + &mo.mask_proximity.pow, + )?; + } + } + check( + Pow::BasecaseSumcheck, + self.basecase.sumcheck.recorded_analytic, + sumcheck_params::analytic_error_bits(&self.basecase.commit, None), + &self.basecase.sumcheck.round_pow, + )?; + if self.basecase.is_zk() { + check( + Pow::BasecaseGammaCombination, + self.basecase.recorded_analytic, + basecase_params::analytic_error_bits(&self.basecase.commit), + &self.basecase.pow, + )?; + } Ok(()) } @@ -206,6 +304,14 @@ impl ProtocolConfig { ) { self.rounds[round_idx].code_switch.target.vector_size = new_size; } + + pub(crate) fn corrupt_round_sumcheck_recorded_analytic_for_test( + &mut self, + round_idx: usize, + new_value: Bits, + ) { + self.rounds[round_idx].sumcheck.recorded_analytic = Some(new_value); + } } #[derive(Clone, Debug)] diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 3c86d8bb..1a1120a7 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -34,13 +34,15 @@ pub fn solve( }, ), }; - let round_pow = grind_to_at(spec, analytic_error_bits(source_irs, mask_oracle), pow)?; + let analytic = analytic_error_bits(source_irs, mask_oracle); + let round_pow = grind_to_at(spec, analytic, pow)?; Ok(SumcheckConfig::new( ctx.vector_size, round_pow, num_sumcheck_rounds(ctx), output_mode, - )) + ) + .with_recorded_analytic(analytic)) } /// Per-sumcheck-round soundness in bits: `min(ε_mca, poly_identity_term)`. diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index 626e040f..b6772baa 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -71,6 +71,10 @@ where pub round_pow: proof_of_work::Config, pub num_rounds: usize, pub mode: SumcheckMode, + /// Analytic-error floor recorded by the params solver that produced this + /// config. `None` when the config wasn't built via `params::sumcheck::solve` + /// (legacy/test paths). Drift checks compare this against a recompute. + pub recorded_analytic: Option, } impl Config { @@ -95,9 +99,15 @@ impl Config { round_pow, num_rounds, mode, + recorded_analytic: None, } } + pub const fn with_recorded_analytic(mut self, analytic: crate::bits::Bits) -> Self { + self.recorded_analytic = Some(analytic); + self + } + const fn mask_length(&self) -> usize { match &self.mode { SumcheckMode::Standard => 0, From fb5b160dacc0e70fa906ec0ee77137be7c6a2b5b Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 30 May 2026 08:16:02 +0530 Subject: [PATCH 30/31] added eror variants --- src/protocols/params/derive.rs | 2 +- src/protocols/params/error.rs | 8 +++++ src/protocols/params/layout.rs | 62 ++++++++++++++++++++++++++++------ 3 files changed, 61 insertions(+), 11 deletions(-) diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 9b859c12..39734308 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -21,7 +21,7 @@ impl ProtocolConfig { shapes, basecase_vector_size, basecase_log_inv_rate, - } = round_layout(&tuning); + } = round_layout(&tuning)?; let mode: RoundBuildMode<'_> = match spec.mode { Mode::Standard => Branch::Standard, diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index c2d522f8..f44ca022 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -137,6 +137,14 @@ pub enum DeriveError { recorded: Bits, recompute: Bits, }, + + /// `tuning.vector_size` must be a power of 2. + #[error("tuning.vector_size ({vector_size}) must be a power of 2")] + TuningVectorSizeNotPowerOfTwo { vector_size: usize }, + + /// `tuning.folding_factor` must yield at least 1 at every round. + #[error("tuning.folding_factor min ({min}) must be ≥ 1")] + TuningFoldingFactorBelowOne { min: usize }, } /// Lift `Result` into `Result` by attaching a diff --git a/src/protocols/params/layout.rs b/src/protocols/params/layout.rs index 122e8938..79df2f0c 100644 --- a/src/protocols/params/layout.rs +++ b/src/protocols/params/layout.rs @@ -9,7 +9,10 @@ use crate::{ algebra::embedding::Embedding, protocols::{ irs_commit::Config as IrsConfig, - params::spec::{RoundContext, TuningSpec}, + params::{ + error::DeriveError, + spec::{RoundContext, TuningSpec}, + }, }, }; @@ -22,15 +25,23 @@ pub(super) struct RoundShape { pub(super) target_folding_factor: u32, } +#[derive(Debug)] pub(super) struct RoundLayout { pub(super) shapes: Vec, pub(super) basecase_vector_size: usize, pub(super) basecase_log_inv_rate: u32, } -pub(super) fn round_layout(tuning: &TuningSpec) -> RoundLayout { - assert!(tuning.vector_size.is_power_of_two()); - assert!(tuning.folding_factor.min() >= 1); +pub(super) fn round_layout(tuning: &TuningSpec) -> Result { + if !tuning.vector_size.is_power_of_two() { + return Err(DeriveError::TuningVectorSizeNotPowerOfTwo { + vector_size: tuning.vector_size, + }); + } + let min_folding = tuning.folding_factor.min(); + if min_folding < 1 { + return Err(DeriveError::TuningFoldingFactorBelowOne { min: min_folding }); + } let mut num_vars = tuning.vector_size.trailing_zeros() as usize; let mut log_inv_rate = tuning.starting_log_inv_rate; @@ -54,11 +65,11 @@ pub(super) fn round_layout(tuning: &TuningSpec) -> RoundLayout { log_inv_rate = log_inv_rate.saturating_add((source_folding as u32).saturating_sub(1)); } - RoundLayout { + Ok(RoundLayout { shapes, basecase_vector_size: 1usize << num_vars, basecase_log_inv_rate: log_inv_rate, - } + }) } pub(super) const fn round_context(shape: &RoundShape) -> RoundContext { @@ -117,7 +128,7 @@ mod tests { rest: VARIED_STEADY_FOLDING, }, }; - let layout = round_layout(&tuning); + let layout = round_layout(&tuning).unwrap(); let mut expected_log_inv_rate = RATE_STEPPING_STARTING_LOG_INV_RATE; for shape in &layout.shapes { @@ -137,7 +148,7 @@ mod tests { rest: VARIED_STEADY_FOLDING, }, }; - let layout = round_layout(&tuning); + let layout = round_layout(&tuning).unwrap(); assert!( layout.shapes.len() >= MIN_ROUNDS_FOR_CHAINING_TEST, "need ≥ {MIN_ROUNDS_FOR_CHAINING_TEST} rounds to test chaining", @@ -153,7 +164,7 @@ mod tests { #[test] fn round_layout_basecase_size_consumes_remaining_num_vars() { let tuning = tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND); - let layout = round_layout(&tuning); + let layout = round_layout(&tuning).unwrap(); let consumed: u32 = layout.shapes.iter().map(|s| s.source_folding_factor).sum(); let initial_num_vars = tuning.vector_size.trailing_zeros(); let remaining = initial_num_vars - consumed; @@ -164,9 +175,40 @@ mod tests { fn round_layout_stops_when_no_room_for_source_plus_target() { let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; let tuning = tuning_with(vector_size); - let layout = round_layout(&tuning); + let layout = round_layout(&tuning).unwrap(); assert!(layout.shapes.is_empty()); assert_eq!(layout.basecase_vector_size, vector_size); assert_eq!(layout.basecase_log_inv_rate, FIXTURE_LOG_INV_RATE); } + + #[test] + fn round_layout_rejects_non_pow2_vector_size() { + let tuning = TuningSpec { + vector_size: 12, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + }; + let err = round_layout(&tuning).expect_err("non-pow2 vector_size must fail"); + assert!( + matches!( + err, + DeriveError::TuningVectorSizeNotPowerOfTwo { vector_size: 12 } + ), + "got {err:?}", + ); + } + + #[test] + fn round_layout_rejects_zero_folding_factor() { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(0), + }; + let err = round_layout(&tuning).expect_err("folding_factor = 0 must fail"); + assert!( + matches!(err, DeriveError::TuningFoldingFactorBelowOne { min: 0 }), + "got {err:?}", + ); + } } From 1e78338428c8721bec24c27d6a7b32b03303e9c6 Mon Sep 17 00:00:00 2001 From: ocdbytes Date: Sat, 30 May 2026 08:57:50 +0530 Subject: [PATCH 31/31] lint --- src/protocols/mask_proximity.rs | 10 +++++-- src/protocols/params/build_round.rs | 7 +++-- src/protocols/params/code_switch.rs | 32 ++++++++++++--------- src/protocols/params/derive.rs | 44 ++++++++++++++++------------- src/protocols/params/error.rs | 4 +-- src/protocols/params/mod.rs | 8 +++--- src/protocols/params/spec.rs | 2 +- src/protocols/params/sumcheck.rs | 2 +- src/protocols/params/test_utils.rs | 3 +- 9 files changed, 65 insertions(+), 47 deletions(-) diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 9b63ef0f..0be304db 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -85,6 +85,12 @@ pub struct Witness { pub type Commitment = IrsCommitment; impl Config { + /// Required `c_zk.num_vectors` for `num_masks` originals: one fresh + /// mask-of-mask per original (Construction 7.2 originals + fresh pairs). + pub const fn num_vectors_for(num_masks: usize) -> usize { + 2 * num_masks + } + pub fn new( c_zk_commit: IrsConfig>, num_masks: usize, @@ -92,7 +98,7 @@ impl Config { ) -> Self { assert_eq!( c_zk_commit.num_vectors, - 2 * num_masks, + Self::num_vectors_for(num_masks), "c_zk.num_vectors must be 2 * num_masks" ); assert_eq!( @@ -332,7 +338,7 @@ mod tests { .prop_flat_map(|(num_masks, vector_size, mask_length)| { let c_zk = IrsConfig::>::arbitrary( Identity::new(), - 2 * num_masks, + Self::num_vectors_for(num_masks), vector_size, mask_length, 1, diff --git a/src/protocols/params/build_round.rs b/src/protocols/params/build_round.rs index e3a26383..420d3ddb 100644 --- a/src/protocols/params/build_round.rs +++ b/src/protocols/params/build_round.rs @@ -11,9 +11,10 @@ use crate::{ }, protocols::{ irs_commit::Config as IrsConfig, + mask_proximity::Config as MaskProximityConfig, params::{ bounds::usize_to_f64, - branch::{Branch, OodMode, RoundBuildMode, RoundBuildPayload}, + branch::{Branch, OodMode, RoundBuildMode, RoundBuildPayload, SolveMode}, code_switch as code_switch_params, error::{DeriveError, Pow}, irs_commit as irs_params, @@ -24,7 +25,7 @@ use crate::{ DecodingRegime, LogInvRate, MaskCodeMessageLen, OodSampleBudget, RoundContext, SecuritySpec, ZkSpec, }, - sumcheck as sumcheck_params, SolveMode, + sumcheck as sumcheck_params, }, }, }; @@ -143,7 +144,7 @@ fn build_mask_oracle( l_zk, source.mask_length(), c_zk_log_inv_rate, - 2 * num_masks, + MaskProximityConfig::::num_vectors_for(num_masks), ); let c_zk_list_size_estimate = spec.decoding_regime.list_size_estimate( (l_zk.get() as f64).log2(), diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs index bedf8036..b18f3cca 100644 --- a/src/protocols/params/code_switch.rs +++ b/src/protocols/params/code_switch.rs @@ -13,10 +13,10 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::usize_to_f64, + branch::SolveMode, error::{grind_to_at, DeriveError, Pow}, protocol_config::MaskOracleInfo, spec::SecuritySpec, - SolveMode, }, }, }; @@ -107,18 +107,22 @@ mod tests { use proptest::prelude::*; use super::*; - use crate::protocols::params::{ - branch::OodMode, - build_round::{compute_l_zk, solve_t_ood}, - irs_commit as irs_params, - spec::{ - DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, - PowBudget, RoundContext, SecuritySpec, ZkSpec, - }, - test_utils::{ - arb_standard_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, assert_close, - assert_pow_closes_gap, build_round_io, deterministic_spec, TestEmbedding, - TestExtensionField, TestField, TestNonIdentityEmbedding, TEST_TARGET_RANGE, + use crate::{ + hash, + protocols::params::{ + branch::OodMode, + build_round::{compute_l_zk, solve_t_ood}, + irs_commit as irs_params, + spec::{ + DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + PowBudget, RoundContext, SecuritySpec, ZkSpec, + }, + test_utils::{ + arb_standard_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, + assert_close, assert_pow_closes_gap, build_round_io, deterministic_spec, + TestEmbedding, TestExtensionField, TestField, TestNonIdentityEmbedding, + TEST_TARGET_RANGE, + }, }, }; @@ -226,7 +230,7 @@ mod tests { decoding_regime: DecodingRegime::Johnson, target_security_bits: LIMITING_TARGET_BITS, pow_budget: PowBudget::Forbidden, - hash_id: crate::hash::BLAKE3, + hash_id: hash::BLAKE3, }; let (source, target, t_ood) = build_round_io::( &spec, diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs index 39734308..fec28001 100644 --- a/src/protocols/params/derive.rs +++ b/src/protocols/params/derive.rs @@ -49,16 +49,22 @@ mod tests { use proptest::prelude::*; use crate::{ - algebra::{embedding::Embedding, fields::FieldWithSize}, + algebra::{ + embedding::Embedding, + fields::{Field64, FieldWithSize}, + }, hash, - protocols::params::{ - basecase as basecase_params, code_switch as code_switch_params, - error::DeriveError, - mask_proximity as mask_proximity_params, - protocol_config::{ProtocolConfig, RoundMode}, - spec::{DecodingRegime, FoldingFactor, Mode, PowBudget, SecuritySpec, TuningSpec}, - sumcheck as sumcheck_params, - test_utils::{assert_close, assert_pow_closes_gap, TestEmbedding}, + protocols::{ + basecase::BasecaseMode, + params::{ + basecase as basecase_params, code_switch as code_switch_params, + error::{ChainSource, ChainTarget, DeriveError, Pow}, + mask_proximity as mask_proximity_params, + protocol_config::{ProtocolConfig, RoundMode}, + spec::{DecodingRegime, FoldingFactor, Mode, PowBudget, SecuritySpec, TuningSpec}, + sumcheck as sumcheck_params, + test_utils::{assert_close, assert_pow_closes_gap, TestEmbedding}, + }, }, }; @@ -124,7 +130,7 @@ mod tests { assert!(plan.rounds().is_empty()); assert!(matches!( plan.basecase().mode, - crate::protocols::basecase::BasecaseMode::ZeroKnowledge + BasecaseMode::ZeroKnowledge )); } @@ -243,7 +249,7 @@ mod tests { .unwrap(); assert!(matches!( plan.basecase().mode, - crate::protocols::basecase::BasecaseMode::ZeroKnowledge + BasecaseMode::ZeroKnowledge )); assert_eq!(plan.basecase().commit.interleaving_depth, 1); assert_eq!(plan.basecase().sumcheck.final_size(), 1); @@ -262,7 +268,7 @@ mod tests { tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), ) .unwrap(); - let field_bits = ::field_size_bits(); + let field_bits = ::field_size_bits(); let mut expected_total = 0.0_f64; for r in plan.rounds() { let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { @@ -341,8 +347,8 @@ mod tests { matches!( err, DeriveError::RoundChainBroken { - from: crate::protocols::params::error::ChainSource::Round(0), - to: crate::protocols::params::error::ChainTarget::NextRound(1), + from: ChainSource::Round(0), + to: ChainTarget::NextRound(1), .. } ), @@ -371,7 +377,7 @@ mod tests { matches!( err, DeriveError::RoundChainBroken { - to: crate::protocols::params::error::ChainTarget::Basecase, + to: ChainTarget::Basecase, .. } ), @@ -419,7 +425,7 @@ mod tests { matches!( err, DeriveError::AnalyticDrift { - pow: crate::protocols::params::error::Pow::RoundSumcheck { index: 0 }, + pow: Pow::RoundSumcheck { index: 0 }, .. } ), @@ -608,7 +614,7 @@ mod tests { ); if matches!( plan.basecase().mode, - crate::protocols::basecase::BasecaseMode::ZeroKnowledge + BasecaseMode::ZeroKnowledge ) { assert_pow_closes_gap( spec, @@ -646,7 +652,7 @@ mod tests { } prop_assert!(matches!( plan.basecase().mode, - crate::protocols::basecase::BasecaseMode::Standard + BasecaseMode::Standard )); prop_assert_eq!(plan.basecase().commit.interleaving_depth, 1); } @@ -676,7 +682,7 @@ mod tests { } prop_assert!(matches!( plan.basecase().mode, - crate::protocols::basecase::BasecaseMode::ZeroKnowledge + BasecaseMode::ZeroKnowledge )); } diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs index f44ca022..1e54af19 100644 --- a/src/protocols/params/error.rs +++ b/src/protocols/params/error.rs @@ -149,7 +149,7 @@ pub enum DeriveError { /// Lift `Result` into `Result` by attaching a /// [`Pow`] label. -pub(crate) trait PowResultExt { +pub trait PowResultExt { fn at(self, pow: Pow) -> Result; } @@ -161,7 +161,7 @@ impl PowResultExt for Result { /// Grind `analytic → spec.target_security_bits`, then check the result against /// `spec.pow_budget`. -pub(crate) fn grind_to_at( +pub fn grind_to_at( spec: &SecuritySpec, analytic: Bits, pow_kind: Pow, diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs index b1b7f0aa..6f7ddf7a 100644 --- a/src/protocols/params/mod.rs +++ b/src/protocols/params/mod.rs @@ -9,14 +9,14 @@ pub(crate) mod bounds; pub(crate) mod branch; pub(crate) mod build_round; pub(crate) mod code_switch; -pub mod derive; -pub mod error; +pub(crate) mod derive; +pub(crate) mod error; pub(crate) mod irs_commit; pub(crate) mod layout; pub(crate) mod mask_proximity; -pub mod protocol_config; +pub(crate) mod protocol_config; pub(crate) mod regime; -pub mod spec; +pub(crate) mod spec; pub(crate) mod sumcheck; #[cfg(test)] diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs index 2b4e203b..42ff3104 100644 --- a/src/protocols/params/spec.rs +++ b/src/protocols/params/spec.rs @@ -1,4 +1,4 @@ -use core::{ +use std::{ fmt::{self, Display, Formatter}, marker::PhantomData, num::NonZeroU32, diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs index 1a1120a7..68b263bf 100644 --- a/src/protocols/params/sumcheck.rs +++ b/src/protocols/params/sumcheck.rs @@ -8,10 +8,10 @@ use crate::{ irs_commit::Config as IrsConfig, params::{ bounds::usize_to_f64, + branch::SolveMode, error::{grind_to_at, DeriveError, Pow}, protocol_config::MaskOracleInfo, spec::{RoundContext, SecuritySpec}, - SolveMode, }, sumcheck::{self, Config as SumcheckConfig, SumcheckMaskLen}, }, diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs index cef71ce6..ae3a0604 100644 --- a/src/protocols/params/test_utils.rs +++ b/src/protocols/params/test_utils.rs @@ -13,6 +13,7 @@ use crate::{ hash, protocols::{ irs_commit::Config as IrsConfig, + mask_proximity::Config as MaskProximityConfig, params::{ branch::OodMode, build_round::solve_t_ood, @@ -134,7 +135,7 @@ pub fn build_test_c_zk( MaskCodeMessageLen::new(l_zk), 0, LogInvRate::new(log_inv_rate), - 2 * num_masks, + MaskProximityConfig::::num_vectors_for(num_masks), ) }