Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions crates/mpcs/src/jagged/assist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,64 @@ pub fn assist_sumcheck_prove<E: ExtensionField>(
n_robp: usize,
transcript: &mut impl Transcript<E>,
) -> (IOPProof<E>, Vec<E>) {
let (_, proof, challenges) = assist_sumcheck_prove_impl(
z_row_padded,
rho_padded,
eq_col,
cumulative_heights,
n_robp,
transcript,
false,
);
(proof, challenges)
}

/// Compute the assist claimed sum, append it to the transcript, and run the
/// assist sumcheck prover using the same ROBP precompute.
pub fn assist_sumcheck_prove_and_append_claim<E: ExtensionField>(
z_row_padded: &[E],
rho_padded: &[E],
eq_col: &[E],
cumulative_heights: &[usize],
n_robp: usize,
transcript: &mut impl Transcript<E>,
) -> (E, IOPProof<E>, Vec<E>) {
assist_sumcheck_prove_impl(
z_row_padded,
rho_padded,
eq_col,
cumulative_heights,
n_robp,
transcript,
true,
)
}

fn assist_sumcheck_prove_impl<E: ExtensionField>(
z_row_padded: &[E],
rho_padded: &[E],
eq_col: &[E],
cumulative_heights: &[usize],
n_robp: usize,
transcript: &mut impl Transcript<E>,
append_claim: bool,
) -> (E, IOPProof<E>, Vec<E>) {
let num_polys = cumulative_heights.len() - 1;
let n_vars = 2 * n_robp;
let max_degree: usize = 2;

// Write transcript header (must match verifier).
transcript.append_message(&n_vars.to_le_bytes());
transcript.append_message(&max_degree.to_le_bytes());

// Precompute per-step symbol matrices.
let step_mats: Vec<[TransitionMatrix<E>; 4]> = (0..n_robp)
.map(|i| symbol_transition_matrices(z_row_padded[i], rho_padded[i]))
.collect();

// Extract Boolean bits in step-major layout: c_bits[i][y], d_bits[i][y].
// c_bits[i][y] = bit_i(t_y), d_bits[i][y] = bit_i(t_{y+1})
let mut c_bits = vec![vec![0usize; num_polys]; n_robp];
let mut d_bits = vec![vec![0usize; num_polys]; n_robp];
for i in 0..n_robp {
for y in 0..num_polys {
c_bits[i][y] = (cumulative_heights[y] >> i) & 1;
d_bits[i][y] = (cumulative_heights[y + 1] >> i) & 1;
// Extract Boolean symbol pairs in step-major layout:
// cd_bits[i][y] = 2 * bit_i(t_y) + bit_i(t_{y+1}).
let mut cd_bits = vec![vec![0u8; num_polys]; n_robp];
for (i, cd_bits_i) in cd_bits.iter_mut().enumerate() {
for (y, cd_bit) in cd_bits_i.iter_mut().enumerate() {
*cd_bit = ((((cumulative_heights[y] >> i) & 1) << 1)
| ((cumulative_heights[y + 1] >> i) & 1)) as u8;
}
}

Expand Down Expand Up @@ -89,14 +126,27 @@ pub fn assist_sumcheck_prove<E: ExtensionField>(
let dst = &mut left[i];
let src = &right[0];
dst.into_par_iter().enumerate().for_each(|(y, dst_y)| {
let cd = c_bits[i][y] * 2 + d_bits[i][y];
let cd = cd_bits[i][y] as usize;
*dst_y = mat_vec_mul(&step_mats[i][cd], &src[y]);
});
}

let source = source_vec();
let mut claimed_sum = E::ZERO;
for (y, eq) in eq_col.iter().enumerate().take(num_polys) {
claimed_sum += *eq * dot4(&source, &bwd[0][y]);
}
if append_claim {
transcript.append_field_element_ext(&claimed_sum);
Comment thread
kunxian-xia marked this conversation as resolved.
}

// Write transcript header (must match verifier).
transcript.append_message(&n_vars.to_le_bytes());
transcript.append_message(&max_degree.to_le_bytes());

// Initialize weights and forward vector.
let mut weights: Vec<E> = eq_col[..num_polys].to_vec();
let mut fwd: StateVec<E> = source_vec();
let mut fwd: StateVec<E> = source;

let mut challenges: Vec<E> = Vec::with_capacity(n_vars);
let mut proof_messages: Vec<IOPProverMessage<E>> = Vec::with_capacity(n_vars);
Expand Down Expand Up @@ -138,7 +188,7 @@ pub fn assist_sumcheck_prove<E: ExtensionField>(
.map(|chunk| {
let mut local_bwd_sum = [[E::ZERO; ROBP_WIDTH]; 4];
for &y in chunk {
let cd = c_bits[i][y] * 2 + d_bits[i][y];
let cd = cd_bits[i][y] as usize;
let w = weights[y];
for s in 0..ROBP_WIDTH {
local_bwd_sum[cd][s] += w * bwd[i + 1][y][s];
Expand Down Expand Up @@ -220,7 +270,7 @@ pub fn assist_sumcheck_prove<E: ExtensionField>(
let start = chunk_idx * batch_size;
for (j, w) in w_chunk.iter_mut().enumerate() {
let y = start + j;
let cd = c_bits[i][y] * 2 + d_bits[i][y];
let cd = cd_bits[i][y] as usize;
*w *= eq_cd[cd];
}
});
Expand All @@ -235,6 +285,7 @@ pub fn assist_sumcheck_prove<E: ExtensionField>(
}

(
claimed_sum,
IOPProof {
proofs: proof_messages,
},
Expand Down
4 changes: 3 additions & 1 deletion crates/mpcs/src/jagged/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ pub mod evaluator;
pub mod sumcheck;
mod types;

pub use assist::{assist_sumcheck_prove, compute_q_at_assist_point};
pub use assist::{
assist_sumcheck_prove, assist_sumcheck_prove_and_append_claim, compute_q_at_assist_point,
};
pub use evaluator::{evaluate_g, evaluate_g_backward, evaluate_g_forward};
pub use sumcheck::{JaggedSumcheckInput, QPrimeEvaluations, jagged_sumcheck_prove};
pub use types::{JaggedBatchOpenProof, JaggedCommitment, JaggedCommitmentWithWitness, JaggedProof};
Expand Down
4 changes: 2 additions & 2 deletions crates/mpcs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ pub mod jagged;
pub use jagged::{
JAGGED_RESHAPE_GROUP_WIDTH, Jagged, JaggedBatchOpenProof, JaggedCommitment,
JaggedCommitmentWithWitness, JaggedProof, JaggedSumcheckInput, assist_sumcheck_prove,
evaluate_g, evaluate_g_backward, evaluate_g_forward, jagged_batch_open, jagged_batch_verify,
jagged_commit, jagged_sumcheck_prove,
assist_sumcheck_prove_and_append_claim, evaluate_g, evaluate_g_backward, evaluate_g_forward,
jagged_batch_open, jagged_batch_verify, jagged_commit, jagged_sumcheck_prove,
};
#[cfg(feature = "whir")]
extern crate whir as whir_external;
Expand Down
11 changes: 11 additions & 0 deletions crates/sumcheck/src/frontload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ use crate::{
pub struct FrontloadProverState<E: ExtensionField> {
pub challenges: Vec<Challenge<E>>,
pub final_evaluations: Vec<Vec<E>>,
pub claimed_sum: E,
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -181,6 +182,7 @@ pub fn prove_2phase<'a, E: ExtensionField>(
}
let mut proofs = Vec::with_capacity(global_num_vars);
let mut challenge: Option<Challenge<E>> = None;
let mut claimed_sum = E::ZERO;

for round in 0..local_num_vars {
workers.par_iter_mut().for_each(|worker| {
Expand Down Expand Up @@ -208,6 +210,9 @@ pub fn prove_2phase<'a, E: ExtensionField>(
},
)
};
if round == 0 {
claimed_sum = evaluations[0] + evaluations[1];
}
evaluations.remove(0);
transcript.append_field_element_exts(&evaluations);
proofs.push(IOPProverMessage { evaluations });
Expand Down Expand Up @@ -249,6 +254,7 @@ pub fn prove_2phase<'a, E: ExtensionField>(
FrontloadProverState {
challenges,
final_evaluations,
claimed_sum,
},
)
}
Expand All @@ -268,6 +274,7 @@ fn prove_inner<'a, E: ExtensionField>(

let mut proof = Vec::with_capacity(num_vars);
let mut challenge: Option<Challenge<E>> = None;
let mut claimed_sum = E::ZERO;

for round in 0..num_vars {
if let Some(challenge) = challenge.take() {
Expand All @@ -276,6 +283,9 @@ fn prove_inner<'a, E: ExtensionField>(
}

let mut evaluations = state.round_evaluations(round);
if round == 0 {
claimed_sum = evaluations[0] + evaluations[1];
}
evaluations.remove(0);
transcript.append_field_element_exts(&evaluations);
proof.push(IOPProverMessage { evaluations });
Expand All @@ -293,6 +303,7 @@ fn prove_inner<'a, E: ExtensionField>(
FrontloadProverState {
challenges: state.challenges,
final_evaluations,
claimed_sum,
},
)
}
Expand Down
13 changes: 13 additions & 0 deletions crates/sumcheck/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
max_num_variables,
poly_meta: vec![],
final_evaluations: Some(state.final_evaluations),
claimed_sum: state.claimed_sum,
phase2_numvar: None,
}
}
Expand Down Expand Up @@ -449,6 +450,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
poly: polynomial,
poly_meta: poly_meta.unwrap_or_else(|| vec![PolyMeta::Normal; num_polys]),
final_evaluations: None,
claimed_sum: E::ZERO,
phase2_numvar,
}
}
Expand Down Expand Up @@ -512,6 +514,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
exit_span!(start);

assert!(uni_polys.len() > 1);
if self.round == 1 {
self.claimed_sum = uni_polys[0] + uni_polys[1];
}
// NOTE remove uni_polys.eval(0) from lagrange domain
// as verifier can derive via claim - uni_polys.eval(1)
uni_polys.remove(0);
Expand Down Expand Up @@ -843,6 +848,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
.collect_vec()
}

pub fn claimed_sum(&self) -> E {
self.claimed_sum
}

pub fn expected_numvars_at_round(&self) -> usize {
// first round start from 1
let num_vars = self.max_num_variables + 1 - self.round;
Expand Down Expand Up @@ -1021,6 +1030,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
poly: polynomial,
poly_meta,
final_evaluations: None,
claimed_sum: E::ZERO,
phase2_numvar: None,
};

Expand Down Expand Up @@ -1130,6 +1140,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
exit_span!(start);

assert!(uni_polys.len() > 1);
if self.round == 1 {
self.claimed_sum = uni_polys[0] + uni_polys[1];
}
// NOTE remove uni_polys.eval(0) from lagrange domain
// as verifier can derive via claim - uni_polys.eval(1)
uni_polys.remove(0);
Expand Down
1 change: 1 addition & 0 deletions crates/sumcheck/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub struct IOPProverState<'a, E: ExtensionField> {
pub(crate) max_num_variables: usize,
pub(crate) poly_meta: Vec<PolyMeta>,
pub(crate) final_evaluations: Option<Vec<Vec<E>>>,
pub(crate) claimed_sum: E,
/// phase 1 and phase 2 sumcheck we share similar implementation
/// thus this option variable only use for phase 1 sumcheck to mark how many variables belongs to phase 2
pub(crate) phase2_numvar: Option<usize>,
Expand Down
Loading