diff --git a/crates/mpcs/src/jagged/assist.rs b/crates/mpcs/src/jagged/assist.rs index 51c5e3c..3d3ed75 100644 --- a/crates/mpcs/src/jagged/assist.rs +++ b/crates/mpcs/src/jagged/assist.rs @@ -35,27 +35,64 @@ pub fn assist_sumcheck_prove( n_robp: usize, transcript: &mut impl Transcript, ) -> (IOPProof, Vec) { + let (_, proof, challenges) = assist_sumcheck_prove_impl( + z_row_padded, + rho_padded, + eq_col, + cumulative_heights, + n_robp, + transcript, + false, + ); + (proof, challenges) +} + +/// Compute the assist claimed sum, append it to the transcript, and run the +/// assist sumcheck prover using the same ROBP precompute. +pub fn assist_sumcheck_prove_and_append_claim( + z_row_padded: &[E], + rho_padded: &[E], + eq_col: &[E], + cumulative_heights: &[usize], + n_robp: usize, + transcript: &mut impl Transcript, +) -> (E, IOPProof, Vec) { + assist_sumcheck_prove_impl( + z_row_padded, + rho_padded, + eq_col, + cumulative_heights, + n_robp, + transcript, + true, + ) +} + +fn assist_sumcheck_prove_impl( + z_row_padded: &[E], + rho_padded: &[E], + eq_col: &[E], + cumulative_heights: &[usize], + n_robp: usize, + transcript: &mut impl Transcript, + append_claim: bool, +) -> (E, IOPProof, Vec) { let num_polys = cumulative_heights.len() - 1; let n_vars = 2 * n_robp; let max_degree: usize = 2; - // Write transcript header (must match verifier). - transcript.append_message(&n_vars.to_le_bytes()); - transcript.append_message(&max_degree.to_le_bytes()); - // Precompute per-step symbol matrices. let step_mats: Vec<[TransitionMatrix; 4]> = (0..n_robp) .map(|i| symbol_transition_matrices(z_row_padded[i], rho_padded[i])) .collect(); - // Extract Boolean bits in step-major layout: c_bits[i][y], d_bits[i][y]. - // c_bits[i][y] = bit_i(t_y), d_bits[i][y] = bit_i(t_{y+1}) - let mut c_bits = vec![vec![0usize; num_polys]; n_robp]; - let mut d_bits = vec![vec![0usize; num_polys]; n_robp]; - for i in 0..n_robp { - for y in 0..num_polys { - c_bits[i][y] = (cumulative_heights[y] >> i) & 1; - d_bits[i][y] = (cumulative_heights[y + 1] >> i) & 1; + // Extract Boolean symbol pairs in step-major layout: + // cd_bits[i][y] = 2 * bit_i(t_y) + bit_i(t_{y+1}). + let mut cd_bits = vec![vec![0u8; num_polys]; n_robp]; + for (i, cd_bits_i) in cd_bits.iter_mut().enumerate() { + for (y, cd_bit) in cd_bits_i.iter_mut().enumerate() { + *cd_bit = ((((cumulative_heights[y] >> i) & 1) << 1) + | ((cumulative_heights[y + 1] >> i) & 1)) as u8; } } @@ -89,14 +126,27 @@ pub fn assist_sumcheck_prove( let dst = &mut left[i]; let src = &right[0]; dst.into_par_iter().enumerate().for_each(|(y, dst_y)| { - let cd = c_bits[i][y] * 2 + d_bits[i][y]; + let cd = cd_bits[i][y] as usize; *dst_y = mat_vec_mul(&step_mats[i][cd], &src[y]); }); } + let source = source_vec(); + let mut claimed_sum = E::ZERO; + for (y, eq) in eq_col.iter().enumerate().take(num_polys) { + claimed_sum += *eq * dot4(&source, &bwd[0][y]); + } + if append_claim { + transcript.append_field_element_ext(&claimed_sum); + } + + // Write transcript header (must match verifier). + transcript.append_message(&n_vars.to_le_bytes()); + transcript.append_message(&max_degree.to_le_bytes()); + // Initialize weights and forward vector. let mut weights: Vec = eq_col[..num_polys].to_vec(); - let mut fwd: StateVec = source_vec(); + let mut fwd: StateVec = source; let mut challenges: Vec = Vec::with_capacity(n_vars); let mut proof_messages: Vec> = Vec::with_capacity(n_vars); @@ -138,7 +188,7 @@ pub fn assist_sumcheck_prove( .map(|chunk| { let mut local_bwd_sum = [[E::ZERO; ROBP_WIDTH]; 4]; for &y in chunk { - let cd = c_bits[i][y] * 2 + d_bits[i][y]; + let cd = cd_bits[i][y] as usize; let w = weights[y]; for s in 0..ROBP_WIDTH { local_bwd_sum[cd][s] += w * bwd[i + 1][y][s]; @@ -220,7 +270,7 @@ pub fn assist_sumcheck_prove( let start = chunk_idx * batch_size; for (j, w) in w_chunk.iter_mut().enumerate() { let y = start + j; - let cd = c_bits[i][y] * 2 + d_bits[i][y]; + let cd = cd_bits[i][y] as usize; *w *= eq_cd[cd]; } }); @@ -235,6 +285,7 @@ pub fn assist_sumcheck_prove( } ( + claimed_sum, IOPProof { proofs: proof_messages, }, diff --git a/crates/mpcs/src/jagged/mod.rs b/crates/mpcs/src/jagged/mod.rs index 3c1f9b9..e45dcb3 100644 --- a/crates/mpcs/src/jagged/mod.rs +++ b/crates/mpcs/src/jagged/mod.rs @@ -112,7 +112,9 @@ pub mod evaluator; pub mod sumcheck; mod types; -pub use assist::{assist_sumcheck_prove, compute_q_at_assist_point}; +pub use assist::{ + assist_sumcheck_prove, assist_sumcheck_prove_and_append_claim, compute_q_at_assist_point, +}; pub use evaluator::{evaluate_g, evaluate_g_backward, evaluate_g_forward}; pub use sumcheck::{JaggedSumcheckInput, QPrimeEvaluations, jagged_sumcheck_prove}; pub use types::{JaggedBatchOpenProof, JaggedCommitment, JaggedCommitmentWithWitness, JaggedProof}; diff --git a/crates/mpcs/src/lib.rs b/crates/mpcs/src/lib.rs index 4b7cf2e..3ba273f 100644 --- a/crates/mpcs/src/lib.rs +++ b/crates/mpcs/src/lib.rs @@ -282,8 +282,8 @@ pub mod jagged; pub use jagged::{ JAGGED_RESHAPE_GROUP_WIDTH, Jagged, JaggedBatchOpenProof, JaggedCommitment, JaggedCommitmentWithWitness, JaggedProof, JaggedSumcheckInput, assist_sumcheck_prove, - evaluate_g, evaluate_g_backward, evaluate_g_forward, jagged_batch_open, jagged_batch_verify, - jagged_commit, jagged_sumcheck_prove, + assist_sumcheck_prove_and_append_claim, evaluate_g, evaluate_g_backward, evaluate_g_forward, + jagged_batch_open, jagged_batch_verify, jagged_commit, jagged_sumcheck_prove, }; #[cfg(feature = "whir")] extern crate whir as whir_external; diff --git a/crates/sumcheck/src/frontload.rs b/crates/sumcheck/src/frontload.rs index 8515fec..bb85ba1 100644 --- a/crates/sumcheck/src/frontload.rs +++ b/crates/sumcheck/src/frontload.rs @@ -103,6 +103,7 @@ use crate::{ pub struct FrontloadProverState { pub challenges: Vec>, pub final_evaluations: Vec>, + pub claimed_sum: E, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -181,6 +182,7 @@ pub fn prove_2phase<'a, E: ExtensionField>( } let mut proofs = Vec::with_capacity(global_num_vars); let mut challenge: Option> = None; + let mut claimed_sum = E::ZERO; for round in 0..local_num_vars { workers.par_iter_mut().for_each(|worker| { @@ -208,6 +210,9 @@ pub fn prove_2phase<'a, E: ExtensionField>( }, ) }; + if round == 0 { + claimed_sum = evaluations[0] + evaluations[1]; + } evaluations.remove(0); transcript.append_field_element_exts(&evaluations); proofs.push(IOPProverMessage { evaluations }); @@ -249,6 +254,7 @@ pub fn prove_2phase<'a, E: ExtensionField>( FrontloadProverState { challenges, final_evaluations, + claimed_sum, }, ) } @@ -268,6 +274,7 @@ fn prove_inner<'a, E: ExtensionField>( let mut proof = Vec::with_capacity(num_vars); let mut challenge: Option> = None; + let mut claimed_sum = E::ZERO; for round in 0..num_vars { if let Some(challenge) = challenge.take() { @@ -276,6 +283,9 @@ fn prove_inner<'a, E: ExtensionField>( } let mut evaluations = state.round_evaluations(round); + if round == 0 { + claimed_sum = evaluations[0] + evaluations[1]; + } evaluations.remove(0); transcript.append_field_element_exts(&evaluations); proof.push(IOPProverMessage { evaluations }); @@ -293,6 +303,7 @@ fn prove_inner<'a, E: ExtensionField>( FrontloadProverState { challenges: state.challenges, final_evaluations, + claimed_sum, }, ) } diff --git a/crates/sumcheck/src/prover.rs b/crates/sumcheck/src/prover.rs index d75e3dc..dae2941 100644 --- a/crates/sumcheck/src/prover.rs +++ b/crates/sumcheck/src/prover.rs @@ -192,6 +192,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { max_num_variables, poly_meta: vec![], final_evaluations: Some(state.final_evaluations), + claimed_sum: state.claimed_sum, phase2_numvar: None, } } @@ -449,6 +450,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { poly: polynomial, poly_meta: poly_meta.unwrap_or_else(|| vec![PolyMeta::Normal; num_polys]), final_evaluations: None, + claimed_sum: E::ZERO, phase2_numvar, } } @@ -512,6 +514,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { exit_span!(start); assert!(uni_polys.len() > 1); + if self.round == 1 { + self.claimed_sum = uni_polys[0] + uni_polys[1]; + } // NOTE remove uni_polys.eval(0) from lagrange domain // as verifier can derive via claim - uni_polys.eval(1) uni_polys.remove(0); @@ -843,6 +848,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .collect_vec() } + pub fn claimed_sum(&self) -> E { + self.claimed_sum + } + pub fn expected_numvars_at_round(&self) -> usize { // first round start from 1 let num_vars = self.max_num_variables + 1 - self.round; @@ -1021,6 +1030,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { poly: polynomial, poly_meta, final_evaluations: None, + claimed_sum: E::ZERO, phase2_numvar: None, }; @@ -1130,6 +1140,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { exit_span!(start); assert!(uni_polys.len() > 1); + if self.round == 1 { + self.claimed_sum = uni_polys[0] + uni_polys[1]; + } // NOTE remove uni_polys.eval(0) from lagrange domain // as verifier can derive via claim - uni_polys.eval(1) uni_polys.remove(0); diff --git a/crates/sumcheck/src/structs.rs b/crates/sumcheck/src/structs.rs index ef5dc28..ba207d3 100644 --- a/crates/sumcheck/src/structs.rs +++ b/crates/sumcheck/src/structs.rs @@ -99,6 +99,7 @@ pub struct IOPProverState<'a, E: ExtensionField> { pub(crate) max_num_variables: usize, pub(crate) poly_meta: Vec, pub(crate) final_evaluations: Option>>, + pub(crate) claimed_sum: E, /// phase 1 and phase 2 sumcheck we share similar implementation /// thus this option variable only use for phase 1 sumcheck to mark how many variables belongs to phase 2 pub(crate) phase2_numvar: Option,