diff --git a/Cargo.lock b/Cargo.lock index acc6ca37..62e5a0d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1360,6 +1360,26 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -1556,6 +1576,7 @@ dependencies = [ "sha3", "spongefish", "static_assertions", + "thiserror", "tracing", "tracing-subscriber", "zerocopy", diff --git a/Cargo.toml b/Cargo.toml index 515ee2b8..52fc40e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ const-oid = "0.9.6" arrayvec = "0.7.6" derive-where = { version = "1.6.0", features = ["safe"] } ordered-float = { version = "5.1.0", features = ["serde"] } +thiserror = "2.0" [dev-dependencies] proptest = "1.0" diff --git a/proptest-regressions/protocols/params/basecase.txt b/proptest-regressions/protocols/params/basecase.txt new file mode 100644 index 00000000..c3b26074 --- /dev/null +++ b/proptest-regressions/protocols/params/basecase.txt @@ -0,0 +1,10 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc b6e8ae0b3e6a9769901e0e0e489da34965bf0a8df7dd049aef66e0541bf10baf # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (1, 1) +cc a2f771fc5031440200810b95ea2d347da895f8eb2e1a87f53fd69ad224287e84 # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (2, 2) +cc f66c89bc700c79bca5f4b7234f1345129962c78e5a2036a6430564f615f19b30 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, decoding_regime: Johnson, target_security_bits: 30, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (1, 1) +cc 312fec8a96f6f55f5d3c0346bdb85690f23150aadf88888d453041e99b05d414 # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 39, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_size, log_inv_rate) = (1, 1) diff --git a/proptest-regressions/protocols/params/code_switch.txt b/proptest-regressions/protocols/params/code_switch.txt new file mode 100644 index 00000000..de6c40cf --- /dev/null +++ b/proptest-regressions/protocols/params/code_switch.txt @@ -0,0 +1,13 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 7a7df094ea650db7a295d162b75dd9da9b52d1fc36947d2b07df8150cd9d906f # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 +cc b42c982074a04c7110df07cf00f45156607be547e176b1ddd5f9d994ad491ddb # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 +cc eaf09a2b6bdffa86026264679f008326498ca800260dd2f17d4370df9fb3f801 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, log_inv_rate = 1, folding_factor = 3, num_vars = 4 +cc 3887a5fa698c99109e8262e843dbd24ea94b9c9d420791e4520b5c9211a3eca0 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 100, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 7) +cc b3e128084f721e6f43e263e05acf2e2de6fcd05dccf3811f063eeb0b63d78f8e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 47, max_pow_bits: Some(15), hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (3, 2, 4) +cc b71da9002ceac9e4a74af097a7b087557a5b916fe8da47e39c4682375d749f88 # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 50, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (1, 2, 4) +cc 1981509d857e56772dd4a79f8692619e968891aa3d84576ea1857f6d9a484a2d # shrinks to spec = SecuritySpec { mode: Standard, decoding_regime: Johnson, target_security_bits: 50, pow_budget: Forbidden, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, (log_inv_rate, folding_factor, num_vars) = (1, 2, 4) diff --git a/proptest-regressions/protocols/params/derive.txt b/proptest-regressions/protocols/params/derive.txt new file mode 100644 index 00000000..ca83c1b3 --- /dev/null +++ b/proptest-regressions/protocols/params/derive.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 104921a4117ed8255308c1ea5d3e12c72356ef72ef0d93fc0f24ed29f93fdd3a # shrinks to tuning = TuningSpec { vector_size: 32, starting_log_inv_rate: 3, folding_factor: Constant(1) } diff --git a/proptest-regressions/protocols/params/irs_commit.txt b/proptest-regressions/protocols/params/irs_commit.txt new file mode 100644 index 00000000..4d35f3df --- /dev/null +++ b/proptest-regressions/protocols/params/irs_commit.txt @@ -0,0 +1,8 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 0b6dd03179c9a4e38b29b34b241b88fba69348a2c8938af7253314b7035bea82 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 4, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 }, out_domain = 0, seed = 0 +cc 7e49f7a2d53f55cfa2f09114d17ab4123678b45ddf69e0cfbc646b246de2f042 # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 80, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c }, ctx = RoundContext { vector_size: 128, log_inv_rate: 2, folding_factor: 2 }, out_domain = 11 diff --git a/proptest-regressions/protocols/params/sumcheck.txt b/proptest-regressions/protocols/params/sumcheck.txt new file mode 100644 index 00000000..d6f6e6ed --- /dev/null +++ b/proptest-regressions/protocols/params/sumcheck.txt @@ -0,0 +1,12 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 0ffdc71948ed0315f4cf55fb8f2dd25bf71f7e41f53cd4fe35ee9da6fb125a20 # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 86, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 32, log_inv_rate: 2, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc e8ab6549772cf6bf4c3af116ebcba3dbf295ffbe2aee4a94be7df4b9f45d61ec # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 91, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: Some(10), hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc 8c4300cc375640956f81e9da5aef9ea11ef476ddc4dd253dc560afa07609262d # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 98, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 1, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc 8ea40f13c63b4c0021386369ce698a5d9289381a39dc85db43d2d69b9b4877bb # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 88, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 16, log_inv_rate: 3, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc f1dca600886474c74d857c547baea0c2b4faf45b2946036f21a008106396eb1c # shrinks to spec = SecuritySpec { mode: Standard { unique_decoding: false }, target_security_bits: 80, vector_size: 256, starting_log_inv_rate: 1, initial_folding_factor: 4, folding_factor: 3, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 256, log_inv_rate: 3, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } +cc 36d0f5929e8099fa8644b0511229cf11634e5a7a66d99c06099c304f5f7a8c6e # shrinks to spec = SecuritySpec { mode: ZeroKnowledge, target_security_bits: 47, max_pow_bits: None, hash_id: 03e01749ebcc0477924254eb482066b864a8dd4d77252464ca6f5b6f5cc05b4c, _embedding: PhantomData, 1>>> }, ctx = RoundContext { round_index: 0, vector_size: 128, log_inv_rate: 4, folding_factor: 1, prev_round_in_domain_samples: 0, prev_round_query_error: 0.0 } diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index d9e79ce6..77574c42 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -18,6 +18,7 @@ use whir::{ cmdline_utils::{AvailableFields, AvailableHash}, hash::HASH_COUNTER, parameters::ProtocolParameters, + protocols::params::DecodingRegime, transcript::{codecs::Empty, Codec, DomainSeparator, ProverState, VerifierState}, }; @@ -48,8 +49,8 @@ struct Args { #[arg(short = 'k', long = "fold", default_value = "4")] folding_factor: usize, - #[arg(long = "unique-decoding", default_value_t = false)] - unique_decoding: bool, + #[arg(long = "decoding-regime", default_value = "Johnson")] + decoding_regime: DecodingRegime, #[arg(short = 'f', long = "field", default_value = "Goldilocks3")] field: AvailableFields, @@ -67,7 +68,7 @@ struct BenchmarkOutput { repetitions: usize, initial_folding_factor: usize, folding_factor: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, field: AvailableFields, hash: AvailableHash, @@ -117,7 +118,7 @@ where let reps = args.verifier_repetitions; let folding_factor = args.folding_factor; let first_round_folding_factor = args.first_round_folding_factor; - let unique_decoding = args.unique_decoding; + let decoding_regime = args.decoding_regime; std::fs::create_dir_all("outputs").unwrap(); @@ -128,7 +129,7 @@ where pow_bits, initial_folding_factor: first_round_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: starting_rate, batch_size: 1, hash_id: args.hash.hash_id(), @@ -298,7 +299,7 @@ where repetitions: reps, initial_folding_factor: first_round_folding_factor, folding_factor, - unique_decoding, + decoding_regime, field: args.field, hash: args.hash, diff --git a/src/bin/main.rs b/src/bin/main.rs index bafe16f5..0faf4b36 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -14,6 +14,7 @@ use whir::{ cmdline_utils::{AvailableFields, AvailableHash}, hash::HASH_COUNTER, parameters::ProtocolParameters, + protocols::params::DecodingRegime, transcript::{codecs::Empty, Codec, DomainSeparator, ProverState, VerifierState}, }; @@ -48,9 +49,9 @@ struct Args { #[arg(short = 'k', long = "fold", default_value = "4")] folding_factor: usize, - /// Restrict PCS to the Unique Decoding regime. LDT is always UD. - #[arg(long = "unique-decoding", default_value_t = false)] - unique_decoding: bool, + /// Reed–Solomon decoding regime: Unique or Johnson (list-decoding). + #[arg(long = "decoding-regime", default_value = "Johnson")] + decoding_regime: DecodingRegime, #[arg(short = 'f', long = "field", default_value = "Goldilocks3")] field: AvailableFields, @@ -109,7 +110,7 @@ where let reps = args.verifier_repetitions; let first_round_folding_factor = args.first_round_folding_factor; let folding_factor = args.folding_factor; - let unique_decoding = args.unique_decoding; + let decoding_regime = args.decoding_regime; let num_evaluations = args.num_evaluations; let num_linear_constraints = args.num_linear_constraints; let hash_id = args.hash.hash_id(); @@ -125,7 +126,7 @@ where pow_bits, initial_folding_factor: first_round_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: starting_rate, batch_size: 1, hash_id, @@ -254,7 +255,7 @@ where let num_coeffs = 1 << num_variables; let whir_params = ProtocolParameters { - unique_decoding: args.unique_decoding, + decoding_regime: args.decoding_regime, security_level, pow_bits, initial_folding_factor: first_round_folding_factor, diff --git a/src/parameters.rs b/src/parameters.rs index 242b9749..8e257c79 100644 --- a/src/parameters.rs +++ b/src/parameters.rs @@ -2,13 +2,13 @@ use std::fmt::{Debug, Display}; use serde::{Deserialize, Serialize}; -use crate::engines::EngineId; +use crate::{engines::EngineId, protocols::params::DecodingRegime}; /// Configuration parameters for WHIR proofs. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct ProtocolParameters { - /// Whether to require unique decoding. - pub unique_decoding: bool, + /// Reed–Solomon decoding regime: `Unique` or `Johnson` (list-decoding). + pub decoding_regime: DecodingRegime, /// The logarithmic inverse rate for sampling. pub starting_log_inv_rate: usize, /// Folding factor for the initial round. @@ -30,13 +30,7 @@ impl Display for ProtocolParameters { writeln!( f, "Targeting {}-bits of security with {}-bits of PoW using {} decoding", - self.security_level, - self.pow_bits, - if self.unique_decoding { - "unique" - } else { - "list" - } + self.security_level, self.pow_bits, self.decoding_regime, )?; writeln!( f, diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 32763271..61e98766 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -1,8 +1,4 @@ -//! Base Case Linear Opening Protocol -//! -//! It support honest verifier zero-knowledge (HVZK), but is not succinct. -//! -//! § 7. +//! Non-succinct linear opening (Construction 7.2, p.43). HVZK in ZK mode. use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; @@ -15,7 +11,7 @@ use crate::{ univariate_evaluate, }, hash::Hash, - protocols::{irs_commit, sumcheck}, + protocols::{irs_commit, proof_of_work, sumcheck}, transcript::{ codecs::U64, Codec, DuplexSpongeInterface, ProverMessage, ProverState, VerifierMessage, VerifierState, @@ -24,28 +20,70 @@ use crate::{ verify, }; -/// Output from the base case protocol (shared by prover and verifier). #[must_use] pub struct Opening { pub evaluation_points: Vec, pub linear_form_evaluation: F, } +/// Standard / ZeroKnowledge selector for basecase. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum BasecaseMode { + Standard, + ZeroKnowledge, +} + +#[must_use] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { pub commit: irs_commit::Config>, pub sumcheck: sumcheck::Config, - - /// Whether to mask the vectors, which adds HVZK. - pub masked: bool, + pub mode: BasecaseMode, + pub pow: proof_of_work::Config, + /// γ-combination analytic floor recorded by `params::basecase::solve` + /// (ZK only). `None` for Standard mode or ad-hoc construction. + pub recorded_analytic: Option, } impl Config { + /// Standard basecase has no γ challenge — PoW must be `none()`. ZK + /// basecase has a γ-combination slot (Lemma 7.4) and may or may not need + /// PoW depending on whether the analytic floor already clears the target + /// (under unique decoding it often does). + pub fn new( + commit: irs_commit::Config>, + sumcheck: sumcheck::Config, + mode: BasecaseMode, + pow: proof_of_work::Config, + ) -> Self { + let has_pow = pow != proof_of_work::Config::none(); + debug_assert!( + !matches!(mode, BasecaseMode::Standard) || !has_pow, + "Standard basecase has no γ challenge — pow must be none()", + ); + Self { + commit, + sumcheck, + mode, + pow, + recorded_analytic: None, + } + } + + pub const fn with_recorded_analytic(mut self, analytic: crate::bits::Bits) -> Self { + self.recorded_analytic = Some(analytic); + self + } + pub const fn size(&self) -> usize { self.sumcheck.initial_size } + pub const fn is_zk(&self) -> bool { + matches!(self.mode, BasecaseMode::ZeroKnowledge) + } + pub fn prove( &self, prover_state: &mut ProverState, @@ -76,63 +114,22 @@ impl Config { }; } - // Even more trivial non-zk protocol: send f and r directly. - if !self.masked { - prover_state.prover_messages(&vector); - prover_state.prover_messages(&witness.masks); - let _ = self.commit.open(prover_state, &[witness]); - let point = self - .sumcheck - .prove(prover_state, &mut vector, &mut covector, &mut sum, &[]) - .round_challenges; - assert!(!vector[0].is_zero(), "Proof failed"); - return Opening { - evaluation_points: point, - linear_form_evaluation: covector[0], - }; - } + let blinding_witness = + self.maybe_blind_prove(prover_state, &mut vector, witness, &covector, &mut sum); - // Create masking vector. - let mask = random_vector(prover_state.rng(), vector.len()); + let witnesses: Vec<&irs_commit::Witness> = blinding_witness + .as_ref() + .map_or_else(|| vec![witness], |b| vec![b, witness]); + let _ = self.commit.open(prover_state, &witnesses); - // Commit to the masking vector. - let mask_witness = self.commit.commit(prover_state, &[&mask]); - - // Compute and send linear form of mask (μ' in paper). - let mask_sum = dot(&mask, &covector); - prover_state.prover_message(&mask_sum); - - // RLC the mask with the vector - let mask_rlc = prover_state.verifier_message::(); - assert!(!mask_rlc.is_zero(), "Proof failed"); - let mut masked_vector = scalar_mul_add_new(&mask, mask_rlc, &vector); - prover_state.prover_messages(&masked_vector); - - // Send combined IRS randomness. (r^* in paper) - let masked_masks = scalar_mul_add_new(&mask_witness.masks, mask_rlc, &witness.masks); - prover_state.prover_messages(&masked_masks); - - // Open the commitment and mask simultaneously. - let _ = self.commit.open(prover_state, &[&mask_witness, witness]); - - // Run sumcheck to reduce linear form claim - let mut masked_sum = mask_sum + mask_rlc * sum; let point = self .sumcheck - .prove( - prover_state, - &mut masked_vector, - &mut covector, - &mut masked_sum, - &[], - ) + .prove(prover_state, &mut vector, &mut covector, &mut sum, &[]) .round_challenges; - // If the MLE of `masked_vector` evaluates to zero, the verifier can not proceed. - // Basically the sumcheck equation has degenerated to 0 * l(r) = 0, which provides - // no constraints on l(r) that the verifier can return. - // This event is cryptographically unlikely as `F` is challenge sized. - assert!(!masked_vector[0].is_zero(), "Proof failed"); + // Negligible event over a challenge-sized field; without it the verifier + // cannot derive `l(r) = sum / vector_mle(r)`. + assert!(!vector[0].is_zero(), "Proof failed"); Opening { evaluation_points: point, @@ -140,10 +137,64 @@ impl Config { } } + /// ZK: commits a blinding codeword, runs the RLC, mutates `vector`/`sum` to + /// the combined values, sends them cleartext. Standard: sends `vector` and + /// `witness.masks` cleartext (no ZK). + fn maybe_blind_prove( + &self, + prover_state: &mut ProverState, + vector: &mut Vec, + witness: &irs_commit::Witness, + covector: &[F], + sum: &mut F, + ) -> Option> + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: Codec<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + Standard: Distribution, + { + match self.mode { + BasecaseMode::Standard => { + prover_state.prover_messages(vector); + prover_state.prover_messages(&witness.masks); + None + } + BasecaseMode::ZeroKnowledge => { + let blinding_vector = random_vector(prover_state.rng(), vector.len()); + let blinding_witness = self.commit.commit(prover_state, &[&blinding_vector]); + let blinding_inner_product = dot(&blinding_vector, covector); + prover_state.prover_message(&blinding_inner_product); + + // Grind the Theorem 7.1 γ-combination gap before γ is sampled. + self.pow.prove(prover_state); + + let combination_randomness = prover_state.verifier_message::(); + assert!(!combination_randomness.is_zero(), "Proof failed"); + + *vector = scalar_mul_add_new(&blinding_vector, combination_randomness, vector); + prover_state.prover_messages(vector); + + let combined_irs_randomness = scalar_mul_add_new( + &blinding_witness.masks, + combination_randomness, + &witness.masks, + ); + prover_state.prover_messages(&combined_irs_randomness); + + *sum = blinding_inner_product + combination_randomness * *sum; + Some(blinding_witness) + } + } + } + pub fn verify( &self, verifier_state: &mut VerifierState, - commitment: &irs_commit::Commitment, + commitment: &irs_commit::Commitment, mut sum: F, ) -> VerificationResult> where @@ -165,72 +216,71 @@ impl Config { }); } - // Unmasked protocol - if !self.masked { - let vector = verifier_state.prover_messages_vec(self.commit.vector_size)?; - let masks = verifier_state - .prover_messages_vec(self.commit.mask_length * self.commit.num_messages())?; - let evals = self.commit.verify(verifier_state, &[commitment])?; - let point = self - .sumcheck - .verify(verifier_state, &mut sum)? - .round_challenges; - - for (&point, value) in zip_strict(&evals.points, evals.values(&[F::ONE])) { - // We expected `f(x) + x^l · g(x)` where l = deg(f) + 1, f is the message and g the mask. - let expected = univariate_evaluate(&vector, point) - + point.pow([self.commit.message_length() as u64]) - * univariate_evaluate(&masks, point); - verify!(value == expected); - } - let mle = multilinear_extend(&vector, &point); - verify!(!mle.is_zero()); - let linear_mle = sum / mle; - return Ok(Opening { - evaluation_points: point, - linear_form_evaluation: linear_mle, - }); - } + let blind = self.maybe_receive_blind(verifier_state, &mut sum)?; - let mask_commitment = self.commit.receive_commitment(verifier_state)?; - let mask_sum: F = verifier_state.prover_message()?; - let mask_rlc: F = verifier_state.verifier_message(); - verify!(!mask_rlc.is_zero()); - let masked_vector: Vec = verifier_state.prover_messages_vec(self.commit.vector_size)?; - let masked_masks: Vec = verifier_state.prover_messages_vec(self.commit.mask_length)?; + let vector = verifier_state.prover_messages_vec(self.commit.vector_size)?; + let irs_randomness = verifier_state + .prover_messages_vec(self.commit.mask_length() * self.commit.num_messages())?; - // Open the commitment and mask simultaneously. - let evals = self - .commit - .verify(verifier_state, &[&mask_commitment, commitment])?; + let (commitments, weights): (Vec<&irs_commit::Commitment>, Vec) = match &blind { + Some((b, gamma)) => (vec![b, commitment], vec![F::ONE, *gamma]), + None => (vec![commitment], vec![F::ONE]), + }; + let evals = self.commit.verify(verifier_state, &commitments)?; - // Spot check evaluations. - for (&point, value) in zip_strict(&evals.points, evals.values(&[F::ONE, mask_rlc])) { - // We expected `f(x) + x^l · g(x)` where l = deg(f) + 1, f is the message and g the mask. - let expected = univariate_evaluate(&masked_vector, point) + // Spot-check: Enc_C(vector, irs_randomness)(x) = Σ weights · opened_row(x). + for (&point, value) in zip_strict(&evals.points, evals.values(&weights)) { + let expected = univariate_evaluate(&vector, point) + point.pow([self.commit.message_length() as u64]) - * univariate_evaluate(&masked_masks, point); + * univariate_evaluate(&irs_randomness, point); verify!(value == expected); } - // Sumcheck on masked inner product - let mut masked_sum = mask_sum + mask_rlc * sum; let point = self .sumcheck - .verify(verifier_state, &mut masked_sum)? + .verify(verifier_state, &mut sum)? .round_challenges; - // Compute implied MLE of the linear form - // f*(r) · l(r) = sum => l(r) = sum / f*(r) - let masked_mle = multilinear_extend(&masked_vector, &point); - verify!(!masked_mle.is_zero()); - let linear_mle = masked_sum / masked_mle; + // l(r) = sum / vector_mle(r), where l is the implicit linear form. + let mle = multilinear_extend(&vector, &point); + verify!(!mle.is_zero()); + let linear_mle = sum / mle; Ok(Opening { evaluation_points: point, linear_form_evaluation: linear_mle, }) } + + /// ZK: reads the blinding commitment + μ' + γ, mutates `sum` to the + /// combined value, returns `(commitment, γ)`. Standard: no-op. + fn maybe_receive_blind( + &self, + verifier_state: &mut VerifierState, + sum: &mut F, + ) -> VerificationResult> + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + match self.mode { + BasecaseMode::Standard => Ok(None), + BasecaseMode::ZeroKnowledge => { + let blinding_commitment = self.commit.receive_commitment(verifier_state)?; + let blinding_inner_product: F = verifier_state.prover_message()?; + // Grind the Theorem 7.1 γ-combination gap before γ is sampled. + self.pow.verify(verifier_state)?; + let combination_randomness: F = verifier_state.verifier_message(); + verify!(!combination_randomness.is_zero()); + *sum = blinding_inner_product + combination_randomness * *sum; + Ok(Some((blinding_commitment, combination_randomness))) + } + } + } } #[cfg(test)] @@ -241,27 +291,27 @@ mod tests { use tracing::instrument; use super::*; - use crate::{ - algebra::fields, protocols::proof_of_work, transcript::DomainSeparator, type_info::Type, - }; + use crate::{algebra::fields, protocols::proof_of_work, transcript::DomainSeparator}; impl Config { pub fn arbitrary(size: usize, mask_length: usize) -> impl Strategy { let commit = irs_commit::Config::arbitrary(Identity::::new(), 1, size, mask_length, 1); - (commit, bool::weighted(0.8)).prop_map(move |(commit, masked)| Self { - commit: irs_commit::Config { - out_domain_samples: 0, - ..commit - }, - sumcheck: sumcheck::Config { - field: Type::new(), - initial_size: size, - round_pow: proof_of_work::Config::none(), - num_rounds: size.next_power_of_two().trailing_zeros() as usize, - mask_length: 0, + (commit, bool::weighted(0.8)).prop_map(move |(commit, is_zk)| Self { + commit, + sumcheck: sumcheck::Config::new( + size, + proof_of_work::Config::none(), + size.next_power_of_two().trailing_zeros() as usize, + sumcheck::SumcheckMode::Standard, + ), + recorded_analytic: None, + mode: if is_zk { + BasecaseMode::ZeroKnowledge + } else { + BasecaseMode::Standard }, - masked, + pow: proof_of_work::Config::none(), }) } } @@ -272,7 +322,6 @@ mod tests { F: Field + Codec, Standard: Distribution, { - // Pseudo-random Instance let instance = U64(seed); let ds = DomainSeparator::protocol(config) .session(&format!("Test at {}:{}", file!(), line!())) @@ -282,7 +331,6 @@ mod tests { let covector = random_vector(&mut rng, config.size()); let sum = dot(&vector, &covector); - // Prover let mut prover_state = ProverState::new_std(&ds); let witness = config.commit.commit(&mut prover_state, &[&vector]); let prover_result = config.prove( @@ -298,7 +346,6 @@ mod tests { ); let proof = prover_state.proof(); - // Verifier let mut verifier_state = VerifierState::new_std(&ds, &proof); let commitment = config .commit diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 62bbb602..84cd668f 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -3,7 +3,7 @@ //! Reduces a proximity claim about oracle f (source code C) to a proximity //! claim about oracle g (target code C'). Supports optional ZK via mask oracle. -use std::fmt; +use std::{fmt, num::NonZeroUsize}; use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; @@ -21,22 +21,36 @@ use crate::{ protocols::{ geometric_challenge::geometric_challenge, irs_commit::{Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness}, + proof_of_work, }, transcript::{ - Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, - VerifierMessage, VerifierState, + codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, + VerificationResult, VerifierMessage, VerifierState, }, verify, }; +/// Standard / ZeroKnowledge selector for code-switch. +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub enum CodeSwitchMode { + Standard, + ZeroKnowledge { message_mask_length: NonZeroUsize }, +} + /// Code-switching IOR config with optional ZK. +#[must_use] #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { pub source: IrsConfig, pub target: IrsConfig>, - pub message_mask_length: usize, // l_zk + pub mode: CodeSwitchMode, pub out_domain_samples: usize, + pub pow: proof_of_work::Config, + /// Analytic-error floor recorded by `params::code_switch::solve`. `None` + /// for configs built via ad-hoc paths. Drift checks compare against a + /// recompute. + pub recorded_analytic: Option, } /// Prover output from the code-switch. @@ -48,27 +62,16 @@ pub struct Witness { } /// Verifier output from the code-switch. -pub type Commitment = IrsCommitment; - -/// Mask input for the code-switch prover. -// TODO : This may be removed after parameter selection PR -pub enum MaskInput<'a, F> { - Disabled, - Enabled(&'a [F]), -} +pub type Commitment = IrsCommitment; impl Config { /// Create a code-switch config. - /// - /// The orchestrator is responsible for: - /// - Setting `target_config.mask_length` for ZK mode before passing it in. - /// - Computing `out_domain_samples` from the security budget. - /// - Setting `message_mask_length` = mask oracle message length (0 for non-ZK). pub fn new( source_config: IrsConfig, target_config: IrsConfig>, out_domain_samples: usize, - message_mask_length: usize, + mode: CodeSwitchMode, + pow: proof_of_work::Config, ) -> Self { assert_eq!( source_config.num_vectors, 1, @@ -78,6 +81,12 @@ impl Config { target_config.num_vectors, 1, "code-switch requires a single target vector" ); + // Construction 9.7 needs at least one OOD challenge; unique-decoding + // Standard mode (`t_ood = 0`) is incompatible with code-switch. + assert!( + out_domain_samples > 0, + "code-switch requires t_ood ≥ 1 (Construction 9.7)", + ); // Target encodes one polynomial of length ℓ = source.message_length() // under C' = D^{ι_t}. The IRS splits the input of length ℓ into ι_t // parallel slices of length ℓ/ι_t, each encoded under D. @@ -90,43 +99,73 @@ impl Config { target_config.interleaving_depth.is_power_of_two(), "target.interleaving_depth must be a power of 2" ); - // Theorem 9.6: ℓ_zk ≥ r (mask oracle must cover source randomness). - if message_mask_length > 0 { + assert!( + source_config.interleaving_depth.is_power_of_two(), + "source.interleaving_depth must be a power of 2" + ); + if let CodeSwitchMode::ZeroKnowledge { + message_mask_length, + } = &mode + { + let l_zk = message_mask_length.get(); + // Theorem 9.6: ℓ_zk ≥ r (mask oracle must cover source randomness). assert!( - message_mask_length >= source_config.mask_length, - "message_mask_length ({message_mask_length}) must be >= source randomness length ({})", - source_config.mask_length, + l_zk >= source_config.mask_length(), + "message_mask_length ({l_zk}) must be >= source randomness length ({})", + source_config.mask_length(), ); assert!( - message_mask_length - source_config.mask_length >= out_domain_samples, - "the sampled randomness (s) length must be covering all the out of domain sample requests" + l_zk - source_config.mask_length() >= out_domain_samples, + "sampled randomness (s) length must cover all out-of-domain sample requests" ); + // t' = target in-domain queries + OOD queries (Construction 9.7 step 4). + // Definition 3.16: a t'-query ZK encoding requires r' ≥ t'; here + // r' = target.mask_length. assert!( - target_config.mask_length - >= target_config.in_domain_samples + target_config.out_domain_samples, - "target encoder violates: t' > r', number of queries should be covered by random mask" + target_config.mask_length() >= target_config.in_domain_samples + out_domain_samples, + "target encoder violates t' ≤ r': queries must be covered by target mask" + ); + } else { + assert_eq!( + source_config.mask_length(), + 0, + "source with IRS randomness requires ZK mode", ); } - assert!( - source_config.mask_length == 0 || message_mask_length > 0, - "source with mask_length > 0 (IRS randomness) requires ZK mode (message_mask_length > 0)" - ); - assert!( - source_config.interleaving_depth.is_power_of_two(), - "source.interleaving_depth must be a power of 2" - ); Self { source: source_config, target: target_config, - message_mask_length, + mode, out_domain_samples, + pow, + recorded_analytic: None, } } + pub const fn with_recorded_analytic(mut self, analytic: crate::bits::Bits) -> Self { + self.recorded_analytic = Some(analytic); + self + } + + /// Mask oracle length `ℓ_zk`. Returns 0 in Standard mode. + pub const fn message_mask_length(&self) -> usize { + match &self.mode { + CodeSwitchMode::Standard => 0, + CodeSwitchMode::ZeroKnowledge { + message_mask_length, + } => message_mask_length.get(), + } + } + + /// `true` iff the protocol is configured for ZK. + pub const fn is_zk(&self) -> bool { + matches!(&self.mode, CodeSwitchMode::ZeroKnowledge { .. }) + } + /// Length of the covector for this code-switch. pub fn covector_length(&self) -> usize { - self.source.message_length() + self.message_mask_length + self.source.message_length() + self.message_mask_length() } /// Prove the code-switch. @@ -146,18 +185,18 @@ impl Config { /// `message` is `Fold(f, γ)`, the post-sumcheck polynomial of length /// `source.message_length()`. /// - /// `mask_input` is `(r || s)` from the orchestrator's shared mask tree - /// (see Construction 9.7 Step 1, p.55). Must be `None` when - /// `message_mask_length == 0`. + /// `mask` is `(r || s)` from the orchestrator's shared mask tree + /// (see Construction 9.7 Step 1, p.55). Length must equal + /// `self.message_mask_length()` — pass an empty slice in Standard mode. #[cfg_attr(feature = "tracing", instrument(skip_all))] pub fn prove( &self, prover_state: &mut ProverState, message: Vec, - witness: &IrsWitness, + witness: &IrsWitness, covector: &mut [M::Target], folding_randomness: &[M::Target], - mask_input: &MaskInput<'_, M::Target>, + mask: &[M::Target], ) -> Witness where H: DuplexSpongeInterface, @@ -165,10 +204,13 @@ impl Config { Standard: Distribution, M::Target: Codec<[H::U]>, u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { assert_eq!(message.len(), self.source.message_length()); assert_eq!(covector.len(), self.covector_length()); + assert_eq!(mask.len(), self.message_mask_length()); assert_eq!( 1 << folding_randomness.len(), self.source.interleaving_depth, @@ -176,41 +218,16 @@ impl Config { folding_randomness.len(), self.source.interleaving_depth, ); - let mask_msg: Option<&[M::Target]> = match &mask_input { - MaskInput::Disabled => { - assert_eq!( - self.message_mask_length, 0, - "MaskInput::Disabled requires message_mask_length == 0" - ); - None - } - MaskInput::Enabled(mask) => { - assert_eq!( - mask.len(), - self.message_mask_length, - "mask_msg length must equal message_mask_length" - ); - Some(mask) - } - }; // Step 1: g := Enc_{C'}(f, r') — Construction 9.7 Step 1, p.55 let target_witness = self.target.commit(prover_state, &[&message]); + // Grind Lemma 9.9 OOD gap before α is sampled. + self.pow.prove(prover_state); + // Step 2-3: OOD challenge + answers — Construction 9.7 Steps 2-3, p.55 - // y := ze_ood(ρ) · [f; r; s] = f(α) + α^ℓ · (r,s)(α) let ood_points: Vec = prover_state.verifier_message_vec(self.out_domain_samples); - let msg_len = message.len(); - for &point in &ood_points { - let f_eval = univariate_evaluate(&message, point); - if let Some(mask) = mask_msg { - let mask_eval = univariate_evaluate(mask, point); - let shift = point.pow([msg_len as u64]); - prover_state.prover_message(&(f_eval + shift * mask_eval)); - } else { - prover_state.prover_message(&f_eval); - } - } + self.maybe_send_ood_answers(prover_state, &message, mask, &ood_points); // Step 4: in-domain queries — Construction 9.7 Step 4, p.55 let source_evaluations = self.source.open(prover_state, &[witness]); @@ -226,24 +243,13 @@ impl Config { // Covector update — sl' from Completeness proof (p.55-56) let eval_points = lift(self.source.embedding(), &source_evaluations.points); scalar_mul(covector, original_sl_coeff); - if self.message_mask_length == 0 { - // Non-ZK: single accumulate over all points - let all_points: Vec<_> = ood_points.iter().chain(&eval_points).copied().collect(); - let pows: Vec<_> = ood_rlc_coeffs - .iter() - .chain(in_domain_rlc_coeffs) - .copied() - .collect(); - geometric_accumulate(covector, pows, &all_points); - } else { - // ZK: OOD contributes to full [f; r; s], in-domain only to [f; r] - geometric_accumulate(covector, ood_rlc_coeffs.to_vec(), &ood_points); - geometric_accumulate( - &mut covector[..self.source.masked_message_length()], - in_domain_rlc_coeffs.to_vec(), - &eval_points, - ); - } + self.update_covector( + covector, + ood_rlc_coeffs, + &ood_points, + in_domain_rlc_coeffs, + &eval_points, + ); Witness { message, @@ -251,6 +257,67 @@ impl Config { } } + /// Send OOD answers `y_i = f(α_i) [+ α_i^ℓ · (r ‖ s)(α_i)]`. + /// In Standard mode the bracketed term is omitted. + fn maybe_send_ood_answers( + &self, + prover_state: &mut ProverState, + message: &[M::Target], + mask: &[M::Target], + ood_points: &[M::Target], + ) where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + M::Target: Codec<[H::U]>, + { + let msg_len = message.len(); + for &point in ood_points { + let f_eval = univariate_evaluate(message, point); + let answer = match &self.mode { + CodeSwitchMode::Standard => f_eval, + CodeSwitchMode::ZeroKnowledge { .. } => { + let mask_eval = univariate_evaluate(mask, point); + let shift = point.pow([msg_len as u64]); + f_eval + shift * mask_eval + } + }; + prover_state.prover_message(&answer); + } + } + + /// Accumulate OOD and in-domain weights into the covector. + /// Standard mode treats all points uniformly; ZK mode applies OOD over + /// the full `[f; r; s]` and in-domain over the `[f; r]` prefix only. + fn update_covector( + &self, + covector: &mut [M::Target], + ood_rlc_coeffs: &[M::Target], + ood_points: &[M::Target], + in_domain_rlc_coeffs: &[M::Target], + in_domain_points: &[M::Target], + ) { + match &self.mode { + CodeSwitchMode::Standard => { + let all_points: Vec<_> = + ood_points.iter().chain(in_domain_points).copied().collect(); + let pows: Vec<_> = ood_rlc_coeffs + .iter() + .chain(in_domain_rlc_coeffs) + .copied() + .collect(); + geometric_accumulate(covector, pows, &all_points); + } + CodeSwitchMode::ZeroKnowledge { .. } => { + geometric_accumulate(covector, ood_rlc_coeffs.to_vec(), ood_points); + geometric_accumulate( + &mut covector[..self.source.masked_message_length()], + in_domain_rlc_coeffs.to_vec(), + in_domain_points, + ); + } + } + } + /// Verify the code-switch. /// /// `folding_randomness` is the **sumcheck folding randomness `γ`** the @@ -283,13 +350,15 @@ impl Config { verifier_state: &mut VerifierState, sum: &mut M::Target, folding_randomness: &[M::Target], - commitment: &IrsCommitment, - ) -> VerificationResult> + commitment: &IrsCommitment, + ) -> VerificationResult where H: DuplexSpongeInterface, Standard: Distribution, M::Target: Codec<[H::U]>, u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { verify!(1 << folding_randomness.len() == self.source.interleaving_depth); @@ -300,6 +369,9 @@ impl Config { // Mask oracle is committed in the shared mask tree by the orchestrator. let target_commitment = self.target.receive_commitment(verifier_state)?; + // Grind Lemma 9.9 OOD gap before α is sampled. + self.pow.verify(verifier_state)?; + // Step 2-3: OOD — Construction 9.7 Steps 2-3, p.55 // In ZK mode, ood_answers = f(α) + α^ℓ · (r,s)(α) where (r,s) is // the mask oracle message committed in the shared tree. @@ -339,7 +411,7 @@ impl fmt::Display for Config { self.source, self.target, self.out_domain_samples, - self.message_mask_length != 0, + self.is_zk(), ) } } @@ -372,27 +444,25 @@ mod tests { select(valid_sizes), 0_usize..=3, // src_mask_len (source IRS randomness, post-fold) bool::ANY, // zk - 0_usize..=5, // ood (= code-switch t_ood) + 1_usize..=5, // ood (= code-switch t_ood; ≥ 1 per Construction 9.7) 0_usize..=5, // fresh_s_len (≥ ood for assumption (c)) select(vec![1_usize, 2, 4]), // ι_s (source interleaving) 0_usize..=10, // target.in_domain_samples (t'_in) - 0_usize..=10, // target.out_domain_samples (t'_out) ); scalars.prop_flat_map( - move |(size, src_mask_len, zk, ood, fresh_s_len, iota_s, t_in, t_out)| { + move |(size, src_mask_len, zk, ood, fresh_s_len, iota_s, t_in)| { // Bound 3 assumption (c): ℓ_zk - r ≥ t_ood ⇒ fresh_s_len ≥ ood. + // Also enforce `ℓ_zk = r + fresh_s_len > 0` so NonZeroUsize + // construction below is total in ZK mode. let fresh_s_len = if zk { - fresh_s_len.max(ood) + let min_fresh = usize::from(src_mask_len == 0); + fresh_s_len.max(ood).max(min_fresh) } else { fresh_s_len }; - // Bound 4 assumption (a): target.mask_length ≥ t' = t_in + t_out. - let target_mask = if zk { t_in + t_out } else { 0 }; - // ZK with source.mask_length = 0 is valid: the assert - // `source.mask_length == 0 || message_mask_length > 0` - // is trivially satisfied. Allows testing the corner - // where the mask oracle has only fresh randomness. + // Bound 4 assumption (a): target.mask_length ≥ t' = t_in + ood. + let target_mask = if zk { t_in + ood } else { 0 }; let source_mask = if zk { src_mask_len } else { 0 }; IrsConfig::arbitrary(embedding.clone(), 1, size, source_mask, iota_s) @@ -416,19 +486,31 @@ mod tests { ); let source = source.clone(); target.prop_map(move |mut target| { - // IrsConfig::arbitrary samples query counts in - // [0,10] independently of mask_length; pin them - // to the values target_mask was sized for so + // IrsConfig::arbitrary samples in_domain_samples + // in [0,10] independently of mask_length; pin it + // to the value target_mask was sized for so // assumption (a) holds. if zk { target.in_domain_samples = t_in; - target.out_domain_samples = t_out; } // r = post-fold randomness length (ι_s parallel // masks fold to a single length-mask_length chunk). - let r = source.mask_length; - let message_mask_length = if zk { r + fresh_s_len } else { 0 }; - Self::new(source.clone(), target, ood, message_mask_length) + let r = source.mask_length(); + let mode = if zk { + CodeSwitchMode::ZeroKnowledge { + message_mask_length: NonZeroUsize::new(r + fresh_s_len) + .expect("ZK ⇒ r + fresh_s_len > 0"), + } + } else { + CodeSwitchMode::Standard + }; + Self::new( + source.clone(), + target, + ood, + mode, + proof_of_work::Config::none(), + ) }) }) }) @@ -479,29 +561,21 @@ mod tests { where Standard: Distribution, { - if config.message_mask_length == 0 { + if !config.is_zk() { return Vec::new(); } // Lift ι parallel masks (total length source.mask_length × ι) and fold // chunks of length source.mask_length down to a single chunk. let raw = lift(config.source.embedding(), &source_witness.masks); - let mut mask = fold_chunks(&raw, config.source.mask_length, folding_randomness); + let mut mask = fold_chunks(&raw, config.source.mask_length(), folding_randomness); // Append fresh padding s of length message_mask_length - source.mask_length. mask.extend(random_vector::( rng, - config.message_mask_length - mask.len(), + config.message_mask_length() - mask.len(), )); mask } - fn mask_input(mask_msg: &[F]) -> MaskInput<'_, F> { - if mask_msg.is_empty() { - MaskInput::Disabled - } else { - MaskInput::Enabled(mask_msg) - } - } - fn test_config>(seed: u64, config: &Config>) where Standard: Distribution, @@ -536,7 +610,7 @@ mod tests { &source_witness, &mut covector, &folding_randomness, - &mask_input(&mask_msg), + &mask_msg, ); let proof = prover_state.proof(); @@ -602,7 +676,7 @@ mod tests { &source_witness, &mut covector, &folding_randomness, - &mask_input(&mask_msg), + &mask_msg, ); let proof = prover_state.proof(); @@ -659,7 +733,7 @@ mod tests { &source_witness, &mut covector, &folding_randomness, - &MaskInput::Disabled, + &[], ); let proof = prover_state.proof(); @@ -743,14 +817,10 @@ mod tests { #[test] fn test_tampered_ood() { crate::tests::init(); - let configs = Config::arbitrary(Identity::::new()).prop_filter( - "non-ZK with ood > 0", - |config| { - config.message_mask_length == 0 - && config.source.mask_length == 0 - && config.out_domain_samples > 0 - }, - ); + let configs = Config::arbitrary(Identity::::new()) + .prop_filter("non-ZK", |config| { + !config.is_zk() && config.source.mask_length() == 0 + }); proptest!(|(seed: u64, config in configs)| { test_tampered_ood_config(seed, &config); }); diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index b5716f3a..2c137fde 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -7,25 +7,14 @@ //! using an NTT friendly Reed-Solomon code to produce a `num_vectors * interleaving_depth` //! by `codeword_size` matrix. This matrix is committed using the [`matrix_commit`] protocol. //! -//! After committing the encoded matrix, the protocol generates a random Reed-Solomon code of -//! length `out_domain_samples` over an extension field `G` of `F` and encodes the original -//! matrix using this code to produce a `num_vectors` by `out_domain_samples` matrix over `G`. -//! Together, these two encoded matrices form a commitment to the original matrix. -//! //! On opening the commitment, the protocol randomly selects `in_domain_samples` rows and opens -//! it using the [`matrix_commit`] protocol. Sampling is done with replacement, so may produce -//! fewer than `in_domain_samples` distinct rows. This produces `in_domain_samples` evaluation -//! points in `F` and `in_domain_samples` by `num_vectors * interleaving_depth`. +//! them using the [`matrix_commit`] protocol. Sampling is done with replacement, so may produce +//! fewer than `in_domain_samples` distinct rows. //! -use std::{ - f64::{self, consts::LOG2_10}, - fmt, - ops::Neg, -}; +use std::{f64, fmt, num::NonZeroUsize}; use ark_ff::{AdditiveGroup, Field}; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; -use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; #[cfg(feature = "tracing")] use tracing::instrument; @@ -37,17 +26,37 @@ use crate::{ }, engines::EngineId, hash::Hash, - protocols::{challenge_indices::challenge_indices, matrix_commit}, + protocols::{ + challenge_indices::challenge_indices, + matrix_commit, + params::{bounds::ood_per_sample_log2, regime::DecodingRegimeParams, spec::DecodingRegime}, + }, transcript::{ Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, VerifierMessage, VerifierState, }, type_info::Typed, utils::{chunks_exact_or_empty, zip_strict}, - verify, }; +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +pub enum IrsMode { + Standard, + ZeroKnowledge { mask_length: NonZeroUsize }, +} + +impl IrsMode { + /// Per-polynomial IRS randomness length. Returns 0 in Standard mode. + pub const fn mask_length(&self) -> usize { + match self { + Self::Standard => 0, + Self::ZeroKnowledge { mask_length } => mask_length.get(), + } + } +} + /// Commit to vectors over an fft-friendly field F +#[must_use] #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { @@ -60,9 +69,6 @@ pub struct Config { /// The number of coefficients in each vector. pub vector_size: usize, - /// The number of masking values to add per codeword. - pub mask_length: usize, - /// The number of Reed-Solomon evaluation points. pub codeword_length: usize, @@ -72,41 +78,35 @@ pub struct Config { /// The matrix commitment configuration. pub matrix_commit: matrix_commit::Config, - /// Slack to the Jonhnson bound in list decoding. - /// Zero indicates unique decoding. - pub johnson_slack: OrderedFloat, + /// Materialized Reed–Solomon decoding regime (Unique / Johnson w/ slack). + pub regime: DecodingRegimeParams, /// The number of in-domain samples. pub in_domain_samples: usize, - /// The number of out-of-domain samples. - pub out_domain_samples: usize, - /// Whether to sort and deduplicate the in-domain samples. /// /// Deduplication can slightly reduce proof size and prover/verifier /// complexity, but it makes transcript pattern and control flow /// non-deterministic. pub deduplicate_in_domain: bool, + + /// Standard / ZeroKnowledge. + pub mode: IrsMode, } #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] #[must_use] -pub struct Witness -where - G: Field, -{ +pub struct Witness { pub masks: Vec, pub matrix: Vec, pub matrix_witness: matrix_commit::Witness, - pub out_of_domain: Evaluations, } #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] #[must_use] -pub struct Commitment { - matrix_commitment: matrix_commit::Commitment, - out_of_domain: Evaluations, +pub struct Commitment { + pub matrix_commitment: matrix_commit::Commitment, } /// Interleaved Reed-Solomon code. @@ -122,65 +122,39 @@ pub struct Evaluations { } impl Config { + #[allow(clippy::too_many_arguments)] pub fn new( security_target: f64, - unique_decoding: bool, + decoding_regime: DecodingRegime, hash_id: EngineId, num_vectors: usize, vector_size: usize, interleaving_depth: usize, rate: f64, + mode: IrsMode, ) -> Self where M: Default, { assert!(vector_size.is_multiple_of(interleaving_depth)); assert!(rate > 0. && rate <= 1.); - let message_length = vector_size / interleaving_depth; + let masked_message_length = vector_size / interleaving_depth + mode.mask_length(); + // `interleaved_encode` requires `codeword_length` to divide the NTT root + // order. `masked_message_length` is allowed to be arbitrary (the coset + // NTT zero-extends internally), so we only round the codeword side here. #[allow(clippy::cast_sign_loss)] - let codeword_length = (message_length as f64 / rate).ceil() as usize; - let rate = message_length as f64 / codeword_length as f64; - - // Pick in- and out-of-domain samples. - // η = slack to Johnson bound. We pick η = √ρ / 20. - // TODO: Optimize picking η. - let johnson_slack = if unique_decoding { - 0.0 - } else { - rate.sqrt() / 20. - }; - let out_domain_samples = { - let list_size = 1. / (2. * johnson_slack * rate.sqrt()); - num_ood_samples( - unique_decoding, - security_target, - M::Target::field_size_bits(), - list_size, - vector_size, - ) - }; - #[allow(clippy::cast_sign_loss)] - let in_domain_samples = { - // Query error is (1 - δ)^q, so we compute 1 - δ - let per_sample = if unique_decoding { - // Unique decoding bound: δ = (1 - ρ) / 2 - f64::midpoint(1., rate) - } else { - // Johnson bound: δ = 1 - √ρ - η - rate.sqrt() + johnson_slack - }; - (security_target / (-per_sample.log2())).ceil() as usize - }; - debug_assert_eq!( - in_domain_samples, - num_in_domain_queries(unique_decoding, security_target, rate) - ); + let raw_codeword_length = (masked_message_length as f64 / rate).ceil() as usize; + let codeword_length = ntt::next_order::(raw_codeword_length) + .expect("codeword length exceeds NTT engine support"); + let rate = masked_message_length as f64 / codeword_length as f64; + + let regime = DecodingRegimeParams::from_policy(decoding_regime, rate); + let in_domain_samples = num_in_domain_queries(decoding_regime, security_target, rate).get(); Self { embedding: Typed::::default(), num_vectors, vector_size, - mask_length: 0, codeword_length, interleaving_depth, matrix_commit: matrix_commit::Config::with_hash( @@ -188,10 +162,10 @@ impl Config { codeword_length, interleaving_depth * num_vectors, ), - johnson_slack: OrderedFloat(johnson_slack), + regime, in_domain_samples, - out_domain_samples, deduplicate_in_domain: false, + mode, } } @@ -216,9 +190,14 @@ impl Config { self.vector_size / self.interleaving_depth } + /// Per-polynomial IRS randomness length. Returns 0 in Standard mode. + pub const fn mask_length(&self) -> usize { + self.mode.mask_length() + } + /// Message length including mask coefficients. pub fn masked_message_length(&self) -> usize { - self.message_length() + self.mask_length + self.message_length() + self.mask_length() } pub fn evaluation_points(&self, indices: &[usize]) -> Vec { @@ -233,58 +212,34 @@ impl Config { self.masked_message_length() as f64 / self.codeword_length as f64 } - pub fn unique_decoding(&self) -> bool { - self.out_domain_samples == 0 && self.johnson_slack == 0.0 + pub const fn unique_decoding(&self) -> bool { + self.regime.is_unique() } - /// Compute a list size bound. - pub fn list_size(&self) -> f64 { - if self.unique_decoding() { - 1. - } else { - // This is the Johnson bound $1 / (2 η √ρ)$. - 1. / (2. * self.johnson_slack.into_inner() * self.rate().sqrt()) - } + fn log_inv_rate(&self) -> f64 { + -self.rate().log2() } - /// Round-by-round soundness of the out-of-domain samples in bits. - pub fn rbr_ood_sample(&self) -> f64 { - let list_size = self.list_size(); - let log_field_size = M::Target::field_size_bits(); - // See [STIR] lemma 4.5. - let l_choose_2 = list_size * (list_size - 1.) / 2.; - let log_per_sample = ((self.vector_size - 1) as f64).log2() - log_field_size; - -l_choose_2.log2() - self.out_domain_samples as f64 * log_per_sample + /// Compute a list size bound. + pub fn list_size(&self) -> f64 { + let log_degree = (self.masked_message_length() as f64).log2(); + self.regime.list_size(log_degree, self.log_inv_rate()) } /// Round-by-round soundness of the in-domain queries in bits. pub fn rbr_queries(&self) -> f64 { - let per_sample = if self.unique_decoding() { - // 1 - δ = 1 - (1 + ρ) / 2 - f64::midpoint(1., self.rate()) - } else { - // 1 - δ = sqrt(ρ) + η - self.rate().sqrt() + self.johnson_slack.into_inner() - }; - self.in_domain_samples as f64 * per_sample.log2().neg() + // Query error is (1 - δ)^q in bits = -q · log2(1 - δ). + -(self.in_domain_samples as f64) * self.regime.one_minus_distance_log2(self.log_inv_rate()) } - // Compute the proximity gaps term of the fold + /// Round-by-round soundness of the proximity-gaps fold in bits. + /// See WHIR Theorem 4.8. pub fn rbr_soundness_fold_prox_gaps(&self) -> f64 { - let log_field_size = M::Target::field_size_bits(); - let log_inv_rate = self.rate().log2().neg(); - let log_k = (self.masked_message_length() as f64).log2(); - // See WHIR Theorem 4.8 - // Recall, at each round we are only folding by two at a time - let error = if self.unique_decoding() { - log_k + log_inv_rate - } else { - let log_eta = self.johnson_slack.into_inner().log2(); - // Make sure η hits the min bound. - assert!(log_eta >= -(0.5 * log_inv_rate + LOG2_10 + 1.0) - 1e-6); - 7. * LOG2_10 + 3.5 * log_inv_rate + 2. * log_k - }; - log_field_size - error + -self.regime.eps_mca_log2( + self.log_inv_rate(), + self.masked_message_length(), + M::Target::field_size_bits(), + ) } /// Commit to one or more vectors. @@ -293,7 +248,7 @@ impl Config { &self, prover_state: &mut ProverState, vectors: &[&[M::Source]], - ) -> Witness + ) -> Witness where Standard: Distribution, H: DuplexSpongeInterface, @@ -311,7 +266,7 @@ impl Config { assert!(vectors.iter().all(|p| p.len() == self.vector_size)); // Generate random mask - let masks = random_vector(prover_state.rng(), self.mask_length * self.num_messages()); + let masks = random_vector(prover_state.rng(), self.mask_length() * self.num_messages()); // Interleaved RS Encode the vectors let messages = vectors @@ -323,29 +278,10 @@ impl Config { // Commit to the matrix let matrix_witness = self.matrix_commit.commit(prover_state, &matrix); - // Handle out-of-domain points and values - // TODO : Remove this logic after main whir protocol is updated - // as this is not required in the new construction. This will be - // removed in next PR (Parameter Selection) - let oods_points: Vec = - prover_state.verifier_message_vec(self.out_domain_samples); - let mut oods_matrix = Vec::with_capacity(self.out_domain_samples * self.num_vectors); - for &point in &oods_points { - for &vector in vectors { - let value = mixed_univariate_evaluate(&*self.embedding, vector, point); - prover_state.prover_message(&value); - oods_matrix.push(value); - } - } - Witness { masks, matrix, matrix_witness, - out_of_domain: Evaluations { - points: oods_points, - matrix: oods_matrix, - }, } } @@ -354,24 +290,67 @@ impl Config { pub fn receive_commitment( &self, verifier_state: &mut VerifierState, - ) -> VerificationResult> + ) -> VerificationResult where H: DuplexSpongeInterface, Hash: ProverMessage<[H::U]>, M::Target: Codec<[H::U]>, { let matrix_commitment = self.matrix_commit.receive_commitment(verifier_state)?; - let oods_points: Vec = - verifier_state.verifier_message_vec(self.out_domain_samples); - let oods_matrix = - verifier_state.prover_messages_vec(self.out_domain_samples * self.num_vectors)?; - Ok(Commitment { - matrix_commitment, - out_of_domain: Evaluations { - points: oods_points, - matrix: oods_matrix, - }, - }) + Ok(Commitment { matrix_commitment }) + } + + /// Commit to vectors and run the legacy WHIR OOD step in one call. + /// + /// Layered helper bundling `commit` + the OOD message exchange (sample + /// `out_domain_samples` random points, send each vector's evaluation at + /// each point). Used by the legacy WHIR protocol while the OOD step + /// is still part of the per-commit protocol shape; the new construction + /// (Construction 9.7) handles OOD at the code-switch level instead. + #[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))] + pub fn commit_with_ood( + &self, + prover_state: &mut ProverState, + vectors: &[&[M::Source]], + out_domain_samples: usize, + ) -> (Witness, Evaluations) + where + Standard: Distribution, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + M::Target: Codec<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + let witness = self.commit(prover_state, vectors); + let points: Vec = prover_state.verifier_message_vec(out_domain_samples); + let mut matrix = Vec::with_capacity(out_domain_samples * vectors.len()); + for &point in &points { + for &vector in vectors { + let value = mixed_univariate_evaluate(&*self.embedding, vector, point); + prover_state.prover_message(&value); + matrix.push(value); + } + } + (witness, Evaluations { points, matrix }) + } + + /// Receive a commitment and the legacy WHIR OOD evaluations in one call. + /// Verifier mirror of `commit_with_ood`. + #[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))] + pub fn receive_commitment_with_ood( + &self, + verifier_state: &mut VerifierState, + out_domain_samples: usize, + ) -> VerificationResult<(Commitment, Evaluations)> + where + H: DuplexSpongeInterface, + Hash: ProverMessage<[H::U]>, + M::Target: Codec<[H::U]>, + { + let commitment = self.receive_commitment(verifier_state)?; + let points: Vec = verifier_state.verifier_message_vec(out_domain_samples); + let matrix = verifier_state.prover_messages_vec(out_domain_samples * self.num_vectors)?; + Ok((commitment, Evaluations { points, matrix })) } /// Opens the commitment and returns the evaluations of the vectors. @@ -385,7 +364,7 @@ impl Config { pub fn open( &self, prover_state: &mut ProverState, - witnesses: &[&Witness], + witnesses: &[&Witness], ) -> Evaluations where H: DuplexSpongeInterface, @@ -395,11 +374,6 @@ impl Config { { for witness in witnesses { assert_eq!(witness.matrix.len(), self.size()); - assert_eq!(witness.out_of_domain.points.len(), self.out_domain_samples); - assert_eq!( - witness.out_of_domain.matrix.len(), - self.out_domain_samples * self.num_vectors - ); } // Get in-domain openings @@ -438,20 +412,13 @@ impl Config { pub fn verify( &self, verifier_state: &mut VerifierState, - commitments: &[&Commitment], + commitments: &[&Commitment], ) -> VerificationResult> where H: DuplexSpongeInterface, u8: Decoding<[H::U]>, Hash: ProverMessage<[H::U]>, { - for commitment in commitments { - verify!(commitment.out_of_domain.points.len() == self.out_domain_samples); - verify!( - commitment.out_of_domain.matrix.len() == self.num_vectors * self.out_domain_samples - ); - } - // Get in-domain openings let (indices, points) = self.in_domain_challenges(verifier_state); @@ -500,28 +467,6 @@ impl Config { } } -impl Commitment { - /// Returns the out-of-domain evaluations. - pub const fn out_of_domain(&self) -> &Evaluations { - &self.out_of_domain - } - - pub fn num_vectors(&self) -> usize { - self.out_of_domain().num_columns() - } -} - -impl Witness { - /// Returns the out-of-domain evaluations. - pub const fn out_of_domain(&self) -> &Evaluations { - &self.out_of_domain - } - - pub fn num_vectors(&self) -> usize { - self.out_of_domain().num_columns() - } -} - impl Evaluations { pub const fn num_points(&self) -> usize { self.points.len() @@ -568,11 +513,7 @@ impl fmt::Display for Config { self.num_vectors, self.vector_size, self.interleaving_depth, )?; write!(f, " rate 2⁻{:.2}", -self.rate().log2())?; - write!( - f, - " samples {} in- {} out-domain", - self.in_domain_samples, self.out_domain_samples - ) + write!(f, " samples {} in-domain", self.in_domain_samples) } } @@ -582,19 +523,19 @@ impl fmt::Display for Config { /// where `L` is the list size and `degree` is the polynomial degree bound. /// See [STIR] Lemma 4.5. #[allow(clippy::cast_sign_loss)] -pub(crate) fn num_ood_samples( - unique_decoding: bool, +pub fn num_ood_samples( + decoding_regime: DecodingRegime, security_target: f64, field_size_bits: f64, list_size: f64, degree: usize, ) -> usize { - if unique_decoding { + if matches!(decoding_regime, DecodingRegime::Unique) { return 0; } - let l_choose_2 = list_size * (list_size - 1.) / 2.; - let log_per_sample = field_size_bits - ((degree - 1) as f64).log2(); + let log_per_sample = -ood_per_sample_log2(degree, field_size_bits); assert!(log_per_sample > 0.); + let l_choose_2 = list_size * (list_size - 1.) / 2.; ((security_target + l_choose_2.log2()) / log_per_sample) .ceil() .max(1.) as usize @@ -602,31 +543,22 @@ pub(crate) fn num_ood_samples( /// Return the number of in-domain queries. /// +/// Always ≥ 1 — the type carries that invariant so callers don't need to +/// re-prove it locally. +/// /// This is used by [`whir_zk`]. // TODO: A method with cleaner abstraction. #[allow(clippy::cast_sign_loss)] pub(crate) fn num_in_domain_queries( - unique_decoding: bool, + decoding_regime: DecodingRegime, security_target: f64, rate: f64, -) -> usize { - // Pick in- and out-of-domain samples. - // η = slack to Johnson bound. We pick η = √ρ / 20. - // TODO: Optimize picking η. - let johnson_slack = if unique_decoding { - 0.0 - } else { - rate.sqrt() / 20. - }; - // Query error is (1 - δ)^q, so we compute 1 - δ - let per_sample = if unique_decoding { - // Unique decoding bound: δ = (1 - ρ) / 2 - f64::midpoint(1., rate) - } else { - // Johnson bound: δ = 1 - √ρ - η - rate.sqrt() + johnson_slack - }; - (security_target / (-per_sample.log2())).ceil() as usize +) -> NonZeroUsize { + let regime = DecodingRegimeParams::from_policy(decoding_regime, rate); + // Query error is (1 - δ)^q in bits = -q · log2(1 - δ). + let log_one_minus_delta = regime.one_minus_distance_log2(-rate.log2()); + let q = (security_target / -log_one_minus_delta).ceil() as usize; + NonZeroUsize::new(q).unwrap_or(NonZeroUsize::MIN) } #[cfg(test)] @@ -683,24 +615,27 @@ pub(crate) mod tests { ) }); - (codeword_matrix, 0_usize..=10, 0_usize..=10, bool::ANY).prop_map( + (codeword_matrix, 0_usize..=10, bool::ANY).prop_map( move |( (codeword_length, matrix_commit), in_domain_samples, - out_domain_samples, - deduplicate_in_domain, - )| Self { - embedding: Typed::new(embedding.clone()), - num_vectors, - vector_size, - mask_length, - codeword_length, - interleaving_depth, - matrix_commit, - johnson_slack: OrderedFloat::default(), - in_domain_samples, - out_domain_samples, deduplicate_in_domain, + )| { + let mode = NonZeroUsize::new(mask_length).map_or(IrsMode::Standard, |n| { + IrsMode::ZeroKnowledge { mask_length: n } + }); + Self { + embedding: Typed::new(embedding.clone()), + num_vectors, + vector_size, + codeword_length, + interleaving_depth, + matrix_commit, + regime: DecodingRegimeParams::Unique, + in_domain_samples, + deduplicate_in_domain, + mode, + } }, ) } @@ -732,30 +667,6 @@ pub(crate) mod tests { &mut prover_state, &vectors.iter().map(|p| p.as_slice()).collect::>(), ); - assert_eq!( - witness.out_of_domain().points.len(), - config.out_domain_samples - ); - assert_eq!( - witness.out_of_domain().matrix.len(), - config.out_domain_samples * config.num_vectors - ); - if config.num_vectors > 0 { - for (point, evals) in zip_strict( - witness.out_of_domain().points.iter(), - witness - .out_of_domain() - .matrix - .chunks_exact(config.num_vectors), - ) { - for (vector, expected) in zip_strict(vectors.iter(), evals.iter()) { - assert_eq!( - mixed_univariate_evaluate(config.embedding(), vector, *point), - *expected - ); - } - } - } let in_domain_evals = config.open(&mut prover_state, &[&witness]); if config.deduplicate_in_domain { // Sorting is over index order, not points @@ -773,7 +684,11 @@ pub(crate) mod tests { in_domain_evals.matrix.len(), in_domain_evals.points.len() * config.num_vectors * config.interleaving_depth ); - if config.num_vectors > 0 { + // Value-correctness assertion only valid in non-ZK mode: in ZK the + // encoding is `Enc(f, r) = f(x) + x^ℓ · r(x)`, so opened values + // include the mask term. The lifecycle round-trip (open/verify + // agreement below) covers both modes. + if config.num_vectors > 0 && config.mask_length() == 0 { let base = config.vector_size / config.interleaving_depth; for (point, evals) in zip_strict( &in_domain_evals.points, @@ -799,7 +714,6 @@ pub(crate) mod tests { // Verifier let mut verifier_state = VerifierState::new_std(&ds, &proof); let commitment = config.receive_commitment(&mut verifier_state).unwrap(); - assert_eq!(commitment.out_of_domain(), witness.out_of_domain()); let verifier_in_domain_evals = config.verify(&mut verifier_state, &[&commitment]).unwrap(); assert_eq!(&verifier_in_domain_evals, &in_domain_evals); verifier_state.check_eof().unwrap(); @@ -816,13 +730,13 @@ pub(crate) mod tests { .collect::>(); let size = select(valid_sizes); - let config = (0_usize..=3, size, 1_usize..=10).prop_flat_map( - |(num_vectors, size, interleaving_depth)| { + let config = (0_usize..=3, size, 1_usize..=10, 0_usize..=8).prop_flat_map( + |(num_vectors, size, interleaving_depth, mask_length)| { Config::arbitrary( embedding.clone(), num_vectors, size * interleaving_depth, - 0, + mask_length, interleaving_depth, ) }, diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 1d632ca0..0be304db 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -47,12 +47,13 @@ use serde::{Deserialize, Serialize}; use crate::{ algebra::{embedding::Identity, random_vector, scalar_mul_add_new, univariate_evaluate}, hash::Hash, - protocols::irs_commit::{ - Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness, + protocols::{ + irs_commit::{Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness}, + proof_of_work, }, transcript::{ - Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, - VerifierMessage, VerifierState, + codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, ProverState, + VerificationResult, VerifierMessage, VerifierState, }, utils::zip_strict, verify, @@ -61,11 +62,16 @@ use crate::{ /// Mask proximity configuration. /// /// Wraps an IRS config for the shared mask tree and the number of mask pairs. +#[must_use] #[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config { pub c_zk_commit: IrsConfig>, pub num_masks: usize, + pub pow: proof_of_work::Config, + /// γ-combination analytic floor recorded by `params::mask_proximity::solve`. + /// `None` for ad-hoc construction. + pub recorded_analytic: Option, } /// Prover output from the commit phase. @@ -76,33 +82,42 @@ pub struct Witness { } /// Verifier output from the commit phase. -pub type Commitment = IrsCommitment; +pub type Commitment = IrsCommitment; impl Config { - pub fn new(c_zk_commit: IrsConfig>, num_masks: usize) -> Self { + /// Required `c_zk.num_vectors` for `num_masks` originals: one fresh + /// mask-of-mask per original (Construction 7.2 originals + fresh pairs). + pub const fn num_vectors_for(num_masks: usize) -> usize { + 2 * num_masks + } + + pub fn new( + c_zk_commit: IrsConfig>, + num_masks: usize, + pow: proof_of_work::Config, + ) -> Self { assert_eq!( c_zk_commit.num_vectors, - 2 * num_masks, + Self::num_vectors_for(num_masks), "c_zk.num_vectors must be 2 * num_masks" ); assert_eq!( c_zk_commit.interleaving_depth, 1, "mask proximity requires interleaving_depth = 1" ); - // OOD evaluations are sent in the clear during IRS commit/receive, - // which would leak raw mask values before the γ-combination and - // break the ZK contract. The OOD path in irs_commit is slated for - // removal in the new construction; until then, enforce zero here. - assert_eq!( - c_zk_commit.out_domain_samples, 0, - "mask proximity requires out_domain_samples = 0 (OOD openings would leak raw mask evaluations)" - ); Self { c_zk_commit, num_masks, + pow, + recorded_analytic: None, } } + pub const fn with_recorded_analytic(mut self, analytic: crate::bits::Bits) -> Self { + self.recorded_analytic = Some(analytic); + self + } + /// Commit all masks and their mask-of-masks in a single shared tree. /// /// Samples n fresh mask-of-mask polynomials, combines them with the @@ -148,7 +163,7 @@ impl Config { pub fn receive_commitment( &self, verifier_state: &mut VerifierState, - ) -> VerificationResult> + ) -> VerificationResult where F: Codec<[H::U]>, H: DuplexSpongeInterface, @@ -169,17 +184,22 @@ impl Config { R: RngCore + CryptoRng, Standard: Distribution, u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { assert_eq!(original_msgs.len(), self.num_masks); assert_eq!(witness.fresh_msgs.len(), self.num_masks); + // Grind the Lemma 7.4 γ-combination gap before γ is sampled. + self.pow.prove(prover_state); + // Step 1: receive combination randomness γ let gamma: F = prover_state.verifier_message(); // Step 2: compute and send combined polynomials + IRS randomness let irs_masks_per_vector = - self.c_zk_commit.mask_length * self.c_zk_commit.interleaving_depth; + self.c_zk_commit.mask_length() * self.c_zk_commit.interleaving_depth; assert_eq!( witness.mask_witness.masks.len(), 2 * self.num_masks * irs_masks_per_vector @@ -214,21 +234,26 @@ impl Config { pub fn verify( &self, verifier_state: &mut VerifierState, - commitment: &Commitment, + commitment: &Commitment, ) -> VerificationResult<()> where F: Codec<[H::U]>, H: DuplexSpongeInterface, u8: Decoding<[H::U]>, + [u8; 32]: Decoding<[H::U]>, + U64: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { + // Grind the Lemma 7.4 γ-combination gap before γ is sampled. + self.pow.verify(verifier_state)?; + // Step 1: send combination randomness γ let gamma: F = verifier_state.verifier_message(); // Step 2: read combined polynomials + IRS randomness let msg_len = self.c_zk_commit.message_length(); let irs_masks_per_vector = - self.c_zk_commit.mask_length * self.c_zk_commit.interleaving_depth; + self.c_zk_commit.mask_length() * self.c_zk_commit.interleaving_depth; let has_irs_masks = irs_masks_per_vector > 0; let mut combined_msgs = Vec::with_capacity(self.num_masks); let mut combined_rs: Option>> = @@ -313,18 +338,16 @@ mod tests { .prop_flat_map(|(num_masks, vector_size, mask_length)| { let c_zk = IrsConfig::>::arbitrary( Identity::new(), - 2 * num_masks, + Self::num_vectors_for(num_masks), vector_size, mask_length, 1, ); (Just(num_masks), c_zk) }) - .prop_filter( - "mask proximity requires out_domain_samples = 0", - |(_, c_zk)| c_zk.out_domain_samples == 0, - ) - .prop_map(|(num_masks, c_zk)| Self::new(c_zk, num_masks)) + .prop_map(|(num_masks, c_zk)| { + Self::new(c_zk, num_masks, proof_of_work::Config::none()) + }) } } @@ -454,7 +477,7 @@ mod tests { let gamma: F = prover_state.verifier_message(); let irs_masks_per_vector = - config.c_zk_commit.mask_length * config.c_zk_commit.interleaving_depth; + config.c_zk_commit.mask_length() * config.c_zk_commit.interleaving_depth; for (i, (orig_msg, fresh_msg)) in original_msgs .iter() diff --git a/src/protocols/mod.rs b/src/protocols/mod.rs index 17e58040..64f1e0b9 100644 --- a/src/protocols/mod.rs +++ b/src/protocols/mod.rs @@ -16,6 +16,7 @@ pub mod irs_commit; pub mod mask_proximity; pub mod matrix_commit; pub mod merkle_tree; +pub mod params; pub mod proof_of_work; pub mod sumcheck; pub mod whir; diff --git a/src/protocols/params/basecase.rs b/src/protocols/params/basecase.rs new file mode 100644 index 00000000..fabcbb6c --- /dev/null +++ b/src/protocols/params/basecase.rs @@ -0,0 +1,205 @@ +//! Basecase (Construction 7.2, p.43) parameter selection + γ-combination bound. + +use ark_ff::Field; + +use crate::{ + algebra::{embedding::Identity, fields::FieldWithSize}, + bits::Bits, + protocols::{ + basecase::{self, Config as BasecaseConfig}, + irs_commit::Config as IrsConfig, + params::{ + error::{grind_to_at, DeriveError, Pow}, + irs_commit as irs_params, + spec::{Mode as SpecMode, OodSampleBudget, RoundContext, SecuritySpec}, + sumcheck as sumcheck_params, + }, + proof_of_work::Config as PowConfig, + sumcheck::{self, Config as SumcheckConfig}, + }, +}; + +pub fn solve( + spec: &SecuritySpec, + vector_size: usize, + log_inv_rate: u32, +) -> Result, DeriveError> { + assert!(vector_size > 0, "basecase requires vector_size ≥ 1"); + + let ctx = RoundContext { + vector_size, + log_inv_rate, + folding_factor: 0, + }; + let commit = irs_params::solve(spec, &ctx, OodSampleBudget::ZERO); + + let sumcheck_analytic = sumcheck_params::analytic_error_bits(&commit, None); + let sumcheck_pow = grind_to_at(spec, sumcheck_analytic, Pow::BasecaseSumcheck)?; + let sumcheck = SumcheckConfig::new( + vector_size, + sumcheck_pow, + vector_size.next_power_of_two().trailing_zeros() as usize, + sumcheck::SumcheckMode::Standard, + ) + .with_recorded_analytic(sumcheck_analytic); + + let mode = match spec.mode { + SpecMode::Standard => basecase::BasecaseMode::Standard, + SpecMode::ZeroKnowledge => basecase::BasecaseMode::ZeroKnowledge, + }; + + let (pow, gamma_analytic) = match mode { + basecase::BasecaseMode::Standard => (PowConfig::none(), None), + basecase::BasecaseMode::ZeroKnowledge => { + let a = analytic_error_bits(&commit); + ( + grind_to_at(spec, a, Pow::BasecaseGammaCombination)?, + Some(a), + ) + } + }; + + let mut cfg = BasecaseConfig::new(commit, sumcheck, mode, pow); + if let Some(a) = gamma_analytic { + cfg = cfg.with_recorded_analytic(a); + } + Ok(cfg) +} + +/// γ-combination soundness (Lemma 7.4 combination-randomness slot, paper p.45). +pub fn analytic_error_bits(commit: &IrsConfig>) -> Bits { + let field_bits = F::field_size_bits(); + let log_list = commit.list_size().log2(); + let prox_gaps = commit.rbr_soundness_fold_prox_gaps(); + let poly_id = field_bits - log_list; + Bits::new(prox_gaps.min(poly_id).max(0.0)) +} + +impl BasecaseConfig { + /// Analytic soundness bits (excluding PoW): `min(sumcheck round error, γ-slot error)`. + pub fn analytic_bits(&self) -> Bits { + let sumcheck_term = f64::from(sumcheck_params::analytic_error_bits(&self.commit, None)); + let min_bits = match self.mode { + basecase::BasecaseMode::Standard => sumcheck_term, + basecase::BasecaseMode::ZeroKnowledge => { + sumcheck_term.min(f64::from(analytic_error_bits(&self.commit))) + } + }; + Bits::new(min_bits.max(0.0)) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::protocols::params::test_utils::{ + arb_standard_spec, arb_zk_spec, assert_close, assert_pow_closes_gap, deterministic_spec, + TestField, TEST_TARGET_RANGE, + }; + + const FIXTURE_VECTOR_SIZE: usize = 16; + const FIXTURE_LOG_INV_RATE: u32 = 2; + + fn arb_dims() -> impl Strategy { + (1u32..=4, 1u32..=3) + } + + #[test] + fn analytic_error_formula() { + use crate::protocols::params::{ + irs_commit as irs_params, + spec::{Mode, OodSampleBudget, RoundContext}, + }; + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = RoundContext { + vector_size: FIXTURE_VECTOR_SIZE, + log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: 0, + }; + let commit: IrsConfig> = + irs_params::solve(&spec, &ctx, OodSampleBudget::ZERO); + + let got = f64::from(analytic_error_bits(&commit)); + let field_bits = TestField::field_size_bits(); + let log_list = commit.list_size().log2(); + let prox_gaps = commit.rbr_soundness_fold_prox_gaps(); + let poly_id = field_bits - log_list; + let expected = prox_gaps.min(poly_id).max(0.0); + + assert_close(got, expected); + } + + #[test] + fn analytic_error_uses_eps_mca_when_limiting() { + use crate::protocols::params::{ + irs_commit as irs_params, + spec::{Mode, OodSampleBudget, RoundContext}, + }; + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = RoundContext { + vector_size: FIXTURE_VECTOR_SIZE, + log_inv_rate: 1, + folding_factor: 0, + }; + let commit: IrsConfig> = + irs_params::solve(&spec, &ctx, OodSampleBudget::ZERO); + + let field_bits = TestField::field_size_bits(); + let log_list = commit.list_size().log2(); + let prox_gaps = commit.rbr_soundness_fold_prox_gaps(); + let poly_id = field_bits - log_list; + assert!( + prox_gaps < poly_id, + "fixture wants prox_gaps to bind: prox_gaps {prox_gaps} ≥ poly_id {poly_id}", + ); + + let got = f64::from(analytic_error_bits(&commit)); + assert_close(got, prox_gaps.max(0.0)); + } + + proptest! { + #[test] + fn solve_standard_assembles( + spec in arb_standard_spec(TEST_TARGET_RANGE), + (log_size, log_inv_rate) in arb_dims(), + ) { + let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); + prop_assert!(matches!(config.mode, basecase::BasecaseMode::Standard)); + prop_assert_eq!(config.commit.interleaving_depth, 1); + prop_assert_eq!(config.commit.num_vectors, 1); + prop_assert_eq!(config.commit.vector_size, config.sumcheck.initial_size); + } + + #[test] + fn solve_zk_assembles( + spec in arb_zk_spec(TEST_TARGET_RANGE), + (log_size, log_inv_rate) in arb_dims(), + ) { + let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); + prop_assert!(matches!(config.mode, basecase::BasecaseMode::ZeroKnowledge)); + prop_assert!(config.commit.mask_length() > 0); + } + + #[test] + fn pow_closes_gap_to_target_zk( + spec in arb_zk_spec(TEST_TARGET_RANGE), + (log_size, log_inv_rate) in arb_dims(), + ) { + let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); + assert_pow_closes_gap(&spec, analytic_error_bits(&config.commit), &config.pow); + } + + #[test] + fn standard_mode_has_no_pow( + spec in arb_standard_spec(TEST_TARGET_RANGE), + (log_size, log_inv_rate) in arb_dims(), + ) { + let config = solve::(&spec, 1usize << log_size, log_inv_rate).unwrap(); + prop_assert_eq!(config.pow, PowConfig::none()); + } + } +} diff --git a/src/protocols/params/bounds.rs b/src/protocols/params/bounds.rs new file mode 100644 index 00000000..c8cad942 --- /dev/null +++ b/src/protocols/params/bounds.rs @@ -0,0 +1,40 @@ +//! Regime-agnostic analytic primitives shared across the params solvers. +//! +//! Regime-specific math (Unique / Johnson / Capacity branches) lives on +//! [`super::regime::DecodingRegimeParams`]. + +/// `ρ = 2^-log_inv_rate`. Centralized so the rate formula lives in one place. +pub(super) fn rate(log_inv_rate: f64) -> f64 { + 2_f64.powf(-log_inv_rate) +} + +/// Lossy `usize → f64` for analytic-error formulas. Named so individual call +/// sites can stay terse and intent-tagged. +pub(super) const fn usize_to_f64(x: usize) -> f64 { + x as f64 +} + +/// log2 of the per-OOD-sample Schwartz–Zippel error: `(k-1)/|F|`. +pub fn ood_per_sample_log2(message_length: usize, field_bits: f64) -> f64 { + ((message_length - 1) as f64).log2() - field_bits +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::params::test_utils::assert_close; + + /// OOD per-sample Schwartz–Zippel: `log₂((k−1) / |F|) = log₂(k−1) − field_bits`. + #[test] + fn ood_per_sample_log2_formula() { + // `k = 129` so `k − 1 = 128 = 2^7` for exact `log2`. + const K: usize = 129; + const FIELD_BITS: f64 = 64.0; + + let got = ood_per_sample_log2(K, FIELD_BITS); + let expected = ((K - 1) as f64).log2() - FIELD_BITS; + assert_close(got, expected); + // (k−1)/|F| < 1 for sane parameters ⇒ log is negative. + assert!(got < 0.0); + } +} diff --git a/src/protocols/params/branch.rs b/src/protocols/params/branch.rs new file mode 100644 index 00000000..0797952f --- /dev/null +++ b/src/protocols/params/branch.rs @@ -0,0 +1,65 @@ +//! Standard-vs-ZK branching used to thread mode through the build pipeline. +//! +//! [`Branch`] is the shared shape: a transient choice between the standard +//! path (no payload) and the zero-knowledge path (payload `T`). Concrete +//! pipeline stages alias it with the payload they carry: +//! +//! - [`RoundBuildMode`] — input to [`super::build_round::build_round_config`]. +//! - [`OodMode`] — input to OOD-bound helpers. +//! - [`SolveMode`] — input to the per-sub-protocol solvers (`sumcheck`, `code_switch`). +//! +//! Sharing one enum gives us a free [`Branch::map`] for stage-to-stage +//! payload conversions, replacing one-off `to_ood_mode`-style helpers. + +use crate::protocols::params::{ + protocol_config::MaskOracleInfo, + spec::{LogInvRate, ZkSpec}, +}; + +/// Standard (no payload) vs. zero-knowledge (payload `T`). +#[derive(Clone, Copy, Debug)] +pub enum Branch { + Standard, + ZeroKnowledge(T), +} + +impl Branch { + pub const fn is_zk(&self) -> bool { + matches!(self, Self::ZeroKnowledge(_)) + } + + /// Transform the ZK payload, leaving `Standard` unchanged. Replaces + /// per-stage `to_*` conversion helpers. + pub fn map(self, f: impl FnOnce(T) -> U) -> Branch { + match self { + Self::Standard => Branch::Standard, + Self::ZeroKnowledge(t) => Branch::ZeroKnowledge(f(t)), + } + } + + pub const fn as_ref(&self) -> Branch<&T> { + match self { + Self::Standard => Branch::Standard, + Self::ZeroKnowledge(t) => Branch::ZeroKnowledge(t), + } + } +} + +/// Payload carried by [`RoundBuildMode::ZeroKnowledge`] — references the +/// `SecuritySpec` (so its lifetime threads through) plus the planner-chosen +/// `C_zk` rate. +#[derive(Clone, Copy, Debug)] +pub struct RoundBuildPayload<'a> { + pub zk_spec: ZkSpec<'a>, + pub c_zk_log_inv_rate: LogInvRate, +} + +/// Mode-dispatch input for [`super::build_round::build_round_config`]. +pub type RoundBuildMode<'a> = Branch>; + +/// Mode flag for the OOD security bound. Payload is the `C_zk` log-inverse +/// rate as `f64` (already coerced for the analytic formula). +pub type OodMode = Branch; + +/// Solver-input mode for the per-round sumcheck and code-switch builders. +pub type SolveMode = Branch; diff --git a/src/protocols/params/build_round.rs b/src/protocols/params/build_round.rs new file mode 100644 index 00000000..420d3ddb --- /dev/null +++ b/src/protocols/params/build_round.rs @@ -0,0 +1,240 @@ +//! Per-round build: turns a [`RoundShape`] into a [`RoundConfig`]. +//! +//! Solves the `t_ood` fix-point, builds source/target IRS configs, and +//! (in ZK) assembles the per-round mask oracle. Consumed by +//! [`super::derive`], which drives the per-round loop. + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, + protocols::{ + irs_commit::Config as IrsConfig, + mask_proximity::Config as MaskProximityConfig, + params::{ + bounds::usize_to_f64, + branch::{Branch, OodMode, RoundBuildMode, RoundBuildPayload, SolveMode}, + code_switch as code_switch_params, + error::{DeriveError, Pow}, + irs_commit as irs_params, + layout::{round_context, target_context, RoundShape}, + mask_proximity as mask_proximity_params, + protocol_config::{MaskOracleConfig, RoundConfig, RoundMode}, + spec::{ + DecodingRegime, LogInvRate, MaskCodeMessageLen, OodSampleBudget, RoundContext, + SecuritySpec, ZkSpec, + }, + sumcheck as sumcheck_params, + }, + }, +}; + +const T_OOD_MAX_ITER: usize = 32; + +pub(super) fn build_round_config( + spec: &SecuritySpec, + shape: &RoundShape, + mode: RoundBuildMode<'_>, +) -> Result, DeriveError> { + let ctx = round_context(shape); + let ood_mode = mode.map(|p| f64::from(p.c_zk_log_inv_rate.get())); + let (source, t_ood) = solve_round_source::(spec, shape, ood_mode)?; + + let (target_budget, solve_mode, round_mode, mask_oracle) = match mode { + Branch::Standard => ( + OodSampleBudget::ZERO, + SolveMode::Standard, + RoundMode::Standard, + None, + ), + Branch::ZeroKnowledge(RoundBuildPayload { + zk_spec, + c_zk_log_inv_rate, + }) => { + let num_masks = + sumcheck_params::masks_required(&ctx) + code_switch_params::masks_required(); + let mask_oracle = build_mask_oracle::( + zk_spec, + &source, + t_ood, + num_masks, + c_zk_log_inv_rate, + shape.round_index, + )?; + let solve_mode = SolveMode::ZeroKnowledge(mask_oracle.info()); + let round_mode = RoundMode::ZeroKnowledge { + t_ood: OodSampleBudget::new(t_ood), + }; + ( + OodSampleBudget::new(t_ood), + solve_mode, + round_mode, + Some(mask_oracle), + ) + } + }; + + let target: IrsConfig> = + irs_params::solve(spec, &target_context(shape, &source), target_budget); + let sumcheck = sumcheck_params::solve( + spec, + &ctx, + &source, + solve_mode, + Pow::RoundSumcheck { + index: shape.round_index, + }, + )?; + let code_switch = + code_switch_params::solve(spec, source, target, t_ood, solve_mode, shape.round_index)?; + + Ok(RoundConfig::new( + shape.round_index, + sumcheck, + code_switch, + round_mode, + mask_oracle, + )) +} + +fn solve_round_source( + spec: &SecuritySpec, + shape: &RoundShape, + ood_mode: OodMode, +) -> Result<(IrsConfig, usize), DeriveError> { + let src_ctx = round_context(shape); + let target_log_inv_rate = f64::from( + shape + .source_log_inv_rate + .saturating_add(shape.source_folding_factor.saturating_sub(1)), + ); + let target_log_degree = f64::from( + shape + .source_vector_size + .trailing_zeros() + .saturating_sub(shape.source_folding_factor), + ); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, target_log_inv_rate); + solve_t_ood::( + spec, + &src_ctx, + target_list_size, + ood_mode, + shape.round_index, + ) +} + +/// ZK-only: assemble the per-round mask oracle (C_zk codeword + mask-proximity +/// check). +fn build_mask_oracle( + zk_spec: ZkSpec<'_>, + source: &IrsConfig, + t_ood: usize, + num_masks: usize, + c_zk_log_inv_rate: LogInvRate, + round_index: usize, +) -> Result, DeriveError> { + let spec = zk_spec.as_inner(); + let l_zk = compute_l_zk(source, t_ood); + let c_zk: IrsConfig> = irs_params::solve_mask_code( + zk_spec, + l_zk, + source.mask_length(), + c_zk_log_inv_rate, + MaskProximityConfig::::num_vectors_for(num_masks), + ); + let c_zk_list_size_estimate = spec.decoding_regime.list_size_estimate( + (l_zk.get() as f64).log2(), + f64::from(c_zk_log_inv_rate.get()), + ); + debug_assert!( + (c_zk.list_size() - c_zk_list_size_estimate).abs() + < 1e-9 * c_zk_list_size_estimate.max(1.0), + "c_zk.list_size() {} drifted from planner estimate {}", + c_zk.list_size(), + c_zk_list_size_estimate, + ); + let mask_proximity = mask_proximity_params::solve(spec, c_zk.clone(), num_masks, round_index)?; + Ok(MaskOracleConfig::new(c_zk, l_zk, mask_proximity)) +} + +/// `ℓ_zk = next_pow2(r + t_ood)` (Theorem 9.6 + Lemma 9.3). +pub(super) const fn compute_l_zk( + source: &IrsConfig, + t_ood: usize, +) -> MaskCodeMessageLen { + MaskCodeMessageLen::new( + source + .mask_length() + .saturating_add(t_ood) + .next_power_of_two(), + ) +} + +/// Per-round `(source, t_ood)`. +/// +/// Under `Unique`, `t_ood = 1` is pinned (the `log(L·(L−1)/2)` term degenerates +/// when `L = 1`, and Construction 9.7 requires `out_domain_samples ≥ 1`). +/// Otherwise linear search over `t_ood = 1..=T_OOD_MAX_ITER` for the smallest +/// value where [`ood_security_bits_at`] meets `protocol_security_target_bits`. +pub(super) fn solve_t_ood( + spec: &SecuritySpec, + src_ctx: &RoundContext, + target_list_size: f64, + ood_mode: OodMode, + round_index: usize, +) -> Result<(IrsConfig, usize), DeriveError> { + if matches!(spec.decoding_regime, DecodingRegime::Unique) { + let source = irs_params::solve(spec, src_ctx, OodSampleBudget::new(1)); + return Ok((source, 1)); + } + + let security_target = f64::from(spec.protocol_security_target_bits()); + let field_bits = M::Target::field_size_bits(); + + for t_ood in 1..=T_OOD_MAX_ITER { + let source: IrsConfig = irs_params::solve(spec, src_ctx, OodSampleBudget::new(t_ood)); + let bits = + ood_security_bits_at(spec, &source, t_ood, target_list_size, ood_mode, field_bits); + if bits >= security_target { + return Ok((source, t_ood)); + } + } + Err(DeriveError::FixedPointDidNotConverge { round_index }) +} + +/// OOD security bits at candidate `t_ood`, per STIR Lemma 4.5: +/// `bits = t · (|F| − log d) − log(L · (L − 1) / 2) ≈ t·(|F| − log d) − 2·log L + 1`. +fn ood_security_bits_at( + spec: &SecuritySpec, + source: &IrsConfig, + t_ood: usize, + target_list_size: f64, + ood_mode: OodMode, + field_bits: f64, +) -> f64 { + let (log_degree, log_combined_list) = match ood_mode { + Branch::Standard => ( + usize_to_f64(source.message_length()).log2(), + target_list_size.log2(), + ), + Branch::ZeroKnowledge(c_zk_log_inv_rate) => { + let l_zk = source + .mask_length() + .saturating_add(t_ood) + .next_power_of_two(); + let c_zk_list = spec + .decoding_regime + .list_size_estimate(usize_to_f64(l_zk).log2(), c_zk_log_inv_rate); + ( + usize_to_f64(source.message_length().saturating_add(l_zk)).log2(), + (target_list_size * c_zk_list).log2(), + ) + } + }; + let ood = usize_to_f64(t_ood); + ood * (field_bits - log_degree) - 2.0 * log_combined_list + 1.0 +} diff --git a/src/protocols/params/code_switch.rs b/src/protocols/params/code_switch.rs new file mode 100644 index 00000000..b18f3cca --- /dev/null +++ b/src/protocols/params/code_switch.rs @@ -0,0 +1,438 @@ +//! Code-switching IOR (Construction 9.7, p.55) builder + Lemma 9.9 OOD bound. + +use std::num::NonZeroUsize; + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, + bits::Bits, + protocols::{ + code_switch::{self, Config as CodeSwitchConfig}, + irs_commit::Config as IrsConfig, + params::{ + bounds::usize_to_f64, + branch::SolveMode, + error::{grind_to_at, DeriveError, Pow}, + protocol_config::MaskOracleInfo, + spec::SecuritySpec, + }, + }, +}; + +/// Per-round code-switch builder. +pub fn solve( + spec: &SecuritySpec, + source: IrsConfig, + target: IrsConfig>, + t_ood: usize, + mode: SolveMode, + round_index: usize, +) -> Result, DeriveError> { + let (mask_oracle, output_mode) = match mode { + SolveMode::Standard => (None, code_switch::CodeSwitchMode::Standard), + SolveMode::ZeroKnowledge(mask_oracle) => { + let l_zk = mask_oracle.l_zk.get(); + assert!( + l_zk >= source.mask_length().saturating_add(t_ood), + "ℓ_zk ({l_zk}) < r + t_ood ({} + {}) — violates Theorem 9.6 witness sizing", + source.mask_length(), + t_ood, + ); + ( + Some(mask_oracle), + code_switch::CodeSwitchMode::ZeroKnowledge { + message_mask_length: NonZeroUsize::new(l_zk).expect("ℓ_zk > 0"), + }, + ) + } + }; + + let analytic = analytic_error_bits(&source, &target, t_ood, mask_oracle); + let pow = grind_to_at(spec, analytic, Pow::RoundCodeSwitch { index: round_index })?; + + Ok( + CodeSwitchConfig::new(source, target, t_ood, output_mode, pow) + .with_recorded_analytic(analytic), + ) +} + +/// Per-round code-switch soundness in bits: `min` over Lemma 9.9's three RBR +/// error slots (OOD, in-domain, combination). +pub fn analytic_error_bits( + source: &IrsConfig, + target: &IrsConfig>, + t_ood: usize, + mask_oracle: Option, +) -> Bits { + assert!(t_ood > 0, "code-switch requires t_ood ≥ 1"); + + let field_bits = M::Target::field_size_bits(); + let combined_list = + target.list_size() * mask_oracle.map_or(1.0, |info| info.c_zk_list_size.get()); + // OOD polynomial is over witness `[f; r_C; s]` of length `ℓ + ℓ_zk` (ZK) or + // `ℓ` (Standard). + let degree = mask_oracle.map_or_else( + || source.message_length(), + |info| source.message_length().saturating_add(info.l_zk.get()), + ); + let t_ood_f = usize_to_f64(t_ood); + + // OOD term — Lemma 9.9, term 1. + let log_degree_minus_1 = usize_to_f64(degree.saturating_sub(1)).log2(); + let log_l_choose_2 = (combined_list * (combined_list - 1.0) / 2.0).log2(); + let ood_term = t_ood_f * (field_bits - log_degree_minus_1) - log_l_choose_2; + + // In-domain term — Lemma 9.9, term 2. + let in_domain_term = source.rbr_queries(); + + // Combination term — Lemma 9.9, term 3 (γ-RLC, bounds doc §5.1). + let log_count = + usize_to_f64(t_ood.saturating_add(source.in_domain_samples * source.interleaving_depth)) + .log2(); + let combination_term = field_bits - log_count - combined_list.log2(); + + Bits::new(ood_term.min(in_domain_term).min(combination_term).max(0.0)) +} + +/// Number of `(r ‖ s)` mask polynomials code-switch contributes to C_zk per +/// round. +pub const fn masks_required() -> usize { + 1 +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::{ + hash, + protocols::params::{ + branch::OodMode, + build_round::{compute_l_zk, solve_t_ood}, + irs_commit as irs_params, + spec::{ + DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + PowBudget, RoundContext, SecuritySpec, ZkSpec, + }, + test_utils::{ + arb_standard_spec as utils_standard_spec, arb_zk_spec as utils_zk_spec, + assert_close, assert_pow_closes_gap, build_round_io, deterministic_spec, + TestEmbedding, TestExtensionField, TestField, TestNonIdentityEmbedding, + TEST_TARGET_RANGE, + }, + }, + }; + + type M = TestEmbedding; + + fn arb_zk_spec() -> impl Strategy { + utils_zk_spec(TEST_TARGET_RANGE) + } + + fn arb_standard_spec() -> impl Strategy { + utils_standard_spec(TEST_TARGET_RANGE) + } + + const NUM_VARS_HEADROOM: u32 = 4; + + fn arb_dims() -> impl Strategy { + (1u32..=3, 1u32..=2).prop_flat_map(|(log_inv_rate, folding_factor)| { + let min_num_vars = 2 * folding_factor; + ( + Just(log_inv_rate), + Just(folding_factor), + min_num_vars..=(min_num_vars + NUM_VARS_HEADROOM), + ) + }) + } + + const FORMULA_LOG_INV_RATE: u32 = 1; + const FORMULA_FOLDING_FACTOR: u32 = 2; + const FORMULA_NUM_VARS: u32 = 6; + + #[test] + fn analytic_error_standard_formula() { + let spec: SecuritySpec = deterministic_spec(Mode::Standard); + let (source, target, t_ood) = build_round_io::( + &spec, + FORMULA_LOG_INV_RATE, + FORMULA_FOLDING_FACTOR, + FORMULA_NUM_VARS, + None, + ); + let got = f64::from(analytic_error_bits(&source, &target, t_ood, None)); + + let field_bits = ::field_size_bits(); + let target_list = target.list_size(); + let degree = source.message_length(); + let log_deg_m1 = ((degree - 1) as f64).log2(); + let l_choose_2 = target_list * (target_list - 1.0) / 2.0; + let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); + let in_domain = source.rbr_queries(); + let count = t_ood + source.in_domain_samples * source.interleaving_depth; + let comb = field_bits - (count as f64).log2() - target_list.log2(); + let expected = ood.min(in_domain).min(comb).max(0.0); + + assert_close(got, expected); + } + + #[test] + fn analytic_error_zk_formula() { + const C_ZK_LIST_SIZE: f64 = 4.0; + const L_ZK_USIZE: usize = 8; + + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let mask_oracle = MaskOracleInfo { + c_zk_list_size: ListSize::new(C_ZK_LIST_SIZE), + l_zk: MaskCodeMessageLen::new(L_ZK_USIZE), + }; + let (source, target, t_ood) = build_round_io::( + &spec, + FORMULA_LOG_INV_RATE, + FORMULA_FOLDING_FACTOR, + FORMULA_NUM_VARS, + Some(FORMULA_LOG_INV_RATE), + ); + let got = f64::from(analytic_error_bits( + &source, + &target, + t_ood, + Some(mask_oracle), + )); + + let field_bits = ::field_size_bits(); + let target_list = target.list_size(); + let combined_list = target_list * C_ZK_LIST_SIZE; + let degree = source.message_length() + L_ZK_USIZE; + let log_deg_m1 = ((degree - 1) as f64).log2(); + let l_choose_2 = combined_list * (combined_list - 1.0) / 2.0; + let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); + let in_domain = source.rbr_queries(); + let count = t_ood + source.in_domain_samples * source.interleaving_depth; + let comb = field_bits - (count as f64).log2() - target_list.log2() - C_ZK_LIST_SIZE.log2(); + let expected = ood.min(in_domain).min(comb).max(0.0); + + assert_close(got, expected); + } + + #[test] + fn analytic_error_uses_in_domain_when_limiting() { + const LIMITING_TARGET_BITS: u32 = 16; + const LIMITING_LOG_INV_RATE: u32 = 1; + const LIMITING_FOLDING_FACTOR: u32 = 1; + const LIMITING_NUM_VARS: u32 = 4; + + let spec = SecuritySpec { + mode: Mode::Standard, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: LIMITING_TARGET_BITS, + pow_budget: PowBudget::Forbidden, + hash_id: hash::BLAKE3, + }; + let (source, target, t_ood) = build_round_io::( + &spec, + LIMITING_LOG_INV_RATE, + LIMITING_FOLDING_FACTOR, + LIMITING_NUM_VARS, + None, + ); + + let field_bits = ::field_size_bits(); + let target_list = target.list_size(); + let degree = source.message_length(); + let log_deg_m1 = ((degree - 1) as f64).log2(); + let l_choose_2 = target_list * (target_list - 1.0) / 2.0; + let ood = (t_ood as f64) * (field_bits - log_deg_m1) - l_choose_2.log2(); + let in_domain = source.rbr_queries(); + let count = t_ood + source.in_domain_samples * source.interleaving_depth; + let comb = field_bits - (count as f64).log2() - target_list.log2(); + assert!( + in_domain < ood && in_domain < comb, + "fixture wants in_domain to bind: in_domain {in_domain}, ood {ood}, comb {comb}", + ); + + let got = f64::from(analytic_error_bits(&source, &target, t_ood, None)); + assert_close(got, in_domain.max(0.0)); + } + + proptest! { + #[test] + fn solve_standard_assembles( + spec in arb_standard_spec(), + (log_inv_rate, folding_factor, num_vars) in arb_dims(), + ) { + let (source, target, t_ood) = + build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); + let config = solve(&spec, source, target, t_ood, SolveMode::Standard, 0).unwrap(); + prop_assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); + prop_assert!(config.out_domain_samples >= 1); + } + + #[test] + fn solve_zk_mask_equals_padded_r_plus_t_ood( + spec in arb_zk_spec(), + (log_inv_rate, folding_factor, num_vars) in arb_dims(), + ) { + let (source, target, t_ood) = build_round_io::( + &spec, log_inv_rate, folding_factor, num_vars, Some(log_inv_rate), + ); + let r = source.mask_length(); + let l_zk = compute_l_zk(&source, t_ood); + let zk_spec = ZkSpec::try_new(&spec).expect("arb_zk_spec"); + let c_zk = irs_params::solve_mask_code::( + zk_spec, + l_zk, + r, + LogInvRate::new(log_inv_rate), + 2, + ); + let mask_oracle = MaskOracleInfo { + c_zk_list_size: ListSize::new(c_zk.list_size()), + l_zk, + }; + let config = solve( + &spec, + source, + target, + t_ood, + SolveMode::ZeroKnowledge(mask_oracle), + 0, + ) + .unwrap(); + prop_assert_eq!(config.message_mask_length(), (r + t_ood).next_power_of_two()); + } + + #[test] + fn pow_closes_gap_to_target_standard( + spec in arb_standard_spec(), + (log_inv_rate, folding_factor, num_vars) in arb_dims(), + ) { + let (source, target, t_ood) = + build_round_io::(&spec, log_inv_rate, folding_factor, num_vars, None); + let error = analytic_error_bits(&source, &target, t_ood, None); + let config = solve(&spec, source, target, t_ood, SolveMode::Standard, 0).unwrap(); + assert_pow_closes_gap(&spec, error, &config.pow); + } + } + + fn non_identity_smoke_ctxs() -> (RoundContext, RoundContext) { + const SOURCE_VECTOR_SIZE: usize = 64; + const SOURCE_LOG_INV_RATE: u32 = 1; + const FOLDING_FACTOR: u32 = 2; + + let source_ctx = RoundContext { + vector_size: SOURCE_VECTOR_SIZE, + log_inv_rate: SOURCE_LOG_INV_RATE, + folding_factor: FOLDING_FACTOR, + }; + let target_ctx = RoundContext { + vector_size: source_ctx.vector_size / (1 << source_ctx.folding_factor), + log_inv_rate: source_ctx.log_inv_rate + source_ctx.folding_factor - 1, + folding_factor: source_ctx.folding_factor, + }; + (source_ctx, target_ctx) + } + + #[test] + #[should_panic(expected = "violates Theorem 9.6")] + fn solve_zk_rejects_l_zk_below_r_plus_t_ood() { + const TOO_SMALL_L_ZK: usize = 1; + + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let (source, target, t_ood) = build_round_io::( + &spec, + FORMULA_LOG_INV_RATE, + FORMULA_FOLDING_FACTOR, + FORMULA_NUM_VARS, + Some(FORMULA_LOG_INV_RATE), + ); + assert!(source.mask_length() + t_ood > TOO_SMALL_L_ZK); + + let mask_oracle = MaskOracleInfo { + c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), + l_zk: MaskCodeMessageLen::new(TOO_SMALL_L_ZK), + }; + let _ = solve( + &spec, + source, + target, + t_ood, + SolveMode::ZeroKnowledge(mask_oracle), + 0, + ); + } + + #[test] + fn solve_works_with_basefield_embedding_standard() { + let spec: SecuritySpec = deterministic_spec(Mode::Standard); + let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); + let target_log_degree = + f64::from((source_ctx.vector_size / (1 << source_ctx.folding_factor)).trailing_zeros()); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, f64::from(target_ctx.log_inv_rate)); + let (source, t_ood) = solve_t_ood::( + &spec, + &source_ctx, + target_list_size, + OodMode::Standard, + 0, + ) + .unwrap(); + let target = irs_params::solve::>( + &spec, + &target_ctx, + OodSampleBudget::ZERO, + ); + + let config = solve(&spec, source, target, t_ood, SolveMode::Standard, 0).unwrap(); + assert!(matches!(config.mode, code_switch::CodeSwitchMode::Standard)); + } + + const SMOKE_C_ZK_LIST_SIZE: f64 = 4.0; + + #[test] + fn solve_works_with_basefield_embedding_zk() { + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let (source_ctx, target_ctx) = non_identity_smoke_ctxs(); + let target_log_degree = + f64::from((source_ctx.vector_size / (1 << source_ctx.folding_factor)).trailing_zeros()); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, f64::from(target_ctx.log_inv_rate)); + let (source, t_ood) = solve_t_ood::( + &spec, + &source_ctx, + target_list_size, + OodMode::ZeroKnowledge(f64::from(source_ctx.log_inv_rate)), + 0, + ) + .unwrap(); + let target = irs_params::solve::>( + &spec, + &target_ctx, + OodSampleBudget::new(t_ood), + ); + + let mask_oracle = MaskOracleInfo { + c_zk_list_size: ListSize::new(SMOKE_C_ZK_LIST_SIZE), + l_zk: MaskCodeMessageLen::new((source.mask_length() + t_ood).next_power_of_two()), + }; + let config = solve( + &spec, + source, + target, + t_ood, + SolveMode::ZeroKnowledge(mask_oracle), + 0, + ) + .unwrap(); + assert!(matches!( + config.mode, + code_switch::CodeSwitchMode::ZeroKnowledge { .. } + )); + } +} diff --git a/src/protocols/params/derive.rs b/src/protocols/params/derive.rs new file mode 100644 index 00000000..fec28001 --- /dev/null +++ b/src/protocols/params/derive.rs @@ -0,0 +1,698 @@ +//! Derives a [`ProtocolConfig`] from a spec + tuning. + +use crate::{ + algebra::embedding::Embedding, + protocols::params::{ + basecase as basecase_params, + branch::{Branch, RoundBuildMode, RoundBuildPayload}, + build_round::build_round_config, + error::DeriveError, + layout::{round_layout, RoundLayout}, + protocol_config::{ProtocolConfig, RoundConfig}, + spec::{LogInvRate, Mode, SecuritySpec, TuningSpec, ZkSpec}, + }, +}; + +impl ProtocolConfig { + /// Fails with [`DeriveError`] when the spec/tuning combination is + /// infeasible. + pub fn derive(spec: SecuritySpec, tuning: TuningSpec) -> Result { + let RoundLayout { + shapes, + basecase_vector_size, + basecase_log_inv_rate, + } = round_layout(&tuning)?; + + let mode: RoundBuildMode<'_> = match spec.mode { + Mode::Standard => Branch::Standard, + Mode::ZeroKnowledge => Branch::ZeroKnowledge(RoundBuildPayload { + zk_spec: ZkSpec::try_new(&spec).expect("matched Mode::ZeroKnowledge above"), + c_zk_log_inv_rate: LogInvRate::new(tuning.starting_log_inv_rate), + }), + }; + + let rounds: Vec> = shapes + .iter() + .map(|shape| build_round_config::(&spec, shape, mode)) + .collect::>()?; + + let basecase = basecase_params::solve(&spec, basecase_vector_size, basecase_log_inv_rate)?; + + let plan = Self::new(spec, tuning, rounds, basecase); + plan.validate()?; + Ok(plan) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use crate::{ + algebra::{ + embedding::Embedding, + fields::{Field64, FieldWithSize}, + }, + hash, + protocols::{ + basecase::BasecaseMode, + params::{ + basecase as basecase_params, code_switch as code_switch_params, + error::{ChainSource, ChainTarget, DeriveError, Pow}, + mask_proximity as mask_proximity_params, + protocol_config::{ProtocolConfig, RoundMode}, + spec::{DecodingRegime, FoldingFactor, Mode, PowBudget, SecuritySpec, TuningSpec}, + sumcheck as sumcheck_params, + test_utils::{assert_close, assert_pow_closes_gap, TestEmbedding}, + }, + }, + }; + + fn arb_tuning() -> impl Strategy { + let folding = prop_oneof![ + (1usize..=3).prop_map(FoldingFactor::Constant), + (1usize..=3, 1usize..=3).prop_map(|(initial, rest)| { + FoldingFactor::ConstantFromSecondRound { initial, rest } + }), + ]; + (4u32..=8, 1u32..=3, folding).prop_map(|(log_size, log_inv_rate, folding_factor)| { + TuningSpec { + vector_size: 1usize << log_size, + starting_log_inv_rate: log_inv_rate, + folding_factor, + } + }) + } + + const FIXTURE_FOLDING_FACTOR: usize = 2; + const FIXTURE_LOG_INV_RATE: u32 = 1; + + const LOG_VECTOR_SIZE_NO_ROUNDS: u32 = 3; + const LOG_VECTOR_SIZE_MULTI_ROUND: u32 = 8; + + fn tuning_with(vector_size: usize) -> TuningSpec { + TuningSpec { + vector_size, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + } + } + + const PLAN_FIXTURE_TARGET_BITS: u32 = 40; + + fn test_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: PLAN_FIXTURE_TARGET_BITS, + pow_budget: PowBudget::per_slot(LOOSE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + } + } + + #[test] + fn derive_standard_with_no_rounds_uses_basecase_only() { + let spec = test_spec(Mode::Standard); + let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; + let plan = ProtocolConfig::::derive(spec, tuning_with(vector_size)).unwrap(); + assert!(plan.rounds().is_empty()); + assert_eq!(plan.basecase().commit.vector_size, vector_size); + } + + #[test] + fn derive_zk_with_no_rounds_uses_zk_basecase_only() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), + ) + .unwrap(); + assert!(plan.rounds().is_empty()); + assert!(matches!( + plan.basecase().mode, + BasecaseMode::ZeroKnowledge + )); + } + + #[test] + fn t_ood_nonzero_in_johnson_zk() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Johnson, + ..test_spec(Mode::ZeroKnowledge) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + for r in plan.rounds() { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + panic!("expected ZK round") + }; + assert!(t_ood.get() >= 1); + } + } + + #[test] + fn t_ood_pinned_to_one_in_unique_zk() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Unique, + ..test_spec(Mode::ZeroKnowledge) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + for r in plan.rounds() { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + panic!("expected ZK round") + }; + assert_eq!(t_ood.get(), 1); + } + } + + #[test] + fn c_zk_keeps_code_switch_mask_under_unique() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Unique, + ..test_spec(Mode::ZeroKnowledge) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + for r in plan.rounds() { + let mask_oracle = r.mask_oracle().expect("ZK round has a mask oracle"); + let k = r.code_switch().source.interleaving_depth.trailing_zeros() as usize; + let expected_num_masks = k + 1; + assert_eq!(mask_oracle.c_zk().num_vectors, 2 * expected_num_masks); + } + } + + #[test] + fn analytic_bits_finite_and_positive_standard() { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + let bits: f64 = plan.analytic_bits().into(); + assert!(bits.is_finite() && bits > 0.0, "bits = {bits}"); + let min_round = plan + .rounds() + .iter() + .map(|r| f64::from(r.analytic_bits())) + .fold(f64::INFINITY, f64::min); + let expected = min_round.min(f64::from(plan.basecase().analytic_bits())); + assert_close(bits, expected); + } + + #[test] + fn analytic_bits_includes_mask_oracle_in_zk() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + let plan_bits: f64 = plan.analytic_bits().into(); + let mo_floor = plan + .rounds() + .iter() + .filter_map(|r| r.mask_oracle().map(|mo| f64::from(mo.analytic_bits()))) + .fold(f64::INFINITY, f64::min); + assert!( + mo_floor.is_finite(), + "ZK plan must contribute mask-oracle bits" + ); + let min_round = plan + .rounds() + .iter() + .map(|r| f64::from(r.analytic_bits())) + .fold(f64::INFINITY, f64::min); + let expected = mo_floor + .min(min_round) + .min(f64::from(plan.basecase().analytic_bits())); + assert_close(plan_bits, expected); + } + + #[test] + fn derive_plans_basecase() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(matches!( + plan.basecase().mode, + BasecaseMode::ZeroKnowledge + )); + assert_eq!(plan.basecase().commit.interleaving_depth, 1); + assert_eq!(plan.basecase().sumcheck.final_size(), 1); + } + + const LOOSE_POW_BUDGET_BITS: u32 = 60; + const OVER_BUDGET_INJECTED_BITS: f64 = 50.0; + + /// Bounds doc §5.3 + §5.7: HVZK privacy error in bits matches the closed + /// form `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` over ZK rounds. + #[test] + fn privacy_error_bits_matches_bound_3_sum() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + let field_bits = ::field_size_bits(); + let mut expected_total = 0.0_f64; + for r in plan.rounds() { + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + panic!("expected ZK round"); + }; + let t = t_ood.get() as f64; + expected_total += 2_f64.powf(f64::midpoint(t * t, t).log2() - field_bits); + } + let expected_bits = -expected_total.log2(); + let got = f64::from(plan.privacy_error_bits()); + assert_close(got, expected_bits); + } + + #[test] + fn privacy_error_bits_standard_returns_target_sentinel() { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert_close( + f64::from(plan.privacy_error_bits()), + f64::from(PLAN_FIXTURE_TARGET_BITS), + ); + } + + #[test] + fn check_pow_bits_passes_on_derived_plan() { + let plan = ProtocolConfig::::derive( + test_spec(Mode::ZeroKnowledge), + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(plan.check_pow_bits()); + } + + #[test] + fn check_pow_bits_detects_over_budget_slot() { + use crate::{bits::Bits, protocols::proof_of_work::Config as PowConfig}; + const MODERATE_POW_BUDGET_BITS: u32 = 30; + let spec = SecuritySpec { + pow_budget: PowBudget::per_slot(MODERATE_POW_BUDGET_BITS), + ..test_spec(Mode::ZeroKnowledge) + }; + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + plan.override_basecase_pow_for_test(PowConfig::from_difficulty(Bits::new( + OVER_BUDGET_INJECTED_BITS, + ))); + assert!(!plan.check_pow_bits()); + } + + #[test] + fn validate_round_chaining_detects_adjacent_round_mismatch() { + let spec = test_spec(Mode::ZeroKnowledge); + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + let n = plan.rounds().len(); + assert!(n >= 2, "need ≥ 2 rounds to break a mid-chain link"); + assert!(plan.check_all_invariants(), "fresh plan must validate"); + + let bad_size = plan.rounds()[0].code_switch().target.vector_size + 1; + plan.corrupt_round_target_vector_size_for_test(0, bad_size); + + let err = plan + .validate_round_chaining() + .expect_err("adjacent-round mismatch must trip the chain check"); + assert!( + matches!( + err, + DeriveError::RoundChainBroken { + from: ChainSource::Round(0), + to: ChainTarget::NextRound(1), + .. + } + ), + "got {err:?}", + ); + assert!(!plan.check_all_invariants()); + } + + #[test] + fn validate_round_chaining_detects_basecase_mismatch() { + let spec = test_spec(Mode::ZeroKnowledge); + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + let n = plan.rounds().len(); + assert!(n >= 2, "need ≥ 2 rounds to break the chain by truncation"); + assert!(plan.check_all_invariants(), "fresh plan must validate"); + + plan.truncate_rounds_for_test(n - 1); + let err = plan + .validate_round_chaining() + .expect_err("truncated tail breaks basecase chaining"); + assert!( + matches!( + err, + DeriveError::RoundChainBroken { + to: ChainTarget::Basecase, + .. + } + ), + "got {err:?}", + ); + assert!(!plan.check_all_invariants()); + } + + #[test] + fn validate_security_target_met_passes_on_fresh_plan() { + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + plan.validate_security_target_met() + .expect("fresh plan must satisfy per-slot target check"); + } + + #[test] + fn validate_security_target_met_catches_recorded_analytic_drift() { + use crate::bits::Bits; + let spec = test_spec(Mode::ZeroKnowledge); + let mut plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "need a round to corrupt"); + let recorded = plan + .rounds() + .first() + .and_then(|r| r.sumcheck().recorded_analytic) + .expect("params solver records sumcheck analytic"); + // Bump the recorded value far from the recompute → triggers drift. + plan.corrupt_round_sumcheck_recorded_analytic_for_test( + 0, + Bits::new(f64::from(recorded) + 10.0), + ); + let err = plan + .validate_security_target_met() + .expect_err("recorded vs recompute mismatch must trip drift check"); + assert!( + matches!( + err, + DeriveError::AnalyticDrift { + pow: Pow::RoundSumcheck { index: 0 }, + .. + } + ), + "got {err:?}", + ); + } + + #[test] + fn derive_reports_pow_ungrindable() { + const UNREACHABLE_TARGET_BITS: u32 = 200; + let spec = SecuritySpec { + target_security_bits: UNREACHABLE_TARGET_BITS, + ..test_spec(Mode::Standard) + }; + let err = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .expect_err("target above grind cap must fail"); + assert!( + matches!(err, DeriveError::PowUngrindable { .. }), + "got {err:?}", + ); + } + + #[test] + fn derive_reports_pow_budget_exceeded() { + const TIGHT_MAX_POW: u32 = 5; + let spec = SecuritySpec { + pow_budget: PowBudget::per_slot(TIGHT_MAX_POW), + ..test_spec(Mode::ZeroKnowledge) + }; + let err = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .expect_err("tight pow_budget must trip auto-validation"); + assert!( + matches!(err, DeriveError::PowBudgetExceeded { .. }), + "got {err:?}", + ); + } + + #[test] + fn derive_threads_unique_decoding_standard() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Unique, + ..test_spec(Mode::Standard) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), + ) + .unwrap(); + assert!(plan.rounds().is_empty()); + assert!(plan.basecase().commit.unique_decoding()); + } + + #[test] + fn derive_threads_unique_decoding_zk() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Unique, + ..test_spec(Mode::ZeroKnowledge) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_NO_ROUNDS), + ) + .unwrap(); + assert!(plan.rounds().is_empty()); + assert!(plan.basecase().commit.unique_decoding()); + } + + #[test] + fn derive_multi_round_unique_decoding_succeeds() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Unique, + ..test_spec(Mode::Standard) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "expected multi-round plan"); + for r in plan.rounds() { + let cs = r.code_switch(); + assert!(cs.source.unique_decoding()); + assert!(cs.target.unique_decoding()); + assert!(cs.out_domain_samples >= 1); + } + assert!(plan.basecase().commit.unique_decoding()); + } + + #[test] + fn derive_multi_round_unique_decoding_zk_succeeds() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Unique, + ..test_spec(Mode::ZeroKnowledge) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "expected multi-round plan"); + for r in plan.rounds() { + let mo = r.mask_oracle().expect("ZK round must own a mask oracle"); + assert!(mo.c_zk().unique_decoding()); + assert!(r.code_switch().source.unique_decoding()); + assert!(r.code_switch().out_domain_samples >= 1); + } + assert!(plan.basecase().commit.unique_decoding()); + } + + #[test] + fn derive_multi_round_capacity_decoding_succeeds() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Capacity, + ..test_spec(Mode::Standard) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "expected multi-round plan"); + for r in plan.rounds() { + assert!(r.code_switch().out_domain_samples >= 1); + } + } + + #[test] + fn derive_multi_round_capacity_decoding_zk_succeeds() { + let spec = SecuritySpec { + decoding_regime: DecodingRegime::Capacity, + ..test_spec(Mode::ZeroKnowledge) + }; + let plan = ProtocolConfig::::derive( + spec, + tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND), + ) + .unwrap(); + assert!(!plan.rounds().is_empty(), "expected multi-round plan"); + for r in plan.rounds() { + r.mask_oracle().expect("ZK round must own a mask oracle"); + assert!(r.code_switch().out_domain_samples >= 1); + } + } + + fn assert_plan_meets_target_per_slot( + spec: &SecuritySpec, + plan: &ProtocolConfig, + ) { + for r in plan.rounds() { + let mask_info = r.mask_oracle_info(); + let cs = r.code_switch(); + assert_pow_closes_gap( + spec, + sumcheck_params::analytic_error_bits(&cs.source, mask_info), + &r.sumcheck().round_pow, + ); + assert_pow_closes_gap( + spec, + code_switch_params::analytic_error_bits( + &cs.source, + &cs.target, + cs.out_domain_samples, + mask_info, + ), + &cs.pow, + ); + if let Some(mo) = r.mask_oracle() { + let mp = mo.mask_proximity(); + assert_pow_closes_gap( + spec, + mask_proximity_params::analytic_error_bits(&mp.c_zk_commit, mp.num_masks), + &mp.pow, + ); + } + } + assert_pow_closes_gap( + spec, + sumcheck_params::analytic_error_bits(&plan.basecase().commit, None), + &plan.basecase().sumcheck.round_pow, + ); + if matches!( + plan.basecase().mode, + BasecaseMode::ZeroKnowledge + ) { + assert_pow_closes_gap( + spec, + basecase_params::analytic_error_bits(&plan.basecase().commit), + &plan.basecase().pow, + ); + } + } + + proptest! { + #[test] + fn derived_plan_meets_target_per_slot_standard(tuning in arb_tuning()) { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive(spec.clone(), tuning).unwrap(); + assert_plan_meets_target_per_slot(&spec, &plan); + } + + #[test] + fn derived_plan_meets_target_per_slot_zk(tuning in arb_tuning()) { + let log_threshold = + tuning.folding_factor.at_round(0) + tuning.folding_factor.at_round(1); + prop_assume!(tuning.vector_size.trailing_zeros() as usize >= log_threshold); + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive(spec.clone(), tuning).unwrap(); + assert_plan_meets_target_per_slot(&spec, &plan); + } + + #[test] + fn derive_standard_succeeds_over_tunings(tuning in arb_tuning()) { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); + for r in plan.rounds() { + prop_assert!(matches!(r.mode(), RoundMode::Standard)); + prop_assert!(r.mask_oracle().is_none()); + } + prop_assert!(matches!( + plan.basecase().mode, + BasecaseMode::Standard + )); + prop_assert_eq!(plan.basecase().commit.interleaving_depth, 1); + } + + #[test] + fn derive_zk_succeeds_over_tunings(tuning in arb_tuning()) { + let log_threshold = + tuning.folding_factor.at_round(0) + tuning.folding_factor.at_round(1); + prop_assume!(tuning.vector_size.trailing_zeros() as usize >= log_threshold); + + let spec = test_spec(Mode::ZeroKnowledge); + let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); + for r in plan.rounds() { + let mask_oracle = r + .mask_oracle() + .expect("ZK round must have a mask oracle"); + let RoundMode::ZeroKnowledge { t_ood, .. } = r.mode() else { + panic!("expected ZK round"); + }; + let cs = r.code_switch(); + let k = cs.source.interleaving_depth.trailing_zeros() as usize; + let num_masks = k + 1; + prop_assert_eq!(mask_oracle.c_zk().num_vectors, 2 * num_masks); + prop_assert_eq!(mask_oracle.mask_proximity().num_masks, num_masks); + let source_mask = cs.source.mask_length(); + prop_assert!(mask_oracle.l_zk().get() >= source_mask + t_ood.get()); + } + prop_assert!(matches!( + plan.basecase().mode, + BasecaseMode::ZeroKnowledge + )); + } + + #[test] + fn analytic_bits_finite_and_non_negative_standard(tuning in arb_tuning()) { + let spec = test_spec(Mode::Standard); + let plan = ProtocolConfig::::derive(spec, tuning).unwrap(); + let analytic = f64::from(plan.analytic_bits()); + prop_assert!(analytic.is_finite()); + prop_assert!(analytic >= 0.0); + } + } +} diff --git a/src/protocols/params/error.rs b/src/protocols/params/error.rs new file mode 100644 index 00000000..1e54af19 --- /dev/null +++ b/src/protocols/params/error.rs @@ -0,0 +1,181 @@ +//! Errors raised by [`super::derive::ProtocolConfig::derive`] and the +//! sub-protocol solvers. + +use std::fmt::{self, Display, Formatter}; + +use thiserror::Error; + +use crate::{ + bits::Bits, + protocols::{ + params::spec::SecuritySpec, + proof_of_work::{Config as PowConfig, PowError}, + }, +}; + +/// Identifies a single PoW grind in the derived protocol. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Pow { + /// Basecase γ-RLC grind (Lemma 7.4) — ZK mode only. + BasecaseGammaCombination, + /// Basecase sumcheck grind. + BasecaseSumcheck, + /// Per-round sumcheck grind at `index`. + RoundSumcheck { index: usize }, + /// Per-round code-switch grind at `index`. + RoundCodeSwitch { index: usize }, + /// Per-round mask-proximity grind at `index` — ZK mode only. + RoundMaskProximity { index: usize }, +} + +impl Display for Pow { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::BasecaseGammaCombination => f.write_str("basecase γ-combination"), + Self::BasecaseSumcheck => f.write_str("basecase sumcheck"), + Self::RoundSumcheck { index } => write!(f, "round {index} sumcheck"), + Self::RoundCodeSwitch { index } => write!(f, "round {index} code-switch"), + Self::RoundMaskProximity { index } => write!(f, "round {index} mask-proximity"), + } + } +} + +/// Origin side of a [`DeriveError::RoundChainBroken`]: either a numbered round +/// or the pre-round `tuning` shape (for plans with no rounds at all). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChainSource { + Tuning, + Round(usize), +} + +impl Display for ChainSource { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Tuning => f.write_str("tuning"), + Self::Round(i) => write!(f, "round {i}"), + } + } +} + +/// Destination side of a [`DeriveError::RoundChainBroken`]: the next round in +/// sequence, or the basecase. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChainTarget { + NextRound(usize), + Basecase, +} + +impl Display for ChainTarget { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::NextRound(i) => write!(f, "round {i}"), + Self::Basecase => f.write_str("basecase"), + } + } +} + +/// Failure modes for [`super::derive::ProtocolConfig::derive`] and the +/// sub-protocol solvers it calls. +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum DeriveError { + /// The `t_ood` fixed-point in [`super::build_round::solve_t_ood`] ran out of + /// iterations. + #[error("t_ood fixed-point did not converge for round {round_index}")] + FixedPointDidNotConverge { round_index: usize }, + + /// A PoW grind cannot close the analytic-to-target gap — the spec is too + /// tight for any single grind to reach `target_security_bits`. + #[error("{pow} cannot be ground: {source}")] + PowUngrindable { + pow: Pow, + #[source] + source: PowError, + }, + + /// A PoW grind fits the grind cap but exceeds the per-slot budget set by + /// [`super::spec::SecuritySpec::pow_budget`]. + #[error("{pow} requires {required} bits, exceeds spec.pow_budget = {max}")] + PowBudgetExceeded { pow: Pow, required: Bits, max: Bits }, + + /// Computed codeword length exceeds the NTT engine's supported order. + #[error("codeword length {length} exceeds the NTT engine's supported order")] + CodewordExceedsNtt { length: usize }, + + /// Cross-round (or round → basecase) shape chain broken: the next + /// component's source `vector_size` does not match the previous + /// component's target `vector_size`. Surfaced by + /// [`super::protocol_config::ProtocolConfig::validate_round_chaining`]. + #[error("chain broken: {from} → {to} expected vector_size {expected}, found {found}")] + RoundChainBroken { + from: ChainSource, + to: ChainTarget, + expected: usize, + found: usize, + }, + + /// A PoW slot's `analytic + pow.difficulty()` is below `target_security_bits`. + /// `grind_to_at` guarantees this at construction; this error fires only if + /// the analytic-error formulas applied at validate-time disagree with what + /// the per-protocol `solve` functions consumed (e.g., a planner regression + /// drifts away from the actual configured IRS rate). Catches the case where + /// the rate-schedule plumbing under-reports a per-slot rate. + #[error("{pow} soundness gap: analytic {analytic} + pow {pow_bits} < target {target}")] + SecurityTargetNotMet { + pow: Pow, + analytic: Bits, + pow_bits: Bits, + target: Bits, + }, + + /// The analytic floor recorded at solve time disagrees with a fresh + /// recompute from the same config's state. Indicates that the inputs to + /// the `analytic_error_bits` formula drifted between solve and validate + /// (e.g., an IRS field was overwritten after construction). + #[error("{pow} analytic drift: recorded {recorded} vs recompute {recompute}")] + AnalyticDrift { + pow: Pow, + recorded: Bits, + recompute: Bits, + }, + + /// `tuning.vector_size` must be a power of 2. + #[error("tuning.vector_size ({vector_size}) must be a power of 2")] + TuningVectorSizeNotPowerOfTwo { vector_size: usize }, + + /// `tuning.folding_factor` must yield at least 1 at every round. + #[error("tuning.folding_factor min ({min}) must be ≥ 1")] + TuningFoldingFactorBelowOne { min: usize }, +} + +/// Lift `Result` into `Result` by attaching a +/// [`Pow`] label. +pub trait PowResultExt { + fn at(self, pow: Pow) -> Result; +} + +impl PowResultExt for Result { + fn at(self, pow: Pow) -> Result { + self.map_err(|source| DeriveError::PowUngrindable { pow, source }) + } +} + +/// Grind `analytic → spec.target_security_bits`, then check the result against +/// `spec.pow_budget`. +pub fn grind_to_at( + spec: &SecuritySpec, + analytic: Bits, + pow_kind: Pow, +) -> Result { + let target = Bits::new(f64::from(spec.target_security_bits)); + let pow = PowConfig::grind_to(target, analytic, spec.hash_id).at(pow_kind)?; + let required = pow.difficulty(); + let max = Bits::new(f64::from(spec.pow_budget.bits())); + if required > max { + return Err(DeriveError::PowBudgetExceeded { + pow: pow_kind, + required, + max, + }); + } + Ok(pow) +} diff --git a/src/protocols/params/irs_commit.rs b/src/protocols/params/irs_commit.rs new file mode 100644 index 00000000..9e0c8ea6 --- /dev/null +++ b/src/protocols/params/irs_commit.rs @@ -0,0 +1,201 @@ +//! IRS-commit parameter selection. +//! +//! ZK mask sizing follows Construction 9.7 / Theorem 9.6: +//! `mask_length = in_domain + t_ood` (Proposition 3.19). + +use crate::{ + algebra::embedding::Embedding, + protocols::{ + irs_commit::{num_in_domain_queries, Config as IrsConfig, IrsMode}, + params::{ + bounds::rate, + spec::{ + LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, RoundContext, SecuritySpec, + ZkSpec, + }, + }, + }, +}; + +pub fn solve( + spec: &SecuritySpec, + ctx: &RoundContext, + out_domain_samples: OodSampleBudget, +) -> IrsConfig { + let security_target = f64::from(spec.protocol_security_target_bits()); + let rate = rate(f64::from(ctx.log_inv_rate)); + let interleaving_depth = 1_usize << ctx.folding_factor; + + let mode = match spec.mode { + Mode::Standard => IrsMode::Standard, + Mode::ZeroKnowledge => { + let mask_length = num_in_domain_queries(spec.decoding_regime, security_target, rate) + .saturating_add(out_domain_samples.get()); + IrsMode::ZeroKnowledge { mask_length } + } + }; + + IrsConfig::new( + security_target, + spec.decoding_regime, + spec.hash_id, + 1, + ctx.vector_size, + interleaving_depth, + rate, + mode, + ) +} + +/// Shared C_zk IRS config for mask polynomials. +/// +/// - `l_zk`: message length, must be a power of 2. +/// - `source_mask_length`: `r` from Theorem 9.6. +/// - `num_vectors`: `2 * num_masks` (Construction 7.2: originals + fresh). +pub fn solve_mask_code( + spec: ZkSpec<'_>, + l_zk: MaskCodeMessageLen, + source_mask_length: usize, + log_inv_rate: LogInvRate, + num_vectors: usize, +) -> IrsConfig { + let l_zk = l_zk.get(); + assert!( + l_zk >= source_mask_length, + "Theorem 9.6: ℓ_zk ({l_zk}) ≥ source mask length ({source_mask_length})", + ); + assert!(l_zk.is_power_of_two(), "ℓ_zk ({l_zk}) must be a power of 2"); + assert!( + num_vectors.is_multiple_of(2), + "num_vectors ({num_vectors}) must be even (mask-proximity original/fresh pairs)", + ); + + let security_target = f64::from(spec.protocol_security_target_bits()); + let rate = rate(f64::from(log_inv_rate.get())); + + IrsConfig::new( + security_target, + spec.decoding_regime, + spec.hash_id, + num_vectors, + l_zk, + 1, + rate, + IrsMode::Standard, + ) +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::protocols::params::test_utils::{ + arb_round_ctx, arb_spec, arb_zk_spec, deterministic_spec, TestEmbedding, + TestNonIdentityEmbedding, + }; + + type M = TestEmbedding; + + #[test] + fn zk_spec_rejects_standard_mode() { + let spec: SecuritySpec = deterministic_spec(Mode::Standard); + assert!(ZkSpec::try_new(&spec).is_none()); + } + + #[test] + #[should_panic(expected = "must be a power of 2")] + fn solve_mask_code_rejects_non_pow2_l_zk() { + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let zk_spec = ZkSpec::try_new(&spec).unwrap(); + let _ = solve_mask_code::( + zk_spec, + MaskCodeMessageLen::new(3), + 0, + LogInvRate::new(1), + 2, + ); + } + + #[test] + #[should_panic(expected = "Theorem 9.6")] + fn solve_mask_code_rejects_l_zk_below_source_mask_length() { + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let zk_spec = ZkSpec::try_new(&spec).unwrap(); + let _ = solve_mask_code::( + zk_spec, + MaskCodeMessageLen::new(2), + 4, + LogInvRate::new(1), + 2, + ); + } + + #[test] + #[should_panic(expected = "must be even")] + fn solve_mask_code_rejects_odd_num_vectors() { + let spec: SecuritySpec = deterministic_spec(Mode::ZeroKnowledge); + let zk_spec = ZkSpec::try_new(&spec).unwrap(); + let _ = solve_mask_code::( + zk_spec, + MaskCodeMessageLen::new(2), + 0, + LogInvRate::new(1), + 3, + ); + } + + const IRS_TARGET_RANGE: std::ops::RangeInclusive = 80..=128; + + fn arb_zk_spec_default() -> impl Strategy { + arb_zk_spec(IRS_TARGET_RANGE) + } + + fn arb_standard_spec() -> impl Strategy { + arb_spec(Mode::Standard, IRS_TARGET_RANGE) + } + + proptest! { + #[test] + fn zk_mask_covers_in_domain_plus_ood( + spec in arb_zk_spec_default(), + ctx in arb_round_ctx(), + out_domain in 0usize..16, + ) { + let config = solve::(&spec, &ctx, OodSampleBudget::new(out_domain)); + prop_assert!( + config.mask_length() >= config.in_domain_samples + out_domain, + "mask {} < in_domain {} + out_domain {}", + config.mask_length(), config.in_domain_samples, out_domain, + ); + } + + #[test] + fn standard_has_no_mask( + spec in arb_standard_spec(), + ctx in arb_round_ctx(), + out_domain in 0usize..8, + ) { + let config = solve::(&spec, &ctx, OodSampleBudget::new(out_domain)); + prop_assert_eq!(config.mask_length(), 0); + } + } + + const SMOKE_VECTOR_SIZE: usize = 64; + const SMOKE_LOG_INV_RATE: u32 = 1; + const SMOKE_FOLDING_FACTOR: u32 = 2; + const SMOKE_OOD_BUDGET: usize = 2; + + #[test] + fn solve_works_with_basefield_embedding_zk() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = RoundContext { + vector_size: SMOKE_VECTOR_SIZE, + log_inv_rate: SMOKE_LOG_INV_RATE, + folding_factor: SMOKE_FOLDING_FACTOR, + }; + let config: IrsConfig = + solve(&spec, &ctx, OodSampleBudget::new(SMOKE_OOD_BUDGET)); + assert!(config.mask_length() > 0); + } +} diff --git a/src/protocols/params/layout.rs b/src/protocols/params/layout.rs new file mode 100644 index 00000000..79df2f0c --- /dev/null +++ b/src/protocols/params/layout.rs @@ -0,0 +1,214 @@ +//! Round-skeleton layout: pure-data walk over the witness shape. +//! +//! Produces per-round shapes (vector size, log_inv_rate, folding factors) +//! independent of [`SecuritySpec`] and IRS solving. Consumed by +//! [`super::build_round`] to instantiate per-round configs and by +//! [`super::derive`] to drive the round/basecase split. + +use crate::{ + algebra::embedding::Embedding, + protocols::{ + irs_commit::Config as IrsConfig, + params::{ + error::DeriveError, + spec::{RoundContext, TuningSpec}, + }, + }, +}; + +#[derive(Debug, Clone, Copy)] +pub(super) struct RoundShape { + pub(super) round_index: usize, + pub(super) source_vector_size: usize, + pub(super) source_log_inv_rate: u32, + pub(super) source_folding_factor: u32, + pub(super) target_folding_factor: u32, +} + +#[derive(Debug)] +pub(super) struct RoundLayout { + pub(super) shapes: Vec, + pub(super) basecase_vector_size: usize, + pub(super) basecase_log_inv_rate: u32, +} + +pub(super) fn round_layout(tuning: &TuningSpec) -> Result { + if !tuning.vector_size.is_power_of_two() { + return Err(DeriveError::TuningVectorSizeNotPowerOfTwo { + vector_size: tuning.vector_size, + }); + } + let min_folding = tuning.folding_factor.min(); + if min_folding < 1 { + return Err(DeriveError::TuningFoldingFactorBelowOne { min: min_folding }); + } + + let mut num_vars = tuning.vector_size.trailing_zeros() as usize; + let mut log_inv_rate = tuning.starting_log_inv_rate; + let mut shapes = Vec::new(); + + loop { + let round = shapes.len(); + let source_folding = tuning.folding_factor.at_round(round); + let target_folding = tuning.folding_factor.at_round(round.saturating_add(1)); + if num_vars < source_folding.saturating_add(target_folding) { + break; + } + shapes.push(RoundShape { + round_index: round, + source_vector_size: 1usize << num_vars, + source_log_inv_rate: log_inv_rate, + source_folding_factor: source_folding as u32, + target_folding_factor: target_folding as u32, + }); + num_vars = num_vars.saturating_sub(source_folding); + log_inv_rate = log_inv_rate.saturating_add((source_folding as u32).saturating_sub(1)); + } + + Ok(RoundLayout { + shapes, + basecase_vector_size: 1usize << num_vars, + basecase_log_inv_rate: log_inv_rate, + }) +} + +pub(super) const fn round_context(shape: &RoundShape) -> RoundContext { + RoundContext { + vector_size: shape.source_vector_size, + log_inv_rate: shape.source_log_inv_rate, + folding_factor: shape.source_folding_factor, + } +} + +pub(super) fn target_context( + shape: &RoundShape, + source: &IrsConfig, +) -> RoundContext { + RoundContext { + vector_size: source.message_length(), + log_inv_rate: shape + .source_log_inv_rate + .saturating_add(shape.source_folding_factor.saturating_sub(1)), + folding_factor: shape.target_folding_factor, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::params::spec::FoldingFactor; + + const FIXTURE_FOLDING_FACTOR: usize = 2; + const FIXTURE_LOG_INV_RATE: u32 = 1; + + const LOG_VECTOR_SIZE_NO_ROUNDS: u32 = 3; + const LOG_VECTOR_SIZE_MULTI_ROUND: u32 = 8; + + const VARIED_INITIAL_FOLDING: usize = 3; + const VARIED_STEADY_FOLDING: usize = 2; + + const RATE_STEPPING_STARTING_LOG_INV_RATE: u32 = 2; + const MIN_ROUNDS_FOR_CHAINING_TEST: usize = 2; + + fn tuning_with(vector_size: usize) -> TuningSpec { + TuningSpec { + vector_size, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + } + } + + #[test] + fn round_layout_rate_steps_up_by_folding_minus_one() { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: RATE_STEPPING_STARTING_LOG_INV_RATE, + folding_factor: FoldingFactor::ConstantFromSecondRound { + initial: VARIED_INITIAL_FOLDING, + rest: VARIED_STEADY_FOLDING, + }, + }; + let layout = round_layout(&tuning).unwrap(); + + let mut expected_log_inv_rate = RATE_STEPPING_STARTING_LOG_INV_RATE; + for shape in &layout.shapes { + assert_eq!(shape.source_log_inv_rate, expected_log_inv_rate); + expected_log_inv_rate += shape.source_folding_factor.saturating_sub(1); + } + assert_eq!(layout.basecase_log_inv_rate, expected_log_inv_rate); + } + + #[test] + fn round_layout_chains_target_to_next_source_folding() { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::ConstantFromSecondRound { + initial: VARIED_INITIAL_FOLDING, + rest: VARIED_STEADY_FOLDING, + }, + }; + let layout = round_layout(&tuning).unwrap(); + assert!( + layout.shapes.len() >= MIN_ROUNDS_FOR_CHAINING_TEST, + "need ≥ {MIN_ROUNDS_FOR_CHAINING_TEST} rounds to test chaining", + ); + for window in layout.shapes.windows(2) { + assert_eq!( + window[0].target_folding_factor, + window[1].source_folding_factor + ); + } + } + + #[test] + fn round_layout_basecase_size_consumes_remaining_num_vars() { + let tuning = tuning_with(1 << LOG_VECTOR_SIZE_MULTI_ROUND); + let layout = round_layout(&tuning).unwrap(); + let consumed: u32 = layout.shapes.iter().map(|s| s.source_folding_factor).sum(); + let initial_num_vars = tuning.vector_size.trailing_zeros(); + let remaining = initial_num_vars - consumed; + assert_eq!(layout.basecase_vector_size, 1usize << remaining); + } + + #[test] + fn round_layout_stops_when_no_room_for_source_plus_target() { + let vector_size = 1usize << LOG_VECTOR_SIZE_NO_ROUNDS; + let tuning = tuning_with(vector_size); + let layout = round_layout(&tuning).unwrap(); + assert!(layout.shapes.is_empty()); + assert_eq!(layout.basecase_vector_size, vector_size); + assert_eq!(layout.basecase_log_inv_rate, FIXTURE_LOG_INV_RATE); + } + + #[test] + fn round_layout_rejects_non_pow2_vector_size() { + let tuning = TuningSpec { + vector_size: 12, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(FIXTURE_FOLDING_FACTOR), + }; + let err = round_layout(&tuning).expect_err("non-pow2 vector_size must fail"); + assert!( + matches!( + err, + DeriveError::TuningVectorSizeNotPowerOfTwo { vector_size: 12 } + ), + "got {err:?}", + ); + } + + #[test] + fn round_layout_rejects_zero_folding_factor() { + let tuning = TuningSpec { + vector_size: 1 << LOG_VECTOR_SIZE_MULTI_ROUND, + starting_log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FoldingFactor::Constant(0), + }; + let err = round_layout(&tuning).expect_err("folding_factor = 0 must fail"); + assert!( + matches!(err, DeriveError::TuningFoldingFactorBelowOne { min: 0 }), + "got {err:?}", + ); + } +} diff --git a/src/protocols/params/mask_proximity.rs b/src/protocols/params/mask_proximity.rs new file mode 100644 index 00000000..4293796e --- /dev/null +++ b/src/protocols/params/mask_proximity.rs @@ -0,0 +1,163 @@ +//! Mask-proximity (Construction 7.2) builder + Lemma 7.4 γ-combination bound. +//! ZK-only. + +use ark_ff::Field; + +use crate::{ + algebra::{embedding::Identity, fields::FieldWithSize}, + bits::Bits, + protocols::{ + irs_commit::Config as IrsConfig, + mask_proximity::Config as MaskProximityConfig, + params::{ + bounds::usize_to_f64, + error::{grind_to_at, DeriveError, Pow}, + spec::SecuritySpec, + }, + }, +}; + +/// `c_zk.num_vectors` must equal `2 * num_masks` (originals + fresh). +pub fn solve( + spec: &SecuritySpec, + c_zk: IrsConfig>, + num_masks: usize, + round_index: usize, +) -> Result, DeriveError> { + let analytic = analytic_error_bits(&c_zk, num_masks); + let pow = grind_to_at( + spec, + analytic, + Pow::RoundMaskProximity { index: round_index }, + )?; + Ok(MaskProximityConfig::new(c_zk, num_masks, pow).with_recorded_analytic(analytic)) +} + +/// γ-combination soundness (Lemma 7.4): +/// `log|F| − log(num_masks · (deg − 1))`, with `deg = c_zk.masked_message_length()`. +pub fn analytic_error_bits(c_zk: &IrsConfig>, num_masks: usize) -> Bits { + let field_bits = F::field_size_bits(); + let deg = c_zk.masked_message_length(); + if deg <= 1 || num_masks == 0 { + return Bits::new(field_bits.max(0.0)); + } + let log_combined = usize_to_f64(num_masks * deg.saturating_sub(1)).log2(); + Bits::new((field_bits - log_combined).max(0.0)) +} + +impl MaskProximityConfig { + /// Analytic soundness bits (excluding PoW) for the Lemma 7.4 γ-combination. + pub fn analytic_bits(&self) -> Bits { + analytic_error_bits(&self.c_zk_commit, self.num_masks) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::{ + algebra::fields::Field64, + hash, + protocols::{ + irs_commit::IrsMode, + params::{ + spec::{DecodingRegime, Mode}, + test_utils::{ + arb_zk_spec, assert_close, assert_pow_closes_gap, build_test_c_zk, + deterministic_spec, TEST_TARGET_RANGE, + }, + }, + }, + }; + + const FIXTURE_L_ZK: usize = 8; + const FIXTURE_NUM_MASKS: usize = 3; + const FIXTURE_LOG_INV_RATE: u32 = 1; + + #[test] + fn analytic_error_formula() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let c_zk = build_test_c_zk(&spec, FIXTURE_L_ZK, FIXTURE_LOG_INV_RATE, FIXTURE_NUM_MASKS); + + let got = f64::from(analytic_error_bits(&c_zk, FIXTURE_NUM_MASKS)); + + let field_bits = ::field_size_bits(); + let deg = c_zk.masked_message_length(); + let log_combined = ((FIXTURE_NUM_MASKS * (deg - 1)) as f64).log2(); + let expected = (field_bits - log_combined).max(0.0); + + assert_close(got, expected); + } + + #[test] + fn analytic_error_saturates_when_no_masks() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let c_zk = build_test_c_zk(&spec, 2, 1, 1); + let bits = f64::from(analytic_error_bits(&c_zk, 0)); + let field_bits = ::field_size_bits(); + assert_close(bits, field_bits.max(0.0)); + } + + proptest! { + #[test] + fn solve_assembles( + spec in arb_zk_spec(TEST_TARGET_RANGE), + log_inv_rate in 1u32..=3, + num_masks in 1usize..=8, + l_zk_log in 1u32..=5, + ) { + let c_zk = build_test_c_zk(&spec, 1usize << l_zk_log, log_inv_rate, num_masks); + let config = solve(&spec, c_zk, num_masks, 0).unwrap(); + prop_assert_eq!(config.num_masks, num_masks); + prop_assert_eq!(config.c_zk_commit.num_vectors, 2 * num_masks); + prop_assert_eq!(config.c_zk_commit.interleaving_depth, 1); + } + + #[test] + fn pow_closes_gap_to_target( + spec in arb_zk_spec(TEST_TARGET_RANGE), + log_inv_rate in 1u32..=3, + num_masks in 1usize..=8, + l_zk_log in 1u32..=5, + ) { + let c_zk = build_test_c_zk(&spec, 1usize << l_zk_log, log_inv_rate, num_masks); + let analytic = analytic_error_bits(&c_zk, num_masks); + let config = solve(&spec, c_zk, num_masks, 0).unwrap(); + assert_pow_closes_gap(&spec, analytic, &config.pow); + } + } + + #[test] + #[should_panic(expected = "c_zk.num_vectors must be 2 * num_masks")] + fn solve_rejects_mismatched_num_vectors() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let c_zk = build_test_c_zk(&spec, 2, 1, 2); + let _ = solve(&spec, c_zk, 3, 0); + } + + #[test] + #[should_panic(expected = "interleaving_depth = 1")] + fn solve_rejects_non_unit_interleaving() { + const SECURITY_TARGET_BITS: f64 = 80.0; + const NUM_VECTORS: usize = 2; + const VECTOR_SIZE: usize = 8; + const NON_UNIT_INTERLEAVING_DEPTH: usize = 2; + const RATE: f64 = 0.5; + const NUM_MASKS: usize = 1; + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let c_zk = IrsConfig::>::new( + SECURITY_TARGET_BITS, + DecodingRegime::Johnson, + hash::BLAKE3, + NUM_VECTORS, + VECTOR_SIZE, + NON_UNIT_INTERLEAVING_DEPTH, + RATE, + IrsMode::Standard, + ); + let _ = solve(&spec, c_zk, NUM_MASKS, 0); + } +} diff --git a/src/protocols/params/mod.rs b/src/protocols/params/mod.rs new file mode 100644 index 00000000..6f7ddf7a --- /dev/null +++ b/src/protocols/params/mod.rs @@ -0,0 +1,33 @@ +//! Parameter selection for HVZK-WHIR. +//! +//! Soundness and ZK bound derivations (referred to in submodule comments as +//! "the bounds doc, §N") live at +//! . + +pub(crate) mod basecase; +pub(crate) mod bounds; +pub(crate) mod branch; +pub(crate) mod build_round; +pub(crate) mod code_switch; +pub(crate) mod derive; +pub(crate) mod error; +pub(crate) mod irs_commit; +pub(crate) mod layout; +pub(crate) mod mask_proximity; +pub(crate) mod protocol_config; +pub(crate) mod regime; +pub(crate) mod spec; +pub(crate) mod sumcheck; + +#[cfg(test)] +pub(crate) mod test_utils; + +pub use branch::{Branch, SolveMode}; +pub use error::{ChainSource, ChainTarget, DeriveError, Pow}; +pub use protocol_config::{ + MaskOracleConfig, MaskOracleInfo, ProtocolConfig, RoundConfig, RoundMode, +}; +pub use spec::{ + DecodingRegime, FoldingFactor, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + PowBudget, RoundContext, SecuritySpec, TuningSpec, ZkSpec, +}; diff --git a/src/protocols/params/protocol_config.rs b/src/protocols/params/protocol_config.rs new file mode 100644 index 00000000..41c8fb4b --- /dev/null +++ b/src/protocols/params/protocol_config.rs @@ -0,0 +1,472 @@ +//! Output of [`super::derive`]: the assembled per-round and basecase configs. + +use ark_ff::Field; + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + fields::FieldWithSize, + }, + bits::Bits, + protocols::{ + basecase::Config as BasecaseConfig, + code_switch::Config as CodeSwitchConfig, + irs_commit::Config as IrsConfig, + mask_proximity::Config as MaskProximityConfig, + params::{ + basecase as basecase_params, + bounds::usize_to_f64, + code_switch as code_switch_params, + error::{ChainSource, ChainTarget, DeriveError, Pow}, + mask_proximity as mask_proximity_params, + spec::{ListSize, MaskCodeMessageLen, OodSampleBudget, SecuritySpec, TuningSpec}, + sumcheck as sumcheck_params, + }, + proof_of_work::Config as PowConfig, + sumcheck::Config as SumcheckConfig, + }, +}; + +#[derive(Clone, Debug)] +pub struct ProtocolConfig { + security: SecuritySpec, + tuning: TuningSpec, + rounds: Vec>, + basecase: BasecaseConfig, +} + +impl ProtocolConfig { + pub(crate) const fn new( + security: SecuritySpec, + tuning: TuningSpec, + rounds: Vec>, + basecase: BasecaseConfig, + ) -> Self { + Self { + security, + tuning, + rounds, + basecase, + } + } + + pub const fn security(&self) -> &SecuritySpec { + &self.security + } + + pub const fn tuning(&self) -> &TuningSpec { + &self.tuning + } + + pub fn rounds(&self) -> &[RoundConfig] { + &self.rounds + } + + pub const fn basecase(&self) -> &BasecaseConfig { + &self.basecase + } + + /// `true` if every PoW slot's difficulty fits within `security.pow_budget`. + pub fn check_pow_bits(&self) -> bool { + self.validate_pow_budget().is_ok() + } + + /// Returns `true` if every post-construction invariant holds. + pub fn check_all_invariants(&self) -> bool { + self.validate().is_ok() + } + + /// Run every post-construction invariant check. + pub fn validate(&self) -> Result<(), DeriveError> { + self.validate_pow_budget()?; + self.validate_round_chaining()?; + self.validate_security_target_met()?; + Ok(()) + } + + /// For each PoW slot: verify (a) the analytic-bits floor recorded at + /// solve time still matches a fresh recompute from the config's current + /// state, and (b) `recorded_analytic + pow.difficulty() ≥ target_security_bits`. + /// + /// `grind_to_at` guarantees (b) at solve time. If (a) holds, (b) holds + /// trivially. If (a) drifts, (b) may fail — most often because a planner + /// regression overwrote an IRS field after the solver consumed it. + /// + /// `EPS` matches the `assert_pow_closes_gap` slack used by the per-slot + /// proptest helper, so validation stays consistent with test-time + /// assertions. + pub fn validate_security_target_met(&self) -> Result<(), DeriveError> { + const EPS: f64 = 1e-3; + let target = Bits::new(f64::from(self.security.target_security_bits)); + let check = |pow_kind: Pow, + recorded: Option, + recompute: Bits, + pow_cfg: &PowConfig| + -> Result<(), DeriveError> { + if let Some(recorded) = recorded { + if (f64::from(recorded) - f64::from(recompute)).abs() > EPS { + return Err(DeriveError::AnalyticDrift { + pow: pow_kind, + recorded, + recompute, + }); + } + } + let analytic = recorded.unwrap_or(recompute); + let pow_bits = pow_cfg.difficulty(); + let sum = f64::from(analytic) + f64::from(pow_bits); + if sum + EPS < f64::from(target) { + return Err(DeriveError::SecurityTargetNotMet { + pow: pow_kind, + analytic, + pow_bits, + target, + }); + } + Ok(()) + }; + for r in &self.rounds { + let mask_info = r.mask_oracle_info(); + check( + Pow::RoundSumcheck { + index: r.round_index, + }, + r.sumcheck.recorded_analytic, + sumcheck_params::analytic_error_bits(&r.code_switch.source, mask_info), + &r.sumcheck.round_pow, + )?; + check( + Pow::RoundCodeSwitch { + index: r.round_index, + }, + r.code_switch.recorded_analytic, + code_switch_params::analytic_error_bits( + &r.code_switch.source, + &r.code_switch.target, + r.code_switch.out_domain_samples, + mask_info, + ), + &r.code_switch.pow, + )?; + if let Some(mo) = r.mask_oracle() { + check( + Pow::RoundMaskProximity { + index: r.round_index, + }, + mo.mask_proximity.recorded_analytic, + mask_proximity_params::analytic_error_bits( + &mo.mask_proximity.c_zk_commit, + mo.mask_proximity.num_masks, + ), + &mo.mask_proximity.pow, + )?; + } + } + check( + Pow::BasecaseSumcheck, + self.basecase.sumcheck.recorded_analytic, + sumcheck_params::analytic_error_bits(&self.basecase.commit, None), + &self.basecase.sumcheck.round_pow, + )?; + if self.basecase.is_zk() { + check( + Pow::BasecaseGammaCombination, + self.basecase.recorded_analytic, + basecase_params::analytic_error_bits(&self.basecase.commit), + &self.basecase.pow, + )?; + } + Ok(()) + } + + /// PoW slot difficulty ≤ `security.pow_budget` for every slot. + pub fn validate_pow_budget(&self) -> Result<(), DeriveError> { + let max = Bits::new(f64::from(self.security.pow_budget.bits())); + let check = |pow: Pow, cfg: &PowConfig| -> Result<(), DeriveError> { + let required = cfg.difficulty(); + if required > max { + Err(DeriveError::PowBudgetExceeded { pow, required, max }) + } else { + Ok(()) + } + }; + for r in &self.rounds { + check( + Pow::RoundSumcheck { + index: r.round_index, + }, + &r.sumcheck.round_pow, + )?; + check( + Pow::RoundCodeSwitch { + index: r.round_index, + }, + &r.code_switch.pow, + )?; + if let Some(mo) = r.mask_oracle() { + check( + Pow::RoundMaskProximity { + index: r.round_index, + }, + &mo.mask_proximity.pow, + )?; + } + } + check(Pow::BasecaseSumcheck, &self.basecase.sumcheck.round_pow)?; + check(Pow::BasecaseGammaCombination, &self.basecase.pow)?; + Ok(()) + } + + /// Cross-round shape chaining: + /// - adjacent rounds: `round[i+1].source.vector_size == round[i].target.vector_size` + /// - last round → basecase: `basecase.commit.vector_size == last.target.vector_size` + /// - no rounds: `basecase.commit.vector_size == tuning.vector_size` + pub fn validate_round_chaining(&self) -> Result<(), DeriveError> { + for window in self.rounds.windows(2) { + let prev = &window[0]; + let next = &window[1]; + let expected = prev.code_switch.target.vector_size; + let found = next.code_switch.source.vector_size; + if expected != found { + return Err(DeriveError::RoundChainBroken { + from: ChainSource::Round(prev.round_index), + to: ChainTarget::NextRound(next.round_index), + expected, + found, + }); + } + } + + let basecase_vector_size = self.basecase.commit.vector_size; + let expected = self.rounds.last().map_or(self.tuning.vector_size, |last| { + last.code_switch.target.vector_size + }); + if expected != basecase_vector_size { + let from = self + .rounds + .last() + .map_or(ChainSource::Tuning, |r| ChainSource::Round(r.round_index)); + return Err(DeriveError::RoundChainBroken { + from, + to: ChainTarget::Basecase, + expected, + found: basecase_vector_size, + }); + } + + Ok(()) + } + + /// HVZK privacy error in bits, summed across ZK rounds: + /// `−log Σ_r (t_ood_r² + t_ood_r) / (2|F|)` (bounds doc, §5.3 + §5.7). + pub fn privacy_error_bits(&self) -> Bits { + let field_bits = ::field_size_bits(); + let mut total_error = 0.0_f64; + for r in &self.rounds { + if let RoundMode::ZeroKnowledge { t_ood, .. } = &r.mode { + let t = usize_to_f64(t_ood.get()); + let log_err = f64::midpoint(t * t, t).log2() - field_bits; + total_error += 2_f64.powf(log_err); + } + } + if total_error == 0.0 { + return Bits::new(f64::from(self.security.target_security_bits)); + } + Bits::new((-total_error.log2()).max(0.0)) + } +} + +impl ProtocolConfig { + /// Analytic soundness bits (excluding PoW). + pub fn analytic_bits(&self) -> Bits { + let mut min_bits = f64::from(self.basecase.analytic_bits()); + for round in &self.rounds { + min_bits = min_bits.min(f64::from(round.analytic_bits())); + } + Bits::new(min_bits.max(0.0)) + } +} + +#[cfg(test)] +impl ProtocolConfig { + pub(crate) const fn override_basecase_pow_for_test(&mut self, pow: PowConfig) { + self.basecase.pow = pow; + } + + pub(crate) fn truncate_rounds_for_test(&mut self, len: usize) { + self.rounds.truncate(len); + } + + pub(crate) fn corrupt_round_target_vector_size_for_test( + &mut self, + round_idx: usize, + new_size: usize, + ) { + self.rounds[round_idx].code_switch.target.vector_size = new_size; + } + + pub(crate) fn corrupt_round_sumcheck_recorded_analytic_for_test( + &mut self, + round_idx: usize, + new_value: Bits, + ) { + self.rounds[round_idx].sumcheck.recorded_analytic = Some(new_value); + } +} + +#[derive(Clone, Debug)] +pub struct RoundConfig { + round_index: usize, + sumcheck: SumcheckConfig, + code_switch: CodeSwitchConfig, + mode: RoundMode, + /// `Some` iff `mode.is_zk()`. Sized for this round's `k + 1` masks. + mask_oracle: Option>, +} + +impl RoundConfig { + pub(crate) const fn new( + round_index: usize, + sumcheck: SumcheckConfig, + code_switch: CodeSwitchConfig, + mode: RoundMode, + mask_oracle: Option>, + ) -> Self { + Self { + round_index, + sumcheck, + code_switch, + mode, + mask_oracle, + } + } + + pub const fn round_index(&self) -> usize { + self.round_index + } + + pub const fn sumcheck(&self) -> &SumcheckConfig { + &self.sumcheck + } + + pub const fn code_switch(&self) -> &CodeSwitchConfig { + &self.code_switch + } + + pub const fn mode(&self) -> &RoundMode { + &self.mode + } + + /// Borrow the round's mask oracle if this is a ZK round. + pub const fn mask_oracle(&self) -> Option<&MaskOracleConfig> { + self.mask_oracle.as_ref() + } + + /// Slim mask-oracle view derived from `mask_oracle()`. + pub fn mask_oracle_info(&self) -> Option { + self.mask_oracle().map(MaskOracleConfig::info) + } +} + +/// Standard vs. ZK round. +/// +/// Non-generic — the per-round `MaskOracleConfig` lives on +/// [`RoundConfig`] as a sibling field. +#[derive(Clone, Copy, Debug)] +pub enum RoundMode { + Standard, + ZeroKnowledge { + /// Lemma 9.9 OOD-sample budget (bounds doc §5.2). + t_ood: OodSampleBudget, + }, +} + +impl RoundMode { + pub const fn is_zk(&self) -> bool { + matches!(self, Self::ZeroKnowledge { .. }) + } +} + +impl RoundConfig { + /// Round-level analytic floor: the smallest of `sumcheck`, `code_switch`, + /// and (when present) the per-round mask-oracle proximity check. + pub fn analytic_bits(&self) -> Bits { + let source = &self.code_switch.source; + let target = &self.code_switch.target; + let mask_info = self.mask_oracle_info(); + + let sumcheck_term = f64::from(sumcheck_params::analytic_error_bits(source, mask_info)); + let code_switch_term = f64::from(code_switch_params::analytic_error_bits( + source, + target, + self.code_switch.out_domain_samples, + mask_info, + )); + let mask_oracle_term = self + .mask_oracle() + .map_or(f64::INFINITY, |mo| f64::from(mo.analytic_bits())); + + Bits::new( + sumcheck_term + .min(code_switch_term) + .min(mask_oracle_term) + .max(0.0), + ) + } +} + +/// One round's mask oracle: a C_zk codeword + ℓ_zk + mask-proximity check. +#[derive(Clone, Debug)] +pub struct MaskOracleConfig { + c_zk: IrsConfig>, + /// `next_pow2(r + t_ood)` (Theorem 9.6 + Lemma 9.3). + l_zk: MaskCodeMessageLen, + mask_proximity: MaskProximityConfig, +} + +impl MaskOracleConfig { + pub(crate) const fn new( + c_zk: IrsConfig>, + l_zk: MaskCodeMessageLen, + mask_proximity: MaskProximityConfig, + ) -> Self { + Self { + c_zk, + l_zk, + mask_proximity, + } + } + + pub const fn c_zk(&self) -> &IrsConfig> { + &self.c_zk + } + + pub const fn l_zk(&self) -> MaskCodeMessageLen { + self.l_zk + } + + pub const fn mask_proximity(&self) -> &MaskProximityConfig { + &self.mask_proximity + } + + pub fn info(&self) -> MaskOracleInfo { + MaskOracleInfo { + c_zk_list_size: ListSize::new(self.c_zk.list_size()), + l_zk: self.l_zk, + } + } +} + +/// Slim mask-oracle view (C_zk's list size + ℓ_zk). +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct MaskOracleInfo { + pub c_zk_list_size: ListSize, + pub l_zk: MaskCodeMessageLen, +} + +impl MaskOracleConfig { + /// Analytic soundness bits (excluding PoW) for this round's mask oracle. + pub fn analytic_bits(&self) -> Bits { + self.mask_proximity.analytic_bits() + } +} diff --git a/src/protocols/params/regime.rs b/src/protocols/params/regime.rs new file mode 100644 index 00000000..478a71bd --- /dev/null +++ b/src/protocols/params/regime.rs @@ -0,0 +1,358 @@ +//! Reed–Solomon decoding regime — materialized per-round parameters and the +//! analytic helpers that depend on them. +//! +//! # References +//! +//! - Johnson proximity-gap error follows the BCSS25 improvement +//! (`O(n/η^5)`, m=10 at canonical slack) over BCIKS '20. +//! - Capacity bound follows STIR Conjecture 5.6: `(1 − ρ − η, d/(ρ·η))`-list +//! decodability for RS codes. + +use std::f64::consts::LOG2_10; + +use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; + +use crate::protocols::params::{ + bounds::{rate, usize_to_f64}, + spec::DecodingRegime, +}; + +/// Materialized decoding-regime parameters at a known rate. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum DecodingRegimeParams { + Unique, + Johnson { slack: OrderedFloat }, + Capacity { slack: OrderedFloat }, +} + +impl DecodingRegimeParams { + /// Materialize spec policy at a known rate. Canonical slacks: `√ρ/20` for + /// Johnson, `ρ/20` for Capacity. + // TODO: Optimize picking η. + pub fn from_policy(policy: DecodingRegime, rate: f64) -> Self { + match policy { + DecodingRegime::Unique => Self::Unique, + DecodingRegime::Johnson => Self::johnson_canonical(rate), + DecodingRegime::Capacity => Self::capacity_canonical(rate), + } + } + + /// Johnson regime with the canonical `η = √ρ / 20` slack. + pub fn johnson_canonical(rate: f64) -> Self { + Self::Johnson { + slack: OrderedFloat(rate.sqrt() / 20.0), + } + } + + /// Capacity regime with the canonical `η = ρ / 20` slack. + pub fn capacity_canonical(rate: f64) -> Self { + Self::Capacity { + slack: OrderedFloat(rate / 20.0), + } + } + + pub const fn is_unique(self) -> bool { + matches!(self, Self::Unique) + } + + /// `log₂ |Λ(C, δ)|`. + pub fn list_size_log2(self, log_degree: f64, log_inv_rate: f64) -> f64 { + match self { + Self::Unique => 0.0, + // Johnson: |Λ| = 1 / (2 η √ρ). + Self::Johnson { slack } => -1.0 - slack.into_inner().log2() + 0.5 * log_inv_rate, + // Capacity (STIR Conj 5.6): |Λ| = d / (ρ · η). + Self::Capacity { slack } => log_degree + log_inv_rate - slack.into_inner().log2(), + } + } + + /// `|Λ(C, δ)|`. + pub fn list_size(self, log_degree: f64, log_inv_rate: f64) -> f64 { + 2_f64.powf(self.list_size_log2(log_degree, log_inv_rate)) + } + + /// `log₂(1 − δ)`. + pub fn one_minus_distance_log2(self, log_inv_rate: f64) -> f64 { + let one_minus_delta = match self { + Self::Unique => f64::midpoint(1.0, rate(log_inv_rate)), + Self::Johnson { slack } => rate(log_inv_rate).sqrt() + slack.into_inner(), + Self::Capacity { slack } => rate(log_inv_rate) + slack.into_inner(), + }; + one_minus_delta.log2() + } + + /// Bits of security delivered by `ood_samples` OOD challenges (STIR Lemma 4.5): + /// `ood · (|F| − log d) − 2·log|Λ| + 1`. Returns `0` under `Unique`. + pub fn ood_security_bits( + self, + log_degree: f64, + log_inv_rate: f64, + field_bits: f64, + ood_samples: usize, + ) -> f64 { + if self.is_unique() { + return 0.0; + } + let log_list = self.list_size_log2(log_degree, log_inv_rate); + let ood = usize_to_f64(ood_samples); + ood * (field_bits - log_degree) - 2.0 * log_list + 1.0 + } + + /// `log₂ ε_mca(C, δ)` for the per-step proximity-gaps error. + /// + /// - Unique: `(k − 1) / |F|`, log = `log k − |F|` (with `+ log ρ⁻¹`). + /// - Johnson: BCSS25 Theorem 1.5 at canonical `η = √ρ/20`, `m = 10`: + /// `ε ≈ (2·10.5⁵/3) · n · ρ^{−3/2} / |F|`. + /// - Capacity: STIR Conj 5.6, `ε ≈ d / (η · ρ²) / |F|`. + pub fn eps_mca_log2(self, log_inv_rate: f64, message_length: usize, field_bits: f64) -> f64 { + let log_k = usize_to_f64(message_length).log2(); + let error = match self { + Self::Unique => log_k + log_inv_rate, + Self::Johnson { slack } => { + debug_assert!( + slack.into_inner().log2() >= -(0.5 * log_inv_rate + LOG2_10 + 1.0) - 1e-6 + ); + // BCSS25 with m = 10: log_2(2·10.5⁵/3) + log n + 1.5·log ρ⁻¹. + let bcss25_const = (2.0 * 10.5_f64.powi(5) / 3.0).log2(); + bcss25_const + log_k + 2.5 * log_inv_rate + } + Self::Capacity { slack } => { + debug_assert!(slack.into_inner().log2() >= -(log_inv_rate + LOG2_10 + 1.0) - 1e-6); + // d / (η · ρ²) at canonical η = ρ/20: log d + log 20 + 3·log ρ⁻¹. + log_k + 3.0 * log_inv_rate + LOG2_10 + 1.0 + } + }; + error - field_bits + } +} + +impl DecodingRegime { + /// `|Λ|` at canonical slack, before an IRS config exists. + pub fn list_size_estimate(self, log_degree: f64, log_inv_rate: f64) -> f64 { + DecodingRegimeParams::from_policy(self, rate(log_inv_rate)) + .list_size(log_degree, log_inv_rate) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::params::test_utils::assert_close; + + const TIGHT_EPS: f64 = 1e-12; + + fn johnson(slack: f64) -> DecodingRegimeParams { + DecodingRegimeParams::Johnson { + slack: OrderedFloat(slack), + } + } + + fn capacity(slack: f64) -> DecodingRegimeParams { + DecodingRegimeParams::Capacity { + slack: OrderedFloat(slack), + } + } + + /// Johnson list size: `|Λ| = 1 / (2η√ρ)`, log₂ form. + #[test] + fn list_size_log2_johnson_formula() { + let got = johnson(0.1).list_size_log2(/* log_degree */ 4.0, 2.0); + let expected = -1.0 - 0.1_f64.log2() + 0.5 * 2.0; + assert_close(got, expected); + } + + /// Capacity list size: `|Λ| = d / (ρ · η)`, log₂ form. + #[test] + fn list_size_log2_capacity_formula() { + let got = capacity(0.125).list_size_log2(4.0, 2.0); + let expected = 4.0 + 2.0 - 0.125_f64.log2(); + assert_close(got, expected); + } + + /// Unique-decoding regime gives `|Λ| = 1`, i.e. log = 0. + #[test] + fn list_size_log2_unique_decoding_is_zero() { + assert_close(DecodingRegimeParams::Unique.list_size_log2(4.0, 2.0), 0.0); + } + + /// `η = √ρ/20` ⇒ `|Λ| = 10/ρ` ⇒ `list_size_estimate(_, b) = 10 · 2^b`. + #[test] + fn johnson_list_size_closed_form() { + for b in [1.0, 2.0, 3.0, 5.0] { + let got = DecodingRegime::Johnson.list_size_estimate(/* log_degree */ 4.0, b); + let expected = 10.0 * 2_f64.powf(b); + assert!( + (got - expected).abs() / expected < TIGHT_EPS, + "log_inv_rate={b}: got {got} vs {expected}", + ); + } + } + + /// `η = ρ/20` ⇒ `|Λ| = 20 · d / ρ²`. + #[test] + fn capacity_list_size_closed_form() { + for (log_d, b) in [(4.0, 1.0), (6.0, 2.0), (8.0, 3.0)] { + let got = DecodingRegime::Capacity.list_size_estimate(log_d, b); + let expected = 20.0 * 2_f64.powf(log_d) * 2_f64.powf(2.0 * b); + assert!( + (got - expected).abs() / expected < TIGHT_EPS, + "log_d={log_d}, log_inv_rate={b}: got {got} vs {expected}", + ); + } + } + + #[test] + fn johnson_list_size_matches_config_list_size() { + use crate::{ + algebra::{embedding::Identity, fields::Field64}, + hash, + protocols::irs_commit::{Config, IrsMode}, + }; + const PLACEHOLDER_SECURITY_TARGET_BITS: f64 = 80.0; + const PLACEHOLDER_NUM_VECTORS: usize = 2; + const PLACEHOLDER_VECTOR_SIZE: usize = 8; + const PLACEHOLDER_INTERLEAVING_DEPTH: usize = 1; + const LOG_INV_RATE: u32 = 2; + + let config: Config> = Config::new( + PLACEHOLDER_SECURITY_TARGET_BITS, + DecodingRegime::Johnson, + hash::BLAKE3, + PLACEHOLDER_NUM_VECTORS, + PLACEHOLDER_VECTOR_SIZE, + PLACEHOLDER_INTERLEAVING_DEPTH, + 2_f64.powf(-f64::from(LOG_INV_RATE)), + IrsMode::Standard, + ); + let log_degree = (config.masked_message_length() as f64).log2(); + let got = DecodingRegime::Johnson.list_size_estimate(log_degree, f64::from(LOG_INV_RATE)); + let expected = config.list_size(); + assert!( + (got - expected).abs() / expected < TIGHT_EPS, + "regime helper ({got}) vs Config::list_size ({expected})", + ); + } + + /// `1 − δ` in unique-decoding mode: midpoint of 1 and ρ. + #[test] + fn one_minus_distance_log2_unique() { + let log_inv_rate = 2.0; + let got = DecodingRegimeParams::Unique.one_minus_distance_log2(log_inv_rate); + let rho = 2_f64.powf(-log_inv_rate); + let expected = f64::midpoint(1.0, rho).log2(); + assert_close(got, expected); + } + + /// `1 − δ` in Johnson regime: `√ρ + η`. + #[test] + fn one_minus_distance_log2_johnson() { + let log_inv_rate = 2.0; + let eta = 0.1; + let got = johnson(eta).one_minus_distance_log2(log_inv_rate); + let rho = 2_f64.powf(-log_inv_rate); + let expected = (rho.sqrt() + eta).log2(); + assert_close(got, expected); + } + + /// `1 − δ` in Capacity regime: `ρ + η`. + #[test] + fn one_minus_distance_log2_capacity() { + let log_inv_rate = 2.0; + let eta = 0.05; + let got = capacity(eta).one_minus_distance_log2(log_inv_rate); + let rho = 2_f64.powf(-log_inv_rate); + let expected = (rho + eta).log2(); + assert_close(got, expected); + } + + /// `ood_security_bits = t · (|F| − log d) − 2·log|Λ| + 1`. Returns 0 + /// under Unique. + #[test] + fn ood_security_bits_formula() { + const LOG_DEGREE: f64 = 6.0; + const LOG_INV_RATE: f64 = 2.0; + const FIELD_BITS: f64 = 64.0; + const OOD: usize = 3; + + // Unique → 0 (no soundness from OOD). + let unique = DecodingRegimeParams::Unique.ood_security_bits( + LOG_DEGREE, + LOG_INV_RATE, + FIELD_BITS, + OOD, + ); + assert_close(unique, 0.0); + + let slack = 2_f64.powf(-LOG_INV_RATE).sqrt() / 20.0; + let got = johnson(slack).ood_security_bits(LOG_DEGREE, LOG_INV_RATE, FIELD_BITS, OOD); + let log_list = johnson(slack).list_size_log2(LOG_DEGREE, LOG_INV_RATE); + let expected = (OOD as f64) * (FIELD_BITS - LOG_DEGREE) - 2.0 * log_list + 1.0; + assert_close(got, expected); + } + + const MCA_MESSAGE_LENGTH: usize = 16; + const MCA_LOG_INV_RATE: f64 = 2.0; + const MCA_FIELD_BITS: f64 = 64.0; + + /// MCA error, unique-decoding branch: `log k + log_inv_rate − field_bits`. + #[test] + fn eps_mca_log2_unique_decoding_formula() { + let got = DecodingRegimeParams::Unique.eps_mca_log2( + MCA_LOG_INV_RATE, + MCA_MESSAGE_LENGTH, + MCA_FIELD_BITS, + ); + let expected = (MCA_MESSAGE_LENGTH as f64).log2() + MCA_LOG_INV_RATE - MCA_FIELD_BITS; + assert_close(got, expected); + } + + /// MCA error, Johnson (BCSS25): `log₂(2·10.5⁵/3) + log k + 2.5·log_inv_rate − field_bits`. + #[test] + fn eps_mca_log2_johnson_formula() { + let canonical_slack = 2_f64.powf(-MCA_LOG_INV_RATE).sqrt() / 20.0; + + let got = johnson(canonical_slack).eps_mca_log2( + MCA_LOG_INV_RATE, + MCA_MESSAGE_LENGTH, + MCA_FIELD_BITS, + ); + let bcss25_const = (2.0 * 10.5_f64.powi(5) / 3.0).log2(); + let expected = bcss25_const + (MCA_MESSAGE_LENGTH as f64).log2() + 2.5 * MCA_LOG_INV_RATE + - MCA_FIELD_BITS; + assert_close(got, expected); + } + + /// MCA error, Capacity (STIR Conj 5.6 at canonical η): + /// `log k + 3·log_inv_rate + log₂10 + 1 − field_bits`. + #[test] + fn eps_mca_log2_capacity_formula() { + let canonical_slack = 2_f64.powf(-MCA_LOG_INV_RATE) / 20.0; + + let got = capacity(canonical_slack).eps_mca_log2( + MCA_LOG_INV_RATE, + MCA_MESSAGE_LENGTH, + MCA_FIELD_BITS, + ); + let expected = (MCA_MESSAGE_LENGTH as f64).log2() + 3.0 * MCA_LOG_INV_RATE + LOG2_10 + 1.0 + - MCA_FIELD_BITS; + assert_close(got, expected); + } + + /// `from_policy` dispatches to the canonical constructor for each regime. + #[test] + fn from_policy_matches_canonical() { + assert_eq!( + DecodingRegimeParams::from_policy(DecodingRegime::Unique, 0.25), + DecodingRegimeParams::Unique, + ); + assert_eq!( + DecodingRegimeParams::from_policy(DecodingRegime::Johnson, 0.25), + DecodingRegimeParams::johnson_canonical(0.25), + ); + assert_eq!( + DecodingRegimeParams::from_policy(DecodingRegime::Capacity, 0.25), + DecodingRegimeParams::capacity_canonical(0.25), + ); + } +} diff --git a/src/protocols/params/spec.rs b/src/protocols/params/spec.rs new file mode 100644 index 00000000..42ff3104 --- /dev/null +++ b/src/protocols/params/spec.rs @@ -0,0 +1,315 @@ +use std::{ + fmt::{self, Display, Formatter}, + marker::PhantomData, + num::NonZeroU32, + ops::Deref, + str::FromStr, +}; + +use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; + +use crate::{bits::Bits, engines::EngineId}; + +/// Per-slot proof-of-work policy. +/// +/// `bits` plays two roles: +/// - **Planning credit**: subtracted from `target_security_bits` so solvers +/// know the analytic floor they must reach. +/// - **Validation cap**: rejects any per-slot PoW that exceeds `bits`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum PowBudget { + Forbidden, + PerSlot { bits: NonZeroU32 }, +} + +impl PowBudget { + /// `Forbidden` when `bits == 0`, else `PerSlot { bits }`. + pub const fn per_slot(bits: u32) -> Self { + match NonZeroU32::new(bits) { + Some(bits) => Self::PerSlot { bits }, + None => Self::Forbidden, + } + } + + /// Bits of grinding allowed per slot. `0` for [`PowBudget::Forbidden`]. + pub const fn bits(self) -> u32 { + match self { + Self::Forbidden => 0, + Self::PerSlot { bits } => bits.get(), + } + } +} + +/// Phantom-typed newtype. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Tagged(T, PhantomData); + +impl Tagged { + pub const fn new(v: T) -> Self { + Self(v, PhantomData) + } + + pub const fn get(self) -> T { + self.0 + } +} + +#[derive(Debug, Clone)] +pub struct SecuritySpec { + pub mode: Mode, + pub decoding_regime: DecodingRegime, + pub target_security_bits: u32, + pub pow_budget: PowBudget, + pub hash_id: EngineId, +} + +impl SecuritySpec { + pub fn protocol_security_target_bits(&self) -> Bits { + let pow = self.pow_budget.bits(); + Bits::new(f64::from(self.target_security_bits.saturating_sub(pow))) + } +} + +/// Per-round folding strategy. `at_round(i)` returns the factor for round `i`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FoldingFactor { + /// Same folding factor across all rounds. + Constant(usize), + /// `at_round(0) = initial`; `at_round(i) = rest` for `i ≥ 1`. + ConstantFromSecondRound { initial: usize, rest: usize }, +} + +impl FoldingFactor { + pub const fn at_round(&self, round: usize) -> usize { + match self { + Self::Constant(f) => *f, + Self::ConstantFromSecondRound { initial, rest } => { + if round == 0 { + *initial + } else { + *rest + } + } + } + } + + /// Smallest factor across rounds. + pub const fn min(&self) -> usize { + match self { + Self::Constant(f) => *f, + Self::ConstantFromSecondRound { initial, rest } => { + if *initial < *rest { + *initial + } else { + *rest + } + } + } + } +} + +/// Proof-size / prover-time / soundness-margin tradeoffs. +#[derive(Debug, Clone)] +pub struct TuningSpec { + pub vector_size: usize, + pub starting_log_inv_rate: u32, + pub folding_factor: FoldingFactor, +} + +/// Per-round context handed to a sub-protocol builder. +#[derive(Debug, Clone)] +pub struct RoundContext { + pub vector_size: usize, + pub log_inv_rate: u32, + pub folding_factor: u32, +} + +/// Standard vs. zero-knowledge selection. Orthogonal to [`DecodingRegime`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Mode { + Standard, + ZeroKnowledge, +} + +/// A `SecuritySpec` borrow proven to be in [`Mode::ZeroKnowledge`]. +#[derive(Debug, Clone, Copy)] +pub struct ZkSpec<'a>(&'a SecuritySpec); + +impl<'a> ZkSpec<'a> { + pub fn try_new(spec: &'a SecuritySpec) -> Option { + matches!(spec.mode, Mode::ZeroKnowledge).then_some(Self(spec)) + } + + pub const fn as_inner(self) -> &'a SecuritySpec { + self.0 + } +} + +impl Deref for ZkSpec<'_> { + type Target = SecuritySpec; + fn deref(&self) -> &SecuritySpec { + self.0 + } +} + +/// Reed–Solomon decoding regime selection. +/// +/// - `Unique`: `δ < (1 − ρ)/2`, list size 1, no conjectures. +/// - `Johnson`: `δ < 1 − √ρ − η`, canonical `η = √ρ/20`. Proximity-gap error +/// per the BCSS25 improvement to BCIKS '20. +/// - `Capacity`: `δ < 1 − ρ − η`, canonical `η = ρ/20`. Conjectured list size +/// `d/(ρ·η)` and proximity-gap error per STIR Conjecture 5.6. +/// +/// WHIR's rate stepping (each round bumps `log_inv_rate` by +/// `folding_factor − 1`) pushes ρ → 1, shrinking the unique-decoding +/// radius. At high security targets or deep folding, `Unique` may exceed +/// the grind cap on per-round PoW and [`super::derive::ProtocolConfig::derive`] +/// will return `PowUngrindable` — pick `Johnson` or `Capacity` for those. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum DecodingRegime { + Unique, + Johnson, + Capacity, +} + +impl Display for DecodingRegime { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Unique => f.write_str("Unique"), + Self::Johnson => f.write_str("Johnson"), + Self::Capacity => f.write_str("Capacity"), + } + } +} + +impl FromStr for DecodingRegime { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "Unique" => Ok(Self::Unique), + "Johnson" => Ok(Self::Johnson), + "Capacity" => Ok(Self::Capacity), + _ => Err(format!( + "invalid decoding regime: {s}, options are: Unique, Johnson, Capacity" + )), + } + } +} + +#[cfg(test)] +mod decoding_regime_tests { + use super::*; + + #[test] + fn from_str_round_trips_display() { + for r in [ + DecodingRegime::Unique, + DecodingRegime::Johnson, + DecodingRegime::Capacity, + ] { + assert_eq!(r.to_string().parse::().unwrap(), r); + } + } + + #[test] + fn from_str_rejects_unknown() { + assert!("johnson".parse::().is_err()); // case-sensitive + assert!("".parse::().is_err()); + assert!("capacity".parse::().is_err()); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OodSampleBudgetTag {} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum MaskCodeMessageLenTag {} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum LogInvRateTag {} + +/// OOD-sample budget (Lemma 9.9 / bounds doc §5.2). +pub type OodSampleBudget = Tagged; + +impl Tagged { + /// Sentinel for "no OOD samples". + pub const ZERO: Self = Self::new(0); +} + +/// C_zk message length (Theorem 9.6: `ℓ_zk ≥ source mask length`). +pub type MaskCodeMessageLen = Tagged; + +/// `rate = 2^-log_inv_rate`. +pub type LogInvRate = Tagged; + +/// Reed–Solomon list-decoding ball size `|Λ(C, δ)|`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ListSize(OrderedFloat); + +impl ListSize { + pub const fn new(v: f64) -> Self { + Self(OrderedFloat(v)) + } + + pub const fn get(self) -> f64 { + self.0 .0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hash; + + const TARGET_BITS: u32 = 100; + + fn spec(pow_budget: PowBudget) -> SecuritySpec { + SecuritySpec { + mode: Mode::ZeroKnowledge, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: TARGET_BITS, + pow_budget, + hash_id: hash::BLAKE3, + } + } + + #[test] + fn forbidden_means_no_pow_credit() { + assert_eq!( + spec(PowBudget::Forbidden).protocol_security_target_bits(), + Bits::new(f64::from(TARGET_BITS)), + ); + } + + #[test] + fn per_slot_zero_collapses_to_forbidden() { + assert_eq!(PowBudget::per_slot(0), PowBudget::Forbidden); + } + + #[test] + fn per_slot_bits_round_trip() { + assert_eq!(PowBudget::per_slot(20).bits(), 20); + assert_eq!(PowBudget::Forbidden.bits(), 0); + } + + #[test] + fn pow_credit_shifts_analytic_floor() { + assert_eq!( + spec(PowBudget::per_slot(20)).protocol_security_target_bits(), + Bits::new(80.0), + ); + assert_eq!( + spec(PowBudget::per_slot(60)).protocol_security_target_bits(), + Bits::new(40.0), + ); + } + + #[test] + fn pow_exceeding_target_saturates_to_zero() { + let pow_over_target = TARGET_BITS + 100; + assert_eq!( + spec(PowBudget::per_slot(pow_over_target)).protocol_security_target_bits(), + Bits::new(0.0), + ); + } +} diff --git a/src/protocols/params/sumcheck.rs b/src/protocols/params/sumcheck.rs new file mode 100644 index 00000000..68b263bf --- /dev/null +++ b/src/protocols/params/sumcheck.rs @@ -0,0 +1,281 @@ +//! Sumcheck parameter selection. ZK mode adds a degree-2 mask per round +//! (Lemma 6.4, p.38). + +use crate::{ + algebra::{embedding::Embedding, fields::FieldWithSize}, + bits::Bits, + protocols::{ + irs_commit::Config as IrsConfig, + params::{ + bounds::usize_to_f64, + branch::SolveMode, + error::{grind_to_at, DeriveError, Pow}, + protocol_config::MaskOracleInfo, + spec::{RoundContext, SecuritySpec}, + }, + sumcheck::{self, Config as SumcheckConfig, SumcheckMaskLen}, + }, +}; + +/// Per-round sumcheck builder. +pub fn solve( + spec: &SecuritySpec, + ctx: &RoundContext, + source_irs: &IrsConfig, + mode: SolveMode, + pow: Pow, +) -> Result, DeriveError> { + let (mask_oracle, output_mode) = match mode { + SolveMode::Standard => (None, sumcheck::SumcheckMode::Standard), + SolveMode::ZeroKnowledge(mask_oracle) => ( + Some(mask_oracle), + sumcheck::SumcheckMode::ZeroKnowledge { + mask_length: zk_mask_length(), + }, + ), + }; + let analytic = analytic_error_bits(source_irs, mask_oracle); + let round_pow = grind_to_at(spec, analytic, pow)?; + Ok(SumcheckConfig::new( + ctx.vector_size, + round_pow, + num_sumcheck_rounds(ctx), + output_mode, + ) + .with_recorded_analytic(analytic)) +} + +/// Per-sumcheck-round soundness in bits: `min(ε_mca, poly_identity_term)`. +/// +/// - Standard (degree-2): `log|F| − log|Λ(C)| − 1`. +/// - ZK (Lemma 6.5, p.40): `log|F| − log|Λ(C)| − log|Λ(C_zk)| − log ℓ_zk`. +pub fn analytic_error_bits( + source_irs: &IrsConfig, + mask_oracle: Option, +) -> Bits { + let field_bits = M::Target::field_size_bits(); + let log_list_size = source_irs.list_size().log2(); + let prox_gaps = source_irs.rbr_soundness_fold_prox_gaps(); + + let poly_id = mask_oracle.map_or(field_bits - log_list_size - 1.0, |info| { + let log_list_size_c_zk = info.c_zk_list_size.get().log2(); + let log_l_zk = usize_to_f64(info.l_zk.get()).log2(); + field_bits - log_list_size - log_list_size_c_zk - log_l_zk + }); + + Bits::new(prox_gaps.min(poly_id).max(0.0)) +} + +/// Number of degree-2 round-polynomial masks sumcheck contributes to C_zk +/// per round (Lemma 6.4). +pub const fn masks_required(ctx: &RoundContext) -> usize { + num_sumcheck_rounds(ctx) +} + +const fn num_sumcheck_rounds(ctx: &RoundContext) -> usize { + ctx.folding_factor as usize +} + +/// Construction 6.3 step 4(a) sends `h_j ∈ F^{ SumcheckMaskLen { + SumcheckMaskLen::new(3) +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::protocols::params::{ + irs_commit as irs_params, + spec::{ListSize, MaskCodeMessageLen, Mode, OodSampleBudget}, + test_utils::{ + arb_round_ctx, arb_standard_spec, arb_zk_spec, assert_close, assert_pow_closes_gap, + build_minimal_mask_oracle, deterministic_spec, TestEmbedding, TestField, + TestNonIdentityEmbedding, EPS, TEST_TARGET_RANGE, + }, + }; + + const FIXTURE_C_ZK_LIST_SIZE: f64 = 4.0; + const FIXTURE_L_ZK: usize = 8; + + fn build_source_irs(spec: &SecuritySpec, ctx: &RoundContext) -> IrsConfig { + irs_params::solve(spec, ctx, OodSampleBudget::ZERO) + } + + const FIXTURE_LOG_VECTOR_SIZE: u32 = 4; + const FIXTURE_LOG_INV_RATE: u32 = 1; + const FIXTURE_FOLDING_FACTOR: u32 = 2; + + fn fixture_ctx() -> RoundContext { + RoundContext { + vector_size: 1 << FIXTURE_LOG_VECTOR_SIZE, + log_inv_rate: FIXTURE_LOG_INV_RATE, + folding_factor: FIXTURE_FOLDING_FACTOR, + } + } + + #[test] + fn zk_mode_has_three_mask_coefficients() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = fixture_ctx(); + let source_irs = build_source_irs(&spec, &ctx); + let mask_oracle = + build_minimal_mask_oracle(&spec).expect("ZK spec must produce a mask oracle"); + let config = solve( + &spec, + &ctx, + &source_irs, + SolveMode::ZeroKnowledge(mask_oracle), + Pow::RoundSumcheck { index: 0 }, + ) + .unwrap(); + match config.mode { + sumcheck::SumcheckMode::ZeroKnowledge { mask_length } => { + assert_eq!(mask_length.get(), 3); + } + sumcheck::SumcheckMode::Standard => panic!("expected ZK"), + } + } + + #[test] + fn analytic_error_standard_formula() { + let spec = deterministic_spec(Mode::Standard); + let ctx = fixture_ctx(); + let irs = build_source_irs(&spec, &ctx); + + let got = f64::from(analytic_error_bits::(&irs, None)); + + let field_bits = TestField::field_size_bits(); + let log_list = irs.list_size().log2(); + let prox = irs.rbr_soundness_fold_prox_gaps(); + let expected = prox.min(field_bits - log_list - 1.0).max(0.0); + + assert_close(got, expected); + } + + #[test] + fn analytic_error_zk_formula() { + let log_c_zk_list = FIXTURE_C_ZK_LIST_SIZE.log2(); + let log_l_zk = (FIXTURE_L_ZK as f64).log2(); + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = fixture_ctx(); + let irs = build_source_irs(&spec, &ctx); + let info = MaskOracleInfo { + c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), + l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), + }; + + let got = f64::from(analytic_error_bits::(&irs, Some(info))); + + let field_bits = TestField::field_size_bits(); + let log_list = irs.list_size().log2(); + let prox = irs.rbr_soundness_fold_prox_gaps(); + let expected = prox + .min(field_bits - log_list - log_c_zk_list - log_l_zk) + .max(0.0); + + assert_close(got, expected); + } + + #[test] + fn analytic_error_clamps_to_zero() { + const OVERSIZED_LOG_C_ZK_LIST: i32 = 60; + const OVERSIZED_LOG_L_ZK: u32 = 30; + + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = fixture_ctx(); + let irs = build_source_irs(&spec, &ctx); + let huge = MaskOracleInfo { + c_zk_list_size: ListSize::new(2_f64.powi(OVERSIZED_LOG_C_ZK_LIST)), + l_zk: MaskCodeMessageLen::new(1 << OVERSIZED_LOG_L_ZK), + }; + let bits = f64::from(analytic_error_bits::(&irs, Some(huge))); + assert_close(bits, 0.0); + } + + proptest! { + #[test] + fn standard_mode_propagates( + spec in arb_standard_spec(TEST_TARGET_RANGE), + ctx in arb_round_ctx(), + ) { + let source_irs = build_source_irs(&spec, &ctx); + let pow = Pow::RoundSumcheck { index: 0 }; + let config = solve(&spec, &ctx, &source_irs, SolveMode::Standard, pow).unwrap(); + prop_assert!(matches!(config.mode, sumcheck::SumcheckMode::Standard)); + } + + #[test] + fn num_rounds_matches_folding_factor( + spec in prop_oneof![ + arb_standard_spec(TEST_TARGET_RANGE), + arb_zk_spec(TEST_TARGET_RANGE), + ], + ctx in arb_round_ctx(), + ) { + let source_irs = build_source_irs(&spec, &ctx); + let pow = Pow::RoundSumcheck { index: 0 }; + let mode = build_minimal_mask_oracle(&spec) + .map_or(SolveMode::Standard, SolveMode::ZeroKnowledge); + let config = solve(&spec, &ctx, &source_irs, mode, pow).unwrap(); + prop_assert_eq!(config.num_rounds, ctx.folding_factor as usize); + } + + #[test] + fn zk_error_le_standard_error( + spec in arb_zk_spec(TEST_TARGET_RANGE), + ctx in arb_round_ctx(), + ) { + let irs = build_source_irs(&spec, &ctx); + let mo = build_minimal_mask_oracle(&spec); + let zk = f64::from(analytic_error_bits::(&irs, mo)); + let standard = f64::from(analytic_error_bits::(&irs, None)); + prop_assert!(zk <= standard + EPS, "zk {} > standard {}", zk, standard); + } + + #[test] + fn round_pow_closes_gap_to_target( + spec in prop_oneof![ + arb_standard_spec(TEST_TARGET_RANGE), + arb_zk_spec(TEST_TARGET_RANGE), + ], + ctx in arb_round_ctx(), + ) { + let source_irs = build_source_irs(&spec, &ctx); + let mask_oracle = build_minimal_mask_oracle(&spec); + let error = analytic_error_bits(&source_irs, mask_oracle); + let pow = Pow::RoundSumcheck { index: 0 }; + let mode = mask_oracle.map_or(SolveMode::Standard, SolveMode::ZeroKnowledge); + let config = solve(&spec, &ctx, &source_irs, mode, pow).unwrap(); + assert_pow_closes_gap(&spec, error, &config.round_pow); + } + } + + #[test] + fn solve_works_with_basefield_embedding_zk() { + let spec = deterministic_spec(Mode::ZeroKnowledge); + let ctx = fixture_ctx(); + let source_irs: IrsConfig = + irs_params::solve(&spec, &ctx, OodSampleBudget::ZERO); + let info = MaskOracleInfo { + c_zk_list_size: ListSize::new(FIXTURE_C_ZK_LIST_SIZE), + l_zk: MaskCodeMessageLen::new(FIXTURE_L_ZK), + }; + let config = solve( + &spec, + &ctx, + &source_irs, + SolveMode::ZeroKnowledge(info), + Pow::RoundSumcheck { index: 0 }, + ) + .unwrap(); + assert!(matches!( + config.mode, + sumcheck::SumcheckMode::ZeroKnowledge { .. } + )); + } +} diff --git a/src/protocols/params/test_utils.rs b/src/protocols/params/test_utils.rs new file mode 100644 index 00000000..ae3a0604 --- /dev/null +++ b/src/protocols/params/test_utils.rs @@ -0,0 +1,174 @@ +//! Shared test fixtures. + +use std::ops::RangeInclusive; + +use proptest::prelude::*; + +use crate::{ + algebra::{ + embedding::{Basefield, Embedding, Identity}, + fields::{Field64, Field64_2}, + }, + bits::Bits, + hash, + protocols::{ + irs_commit::Config as IrsConfig, + mask_proximity::Config as MaskProximityConfig, + params::{ + branch::OodMode, + build_round::solve_t_ood, + irs_commit as irs_params, + protocol_config::MaskOracleInfo, + spec::{ + DecodingRegime, ListSize, LogInvRate, MaskCodeMessageLen, Mode, OodSampleBudget, + PowBudget, RoundContext, SecuritySpec, ZkSpec, + }, + }, + proof_of_work::Config as PowConfig, + }, +}; + +pub type TestField = Field64; +pub type TestEmbedding = Identity; +pub type TestExtensionField = Field64_2; +/// `Source = Field64, Target = Field64_2`. +pub type TestNonIdentityEmbedding = Basefield; + +pub const TEST_TARGET_RANGE: RangeInclusive = 30..=50; + +pub const FIXTURE_TARGET_BITS: u32 = 80; + +pub const EPS: f64 = 1e-9; + +pub const FIXTURE_POW_BUDGET_BITS: u32 = 60; + +pub fn deterministic_spec(mode: Mode) -> SecuritySpec { + SecuritySpec { + mode, + decoding_regime: DecodingRegime::Johnson, + target_security_bits: FIXTURE_TARGET_BITS, + pow_budget: PowBudget::per_slot(FIXTURE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + } +} + +fn arb_decoding_regime() -> impl Strategy { + prop_oneof![ + Just(DecodingRegime::Johnson), + Just(DecodingRegime::Unique), + Just(DecodingRegime::Capacity), + ] +} + +pub fn arb_spec( + mode: Mode, + target_range: RangeInclusive, +) -> impl Strategy { + (target_range, arb_decoding_regime()).prop_map(move |(target, decoding_regime)| SecuritySpec { + mode, + decoding_regime, + target_security_bits: target, + pow_budget: PowBudget::per_slot(FIXTURE_POW_BUDGET_BITS), + hash_id: hash::BLAKE3, + }) +} + +pub fn arb_zk_spec(target_range: RangeInclusive) -> impl Strategy { + arb_spec(Mode::ZeroKnowledge, target_range) +} + +pub fn arb_standard_spec(target_range: RangeInclusive) -> impl Strategy { + arb_spec(Mode::Standard, target_range) +} + +pub fn arb_round_ctx() -> impl Strategy { + (4u32..=8, 1u32..=4, 1u32..=3).prop_map(|(log_size, log_inv_rate, folding_factor)| { + RoundContext { + vector_size: 1usize << log_size, + log_inv_rate, + folding_factor, + } + }) +} + +/// `None` in Standard; `Some(ℓ_zk=2, c_zk rate 1/2)` in ZK. +pub fn build_minimal_mask_oracle(spec: &SecuritySpec) -> Option { + let zk_spec = ZkSpec::try_new(spec)?; + let l_zk = MaskCodeMessageLen::new(2); + let c_zk: IrsConfig = + irs_params::solve_mask_code(zk_spec, l_zk, 0, LogInvRate::new(1), 2); + Some(MaskOracleInfo { + c_zk_list_size: ListSize::new(c_zk.list_size()), + l_zk, + }) +} + +/// `analytic_error_bits + pow.difficulty() ≥ target_security_bits`. +pub fn assert_pow_closes_gap(spec: &SecuritySpec, analytic: Bits, pow: &PowConfig) { + let error = f64::from(analytic); + let pow_bits = f64::from(pow.difficulty()); + let target = f64::from(spec.target_security_bits); + assert!( + error + pow_bits >= target - 1e-3, + "error {error} + pow {pow_bits} < target {target}", + ); +} + +/// `|got − expected| < EPS`. +pub fn assert_close(got: f64, expected: f64) { + assert!( + (got - expected).abs() < EPS, + "got {got} vs expected {expected}", + ); +} + +/// C_zk fixture for `mask_proximity` tests. +pub fn build_test_c_zk( + spec: &SecuritySpec, + l_zk: usize, + log_inv_rate: u32, + num_masks: usize, +) -> IrsConfig { + let zk_spec = ZkSpec::try_new(spec).expect("build_test_c_zk requires a ZK spec"); + irs_params::solve_mask_code( + zk_spec, + MaskCodeMessageLen::new(l_zk), + 0, + LogInvRate::new(log_inv_rate), + MaskProximityConfig::::num_vectors_for(num_masks), + ) +} + +/// Builds a self-consistent `(source, target, t_ood)` triplet matching the +/// per-round shape that `code_switch::solve` expects. +pub fn build_round_io( + spec: &SecuritySpec, + log_inv_rate: u32, + folding_factor: u32, + num_vars: u32, + c_zk_log_inv_rate: Option, +) -> (IrsConfig, IrsConfig>, usize) { + let source_ctx = RoundContext { + vector_size: 1usize << num_vars, + log_inv_rate, + folding_factor, + }; + let target_log_inv_rate = log_inv_rate + folding_factor - 1; + let target_log_degree = f64::from(num_vars - folding_factor); + let target_list_size = spec + .decoding_regime + .list_size_estimate(target_log_degree, f64::from(target_log_inv_rate)); + let ood_mode = c_zk_log_inv_rate.map_or(OodMode::Standard, |rate| { + OodMode::ZeroKnowledge(f64::from(rate)) + }); + let (source, t_ood) = solve_t_ood::(spec, &source_ctx, target_list_size, ood_mode, 0) + .expect("solve_t_ood diverged in test fixture"); + + let target_ctx = RoundContext { + vector_size: source.message_length(), + log_inv_rate: target_log_inv_rate, + folding_factor, + }; + let target = irs_params::solve(spec, &target_ctx, OodSampleBudget::new(t_ood)); + (source, target, t_ood) +} diff --git a/src/protocols/proof_of_work.rs b/src/protocols/proof_of_work.rs index a2a8aabc..ccdc3a60 100644 --- a/src/protocols/proof_of_work.rs +++ b/src/protocols/proof_of_work.rs @@ -26,8 +26,20 @@ pub struct Config { pub threshold: u64, } +/// Failure modes for [`Config::grind_to`]. +#[derive(Debug, thiserror::Error, Clone, Copy, PartialEq, Eq)] +pub enum PowError { + /// `target − analytic_error` exceeds what a single grind slot can deliver + /// (`MAX_DIFFICULTY = 60` bits, matching the grinding engine's capacity). + #[error("required {required} bits exceeds the {max} grind cap")] + GapExceedsGrindCap { required: Bits, max: Bits }, +} + +/// Largest gap a single grind slot can close (in bits). +pub const MAX_DIFFICULTY: f64 = 60.0; + pub fn threshold(difficulty: Bits) -> u64 { - assert!((0.0..=60.0).contains(&difficulty.into())); + assert!((0.0..=MAX_DIFFICULTY).contains(&difficulty.into())); let threshold = (64.0 - f64::from(difficulty)).exp2().ceil(); #[allow(clippy::cast_sign_loss)] @@ -64,6 +76,34 @@ impl Config { difficulty(self.threshold) } + /// Build a PoW config whose difficulty closes the gap `target - analytic_error`, + /// clamped at zero. + /// + /// Used by parameter solvers: each PoW slot independently lifts its own + /// soundness up to `target`. The caller is responsible for ensuring + /// `analytic_error` is computed from the local protocol step (see e.g. + /// `params::sumcheck`). + /// + /// Returns [`PowError::GapExceedsGrindCap`] if the required difficulty + /// exceeds [`MAX_DIFFICULTY`] — the spec is too tight for any single slot. + pub fn grind_to( + target: Bits, + analytic_error: Bits, + hash_id: EngineId, + ) -> Result { + let gap = (f64::from(target) - f64::from(analytic_error)).max(0.0); + if gap > MAX_DIFFICULTY { + return Err(PowError::GapExceedsGrindCap { + required: Bits::new(gap), + max: Bits::new(MAX_DIFFICULTY), + }); + } + Ok(Self { + hash_id, + threshold: threshold(Bits::new(gap)), + }) + } + #[cfg_attr(feature = "tracing", instrument(skip_all, fields(engine)))] pub fn prove(&self, prover_state: &mut ProverState) where diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index 5018e22d..b6772baa 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -30,6 +30,36 @@ pub struct SumcheckOpening { pub mask_rlc: F, } +/// ZK sumcheck mask polynomial dimension. +/// +/// Validated at construction to be at least `MIN = 3` — the round polynomial +/// has 3 coefficients (degree-2), so the mask must have at least as many to +/// hide it. Lemma 6.4 itself only requires `ℓ_zk ≥ 2`; the `3` floor is a +/// WHIR design choice tied to the degree-2 round polynomial (see +/// `params::sumcheck::zk_mask_length`). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct SumcheckMaskLen(usize); + +impl SumcheckMaskLen { + pub const MIN: usize = 3; + + pub const fn new(n: usize) -> Self { + assert!(n >= Self::MIN); + Self(n) + } + + pub const fn get(self) -> usize { + self.0 + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum SumcheckMode { + Standard, + ZeroKnowledge { mask_length: SumcheckMaskLen }, +} + +#[must_use] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(bound = "")] pub struct Config @@ -40,10 +70,51 @@ where pub initial_size: usize, pub round_pow: proof_of_work::Config, pub num_rounds: usize, - pub mask_length: usize, + pub mode: SumcheckMode, + /// Analytic-error floor recorded by the params solver that produced this + /// config. `None` when the config wasn't built via `params::sumcheck::solve` + /// (legacy/test paths). Drift checks compare this against a recompute. + pub recorded_analytic: Option, } impl Config { + pub fn new( + initial_size: usize, + round_pow: proof_of_work::Config, + num_rounds: usize, + mode: SumcheckMode, + ) -> Self { + assert!(num_rounds == 0 || initial_size.next_power_of_two() >= 1 << num_rounds); + // `SumcheckMaskLen::new` already enforces the ≥ 3 floor at construction; + // here we only need the field-characteristic precondition from Lemma 6.4. + if matches!(mode, SumcheckMode::ZeroKnowledge { .. }) { + assert!( + !F::ONE.double().is_zero(), + "ZK sumcheck requires char(F) ≠ 2" + ); + } + Self { + field: Type::new(), + initial_size, + round_pow, + num_rounds, + mode, + recorded_analytic: None, + } + } + + pub const fn with_recorded_analytic(mut self, analytic: crate::bits::Bits) -> Self { + self.recorded_analytic = Some(analytic); + self + } + + const fn mask_length(&self) -> usize { + match &self.mode { + SumcheckMode::Standard => 0, + SumcheckMode::ZeroKnowledge { mask_length } => mask_length.get(), + } + } + pub fn final_size(&self) -> usize { assert!( self.num_rounds == 0 || self.initial_size.next_power_of_two() >= 1 << self.num_rounds @@ -84,29 +155,21 @@ impl Config { assert!( self.num_rounds == 0 || self.initial_size.next_power_of_two() >= 1 << self.num_rounds ); - // Mask must cover all 3 sumcheck polynomial coefficients (c0, c1, c2) - assert!(self.mask_length == 0 || self.mask_length >= 3); assert_eq!(a.len(), self.initial_size); assert_eq!(b.len(), self.initial_size); debug_assert_eq!(dot(a, b), *sum); - assert_eq!(masks.len(), self.num_rounds * self.mask_length); + assert_eq!(masks.len(), self.num_rounds * self.mask_length()); + let half = F::from(2).inverse().unwrap(); + let polynomial_len = self.mask_length().max(3); - let mut mask_sum = F::ZERO; - let mut mask_rlc = F::ONE; - if self.mask_length > 0 && self.num_rounds > 0 { - let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(1)); - mask_sum = masks.chunks_exact(self.mask_length).map(eval_01).sum::() * sum_multiple; - prover_state.prover_message(&mask_sum); - mask_rlc = prover_state.verifier_message(); - } + let (mut mask_sum, mask_rlc) = self.maybe_send_initial_mask_sum(prover_state, masks); - // We do a staggered Sumcheck loop so we can merge the inner fold+compute loops. - let mut univariate = Vec::new(); + let mut univariate = Vec::with_capacity(polynomial_len); let mut round_challenges = Vec::with_capacity(self.num_rounds); let mut prev_round_challenge = None; for (round, mask) in - chunks_exact_or_empty(masks, self.mask_length, self.num_rounds).enumerate() + chunks_exact_or_empty(masks, self.mask_length(), self.num_rounds).enumerate() { // Fold and compute sumcheck polynomial in one pass. let (c0, c2) = if let Some(w) = prev_round_challenge { @@ -116,40 +179,33 @@ impl Config { }; let c1 = *sum - c0.double() - c2; - // Optionally mask with univariate - if mask.is_empty() { - prover_state.prover_messages(&[c0, c2]); - } else { - // Initialize to round masking univariate polynomial. - univariate.clear(); - let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(round + 1)); - univariate.extend(mask.iter().map(|m| sum_multiple * *m)); - - // Add constant term from previous and future masks. - univariate[0] += (mask_sum - sum_multiple * eval_01(mask)) * half; - - // Add plain sumcheck polynomial - univariate[0] += mask_rlc * c0; - univariate[1] += mask_rlc * c1; - univariate[2] += mask_rlc * c2; - - prover_state.prover_message(&univariate[0]); - prover_state.prover_messages(&univariate[2..]); + // Build round polynomial. In Standard (`mask = []`, `mask_rlc = 1`, + // `mask_sum = 0`) this collapses to `[c0, c1, c2]`. + univariate.clear(); + univariate.resize(polynomial_len, F::ZERO); + let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(round + 1)); + for (u, m) in univariate.iter_mut().zip(mask.iter()) { + *u = sum_multiple * *m; } + univariate[0] += (mask_sum - sum_multiple * eval_01(mask)) * half; + univariate[0] += mask_rlc * c0; + univariate[1] += mask_rlc * c1; + univariate[2] += mask_rlc * c2; + + prover_state.prover_message(&univariate[0]); + prover_state.prover_messages(&univariate[2..]); - // Receive the random evaluation point and update the sum + // Receive the random evaluation point and update the sum. self.round_pow.prove(prover_state); let r = prover_state.verifier_message::(); round_challenges.push(r); *sum = (c2 * r + c1) * r + c0; - if self.mask_length > 0 && self.num_rounds > 0 { - let masked_sum = univariate_evaluate(&univariate, r); - mask_sum = masked_sum - mask_rlc * *sum; - } + + mask_sum = univariate_evaluate(&univariate, r) - mask_rlc * *sum; prev_round_challenge = Some(r); } if let Some(w) = prev_round_challenge { - // Final fold of the inputs (no polynomial computation) + // Final fold of the inputs (no polynomial computation). fold(a, w); fold(b, w); } @@ -161,6 +217,35 @@ impl Config { } } + fn maybe_send_initial_mask_sum( + &self, + prover_state: &mut ProverState, + masks: &[F], + ) -> (F, F) + where + H: DuplexSpongeInterface, + R: CryptoRng + RngCore, + F: Codec<[H::U]>, + { + match &self.mode { + SumcheckMode::Standard => (F::ZERO, F::ONE), + SumcheckMode::ZeroKnowledge { mask_length } => { + if self.num_rounds == 0 { + return (F::ZERO, F::ONE); + } + let sum_multiple = F::from(1 << self.num_rounds.saturating_sub(1)); + let mask_sum = masks + .chunks_exact(mask_length.get()) + .map(eval_01) + .sum::() + * sum_multiple; + prover_state.prover_message(&mask_sum); + let mask_rlc = prover_state.verifier_message(); + (mask_sum, mask_rlc) + } + } + } + #[cfg_attr(feature = "tracing", instrument(skip_all))] pub fn verify( &self, @@ -176,17 +261,10 @@ impl Config { assert!( self.num_rounds == 0 || self.initial_size.next_power_of_two() >= 1 << self.num_rounds ); - // Mask must cover all 3 sumcheck polynomial coefficients (c0, c1, c2) - assert!(self.mask_length == 0 || self.mask_length >= 3); - - let mut mask_rlc = F::ONE; - if self.mask_length > 0 && self.num_rounds > 0 { - let mask_sum: F = verifier_state.prover_message()?; - mask_rlc = verifier_state.verifier_message(); - *sum = mask_sum + mask_rlc * *sum; - } - let mut univariate = vec![F::ZERO; self.mask_length.max(3)]; + let mask_rlc = self.maybe_receive_initial_mask_sum(verifier_state, sum)?; + + let mut univariate = vec![F::ZERO; self.mask_length().max(3)]; let mut round_challenges = Vec::with_capacity(self.num_rounds); for _ in 0..self.num_rounds { // Receive all but linear coefficient. @@ -195,17 +273,17 @@ impl Config { *c = verifier_state.prover_message()?; } - // Derive linear coefficient from relation `univariate(0) + univariate(1) = sum` + // Derive linear coefficient from relation `univariate(0) + univariate(1) = sum`. univariate[1] = *sum - univariate[0].double() - univariate[2..].iter().sum::(); - // Check proof of work (if any) + // Check proof of work (if any). self.round_pow.verify(verifier_state)?; - // Receive the random evaluation point + // Receive the random evaluation point. let round_challenge = verifier_state.verifier_message::(); round_challenges.push(round_challenge); - // Update the sum + // Update the sum. *sum = univariate_evaluate(&univariate, round_challenge); } Ok(SumcheckOpening { @@ -213,17 +291,46 @@ impl Config { mask_rlc, }) } + + fn maybe_receive_initial_mask_sum( + &self, + verifier_state: &mut VerifierState, + sum: &mut F, + ) -> VerificationResult + where + H: DuplexSpongeInterface, + F: Codec<[H::U]>, + { + match &self.mode { + SumcheckMode::Standard => Ok(F::ONE), + SumcheckMode::ZeroKnowledge { .. } => { + if self.num_rounds == 0 { + return Ok(F::ONE); + } + let mask_sum: F = verifier_state.prover_message()?; + let mask_rlc = verifier_state.verifier_message(); + *sum = mask_sum + mask_rlc * *sum; + Ok(mask_rlc) + } + } + } } impl fmt::Display for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mode_str = match &self.mode { + SumcheckMode::Standard => "standard".to_string(), + SumcheckMode::ZeroKnowledge { mask_length } => { + format!("zk ℓ_zk={}", mask_length.get()) + } + }; write!( f, - "size {} rounds {} pow {:.2} ℓ_zk {}", + "size {} rounds {} pow {:.2} {}", self.initial_size, self.num_rounds, self.round_pow.difficulty(), - self.mask_length + mode_str, ) } } @@ -261,21 +368,22 @@ mod tests { Standard: Distribution, { pub fn arbitrary() -> impl Strategy { - let mask_length = prop_oneof![ - 3 => Just(0_usize), - 7 => 3_usize..20, + let mode_strategy = prop_oneof![ + 3 => Just(SumcheckMode::Standard), + 7 => (3_usize..20).prop_map(|n| SumcheckMode::ZeroKnowledge { + mask_length: SumcheckMaskLen::new(n), + }), ]; - (0_usize..(1 << 12), 0_usize..12, mask_length).prop_map( - |(initial_size, num_rounds, mask_length)| { + (0_usize..(1 << 12), 0_usize..12, mode_strategy).prop_map( + |(initial_size, num_rounds, mode)| { let num_rounds = num_rounds.min(initial_size.next_power_of_two().trailing_zeros() as usize); - Self { - field: Type::new(), + Self::new( initial_size, + proof_of_work::Config::none(), num_rounds, - round_pow: proof_of_work::Config::none(), - mask_length, - } + mode, + ) }, ) } @@ -296,7 +404,7 @@ mod tests { let initial_vector = random_vector(&mut rng, config.initial_size); let initial_covector = random_vector(&mut rng, config.initial_size); let initial_sum = dot(&initial_vector, &initial_covector); - let masks = random_vector(&mut rng, config.mask_length * config.num_rounds); + let masks = random_vector(&mut rng, config.mask_length() * config.num_rounds); // Prover let mut vector = initial_vector.clone(); @@ -323,7 +431,7 @@ mod tests { } let expected_mask_sum: F = - chunks_exact_or_empty(&masks, config.mask_length, config.num_rounds) + chunks_exact_or_empty(&masks, config.mask_length(), config.num_rounds) .zip(&point) .map(|(m, x)| univariate_evaluate(m, *x)) .sum(); @@ -345,8 +453,8 @@ mod tests { assert_eq!(verifier_sum, sum); verifier_state.check_eof().unwrap(); - // Non-ZK path: mask_rlc defaults to ONE (no combination randomness sampled) - if config.mask_length == 0 || config.num_rounds == 0 { + // Standard path: mask_rlc defaults to ONE (no combination randomness sampled). + if matches!(config.mode, SumcheckMode::Standard) || config.num_rounds == 0 { assert_eq!(mask_rlc, F::ONE); } } @@ -365,13 +473,14 @@ mod tests { fn test_single_round() { test_config( 0, - &Config:: { - field: Type::new(), - initial_size: 2, - round_pow: proof_of_work::Config::none(), - num_rounds: 1, - mask_length: 3, - }, + &Config::::new( + 2, + proof_of_work::Config::none(), + 1, + SumcheckMode::ZeroKnowledge { + mask_length: SumcheckMaskLen::new(3), + }, + ), ); } @@ -379,13 +488,14 @@ mod tests { fn test_two_rounds() { test_config( 0, - &Config:: { - field: Type::new(), - initial_size: 3, - round_pow: proof_of_work::Config::none(), - num_rounds: 2, - mask_length: 3, - }, + &Config::::new( + 3, + proof_of_work::Config::none(), + 2, + SumcheckMode::ZeroKnowledge { + mask_length: SumcheckMaskLen::new(3), + }, + ), ); } @@ -393,13 +503,14 @@ mod tests { fn test_three_rounds() { test_config( 0, - &Config:: { - field: Type::new(), - initial_size: 5, - round_pow: proof_of_work::Config::none(), - num_rounds: 3, - mask_length: 3, - }, + &Config::::new( + 5, + proof_of_work::Config::none(), + 3, + SumcheckMode::ZeroKnowledge { + mask_length: SumcheckMaskLen::new(3), + }, + ), ); } diff --git a/src/protocols/whir/config.rs b/src/protocols/whir/config.rs index 89892425..2a2c2f24 100644 --- a/src/protocols/whir/config.rs +++ b/src/protocols/whir/config.rs @@ -7,10 +7,26 @@ use crate::{ algebra::{embedding::Embedding, fields::FieldWithSize}, bits::Bits, parameters::ProtocolParameters, - protocols::{irs_commit, proof_of_work, sumcheck}, - type_info::Type, + protocols::{ + irs_commit::{self, num_ood_samples, IrsMode}, + proof_of_work, sumcheck, + }, }; +/// log2 round-by-round soundness of `t_ood` OOD samples against a code with +/// the given list size — formerly `irs_commit::Config::rbr_ood_sample`. +fn rbr_ood_sample( + list_size: f64, + log_field_size: f64, + vector_size: usize, + out_domain_samples: usize, +) -> f64 { + // [STIR] Lemma 4.5. + let l_choose_2 = list_size * (list_size - 1.) / 2.; + let log_per_sample = ((vector_size - 1) as f64).log2() - log_field_size; + -l_choose_2.log2() - out_domain_samples as f64 * log_per_sample +} + impl Config { #[allow(clippy::too_many_lines)] pub fn new(size: usize, whir_parameters: &ProtocolParameters) -> Self @@ -39,12 +55,20 @@ impl Config { #[allow(clippy::cast_possible_wrap)] let initial_committer = irs_commit::Config::new( protocol_security_level, - whir_parameters.unique_decoding, + whir_parameters.decoding_regime, whir_parameters.hash_id, whir_parameters.batch_size, size, 1 << whir_parameters.initial_folding_factor, 0.5_f64.powi(whir_parameters.starting_log_inv_rate as i32), + IrsMode::Standard, + ); + let initial_out_domain_samples = num_ood_samples( + whir_parameters.decoding_regime, + protocol_security_level, + field_size_bits, + initial_committer.list_size(), + size, ); // Initial sumcheck round pow bits. @@ -79,16 +103,24 @@ impl Config { #[allow(clippy::cast_possible_wrap)] let irs_committer = irs_commit::Config::new( protocol_security_level, - whir_parameters.unique_decoding, + whir_parameters.decoding_regime, whir_parameters.hash_id, 1, 1 << num_variables, 1 << whir_parameters.folding_factor, 0.5_f64.powi(next_rate as i32), + IrsMode::Standard, + ); + let round_out_domain_samples = num_ood_samples( + whir_parameters.decoding_regime, + protocol_security_level, + field_size_bits, + irs_committer.list_size(), + 1 << num_variables, ); let combination_error = { let log_list_size = irs_committer.list_size().log2(); - let count = irs_committer.out_domain_samples + in_domain_samples; + let count = round_out_domain_samples + in_domain_samples; let log_combination = (count as f64).log2(); field_size_bits - (log_combination + log_list_size + 1.) }; @@ -103,13 +135,13 @@ impl Config { let config = RoundConfig { irs_committer, - sumcheck: sumcheck::Config { - field: Type::new(), - initial_size: 1 << num_variables, - round_pow: pow(folding_pow_bits), - num_rounds: whir_parameters.folding_factor, - mask_length: 0, - }, + out_domain_samples: round_out_domain_samples, + sumcheck: sumcheck::Config::new( + 1 << num_variables, + pow(folding_pow_bits), + whir_parameters.folding_factor, + sumcheck::SumcheckMode::Standard, + ), pow: pow(pow_bits), }; @@ -131,22 +163,21 @@ impl Config { Self { initial_committer, - initial_sumcheck: sumcheck::Config { - field: Type::new(), - initial_size: size, - round_pow: pow(starting_folding_pow_bits), - num_rounds: whir_parameters.initial_folding_factor, - mask_length: 0, - }, + initial_out_domain_samples, + initial_sumcheck: sumcheck::Config::new( + size, + pow(starting_folding_pow_bits), + whir_parameters.initial_folding_factor, + sumcheck::SumcheckMode::Standard, + ), initial_skip_pow: pow(initial_skip_pow_bits), round_configs, - final_sumcheck: sumcheck::Config { - field: Type::new(), - initial_size: 1 << num_variables, - round_pow: pow(final_folding_pow_bits), - num_rounds: num_variables, - mask_length: 0, - }, + final_sumcheck: sumcheck::Config::new( + 1 << num_variables, + pow(final_folding_pow_bits), + num_variables, + sumcheck::SumcheckMode::Standard, + ), final_pow: pow(final_pow_bits), } } @@ -171,11 +202,15 @@ impl Config { security_level = security_level.min(field_size_bits - ((num_linear_forms - 1) as f64).log2()); } - let has_initial_constraints = - num_linear_forms > 0 || self.initial_committer.out_domain_samples > 0; + let has_initial_constraints = num_linear_forms > 0 || self.initial_out_domain_samples > 0; if !self.initial_committer.unique_decoding() { - security_level = security_level.min(self.initial_committer.rbr_ood_sample()); + security_level = security_level.min(rbr_ood_sample( + self.initial_committer.list_size(), + field_size_bits, + self.initial_committer.vector_size, + self.initial_out_domain_samples, + )); } // Initial sumcheck error (or the skipped version for LDT). @@ -200,13 +235,18 @@ impl Config { let new_unique_decoding = round.irs_committer.unique_decoding(); if !new_unique_decoding { - let ood_error = round.irs_committer.rbr_ood_sample(); + let ood_error = rbr_ood_sample( + round.irs_committer.list_size(), + field_size_bits, + round.irs_committer.vector_size, + round.out_domain_samples, + ); security_level = security_level.min(ood_error); } let log_list_size = round.irs_committer.list_size().log2(); let combination_error = { - let count = round.irs_committer.out_domain_samples + old_in_domain_samples; + let count = round.out_domain_samples + old_in_domain_samples; let log_combination = (count as f64).log2(); field_size_bits - (log_combination + log_list_size + 1.) }; @@ -347,7 +387,12 @@ impl Display for Config { writeln!( f, "{:.1} bits -- OOD commitment", - self.initial_committer.rbr_ood_sample() + rbr_ood_sample( + self.initial_committer.list_size(), + field_size_bits, + self.initial_committer.vector_size, + self.initial_out_domain_samples, + ) )?; } let prox_gaps_error = self.initial_committer.rbr_soundness_fold_prox_gaps(); @@ -372,13 +417,18 @@ impl Display for Config { writeln!( f, "{:.1} bits -- OOD sample", - r.irs_committer.rbr_ood_sample() + rbr_ood_sample( + r.irs_committer.list_size(), + field_size_bits, + r.irs_committer.vector_size, + r.out_domain_samples, + ) )?; } let log_list_size = r.irs_committer.list_size().log2(); let combination_error = { - let count = r.irs_committer.out_domain_samples + old_in_domain_samples; + let count = r.out_domain_samples + old_in_domain_samples; let log_combination = (count as f64).log2(); field_size_bits - (log_combination + log_list_size + 1.) }; @@ -462,8 +512,6 @@ impl Display for RoundConfig { #[cfg(test)] mod tests { - use ordered_float::OrderedFloat; - use super::*; use crate::{ algebra::{ @@ -472,7 +520,7 @@ mod tests { }, bits::Bits, hash, - protocols::matrix_commit, + protocols::{matrix_commit, params::regime::DecodingRegimeParams}, type_info::Typed, utils::test_serde, }; @@ -484,7 +532,7 @@ mod tests { pow_bits: 20, initial_folding_factor: 4, folding_factor: 4, - unique_decoding: false, + decoding_regime: crate::protocols::params::DecodingRegime::Johnson, starting_log_inv_rate: 1, batch_size: 1, hash_id: hash::BLAKE3, @@ -530,22 +578,21 @@ mod tests { embedding: Typed::new(embedding::Identity::new()), num_vectors: 1, vector_size: 1 << 10, - mask_length: 0, + mode: IrsMode::Standard, codeword_length: 1 << (10 + 3 - 2), interleaving_depth: 1 << 2, matrix_commit: matrix_commit::Config::::new(0, 0), - johnson_slack: OrderedFloat::default(), + regime: DecodingRegimeParams::Unique, in_domain_samples: 5, - out_domain_samples: 2, deduplicate_in_domain: true, }, - sumcheck: sumcheck::Config { - field: Type::::new(), - initial_size: 1 << 10, - round_pow: proof_of_work::Config::from_difficulty(Bits::new(19.0)), - num_rounds: 2, - mask_length: 0, - }, + out_domain_samples: 2, + sumcheck: sumcheck::Config::::new( + 1 << 10, + proof_of_work::Config::from_difficulty(Bits::new(19.0)), + 2, + sumcheck::SumcheckMode::Standard, + ), pow: proof_of_work::Config::from_difficulty(Bits::new(17.0)), }, RoundConfig { @@ -553,22 +600,21 @@ mod tests { embedding: Typed::new(embedding::Identity::new()), num_vectors: 1, vector_size: 1 << 10, - mask_length: 0, + mode: IrsMode::Standard, codeword_length: 1 << (10 + 4 - 2), interleaving_depth: 1 << 2, matrix_commit: matrix_commit::Config::::new(0, 0), - johnson_slack: OrderedFloat::default(), + regime: DecodingRegimeParams::Unique, in_domain_samples: 6, - out_domain_samples: 2, deduplicate_in_domain: true, }, - sumcheck: sumcheck::Config { - field: Type::::new(), - initial_size: 1 << 10, - round_pow: proof_of_work::Config::from_difficulty(Bits::new(19.5)), - num_rounds: 2, - mask_length: 0, - }, + out_domain_samples: 2, + sumcheck: sumcheck::Config::::new( + 1 << 10, + proof_of_work::Config::from_difficulty(Bits::new(19.5)), + 2, + sumcheck::SumcheckMode::Standard, + ), pow: proof_of_work::Config::from_difficulty(Bits::new(18.0)), }, ]; diff --git a/src/protocols/whir/mod.rs b/src/protocols/whir/mod.rs index 805a206e..3f338e92 100644 --- a/src/protocols/whir/mod.rs +++ b/src/protocols/whir/mod.rs @@ -30,6 +30,9 @@ use crate::{ #[serde(bound = "")] pub struct Config { pub initial_committer: irs_commit::Config, + /// OOD samples on the initial commit (Construction 9.7-style OOD step, + /// formerly inside `irs_commit::commit`). + pub initial_out_domain_samples: usize, pub initial_sumcheck: sumcheck::Config, pub initial_skip_pow: proof_of_work::Config, pub round_configs: Vec>, @@ -41,12 +44,42 @@ pub struct Config { #[serde(bound = "")] pub struct RoundConfig { pub irs_committer: irs_commit::Config>, + /// OOD samples for this round's commit. + pub out_domain_samples: usize, pub sumcheck: sumcheck::Config, pub pow: proof_of_work::Config, } -pub type Witness> = irs_commit::Witness; -pub type Commitment = irs_commit::Commitment; +/// WHIR-level witness: IRS witness + OOD evaluations sampled at commit time. +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[serde(bound( + serialize = "M::Source: Serialize, F: Serialize", + deserialize = "M::Source: Deserialize<'de>, F: Deserialize<'de>" +))] +pub struct Witness> { + pub irs: irs_commit::Witness, + pub out_of_domain: irs_commit::Evaluations, +} + +/// WHIR-level commitment: IRS commitment + OOD evaluations received at commit time. +#[derive(Clone, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[serde(bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>"))] +pub struct Commitment { + pub irs: irs_commit::Commitment, + pub out_of_domain: irs_commit::Evaluations, +} + +impl> Witness { + pub fn num_vectors(&self) -> usize { + self.out_of_domain.num_columns() + } +} + +impl Commitment { + pub fn num_vectors(&self) -> usize { + self.out_of_domain.num_columns() + } +} #[must_use = "The final claim must be checked if there where any linear forms."] #[derive(Debug, Clone, PartialEq, Eq, Default)] @@ -75,6 +108,10 @@ impl FinalClaim { impl Config { /// Commit to one or more vectors. + /// + /// After the IRS commit, runs the legacy WHIR OOD step: samples + /// `initial_out_domain_samples` random points from the verifier and sends + /// each vector's evaluation at each point. #[cfg_attr(feature = "tracing", instrument(skip_all, fields(size = vectors.first().unwrap().len())))] pub fn commit( &self, @@ -88,7 +125,12 @@ impl Config { M::Target: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - self.initial_committer.commit(prover_state, vectors) + let (irs, out_of_domain) = self.initial_committer.commit_with_ood( + prover_state, + vectors, + self.initial_out_domain_samples, + ); + Witness { irs, out_of_domain } } /// Receive a commitment to vectors. @@ -101,7 +143,10 @@ impl Config { M::Target: Codec<[H::U]>, Hash: ProverMessage<[H::U]>, { - self.initial_committer.receive_commitment(verifier_state) + let (irs, out_of_domain) = self + .initial_committer + .receive_commitment_with_ood(verifier_state, self.initial_out_domain_samples)?; + Ok(Commitment { irs, out_of_domain }) } /// Disable proof-of-work for test. @@ -135,6 +180,7 @@ mod tests { }, hash, parameters::ProtocolParameters, + protocols::params::DecodingRegime, transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, utils::test_serde, }; @@ -178,7 +224,7 @@ mod tests { initial_folding_factor: usize, folding_factor: usize, num_points: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, pow_bits: usize, ) { // Number of coefficients in the multilinear polynomial (2^num_variables) @@ -190,7 +236,7 @@ mod tests { pow_bits, initial_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: 1, batch_size: 1, hash_id: hash::SHA2, @@ -276,14 +322,14 @@ mod tests { let num_variables = folding_factor..=3 * folding_factor; for num_variable in num_variables { for num_points in [0, 1, 2] { - for unique_decoding in [true, false] { + for decoding_regime in [DecodingRegime::Unique, DecodingRegime::Johnson] { for pow_bits in [0, 5, 10] { eprintln!(); dbg!( folding_factor, num_variable, num_points, - unique_decoding, + decoding_regime, pow_bits ); @@ -292,7 +338,7 @@ mod tests { folding_factor, folding_factor, num_points, - unique_decoding, + decoding_regime, pow_bits, ); } @@ -304,7 +350,7 @@ mod tests { #[test] fn test_fail() { - make_whir_things(3, 2, 2, 0, false, 0); + make_whir_things(3, 2, 2, 0, DecodingRegime::Johnson, 0); } #[test] @@ -334,7 +380,7 @@ mod tests { initial_folding_factor, folding_factor, num_points, - false, + DecodingRegime::Johnson, 5, ); } @@ -353,7 +399,7 @@ mod tests { folding_factor: usize, num_points_per_poly: usize, num_vectors: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, pow_bits: usize, ) { let num_coeffs = 1 << num_variables; @@ -363,7 +409,7 @@ mod tests { pow_bits, initial_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: 1, batch_size: 1, hash_id: hash::SHA2, @@ -484,7 +530,7 @@ mod tests { folding_factor, num_points_per_poly, num_polys, - false, + DecodingRegime::Johnson, 0, // pow_bits ); } @@ -503,7 +549,8 @@ mod tests { 2, // folding_factor 2, // num_points_per_poly 1, // num_polynomials (single!) - false, 0, + DecodingRegime::Johnson, + 0, ); } @@ -531,7 +578,7 @@ mod tests { pow_bits: 0, initial_folding_factor, folding_factor, - unique_decoding: false, + decoding_regime: DecodingRegime::Johnson, starting_log_inv_rate: 1, batch_size: 1, hash_id: hash::SHA2, @@ -627,7 +674,7 @@ mod tests { num_points_per_poly: usize, num_witnesses: usize, batch_size: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, pow_bits: usize, ) { let num_coeffs = 1 << num_variables; @@ -637,7 +684,7 @@ mod tests { pow_bits, initial_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: 1, batch_size, // KEY: batch_size > 1 hash_id: hash::SHA2, @@ -743,7 +790,7 @@ mod tests { 1, // num_points_per_poly num_witness, batch_size, - false, + DecodingRegime::Johnson, 0, // pow_bits ); } @@ -758,7 +805,7 @@ mod tests { initial_folding_factor: usize, folding_factor: usize, num_points: usize, - unique_decoding: bool, + decoding_regime: DecodingRegime, pow_bits: usize, ) { eprintln!("\n---------------------"); @@ -768,7 +815,7 @@ mod tests { eprintln!(" initial_folding : {initial_folding_factor}"); eprintln!(" folding_factor : {folding_factor}"); eprintln!(" num_points : {num_points:?}"); - eprintln!(" unique_decoding : {unique_decoding:?}"); + eprintln!(" decoding_regime : {decoding_regime:?}"); eprintln!(" pow_bits : {pow_bits}"); // Number of coefficients in the multilinear polynomial (2^num_variables) @@ -780,7 +827,7 @@ mod tests { pow_bits, initial_folding_factor, folding_factor, - unique_decoding, + decoding_regime, starting_log_inv_rate: 1, batch_size, hash_id: hash::SHA2, @@ -865,7 +912,7 @@ mod tests { #[test] fn test_batched_whir() { let folding_factors = [1, 4]; - let unique_decoding_options = [false, true]; + let decoding_regime_options = [DecodingRegime::Johnson, DecodingRegime::Unique]; let num_points = [0, 2]; let pow_bits = [0, 10]; @@ -873,7 +920,7 @@ mod tests { let num_variables = (2 * folding_factor)..=3 * folding_factor; for num_variable in num_variables { for num_points in num_points { - for unique_decoding in unique_decoding_options { + for decoding_regime in decoding_regime_options { for pow_bits in pow_bits { for batch_size in 1..=4 { make_batched_whir_things( @@ -882,7 +929,7 @@ mod tests { folding_factor, folding_factor, num_points, - unique_decoding, + decoding_regime, pow_bits, ); } diff --git a/src/protocols/whir/prover.rs b/src/protocols/whir/prover.rs index 04b4473e..1a300cc0 100644 --- a/src/protocols/whir/prover.rs +++ b/src/protocols/whir/prover.rs @@ -26,8 +26,8 @@ use crate::{ }; enum RoundWitness<'a, F: Field, M: Embedding> { - Initial(Vec>>), - Round(irs_commit::Witness), + Initial(Vec>>), + Round(irs_commit::Witness), } impl Config { @@ -103,8 +103,8 @@ impl Config { let mut vector_offset = 0; for witness in &witnesses { for (oods_eval, oods_row) in zip_strict( - witness.out_of_domain().evaluators(self.initial_size()), - witness.out_of_domain().rows(), + witness.out_of_domain.evaluators(self.initial_size()), + witness.out_of_domain.rows(), ) { for (j, vector) in vectors.iter().enumerate() { if j >= vector_offset && j < oods_row.len() + vector_offset { @@ -219,8 +219,12 @@ impl Config { // Execute standard WHIR rounds on the batched vectors for (round_index, round_config) in self.round_configs.iter().enumerate() { - // Commit to the vector, this generates out-of-domain evaluations. - let new_witness = round_config.irs_committer.commit(prover_state, &[&vector]); + // Commit to the folded vector and run the per-round OOD step. + let (new_witness, out_of_domain) = round_config.irs_committer.commit_with_ood( + prover_state, + &[&vector], + round_config.out_domain_samples, + ); // Proof of work before in-domain challenges round_config.pow.prove(prover_state); @@ -228,9 +232,9 @@ impl Config { // Open the previous round's witness. let in_domain = match prev_witness { RoundWitness::Initial(init_witnesses) => { - let witness_refs: Vec<&_> = init_witnesses.iter().map(|c| &**c).collect(); + let irs_refs: Vec<&_> = init_witnesses.iter().map(|c| &c.irs).collect(); self.initial_committer - .open(prover_state, &witness_refs) + .open(prover_state, &irs_refs) .lift(self.embedding()) } RoundWitness::Round(old_witness) => { @@ -242,13 +246,11 @@ impl Config { }; // Collect constraints for this round and RLC them in - let stir_challenges = new_witness - .out_of_domain() + let stir_challenges = out_of_domain .evaluators(round_config.initial_size()) .chain(in_domain.evaluators(round_config.initial_size())) .collect::>(); - let stir_evaluations = new_witness - .out_of_domain() + let stir_evaluations = out_of_domain .values(&[M::Target::ONE]) .chain(in_domain.values(&tensor_product( &vector_rlc_coeffs, @@ -289,8 +291,8 @@ impl Config { // Open and consume the final previous witness. match prev_witness { RoundWitness::Initial(init_witnesses) => { - let witness_refs: Vec<&_> = init_witnesses.iter().map(|c| &**c).collect(); - let _in_domain = self.initial_committer.open(prover_state, &witness_refs); + let irs_refs: Vec<&_> = init_witnesses.iter().map(|c| &c.irs).collect(); + let _in_domain = self.initial_committer.open(prover_state, &irs_refs); } RoundWitness::Round(old_witness) => { let prev_config = self.round_configs.last().unwrap(); diff --git a/src/protocols/whir/verifier.rs b/src/protocols/whir/verifier.rs index 8421ceab..5f4e7b84 100644 --- a/src/protocols/whir/verifier.rs +++ b/src/protocols/whir/verifier.rs @@ -23,11 +23,11 @@ use crate::{ enum RoundCommitment<'a, F: Field> { Initial { - commitments: &'a [&'a irs_commit::Commitment], + commitments: &'a [&'a Commitment], batching_weights: Vec, }, Round { - commitment: irs_commit::Commitment, + commitment: irs_commit::Commitment, }, } @@ -63,6 +63,13 @@ impl Config { return Ok(FinalClaim::default()); } + let expected_matrix_len = + self.initial_out_domain_samples * self.initial_committer.num_vectors; + for commitment in commitments { + verify!(commitment.out_of_domain.points.len() == self.initial_out_domain_samples); + verify!(commitment.out_of_domain.matrix.len() == expected_matrix_len); + } + // Complete the constraint and evaluation matrix with OODs and their cross-terms. let (oods_evals, oods_matrix) = { let mut oods_evals = Vec::new(); @@ -72,8 +79,8 @@ impl Config { let mut vector_offset = 0; for commitment in commitments { for (weights, oods_row) in zip_strict( - commitment.out_of_domain().evaluators(self.initial_size()), - commitment.out_of_domain().rows(), + commitment.out_of_domain.evaluators(self.initial_size()), + commitment.out_of_domain.rows(), ) { for j in 0..num_vectors { if j >= vector_offset && j < oods_row.len() + vector_offset { @@ -133,10 +140,10 @@ impl Config { round_folding_randomness.push(folding_randomness); for (round_index, round_config) in self.round_configs.iter().enumerate() { - // Receive commitment to the folded vector, plus out-of-domain constraints - let commitment = round_config + // Receive commitment to the folded vector and the per-round OOD evaluations. + let (commitment, out_of_domain) = round_config .irs_committer - .receive_commitment(verifier_state)?; + .receive_commitment_with_ood(verifier_state, round_config.out_domain_samples)?; // Proof of work before in-domain challenges round_config.pow.verify(verifier_state)?; @@ -147,7 +154,8 @@ impl Config { commitments, batching_weights, } => { - let in_domain = self.initial_committer.verify(verifier_state, commitments)?; + let irs_refs: Vec<&_> = commitments.iter().map(|c| &c.irs).collect(); + let in_domain = self.initial_committer.verify(verifier_state, &irs_refs)?; // TODO: Skip lift and keep initial in-domain in subfield for evaluation. // This should be every so slightly more performant. (in_domain.lift(self.embedding()), batching_weights) @@ -162,13 +170,11 @@ impl Config { }; // Random linear combination of out- and in-domain constraints - let constraint_weights = commitment - .out_of_domain() + let constraint_weights = out_of_domain .evaluators(round_config.initial_size()) .chain(in_domain.evaluators(round_config.initial_size())) .collect::>(); - let constraint_values = commitment - .out_of_domain() + let constraint_values = out_of_domain .values(&[M::Target::ONE]) .chain(in_domain.values(&tensor_product( &poly_rlc, @@ -202,7 +208,8 @@ impl Config { commitments, batching_weights, } => { - let in_domain = self.initial_committer.verify(verifier_state, commitments)?; + let irs_refs: Vec<&_> = commitments.iter().map(|c| &c.irs).collect(); + let in_domain = self.initial_committer.verify(verifier_state, &irs_refs)?; (in_domain.lift(self.embedding()), batching_weights) } RoundCommitment::Round { commitment } => { diff --git a/src/protocols/whir_zk/committer.rs b/src/protocols/whir_zk/committer.rs index 942b87e1..44768014 100644 --- a/src/protocols/whir_zk/committer.rs +++ b/src/protocols/whir_zk/committer.rs @@ -5,8 +5,9 @@ use tracing::instrument; use super::{utils::BlindingPolynomials, Config}; use crate::{ + algebra::embedding::Identity, hash::Hash, - protocols::{irs_commit, whir}, + protocols::whir, transcript::{ Codec, DuplexSpongeInterface, ProverMessage, ProverState, VerificationResult, VerifierState, }, @@ -26,10 +27,10 @@ pub struct Commitment { #[derive(Clone, Debug)] pub struct Witness { pub f_hat_vectors: Vec>, - pub f_hat_witnesses: Vec>, + pub f_hat_witnesses: Vec>>, pub blinding_polynomials: Vec>, pub blinding_vectors: Vec>, - pub blinding_witness: irs_commit::Witness, + pub blinding_witness: whir::Witness>, } impl Config { diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs index b5e41c6e..18dacc42 100644 --- a/src/protocols/whir_zk/mod.rs +++ b/src/protocols/whir_zk/mod.rs @@ -43,16 +43,18 @@ impl BlindingSizePolicy { .saturating_sub(main_whir_params.pow_bits); #[allow(clippy::cast_possible_wrap)] let q_delta_1 = irs_commit::num_in_domain_queries( - main_whir_params.unique_decoding, + main_whir_params.decoding_regime, protocol_security_level_main as f64, 0.5_f64.powi(main_whir_params.starting_log_inv_rate as i32), - ); + ) + .get(); #[allow(clippy::cast_possible_wrap)] let q_delta_2 = irs_commit::num_in_domain_queries( - main_whir_params.unique_decoding, + main_whir_params.decoding_regime, main_whir_params.security_level as f64, 0.5_f64.powi(main_whir_params.starting_log_inv_rate as i32), - ); + ) + .get(); // Default send-in-clear thresholds match query complexities. Self { @@ -255,6 +257,7 @@ mod tests { }, hash, parameters::ProtocolParameters, + protocols::params::DecodingRegime, transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, }; @@ -275,7 +278,7 @@ mod tests { fn make_test_config(num_polynomials: usize) -> Config { let whir_params = ProtocolParameters { - unique_decoding: false, + decoding_regime: DecodingRegime::Johnson, security_level: 16, pow_bits: 0, initial_folding_factor: 2, diff --git a/src/protocols/whir_zk/prover.rs b/src/protocols/whir_zk/prover.rs index d1527c13..c2da218e 100644 --- a/src/protocols/whir_zk/prover.rs +++ b/src/protocols/whir_zk/prover.rs @@ -320,7 +320,7 @@ impl Config { let initial_in_domain = { #[cfg(feature = "tracing")] let _span = tracing::info_span!("open_f_hat").entered(); - let witness_refs: Vec<_> = f_hat_witnesses.iter().collect(); + let witness_refs: Vec<_> = f_hat_witnesses.iter().map(|w| &w.irs).collect(); self.blinded_commitment .initial_committer .open(prover_state, &witness_refs) diff --git a/src/protocols/whir_zk/verifier.rs b/src/protocols/whir_zk/verifier.rs index 7df12023..1ad7e0d6 100644 --- a/src/protocols/whir_zk/verifier.rs +++ b/src/protocols/whir_zk/verifier.rs @@ -73,10 +73,11 @@ impl Config { let masking_challenge: F = verifier_state.verifier_message(); verify!(masking_challenge != F::ZERO); let commitments = commitment.f_hat.iter().collect::>(); + let irs_commitments = commitments.iter().map(|c| &c.irs).collect::>(); let initial_in_domain = self .blinded_commitment .initial_committer - .verify(verifier_state, &commitments)?; + .verify(verifier_state, &irs_commitments)?; // Expand base queries into coset points for the first folding round. let h_gammas = self.all_gammas(&initial_in_domain.points);