diff --git a/plonk/src/nightfall/accumulation/circuit/atomic_gadgets.rs b/plonk/src/nightfall/accumulation/circuit/atomic_gadgets.rs index 257e6a44..f88d7989 100644 --- a/plonk/src/nightfall/accumulation/circuit/atomic_gadgets.rs +++ b/plonk/src/nightfall/accumulation/circuit/atomic_gadgets.rs @@ -322,8 +322,8 @@ mod tests { circuit .finalize_for_recursive_mle_arithmetization::>() .unwrap(); - let pi = circuit.public_input().unwrap()[0]; - circuit.check_circuit_satisfiability(&[pi]).unwrap(); + let pi = circuit.public_input().unwrap(); + circuit.check_circuit_satisfiability(&pi).unwrap(); } #[test] @@ -450,8 +450,8 @@ mod tests { circuit .finalize_for_recursive_mle_arithmetization::>() .unwrap(); - let pi = circuit.public_input().unwrap()[0]; - circuit.check_circuit_satisfiability(&[pi]).unwrap(); + let pi = circuit.public_input().unwrap(); + circuit.check_circuit_satisfiability(&pi).unwrap(); let srs_size = circuit.num_gates().ilog2() as usize; let srs = Zeromorph::>::gen_srs_for_testing(&mut rng, srs_size) .unwrap(); diff --git a/plonk/src/nightfall/circuit/plonk_partial_verifier/gadgets.rs b/plonk/src/nightfall/circuit/plonk_partial_verifier/gadgets.rs index 74070bd1..8550ea6c 100644 --- a/plonk/src/nightfall/circuit/plonk_partial_verifier/gadgets.rs +++ b/plonk/src/nightfall/circuit/plonk_partial_verifier/gadgets.rs @@ -33,7 +33,7 @@ use super::{ #[allow(clippy::too_many_arguments)] pub fn compute_scalars_for_native_field( circuit: &mut PlonkCircuit, - pi: &Variable, + pi: &[Variable; 2], challenges: &ChallengesVar, proof_evals: &ProofEvalsVarNative, lookup_evals: &Option, @@ -275,7 +275,7 @@ pub fn compute_scalars_for_native_field( #[allow(clippy::too_many_arguments)] pub(crate) fn compute_scalars_for_native_field_base( circuit: &mut PlonkCircuit, - pi: &Variable, + pi: &[Variable; 2], challenges: &ChallengesVar, proof_evals: &ProofEvalsVarNative, lookup_evals: &Option, @@ -307,6 +307,7 @@ pub(crate) fn compute_scalars_for_native_field_base as Transcript>::new_transcript(b"mle_plonk"); - let mle_challenges = MLEChallenges::new_recursion(&proof.proof, &[pi], &mut transcript) + let mle_challenges = MLEChallenges::new_recursion(&proof.proof, &pi, &mut transcript) .map_err(|_| { - CircuitError::ParameterError("MLE challenge generation failed".to_string()) - })?; + CircuitError::ParameterError("MLE challenge generation failed".to_string()) + })?; let mut plonk_circuit = PlonkCircuit::::new_ultra_plonk(RANGE_BIT_LEN_FOR_TEST); let mle_proof_var = SAMLEProofVar::from_struct::

(&mut plonk_circuit, &proof.proof)?; let mut transcript_var = RescueTranscriptVar::::new_transcript(&mut plonk_circuit); - let pi_var = plonk_circuit.create_emulated_variable(pi)?; + let pi_vars: [EmulatedVariable; 2] = pi + .iter() + .map(|val| plonk_circuit.create_emulated_variable(*val)) + .collect::, CircuitError>>()? + .try_into() + .map_err(|_| { + CircuitError::ParameterError( + "Couldn't convert public inputs to fixed length array".to_string(), + ) + })?; let mle_challenges_var = EmulatedMLEChallenges::::compute_challenges_vars::< PCS, @@ -786,7 +796,7 @@ mod test { RescueTranscriptVar, >( &mut plonk_circuit, - &pi_var, + &pi_vars, &mle_proof_var, &mut transcript_var, )?; diff --git a/plonk/src/nightfall/circuit/plonk_partial_verifier/mod.rs b/plonk/src/nightfall/circuit/plonk_partial_verifier/mod.rs index bd2f82d9..4da21605 100644 --- a/plonk/src/nightfall/circuit/plonk_partial_verifier/mod.rs +++ b/plonk/src/nightfall/circuit/plonk_partial_verifier/mod.rs @@ -936,7 +936,9 @@ mod test { circuit .finalize_for_recursive_arithmetization::>() .unwrap(); - let pi = circuit.public_input().unwrap()[0]; + let pi: [P::ScalarField; 2] = circuit.public_input()?.try_into().map_err(|_| { + CircuitError::ParameterError("Couldn't convert to fixed length array".to_string()) + })?; let max_degree = circuit.srs_size(blind)?; let srs = FFTPlonk::::universal_setup_for_testing(max_degree, rng).unwrap(); @@ -954,7 +956,7 @@ mod test { let pcs_info = verifier .prepare_pcs_info::>( &vk, - &[pi], + &pi, &proof.proof, &None, blind, @@ -969,7 +971,7 @@ mod test { let challenges = FFTVerifier::::compute_challenges::< RescueTranscript, - >(&vk, &[pi], &proof.proof, &None)?; + >(&vk, &pi, &proof.proof, &None)?; let mut circuit = PlonkCircuit::::new_turbo_plonk(); let tau = circuit.create_variable(challenges.tau)?; @@ -999,11 +1001,20 @@ mod test { &proof.proof.plookup_proof.as_ref().unwrap().poly_evals, )?; - let pi_var = circuit.create_variable(pi)?; + let pi_vars: [Variable; 2] = pi + .iter() + .map(|val| circuit.create_variable(*val)) + .collect::, CircuitError>>()? + .try_into() + .map_err(|_| { + CircuitError::ParameterError( + "Couldn't convert to fixed length array".to_string(), + ) + })?; let scalars = compute_scalars_for_native_field::( &mut circuit, - &pi_var, + &pi_vars, &challenges_var, &proof_evals, &Some(lookup_evals), @@ -1083,7 +1094,7 @@ mod test { let vk_k = vec![Fr254::zero(); 6]; let scalars = compute_scalars_for_native_field::( &mut circuit, - &0, + &[0, 0], &challenges_var, &proof_evals, &Some(lookup_evals), diff --git a/plonk/src/nightfall/circuit/plonk_partial_verifier/poly.rs b/plonk/src/nightfall/circuit/plonk_partial_verifier/poly.rs index cb167910..c3eb1f8f 100644 --- a/plonk/src/nightfall/circuit/plonk_partial_verifier/poly.rs +++ b/plonk/src/nightfall/circuit/plonk_partial_verifier/poly.rs @@ -120,7 +120,7 @@ where /// This helper function generate the variables for the following data /// - Circuit evaluation of vanishing polynomial at point `zeta` i.e., output = /// zeta ^ domain_size - 1 mod Fr::modulus -/// - Evaluations of the first and the last lagrange polynomial at point `zeta` +/// - Evaluations of the first, second and last lagrange polynomial at point `zeta` /// /// Note that outputs and zeta are both Fr element /// so this needs to be carried out over a non-native circuit @@ -138,7 +138,7 @@ pub(super) fn evaluate_poly_helper_native( zeta_var: Variable, gen_inv: F, domain_size: usize, -) -> Result<[Variable; 4], CircuitError> +) -> Result<[Variable; 5], CircuitError> where F: PrimeField + RescueParameter, { @@ -158,10 +158,16 @@ where // ================================ // evaluate lagrange at 1 - // lagrange_1_eval = (zeta^n - 1) / (zeta - 1) / domain_size + // lagrange_1_eval = (zeta^n - 1) / (domain_size * (zeta - 1)) // // which is proven via // domain_size * lagrange_1_eval * (zeta - 1) = zeta^n - 1 mod Fr::modulus + // + // similarly we calculate + // lagrange_2_eval = (zeta^n - 1) * omega / (domain_size * (zeta - omega)) = (zeta^n - 1) / (domain_size * (zeta * omega^{-1} - 1)) + // and + // lagrange_n_eval = (zeta^n - 1) * omega^{-1} / (domain_size * (zeta - omega^{-1})) = (zeta^n - 1) / (domain_size * (zeta * omega - 1)) + // // ================================ let domain_size = F::from(domain_size as u64); @@ -178,20 +184,30 @@ where // Constrain the lagrange_1_eval to be correct. circuit.mul_gate(divisor_var, lagrange_1_eval_var, zeta_n_minus_one_var)?; + // Compute lagrange_2_eval + let divisor_var = circuit.lin_comb(&[domain_size * gen_inv], &(-domain_size), &[zeta_var])?; + let divisor = circuit.witness(divisor_var)?; + let lagrange_2_eval = zeta_n_minus_one / divisor; + let lagrange_2_eval_var = circuit.create_variable(lagrange_2_eval)?; + // Constrain the lagrange_2_eval to be correct. + circuit.mul_gate(divisor_var, lagrange_2_eval_var, zeta_n_minus_one_var)?; + + let gen = gen_inv + .inverse() + .ok_or(CircuitError::ParameterError("Inverse Failed".to_string()))?; // Compute lagrange_n_eval - let divisor_var = circuit.lin_comb(&[domain_size], &(-domain_size * gen_inv), &[zeta_var])?; - let numerator_var = circuit.mul_constant(zeta_n_minus_one_var, &gen_inv)?; + let divisor_var = circuit.lin_comb(&[domain_size * gen], &(-domain_size), &[zeta_var])?; let divisor = circuit.witness(divisor_var)?; - let numerator = circuit.witness(numerator_var)?; - let lagrange_n_eval = numerator / divisor; + let lagrange_n_eval = zeta_n_minus_one / divisor; let lagrange_n_eval_var = circuit.create_variable(lagrange_n_eval)?; // Constrain the lagrange_n_eval to be correct. - circuit.mul_gate(divisor_var, lagrange_n_eval_var, numerator_var)?; + circuit.mul_gate(divisor_var, lagrange_n_eval_var, zeta_n_minus_one_var)?; Ok([ zeta_n_var, zeta_n_minus_one_var, lagrange_1_eval_var, + lagrange_2_eval_var, lagrange_n_eval_var, ]) } @@ -199,10 +215,11 @@ where pub(super) fn evaluate_poly_helper_native_base( circuit: &mut PlonkCircuit, zeta_var: Variable, + gen_var: Variable, gen_inv_var: Variable, domain_size_var: Variable, max_domain_size: usize, -) -> Result<[Variable; 4], CircuitError> +) -> Result<[Variable; 5], CircuitError> where F: PrimeField + RescueParameter, { @@ -216,10 +233,16 @@ where // ================================ // evaluate lagrange at 1 - // lagrange_1_eval = (zeta^n - 1) / (zeta - 1) / domain_size + // lagrange_1_eval = (zeta^n - 1) / (domain_size * (zeta - 1)) // // which is proven via // domain_size * lagrange_1_eval * (zeta - 1) = zeta^n - 1 mod Fr::modulus + // + // similarly we calculate + // lagrange_2_eval = (zeta^n - 1) * omega / (domain_size * (zeta - omega)) + // and + // lagrange_n_eval = (zeta^n - 1) * omega^{-1} / (domain_size * (zeta - omega^{-1})) + // // ================================ // lagrange_1_eval @@ -237,6 +260,19 @@ where // Constrain the lagrange_1_eval to be correct. circuit.mul_gate(divisor_var, lagrange_1_eval_var, zeta_n_minus_one_var)?; + // Compute lagrange_2_eval + let divisor_var = circuit.mul_add( + &[domain_size_var, zeta_var, domain_size_var, gen_var], + &[F::one(), -F::one()], + )?; + let numerator_var = circuit.mul(zeta_n_minus_one_var, gen_var)?; + let divisor = circuit.witness(divisor_var)?; + let numerator = circuit.witness(numerator_var)?; + let lagrange_2_eval = numerator / divisor; + let lagrange_2_eval_var = circuit.create_variable(lagrange_2_eval)?; + // Constrain the lagrange_n_eval to be correct. + circuit.mul_gate(divisor_var, lagrange_2_eval_var, numerator_var)?; + // Compute lagrange_n_eval let divisor_var = circuit.mul_add( &[domain_size_var, zeta_var, domain_size_var, gen_inv_var], @@ -254,6 +290,7 @@ where zeta_n_var, zeta_n_minus_one_var, lagrange_1_eval_var, + lagrange_2_eval_var, lagrange_n_eval_var, ]) } @@ -431,8 +468,8 @@ pub(super) fn compute_lin_poly_constant_term_circuit_native( gen_inv: &F, challenges: &ChallengesVar, proof_evals: &ProofEvalsVarNative, - pi: &Variable, - evals: &[Variable; 4], + pi: &[Variable; 2], + evals: &[Variable; 5], lookup_evals: &Option, ) -> Result where @@ -493,7 +530,7 @@ where prod = circuit.mul(tmp, prod)?; // r_plonk - let pi_eval = circuit.mul(*pi, evals[2])?; + let pi_eval = circuit.mul_add(&[pi[0], evals[2], pi[1], evals[3]], &[F::one(), F::one()])?; let wires = [pi_eval, prod, evals[2], challenges.alphas[1]]; let non_lookup = circuit.gen_quad_poly( &wires, @@ -512,7 +549,7 @@ where ]; let tmp = circuit.lc(&wires, &[F::one(), -F::one(), -F::one(), F::zero()])?; let term_one = circuit.mul_add( - &[evals[3], tmp, evals[2], challenges.alphas[0]], + &[evals[4], tmp, evals[2], challenges.alphas[0]], &[F::one(), -F::one()], )?; @@ -581,8 +618,8 @@ pub(super) fn compute_lin_poly_constant_term_circuit_native_base( gen_inv_var: &Variable, challenges: &ChallengesVar, proof_evals: &ProofEvalsVarNative, - pi: &Variable, - evals: &[Variable; 4], + pi: &[Variable; 2], + evals: &[Variable; 5], lookup_evals: &Option, ) -> Result where @@ -643,7 +680,7 @@ where prod = circuit.mul(tmp, prod)?; // r_plonk - let pi_eval = circuit.mul(*pi, evals[2])?; + let pi_eval = circuit.mul_add(&[pi[0], evals[2], pi[1], evals[3]], &[F::one(), F::one()])?; let wires = [pi_eval, prod, evals[2], challenges.alphas[1]]; let non_lookup = circuit.gen_quad_poly( &wires, @@ -662,7 +699,7 @@ where ]; let tmp = circuit.lc(&wires, &[F::one(), -F::one(), -F::one(), F::zero()])?; let term_one = circuit.mul_add( - &[evals[3], tmp, evals[2], challenges.alphas[0]], + &[evals[4], tmp, evals[2], challenges.alphas[0]], &[F::one(), -F::one()], )?; @@ -760,7 +797,7 @@ pub fn linearization_scalars_circuit_native( circuit: &mut PlonkCircuit, vk_k: &[F], challenges: &ChallengesVar, - evals: &[Variable; 4], + evals: &[Variable; 5], poly_evals: &ProofEvalsVarNative, lookup_evals: &Option, gen_inv: &F, @@ -1004,7 +1041,7 @@ where challenges.alphas[0], evals[2], challenges.alphas[1], - evals[3], + evals[4], ], &[F::one(), F::one()], )?; @@ -1081,7 +1118,7 @@ pub fn linearization_scalars_circuit_native_base( circuit: &mut PlonkCircuit, vk_k: &[Variable], challenges: &ChallengesVar, - evals: &[Variable; 4], + evals: &[Variable; 5], poly_evals: &ProofEvalsVarNative, lookup_evals: &Option, gen_inv_var: &Variable, @@ -1318,7 +1355,7 @@ where challenges.alphas[0], evals[2], challenges.alphas[1], - evals[3], + evals[4], ], &[F::one(), F::one()], )?; diff --git a/plonk/src/nightfall/circuit/plonk_partial_verifier/proof_to_var.rs b/plonk/src/nightfall/circuit/plonk_partial_verifier/proof_to_var.rs index a9f0f7f8..8e930080 100644 --- a/plonk/src/nightfall/circuit/plonk_partial_verifier/proof_to_var.rs +++ b/plonk/src/nightfall/circuit/plonk_partial_verifier/proof_to_var.rs @@ -278,7 +278,7 @@ pub struct Bn254OutputScalarsAndBasesVar { /// The proof generated by the recursive prover. pub proof: Bn254ProofScalarsandBasesVar, /// The hash of the public inputs to this proof stored in the clear. - pub pi_hash: Fr254, + pub pi_hash: [Fr254; 2], /// The transcript of the proof stored in the clear. pub transcript: RescueTranscript, } diff --git a/plonk/src/nightfall/circuit/plonk_partial_verifier/scalars_and_bases.rs b/plonk/src/nightfall/circuit/plonk_partial_verifier/scalars_and_bases.rs index f377120e..6a7a4d95 100644 --- a/plonk/src/nightfall/circuit/plonk_partial_verifier/scalars_and_bases.rs +++ b/plonk/src/nightfall/circuit/plonk_partial_verifier/scalars_and_bases.rs @@ -747,7 +747,7 @@ mod tests { [true, false], ) { let circuit = gen_circuit_for_test::(m, 3, PlonkType::UltraPlonk, true)?; - let pi = circuit.public_input()?[0]; + let pi = circuit.public_input()?; let srs_size = circuit.srs_size(blind)?; let srs = UnivariateKzgPCS::::gen_srs_for_testing(rng, srs_size)?; @@ -764,7 +764,7 @@ mod tests { let pcs_info = fft_verifier.prepare_pcs_info::>( &vk, - &[pi], + &pi, &output.proof, &None, blind, @@ -776,7 +776,7 @@ mod tests { let (mut output_var, pcs_info_var) = fft_verifier .prepare_pcs_info_with_bases_var::>( &vk_var, - &[pi], + &pi, &output, &None, &mut verifier_circuit, diff --git a/plonk/src/nightfall/circuit/plonk_partial_verifier/structs.rs b/plonk/src/nightfall/circuit/plonk_partial_verifier/structs.rs index f1318470..d2fe0e06 100644 --- a/plonk/src/nightfall/circuit/plonk_partial_verifier/structs.rs +++ b/plonk/src/nightfall/circuit/plonk_partial_verifier/structs.rs @@ -106,7 +106,7 @@ impl ChallengesVar { pub fn compute_challenges( circuit: &mut PlonkCircuit, vk_id: Option, - pi: &Variable, + pi: &[Variable; 2], proof: &ProofVarNative

, transcript: &mut C, ) -> Result @@ -121,7 +121,9 @@ impl ChallengesVar { if let Some(id) = vk_id { transcript.push_variable(&id)?; } - transcript.push_variable(pi)?; + for pi in pi { + transcript.push_variable(pi)?; + } transcript.append_point_variables(&proof.wire_commitments, circuit)?; @@ -214,7 +216,7 @@ impl EmulatedMLEChallenges { /// Computes challenges from a proof. pub fn compute_challenges_vars( circuit: &mut PlonkCircuit, - pi: &EmulatedVariable, + public_input: &[EmulatedVariable; 2], proof_var: &SAMLEProofVar, transcript_var: &mut C, ) -> Result, CircuitError> @@ -225,7 +227,9 @@ impl EmulatedMLEChallenges { P::ScalarField: PrimeField + EmulationConfig + RescueParameter, C: CircuitTranscript, { - transcript_var.push_emulated_variable(pi, circuit)?; + for pi in public_input { + transcript_var.push_emulated_variable(pi, circuit)?; + } transcript_var.append_point_variables(&proof_var.wire_commitments_var, circuit)?; let [gamma, tau]: [usize; 2] = transcript_var @@ -931,7 +935,7 @@ impl ProofEvalsVarNative { pub struct ProofScalarsVarNative { pub(crate) evals: ProofEvalsVarNative, pub(crate) lookup_evals: Option, - pub(crate) pi_hash: Variable, + pub(crate) pi_hash: [Variable; 2], } impl ProofScalarsVarNative { @@ -939,7 +943,7 @@ impl ProofScalarsVarNative { pub fn new( evals: ProofEvalsVarNative, lookup_evals: Option, - pi_hash: Variable, + pi_hash: [Variable; 2], ) -> Self { Self { evals, @@ -951,7 +955,7 @@ impl ProofScalarsVarNative { /// Create a new [`ProofScalarVarNative`] variable from a reference to a [`ProofVarNative`]. pub fn from_struct

( proof_var_native: &ProofVarNative

, - pi_hash: Variable, + pi_hash: [Variable; 2], ) -> Result where P: HasTEForm, diff --git a/plonk/src/nightfall/ipa_snark.rs b/plonk/src/nightfall/ipa_snark.rs index 75dcf5d8..fb08f131 100644 --- a/plonk/src/nightfall/ipa_snark.rs +++ b/plonk/src/nightfall/ipa_snark.rs @@ -158,7 +158,7 @@ where let verifier = FFTVerifier::::new(verify_key.domain_size)?; let pcs_info = verifier.prepare_pcs_info::( verify_key, - &[proof.pi_hash], + &proof.pi_hash, &proof.proof, &extra_transcript_init_msg, blind, @@ -445,8 +445,16 @@ where if prove_keys.vk.id.is_some() { transcript.append_visitor(&prove_keys.vk)?; } - // In the recursive setting we know that the public inputs have length 1. - transcript.push_message(b"public_input", &circuits.public_input()?[0])?; + // In the recursive setting we know that the public inputs have length 2. + let public_inputs: [P::ScalarField; 2] = + circuits.public_input()?.try_into().map_err(|_| { + PlonkError::InvalidParameters( + "Public input length is not equal to 2 in recursive proving".to_string(), + ) + })?; + for pub_in in public_inputs.iter() { + transcript.push_message(b"public_input", pub_in)?; + } // Round 1 let ((wires_poly_comms, wire_polys), pi_poly) = @@ -849,7 +857,11 @@ where blind, )?; - let pi_hash = circuit.public_input()?[0]; + let pi_hash: [P::ScalarField; 2] = circuit.public_input()?.try_into().map_err(|_| { + PlonkError::InvalidParameters( + "Public input length is not equal to 2 in recursive proving".to_string(), + ) + })?; Ok(RecursiveOutput::new(proof, pi_hash, transcript)) } @@ -1325,7 +1337,7 @@ pub mod test { // Inconsistent proof should fail the verification. let bad_proof = RecursiveOutput { proof: proof.proof.clone(), - pi_hash: E::ScalarField::zero(), + pi_hash: [E::ScalarField::zero(); 2], transcript: T::new_transcript(b"bad_transcript"), }; diff --git a/plonk/src/nightfall/ipa_verifier.rs b/plonk/src/nightfall/ipa_verifier.rs index 761c200f..cc65918d 100644 --- a/plonk/src/nightfall/ipa_verifier.rs +++ b/plonk/src/nightfall/ipa_verifier.rs @@ -72,7 +72,7 @@ pub(crate) struct FFTVerifier { /// Function used to reproduce the end state of a transcript, used in recursive proving and verification. pub fn reproduce_transcript( vk_id: Option, - public_input: E::ScalarField, + public_input: [E::ScalarField; 2], proof: &Proof, ) -> Result where @@ -93,7 +93,10 @@ where if let Some(id) = vk_id { transcript.push_message(b"vk_id", &E::ScalarField::from(id as u8))?; } - transcript.push_message(b"public_input", &public_input)?; + + for pi in public_input.iter() { + transcript.push_message(b"public_input", pi)?; + } transcript.append_curve_points(b"witness_poly_comms", &proof.wires_poly_comms)?; diff --git a/plonk/src/nightfall/mle/mle_structs.rs b/plonk/src/nightfall/mle/mle_structs.rs index 07116f11..436eeadb 100644 --- a/plonk/src/nightfall/mle/mle_structs.rs +++ b/plonk/src/nightfall/mle/mle_structs.rs @@ -467,7 +467,9 @@ impl MLEChallenges { T: Transcript, { // Append public input to transcript. - transcript.push_message(b"public input", &public_input[0])?; + for pi in public_input.iter() { + transcript.push_message(b"public input", pi)?; + } // We know that the commitments we are using will always be points on an SW curve. // We append wire commitments here. diff --git a/plonk/src/nightfall/mle/mod.rs b/plonk/src/nightfall/mle/mod.rs index bbfbdfce..1169001d 100644 --- a/plonk/src/nightfall/mle/mod.rs +++ b/plonk/src/nightfall/mle/mod.rs @@ -183,7 +183,11 @@ where } let (proof, transcript) = Self::sa_prove::<_, _, _, T>(circuit, prove_key, extra_transcript_init_msg)?; - let pi_hash = circuit.public_input()?[0]; + let pi_hash: [P::ScalarField; 2] = circuit.public_input()?.try_into().map_err(|_| { + PlonkError::InvalidParameters( + "Public input length does not match expected length".to_string(), + ) + })?; Ok(RecursiveOutput::new(proof, pi_hash, transcript)) } } diff --git a/plonk/src/nightfall/mle/snark.rs b/plonk/src/nightfall/mle/snark.rs index ea9021eb..27487032 100644 --- a/plonk/src/nightfall/mle/snark.rs +++ b/plonk/src/nightfall/mle/snark.rs @@ -218,7 +218,7 @@ impl MLEPlonk { T::new_transcript(b"mle_plonk") }; - // Append public input to transcript. + // Append public inputs to transcript. for public_input in circuit.public_input()? { transcript.push_message(b"public input", &public_input)?; } @@ -622,8 +622,10 @@ impl MLEPlonk { // Compute the public input mle. let mut public_inputs = circuit.public_input()?; - // Append the singular public input to the transcript. - transcript.push_message(b"pi", &public_inputs[0])?; + // Append the public inputs to the transcript. + for pi in &public_inputs { + transcript.push_message(b"public input", pi)?; + } public_inputs.resize(circuit.num_gates(), P::ScalarField::zero()); @@ -1114,7 +1116,7 @@ impl MLEPlonk { recursion_output: &RecursiveOutput, opening_proof: &PCS::Proof, vk: &MLEVerifyingKey, - public_input: P::ScalarField, + public_input: [P::ScalarField; 2], _rng: &mut R, extra_transcript_init_msg: Option>, ) -> Result @@ -1142,17 +1144,14 @@ impl MLEPlonk { let n = 1usize << num_vars; let shared = MLEProofShared::from(proof); - check_proof_shape(&shared, vk, &[public_input], num_vars)?; + check_proof_shape(&shared, vk, &public_input, num_vars)?; - let mut pi_evals = vec![public_input]; + let mut pi_evals = public_input.to_vec(); pi_evals.resize(n, P::ScalarField::zero()); let pi_poly = DenseMultilinearExtension::from_evaluations_vec(num_vars, pi_evals); - let challenges = MLEChallenges::::new_recursion( - proof, - &[public_input], - &mut transcript, - )?; + let challenges = + MLEChallenges::::new_recursion(proof, &public_input, &mut transcript)?; let gkr_deferred_check = batch_verify_gkr::(&proof.gkr_proof, &mut transcript)?; @@ -1632,12 +1631,18 @@ pub mod tests { &proof.proof.opening_point, )?; + let pi: [E::ScalarField; 2] = public_inputs[i].clone().try_into().map_err(|_| { + PlonkError::SnarkError(SnarkError::ParameterError( + "Public inputs length mismatch".to_string(), + )) + })?; + assert!( MLEPlonk::::verify_recursive_proof::<_, _, _, RescueTranscript>( proof, &opening_proof, vk_ref, - public_inputs[i][0], + pi, rng, None ) @@ -1650,7 +1655,7 @@ pub mod tests { proof, &opening_proof, vk_ref, - E::ScalarField::zero(), + [E::ScalarField::zero(); 2], rng, None ) @@ -1667,7 +1672,7 @@ pub mod tests { proof, &default_opening, vk_ref, - public_inputs[i][0], + pi, rng, None ) diff --git a/plonk/src/proof_system/mod.rs b/plonk/src/proof_system/mod.rs index b2f4d3f8..28b60757 100644 --- a/plonk/src/proof_system/mod.rs +++ b/plonk/src/proof_system/mod.rs @@ -167,8 +167,8 @@ where { /// The proof generated by the recursive prover. pub proof: Scheme::RecursiveProof, - /// The hash of the public inputs to this proof. - pub pi_hash: ::ScalarField, + /// The hash of the public inputs to this proof represented by two field elements. + pub pi_hash: [::ScalarField; 2], /// The transcript of the proof. pub transcript: T, } @@ -187,7 +187,7 @@ where /// Create a new recursive output. pub fn new( proof: Scheme::RecursiveProof, - pi_hash: ::ScalarField, + pi_hash: [::ScalarField; 2], transcript: T, ) -> Self { Self { diff --git a/plonk/src/recursion/circuits/challenges.rs b/plonk/src/recursion/circuits/challenges.rs index 13f25e08..6627fcab 100644 --- a/plonk/src/recursion/circuits/challenges.rs +++ b/plonk/src/recursion/circuits/challenges.rs @@ -128,7 +128,7 @@ impl MLEProofChallenges { pub fn reconstruct_mle_challenges( proof_var: &SAMLEProofVar, circuit: &mut PlonkCircuit, - pi_hash: &EmulatedVariable, + pi_hash: &[EmulatedVariable; 2], initialisation_msg: &Option>, ) -> Result<(MLEProofChallengesEmulatedVar, C), CircuitError> where diff --git a/plonk/src/recursion/circuits/emulated_mle_arithmetic.rs b/plonk/src/recursion/circuits/emulated_mle_arithmetic.rs index 3dc42631..f3354858 100644 --- a/plonk/src/recursion/circuits/emulated_mle_arithmetic.rs +++ b/plonk/src/recursion/circuits/emulated_mle_arithmetic.rs @@ -463,7 +463,7 @@ type MLEScalarsAndAccEval = (Vec>, EmulatedVariable, EmulatedVariable)], + proof_vars: &[(SAMLEProofVar, [EmulatedVariable; 2])], acc_info: &SplitAccumulationInfo, old_accs: &[(PointVariable, EmulatedVariable)], gate_info: &GateInfo, @@ -503,17 +503,26 @@ pub fn emulated_combine_mle_proof_scalars( &mut transcript_var, )?; - let zero_eval = - proof_var - .sumcheck_proof - .point_var - .iter() - .try_fold(one_var.clone(), |acc, point| { - let tmp1 = circuit.emulated_mul(&acc, point)?; - circuit.emulated_sub(&acc, &tmp1) - })?; + // Since we have two public inputs, we need to construct the public input polynomial given by: + // pi[0] * (1 - p_0) * (1 - p_1) * ... * (1 - p_{n-1}) + pi[1] * p_0 * (1 - p_1) * ... * (1 - p_{n-1}), + // where p_i are the coordinates in the sumcheck proof point. + // We first construct (1 - p_1) * ... * (1 - p_{n-1}). + let intermediate_eval = proof_var.sumcheck_proof.point_var.iter().skip(1).try_fold( + one_var.clone(), + |acc, point| { + let tmp1 = circuit.emulated_mul(&acc, point)?; + circuit.emulated_sub(&acc, &tmp1) + }, + )?; + + let one_minus_p0 = + circuit.emulated_sub(&one_var, &proof_var.sumcheck_proof.point_var[0])?; + let eval_0 = circuit.emulated_mul(&one_minus_p0, &intermediate_eval)?; + let eval_1 = + circuit.emulated_mul(&proof_var.sumcheck_proof.point_var[0], &intermediate_eval)?; - let pi_eval = circuit.emulated_mul(pi, &zero_eval)?; + let mut pi_eval = circuit.emulated_mul(&pi[0], &eval_0)?; + pi_eval = circuit.emulated_mul_add(&pi[1], &eval_1, &pi_eval)?; let (scalars, eval) = verify_mleplonk_emulated_scalar_arithmetic( circuit, @@ -719,12 +728,22 @@ mod tests { }) .collect::)>, CircuitError>>()?; - let proof_vars: Vec<(SAMLEProofVar, EmulatedVariable)> = outputs + let proof_vars: Vec<(SAMLEProofVar, [EmulatedVariable; 2])> = outputs .iter() .map(|o| { let p = SAMLEProofVar::from_struct(&mut verifier_circuit, &o.proof)?; - let h = verifier_circuit.create_emulated_variable(o.pi_hash)?; - Ok::<(SAMLEProofVar, EmulatedVariable), CircuitError>((p, h)) + let h: [EmulatedVariable; 2] = o + .pi_hash + .iter() + .map(|val| verifier_circuit.create_emulated_variable(*val)) + .collect::>, CircuitError>>()? + .try_into() + .map_err(|_| { + CircuitError::ParameterError("Invalid pi_hash length".to_string()) + })?; + Ok::<(SAMLEProofVar, [EmulatedVariable; 2]), CircuitError>(( + p, h, + )) }) .collect::, _>>()?; diff --git a/plonk/src/recursion/circuits/fft_arithmetic.rs b/plonk/src/recursion/circuits/fft_arithmetic.rs index 26c76fe1..5b703220 100644 --- a/plonk/src/recursion/circuits/fft_arithmetic.rs +++ b/plonk/src/recursion/circuits/fft_arithmetic.rs @@ -445,7 +445,9 @@ mod tests { let rng = &mut jf_utils::test_rng(); for (m, blind) in (2..8).zip([true, false]) { let circuit = gen_circuit_for_test::(m, 3, PlonkType::UltraPlonk, true)?; - let pi = circuit.public_input()?[0]; + let pi: [Fr254; 2] = circuit.public_input()?.try_into().map_err(|_| { + PlonkError::InvalidParameters("public input must have length 2".to_string()) + })?; let srs_size = circuit.srs_size(blind)?; let srs = UnivariateKzgPCS::::gen_srs_for_testing(rng, srs_size)?; @@ -460,7 +462,15 @@ mod tests { let mut verifier_circuit = PlonkCircuit::::new_ultra_plonk(8); let base_var = ProofVarNative::from_struct(&output.proof, &mut verifier_circuit)?; - let pi_hash = verifier_circuit.create_variable(output.pi_hash)?; + let pi_hash: [Variable; 2] = output + .pi_hash + .into_iter() + .map(|pi_elem| verifier_circuit.create_variable(pi_elem)) + .collect::, _>>()? + .try_into() + .map_err(|_| { + PlonkError::InvalidParameters("pi_hash must have length 2".to_string()) + })?; let scalar_var = ProofScalarsVarNative::from_struct(&base_var, pi_hash)?; let pcs_info_circuit = partial_verify_fft_plonk( @@ -476,7 +486,7 @@ mod tests { let pcs_info = fft_verifier.prepare_pcs_info::>( &vk, - &[pi], + &pi, &output.proof, &None, blind, @@ -517,7 +527,9 @@ mod tests { [true, false], ) { let circuit = gen_circuit_for_test::(m, 3, PlonkType::UltraPlonk, true)?; - let pi = circuit.public_input()?[0]; + let pi: [Fr254; 2] = circuit.public_input()?.try_into().map_err(|_| { + PlonkError::InvalidParameters("public input must have length 2".to_string()) + })?; let srs_size = circuit.srs_size(blind)?; let srs = UnivariateKzgPCS::::gen_srs_for_testing(rng, srs_size)?; @@ -532,7 +544,15 @@ mod tests { let mut verifier_circuit = PlonkCircuit::::new_ultra_plonk(8); let vk_var = VerifyingKeyNativeScalarsVar::new(&mut verifier_circuit, &vk)?; - let pi_hash = verifier_circuit.create_variable(output.pi_hash)?; + let pi_hash: [Variable; 2] = output + .pi_hash + .iter() + .map(|val| verifier_circuit.create_variable(*val)) + .collect::, CircuitError>>()? + .try_into() + .map_err(|_| { + PlonkError::InvalidParameters("pi_hash must have length 2".to_string()) + })?; let base_var = ProofVarNative::from_struct(&output.proof, &mut verifier_circuit)?; let scalar_var = ProofScalarsVarNative::from_struct(&base_var, pi_hash)?; @@ -550,7 +570,7 @@ mod tests { let pcs_info = fft_verifier.prepare_pcs_info::>( &vk, - &[pi], + &pi, &output.proof, &None, blind, @@ -585,8 +605,8 @@ mod tests { for (m, blind) in (2..8).zip([true, false]) { let circuit_one = gen_circuit_for_test::(m, 3, PlonkType::UltraPlonk, true)?; let circuit_two = gen_circuit_for_test::(m, 4, PlonkType::UltraPlonk, true)?; - let pi_one = circuit_one.public_input()?[0]; - let pi_two = circuit_two.public_input()?[0]; + let pi_one = circuit_one.public_input()?; + let pi_two = circuit_two.public_input()?; let srs_size = circuit_one.srs_size(blind)?; @@ -652,7 +672,7 @@ mod tests { let verifier = FFTVerifier::new(vk.domain_size)?; verifier.prepare_pcs_info_with_bases_var::>( vk, - &[output.pi_hash], + &output.pi_hash, output, &None, &mut bases_verifier_circuit, @@ -695,7 +715,11 @@ mod tests { .iter() .map(|output| { let proof = ProofVarNative::from_struct(&output.proof, &mut scalars_verifier_circuit)?; - let pi_hash = scalars_verifier_circuit.create_variable(output.pi_hash)?; + let pi_hash: [Variable; 2] = output.pi_hash.iter().map(|val| { + scalars_verifier_circuit.create_variable(*val) + }).collect::, CircuitError>>()?.try_into().map_err(|_| { + PlonkError::InvalidParameters("pi_hash must have length 2".to_string()) + })?; let proof_evals = ProofScalarsVarNative::from_struct(&proof, pi_hash)?; Ok((proof_evals, proof)) }) @@ -785,10 +809,10 @@ mod tests { let pcs_infos = outputs .iter() .zip(pis.iter()) - .map(|(output, &pi)| { + .map(|(output, pi)| { fft_verifier.prepare_pcs_info::>( &vk, - &[pi], + pi, &output.proof, &None, blind, @@ -880,8 +904,8 @@ mod tests { ) { let circuit_one = gen_circuit_for_test::(m, 3, PlonkType::UltraPlonk, true)?; let circuit_two = gen_circuit_for_test::(m, 4, PlonkType::UltraPlonk, true)?; - let pi_one = circuit_one.public_input()?[0]; - let pi_two = circuit_two.public_input()?[0]; + let pi_one = circuit_one.public_input()?; + let pi_two = circuit_two.public_input()?; let srs_size = circuit_one.srs_size(blind)?; @@ -951,7 +975,7 @@ mod tests { let verifier = FFTVerifier::new(vk.domain_size)?; verifier.prepare_pcs_info_with_bases_var::>( vk, - &[output.pi_hash], + &output.pi_hash, output, &None, &mut bases_verifier_circuit, @@ -995,7 +1019,11 @@ mod tests { .iter() .map(|output| { let proof = ProofVarNative::from_struct(&output.proof, &mut scalars_verifier_circuit)?; - let pi_hash = scalars_verifier_circuit.create_variable(output.pi_hash)?; + let pi_hash: [Variable; 2] = output.pi_hash.iter().map(|val| { + scalars_verifier_circuit.create_variable(*val) + }).collect::, CircuitError>>()?.try_into().map_err(|_| { + PlonkError::InvalidParameters("pi_hash must have length 2".to_string()) + })?; let proof_evals = ProofScalarsVarNative::from_struct(&proof, pi_hash)?; Ok((proof_evals, proof)) }) @@ -1086,7 +1114,7 @@ mod tests { let fft_verifier = FFTVerifier::::new(vk.domain_size)?; fft_verifier.prepare_pcs_info::>( &vk, - &[pi], + &pi, &output.proof, &None, blind, diff --git a/plonk/src/recursion/circuits/mle_arithmetic.rs b/plonk/src/recursion/circuits/mle_arithmetic.rs index a0693123..02343889 100644 --- a/plonk/src/recursion/circuits/mle_arithmetic.rs +++ b/plonk/src/recursion/circuits/mle_arithmetic.rs @@ -508,7 +508,7 @@ type MLEScalarsAndAccEval = (Vec, Variable); pub fn combine_mle_proof_scalars( outputs: &[RecursiveOutput, RescueTranscript>], challenges: &[MLEProofChallengesVar], - pi_hashes: &[Variable], + pi_hashes: &Vec<[Variable; 2]>, acc_info: &SplitAccumulationInfoVar, vk: &MLEVerifyingKey, circuit: &mut PlonkCircuit, @@ -533,19 +533,32 @@ pub fn combine_mle_proof_scalars( { let proof_var = SAMLEProofNative::from_struct(circuit, &output.proof)?; - let zero_eval = - proof_var - .sumcheck_proof() - .point_var - .iter() - .try_fold(circuit.one(), |acc, point| { - circuit.mul_add( - &[acc, circuit.one(), acc, *point], - &[Fq254::one(), -Fq254::one()], - ) - })?; + // Since we have two public inputs, we need to construct the public input polynomial given by: + // pi_hash[0] * (1 - p_0) * (1 - p_1) * ... * (1 - p_{n-1}) + pi_hash[1] * p_0 * (1 - p_1) * ... * (1 - p_{n-1}), + // where p_i are the coordinates in the sumcheck proof point. + // We first construct (1 - p_1) * ... * (1 - p_{n-1}). + let intermediate_eval = proof_var.sumcheck_proof.point_var.iter().skip(1).try_fold( + circuit.one(), + |acc, point| { + circuit.mul_add( + &[acc, circuit.one(), acc, *point], + &[Fq254::one(), -Fq254::one()], + ) + }, + )?; + + let one_minus_p0 = circuit.lin_comb( + &[-Fq254::one()], + &Fq254::one(), + &[proof_var.sumcheck_proof.point_var[0]], + )?; + let eval_0 = circuit.mul(one_minus_p0, intermediate_eval)?; + let eval_1 = circuit.mul(proof_var.sumcheck_proof.point_var[0], intermediate_eval)?; - let pi_eval = circuit.mul(*pi_hash, zero_eval)?; + let pi_eval = circuit.mul_add( + &[pi_hash[0], eval_0, pi_hash[1], eval_1], + &[Fq254::one(), Fq254::one()], + )?; let (scalars, eval) = verify_mleplonk_scalar_arithmetic( circuit, @@ -655,7 +668,7 @@ mod tests { use ark_std::{sync::Arc, vec, vec::Vec, UniformRand}; use jf_primitives::pcs::{Accumulation, PolynomialCommitmentScheme}; use jf_relation::{ - gadgets::{ecc::HasTEForm, EmulationConfig}, + gadgets::{ecc::HasTEForm, EmulatedVariable, EmulationConfig}, Arithmetization, PlonkType, }; use jf_utils::test_rng; @@ -745,7 +758,7 @@ mod tests { let mut transcript = RescueTranscript::::new_transcript(b"mle_plonk"); let mle_challenges = MLEChallenges::::new_recursion( &proof.proof, - &[public_input[0]], + public_input, &mut transcript, ) .unwrap(); @@ -802,12 +815,15 @@ mod tests { &inner_proof.opening_point, ) .unwrap(); + let pi: &[F; 2] = public_input.as_slice().try_into().map_err(|_| { + PlonkError::InvalidParameters("Public input length mismatch".to_string()) + })?; assert!(MLEPlonk::::verify_recursive_proof::< _, _, _, RescueTranscript, - >(proof, &opening_proof, &_vk1, public_input[0], rng, None) + >(proof, &opening_proof, &_vk1, *pi, rng, None) .unwrap()); let mut circuit = PlonkCircuit::::new_ultra_plonk(8); @@ -834,13 +850,15 @@ mod tests { let proof_native = SAMLEProofNative::from_struct(&mut circuit, &proof.proof)?; let gate_info = &pk1.verifying_key.gate_info; - let mut pi_poly = vec![public_input[0]]; let num_vars = gkr_sumcheck_challenges.last().as_ref().unwrap().len(); - pi_poly.resize(1 << num_vars, F::zero()); - let pi_pol = DenseMultilinearExtension::::from_evaluations_vec(num_vars, pi_poly); + let mut pi_poly = vec![F::zero(); 1 << num_vars]; + pi_poly[..public_input.len()].copy_from_slice(public_input); - let pi_eval = pi_pol.evaluate(&proof.proof.sumcheck_proof.point).unwrap(); + let pi_poly = + DenseMultilinearExtension::::from_evaluations_vec(num_vars, pi_poly.to_vec()); + + let pi_eval = pi_poly.evaluate(&proof.proof.sumcheck_proof.point).unwrap(); let pi_eval = circuit.create_variable(pi_eval)?; let epsilon_var = circuit.create_variable(epsilon)?; @@ -961,8 +979,13 @@ mod tests { let mle_proof_challenges = outputs .iter() .map(|output| { - let pi_hash = challenges_circuit - .create_emulated_variable(output.pi_hash) + let pi_hash: [EmulatedVariable; 2] = output + .pi_hash + .iter() + .map(|val| challenges_circuit.create_emulated_variable(*val)) + .collect::>, CircuitError>>() + .unwrap() + .try_into() .unwrap(); let proof_var = SAMLEProofVar::from_struct(&mut challenges_circuit, &output.proof).unwrap(); @@ -984,11 +1007,20 @@ mod tests { }) .collect::>(); - let pi_hashes: Vec = outputs + let pi_hashes: Vec<[Variable; 2]> = outputs .iter() - .map(|o| verifier_circuit.create_variable(o.pi_hash)) - .collect::, _>>() - .unwrap(); + .map(|o| { + let pi: [Variable; 2] = o + .pi_hash + .iter() + .map(|val| verifier_circuit.create_variable(*val)) + .collect::, CircuitError>>() + .unwrap() + .try_into() + .unwrap(); + pi + }) + .collect::>(); let (combined_scalars, combined_eval) = combine_mle_proof_scalars( &outputs, &mle_proof_challenges, diff --git a/plonk/src/recursion/merge_functions.rs b/plonk/src/recursion/merge_functions.rs index e1c9976b..1e01e647 100644 --- a/plonk/src/recursion/merge_functions.rs +++ b/plonk/src/recursion/merge_functions.rs @@ -468,7 +468,7 @@ pub fn prove_bn254_accumulation( verifier.prepare_pcs_info_with_bases_var::>( vk, - &[output.pi_hash], + &output.pi_hash, output, &fs_msg, circuit, @@ -588,11 +588,22 @@ pub fn prove_bn254_accumulation( // Now we verify scalar arithmetic for the four previous Grumpkin proofs and the pi_hash. if !IS_FIRST_ROUND { - let old_pi_hashes: Vec = bn254info + let old_pi_hashes: Vec<[Variable; 2]> = bn254info .grumpkin_outputs .iter() - .map(|o| circuit.create_variable(o.pi_hash)) - .collect::, _>>()?; + .map(|o| { + o.pi_hash + .iter() + .map(|val| circuit.create_variable(*val)) + .collect::, CircuitError>>()? + .try_into() + .map_err(|_| { + PlonkError::InvalidParameters( + "Expected pi_hash to have length 2".to_string(), + ) + }) + }) + .collect::, _>>()?; let mle_plonk_challenges: Vec = bn254info .challenges @@ -625,7 +636,7 @@ pub fn prove_bn254_accumulation( let (scalars, eval) = combine_mle_proof_scalars( output_pair, challenges_pair, - old_pi_hashes, + &old_pi_hashes.to_vec(), split_acc_info, vk_grumpkin, circuit, @@ -736,7 +747,11 @@ pub fn prove_bn254_accumulation( old_acc.as_slice(), forwarded_acc.as_slice(), acc_eval.as_slice(), - old_hashes, + &old_hashes + .iter() + .flatten() + .copied() + .collect::>(), mle_challenges_flat.as_slice(), batch_challenge.as_slice(), ] @@ -750,38 +765,38 @@ pub fn prove_bn254_accumulation( let bytes = value.into_bigint().to_bytes_le(); let (challenge, leftover) = bytes.split_at(31); - let pi_hash = circuit.create_variable(Fq254::from_le_bytes_mod_order(challenge))?; - let leftover_var = - circuit.create_variable(Fq254::from_le_bytes_mod_order(leftover))?; + let low_var = circuit.create_variable(Fq254::from_le_bytes_mod_order(challenge))?; + let high_var = circuit.create_variable(Fq254::from_le_bytes_mod_order(leftover))?; - circuit.enforce_in_range(pi_hash, 8 * 31)?; - circuit.enforce_in_range(leftover_var, 6)?; + circuit.enforce_in_range(low_var, 8 * 31)?; + circuit.enforce_in_range(high_var, 6)?; let coeff = Fq254::from(2u32).pow([248u64]); circuit.lc_gate( &[ - pi_hash, - leftover_var, + low_var, + high_var, circuit.zero(), circuit.zero(), pi_hash_pre, ], &[Fq254::one(), coeff, Fq254::zero(), Fq254::zero()], )?; - Ok(pi_hash) + Ok([low_var, high_var]) }, ) - .collect::, CircuitError>>()?; + .collect::, CircuitError>>()?; // For checking correctness during testing #[cfg(test)] { - for (circuit_hash, actual_hash) in pi_hashes.iter().zip(bn254info.bn254_outputs.iter()) - { - assert_eq!( - circuit.witness(*circuit_hash).unwrap(), - fr_to_fq::(&actual_hash.pi_hash) - ); + for (circuit_hash, output) in pi_hashes.iter().zip(bn254info.bn254_outputs.iter()) { + for (pi, exp_pi) in circuit_hash.iter().zip(output.pi_hash.iter()) { + assert_eq!( + circuit.witness(*pi).unwrap(), + fr_to_fq::(exp_pi) + ); + } } } @@ -840,9 +855,11 @@ pub fn prove_bn254_accumulation( .iter() .try_for_each(|&var| circuit.set_variable_public(var))?; - pi_hashes - .iter() - .try_for_each(|x| circuit.set_variable_public(*x))?; + pi_hashes.iter().try_for_each(|pi_hash| { + pi_hash + .iter() + .try_for_each(|x| circuit.set_variable_public(*x)) + })?; let specific_pi_out = specific_pi_vars .iter() @@ -890,19 +907,18 @@ pub fn prove_bn254_accumulation( let bytes = value.into_bigint().to_bytes_le(); let (challenge, leftover) = bytes.split_at(31); - let pi_hash = circuit.create_variable(Fq254::from_le_bytes_mod_order(challenge))?; - let leftover_var = - circuit.create_variable(Fq254::from_le_bytes_mod_order(leftover))?; + let low_var = circuit.create_variable(Fq254::from_le_bytes_mod_order(challenge))?; + let high_var = circuit.create_variable(Fq254::from_le_bytes_mod_order(leftover))?; - circuit.enforce_in_range(pi_hash, 8 * 31)?; - circuit.enforce_in_range(leftover_var, 6)?; + circuit.enforce_in_range(low_var, 8 * 31)?; + circuit.enforce_in_range(high_var, 6)?; let coeff = Fq254::from(2u32).pow([248u64]); circuit.lc_gate( &[ - pi_hash, - leftover_var, + low_var, + high_var, circuit.zero(), circuit.zero(), pi_hash_pre, @@ -910,19 +926,20 @@ pub fn prove_bn254_accumulation( &[Fq254::one(), coeff, Fq254::zero(), Fq254::zero()], )?; - Ok(pi_hash) + Ok([low_var, high_var]) }) - .collect::, CircuitError>>()?; + .collect::, CircuitError>>()?; // For checking correctness during testing #[cfg(test)] { - for (circuit_hash, actual_hash) in pi_hashes.iter().zip(bn254info.bn254_outputs.iter()) - { - assert_eq!( - circuit.witness(*circuit_hash).unwrap(), - fr_to_fq::(&actual_hash.pi_hash) - ); + for (circuit_hash, output) in pi_hashes.iter().zip(bn254info.bn254_outputs.iter()) { + for (pi, exp_pi) in circuit_hash.iter().zip(output.pi_hash.iter()) { + assert_eq!( + circuit.witness(*pi).unwrap(), + fr_to_fq::(exp_pi) + ); + } } } // Do any specific pi required @@ -984,9 +1001,11 @@ pub fn prove_bn254_accumulation( .iter() .try_for_each(|&var| circuit.set_variable_public(var))?; - pi_hashes - .iter() - .try_for_each(|x| circuit.set_variable_public(*x))?; + pi_hashes.iter().try_for_each(|pi_hash| { + pi_hash + .iter() + .try_for_each(|x| circuit.set_variable_public(*x)) + })?; let specific_pi = specific_pi_vars .iter() @@ -1202,7 +1221,15 @@ pub fn prove_grumpkin_accumulation( .iter() .map(|output| { let proof = ProofVarNative::from_struct(&output.proof, circuit)?; - let pi_hash = circuit.create_variable(output.pi_hash)?; + let pi_hash: [Variable; 2] = output + .pi_hash + .iter() + .map(|val| circuit.create_variable(*val)) + .collect::, CircuitError>>()? + .try_into() + .map_err(|_| { + PlonkError::InvalidParameters("Expected pi_hash to have length 2".to_string()) + })?; let proof_evals = ProofScalarsVarNative::from_struct(&proof, pi_hash)?; Ok((proof_evals, proof)) }) @@ -1339,12 +1366,12 @@ pub fn prove_grumpkin_accumulation( .collect::>, CircuitError>>()?; let mut bn254_acc_vars = Vec::::new(); - let mut pi_hash_vars = Vec::::new(); + let mut pi_hash_vars = Vec::<[Variable; 2]>::new(); let bn254_pi_hashes = output_scalar_vars .iter() .map(|scalar_vars| scalar_vars.pi_hash) - .collect::>(); + .collect::>(); let acc_comms: Vec<(PointVariable, EmulatedVariable)> = grumpkin_info .old_accumulators @@ -1362,7 +1389,7 @@ pub fn prove_grumpkin_accumulation( .map(|e| SAMLEProofVar::from_struct(circuit, &e.proof)) .collect::, _>>()?; - let grumpkin_pi_hashes: Vec = grumpkin_info + let grumpkin_pi_hashes: Vec<[Fq254; 2]> = grumpkin_info .grumpkin_outputs .iter() .map(|e| e.pi_hash) @@ -1529,7 +1556,9 @@ pub fn prove_grumpkin_accumulation( let bn_pi_hashes_prepped = bn254_pi_hashes .iter() - .map(|&var| convert_to_hash_form(circuit, var)) + .flatten() + .copied() + .map(|var| convert_to_hash_form(circuit, var)) .collect::, CircuitError>>()? .into_iter() .flatten() @@ -1555,37 +1584,48 @@ pub fn prove_grumpkin_accumulation( let bytes = value.into_bigint().to_bytes_le(); let (challenge, leftover) = bytes.split_at(31); - let pi_hash = circuit.create_variable(Fr254::from_le_bytes_mod_order(challenge))?; + let low_var = circuit.create_variable(Fr254::from_le_bytes_mod_order(challenge))?; + let high_var = circuit.create_variable(Fr254::from_le_bytes_mod_order(leftover))?; - let leftover_var = circuit.create_variable(Fr254::from_le_bytes_mod_order(leftover))?; - - circuit.enforce_in_range(pi_hash, 8 * 31)?; - circuit.enforce_in_range(leftover_var, 6)?; + circuit.enforce_in_range(low_var, 8 * 31)?; + circuit.enforce_in_range(high_var, 6)?; let coeff = Fr254::from(2u32).pow([248u64]); circuit.lc_gate( &[ - pi_hash, - leftover_var, + low_var, + high_var, circuit.zero(), circuit.zero(), calc_pi_hash, ], &[Fr254::one(), coeff, Fr254::zero(), Fr254::zero()], )?; + let pi_hash = [low_var, high_var]; pi_hash_vars.push(pi_hash); // For checking correctness during testing #[cfg(test)] { - assert_eq!( - circuit.witness(pi_hash).unwrap(), - fr_to_fq::(_grumpkin_pi_hash) - ); + for i in 0..2 { + assert_eq!( + circuit.witness(pi_hash[i]).unwrap(), + fr_to_fq::(&_grumpkin_pi_hash[i]) + ) + } } - let pi_hash_emul: EmulatedVariable = circuit.to_emulated_variable(pi_hash)?; + let pi_hash_emul: [EmulatedVariable; 2] = pi_hash + .iter() + .map(|var| circuit.to_emulated_variable(*var)) + .collect::>, CircuitError>>()? + .try_into() + .map_err(|_| { + CircuitError::ParameterError( + "Could not convert slice to fixed length array".to_string(), + ) + })?; let fs_msg = if let Some(fs_metadata) = fs_metadata { let layer = if IS_BASE { @@ -1714,8 +1754,10 @@ pub fn prove_grumpkin_accumulation( .collect::, CircuitError>>()?; // Finally pi hashes are constructed to fit into either field - for var in pi_hash_vars.iter() { - circuit.set_variable_public(*var)?; + for pi_hash_var in pi_hash_vars.iter() { + for &var in pi_hash_var.iter() { + circuit.set_variable_public(var)?; + } } next_grumpkin_challenges.iter().try_for_each(|c| { @@ -1774,7 +1816,15 @@ pub fn decider_circuit( .iter() .map(|output| { let proof = ProofVarNative::from_struct(&output.proof, circuit)?; - let pi_hash = circuit.create_variable(output.pi_hash)?; + let pi_hash: [Variable; 2] = output + .pi_hash + .iter() + .map(|val| circuit.create_variable(*val)) + .collect::, CircuitError>>()? + .try_into() + .map_err(|_| { + PlonkError::InvalidParameters("Expected pi_hash to have length 2".to_string()) + })?; let proof_evals = ProofScalarsVarNative::from_struct(&proof, pi_hash)?; Ok((proof_evals, proof)) }) @@ -1840,7 +1890,7 @@ pub fn decider_circuit( let bn254_pi_hashes = output_scalar_vars .iter() .map(|o| o.pi_hash) - .collect::>(); + .collect::>(); let acc_comms: Vec<(PointVariable, EmulatedVariable)> = grumpkin_info .old_accumulators @@ -1852,11 +1902,20 @@ pub fn decider_circuit( }) .collect::, CircuitError>>()?; - let pi_hashes: Vec> = grumpkin_info + let pi_hashes: Vec<[EmulatedVariable; 2]> = grumpkin_info .grumpkin_outputs .iter() - .map(|o| circuit.create_emulated_variable(o.pi_hash)) - .collect::>, _>>()?; + .map(|o| { + o.pi_hash + .iter() + .map(|val| circuit.create_emulated_variable(*val)) + .collect::>, CircuitError>>()? + .try_into() + .map_err(|_| { + CircuitError::ParameterError("Expected pi_hash to have length 2".to_string()) + }) + }) + .collect::; 2]>, _>>()?; // Now we reform the pi_hashes for both grumpkin proof and extract the scalars from them. izip!( @@ -1994,7 +2053,9 @@ pub fn decider_circuit( let bn_pi_hashes_prepped = bn254_pi_hashes .iter() - .map(|&var| convert_to_hash_form(circuit, var)) + .flatten() + .copied() + .map(|var| convert_to_hash_form(circuit, var)) .collect::, CircuitError>>()? .into_iter() .flatten() @@ -2016,29 +2077,31 @@ pub fn decider_circuit( let bytes = value.into_bigint().to_bytes_le(); let (challenge, leftover) = bytes.split_at(31); - let pi_hash = circuit.create_variable(Fr254::from_le_bytes_mod_order(challenge))?; - - let leftover_var = circuit.create_variable(Fr254::from_le_bytes_mod_order(leftover))?; + let low_var = circuit.create_variable(Fr254::from_le_bytes_mod_order(challenge))?; + let high_var = circuit.create_variable(Fr254::from_le_bytes_mod_order(leftover))?; let coeff = Fr254::from(2u32).pow([248u64]); - circuit.enforce_in_range(pi_hash, 8 * 31)?; - circuit.enforce_in_range(leftover_var, 6)?; + circuit.enforce_in_range(low_var, 8 * 31)?; + circuit.enforce_in_range(high_var, 6)?; circuit.lc_gate( &[ - pi_hash, - leftover_var, + low_var, + high_var, circuit.zero(), circuit.zero(), calc_pi_hash, ], &[Fr254::one(), coeff, Fr254::zero(), Fr254::zero()], )?; + let pi_hash = [low_var, high_var]; - let pi_native = circuit.mod_to_native_field(pi_hash_emul)?; - - circuit.enforce_equal(pi_native, pi_hash) + for i in 0..2 { + let pi_native = circuit.mod_to_native_field(&pi_hash_emul[i])?; + circuit.enforce_equal(pi_native, pi_hash[i])?; + } + Ok::<(), CircuitError>(()) }, )?; let split_acc_info = SplitAccumulationInfo::perform_accumulation( @@ -2094,7 +2157,7 @@ pub fn decider_circuit( &[proof_one, proof_two] .into_iter() .zip_eq(pi_hashes) - .collect::, EmulatedVariable)>>(), + .collect::, [EmulatedVariable; 2])>>(), &split_acc_info, &acc_comms, &pk_grumpkin.verifying_key.gate_info, diff --git a/primitives/src/rescue/sponge.rs b/primitives/src/rescue/sponge.rs index add56c0b..dca82ca5 100644 --- a/primitives/src/rescue/sponge.rs +++ b/primitives/src/rescue/sponge.rs @@ -36,7 +36,7 @@ pub struct RescueCRHF { impl RecursionHasher for RescueCRHF { type Error = CircuitError; - fn hash_public_inputs(public_inputs: &[E]) -> Result { + fn hash_public_inputs(public_inputs: &[E]) -> Result<[E; 2], Self::Error> { let mut input = Vec::new(); let e_modulus: BigUint = E::MODULUS.into(); let f_modulus: BigUint = F::MODULUS.into(); @@ -61,17 +61,20 @@ impl RecursionHasher for RescueCRHF { // Find the byte length of the scalar field (minus one). let field_bytes_length = (E::MODULUS_BIT_SIZE as usize - 1) / 8; - let hash = E::from_le_bytes_mod_order( - output - .into_bigint() - .to_bytes_le() - .iter() - .take(field_bytes_length) - .copied() - .collect::>() - .as_slice(), - ); - Ok(hash) + + // Two elements of E. One per chunk of `field_bytes_length` bytes + let hashes: [E; 2] = output + .into_bigint() + .to_bytes_le() + .chunks(field_bytes_length) + .map(|chunk| E::from_le_bytes_mod_order(chunk)) + .collect::>() + .try_into() + .map_err(|_| { + CircuitError::InternalError("Could not convert to fixed length array".to_string()) + })?; + + Ok(hashes) } } /// PRF diff --git a/relation/src/constraint_system.rs b/relation/src/constraint_system.rs index c797ec99..ab5e60eb 100644 --- a/relation/src/constraint_system.rs +++ b/relation/src/constraint_system.rs @@ -1539,17 +1539,19 @@ impl PlonkCircuit { .unwrap(); } - // Here we make the only public input the hash of all the public inputs. + // Here we make the only two public inputs the hash of all the public inputs. let public_input = self .pub_input_indices .iter() .map(|&i| self.witness[i]) .collect::>(); - let new_public_input: F = H::hash_public_inputs::(&public_input) + let new_public_inputs: [F; 2] = H::hash_public_inputs::(&public_input) .map_err(|_| CircuitError::InternalError("Public input hashing failed".to_string()))?; self.pub_input_indices = vec![]; - let _ = self.create_public_variable(new_public_input)?; + for new_public_input in new_public_inputs.iter() { + self.create_public_variable(*new_public_input)?; + } let wire_vars = self .pub_input_indices @@ -1664,17 +1666,19 @@ impl PlonkCircuit { .unwrap(); } - // Here we make the only public input the hash of all the public inputs. + // Here we make the only two public inputs the hash of all the public inputs. let public_input = self .pub_input_indices .iter() .map(|&i| self.witness[i]) .collect::>(); - let new_public_input: F = H::hash_public_inputs::(&public_input) + let new_public_input: [F; 2] = H::hash_public_inputs::(&public_input) .map_err(|_| CircuitError::InternalError("Public input hashing failed".to_string()))?; self.pub_input_indices = vec![]; - let _ = self.create_public_variable(new_public_input)?; + for new_public_input in new_public_input.iter() { + self.create_public_variable(*new_public_input)?; + } let wire_vars = self .pub_input_indices diff --git a/relation/src/lib.rs b/relation/src/lib.rs index feb58288..02808850 100644 --- a/relation/src/lib.rs +++ b/relation/src/lib.rs @@ -22,5 +22,5 @@ pub trait RecursionHasher { /// The error type for this hasher. type Error; /// This function defines how public inputs will be hashed in a recursive setting. - fn hash_public_inputs(public_inputs: &[F]) -> Result; + fn hash_public_inputs(public_inputs: &[F]) -> Result<[F; 2], Self::Error>; }