From 03254c3914d99554e1b2774e6b15f72952b38ef1 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 5 Jun 2026 19:17:06 +0800 Subject: [PATCH 1/6] fix: align jagged q layout with committed rows --- crates/mpcs/src/jagged/mod.rs | 113 ++++++++++++++++++++++++++-------- 1 file changed, 89 insertions(+), 24 deletions(-) diff --git a/crates/mpcs/src/jagged/mod.rs b/crates/mpcs/src/jagged/mod.rs index ee6d861..48ae9fa 100644 --- a/crates/mpcs/src/jagged/mod.rs +++ b/crates/mpcs/src/jagged/mod.rs @@ -21,9 +21,8 @@ //! ## Cumulative Heights //! //! Each polynomial `p_i` has `s_i = ceil_log2(h_i)` variables, where `h_i` is the -//! real number of evaluations from the input matrix column. `q'` stores exactly -//! those `h_i` evaluations; any implicit zero padding to `2^{s_i}` is only an MLE -//! evaluation convention and is not materialized inside the concatenation. +//! committed padded height of the input matrix column. `q'` stores exactly those +//! `h_i` evaluations so its layout matches the inner PCS column MLEs. //! //! The cumulative height sequence `t` tracks the starting position of each polynomial in `q'`: //! - `t[0] = 0` @@ -38,8 +37,8 @@ //! //! ## Commit Protocol //! -//! 1. For each input matrix `M_k` (with `h_k` rows and `w_k` columns), extract each -//! column as a polynomial with `h_k` evaluations. +//! 1. For each input matrix `M_k` (with committed height `h_k` and `w_k` columns), +//! extract each column as a polynomial with `h_k` evaluations. //! 2. Concatenate all column polynomials: `cat = p_0 || p_1 || ...` //! 3. Compute cumulative heights `t[i]`. //! 4. Pad `cat` to the next power of two (required for MLE representation). @@ -54,8 +53,8 @@ //! //! ### Correction factors for different-height polynomials //! -//! In the giga polynomial `q'`, each `p_i` occupies only `h_i` slots (padded to -//! `2^{s_i}` for MLE representation). When `s_i < m`, this is equivalent to +//! In the giga polynomial `q'`, each `p_i` occupies `h_i` committed slots. When +//! `s_i < m`, this is equivalent to //! zero-padding `p_i` to `m` variables: //! //! ```text @@ -201,7 +200,7 @@ impl Serialize for Jagged { /// /// # Arguments /// * `pp` — Prover parameters for `InnerPcs`. -/// * `rmms` — Non-empty sequence of row-major matrices. This function uses each matrix's height exactly as given. +/// * `rmms` — Non-empty sequence of row-major matrices. /// /// # Errors /// Returns `Error::InvalidPcsParam` if `rmms` is empty or all matrices are empty. @@ -222,10 +221,10 @@ pub fn jagged_commit> .flat_map(|rmm| rmm.to_mles().into_iter().map(Arc::new)) .collect_vec(); - // --- Step 1: Compute cumulative heights from real matrix heights --- + // --- Step 1: Compute cumulative heights from committed matrix heights --- let mut poly_heights: Vec = Vec::new(); for rmm in &rmms { - let num_rows = rmm.occupied_physical_rows(); + let num_rows = rmm.height(); let num_cols = rmm.width(); if num_rows == 0 { @@ -273,7 +272,7 @@ pub fn jagged_commit> let mut poly_idx = 0; for rmm in rmms { let n_cols = rmm.width(); - let n_rows = rmm.occupied_physical_rows(); + let n_rows = rmm.height(); let n_cells = n_cols * n_rows; // The start position in `concatenated` for this matrix's block of polynomials. @@ -362,7 +361,7 @@ fn default_reshape_log_height usize { let max_poly_height = rmms .iter() - .map(|rmm| rmm.occupied_physical_rows()) + .map(|rmm| rmm.height()) .max() .unwrap_or(1); pp.get_max_message_size_log() @@ -1060,8 +1059,8 @@ mod tests { } #[test] - fn test_jagged_commit_uses_real_heights() { - // 3x1 + 5x2 → heights [3, 5, 5], not [4, 8, 8]. + fn test_jagged_commit_uses_committed_heights() { + // 3x1 + 5x2 are padded by RowMajorMatrix to committed heights [4, 8, 8]. let reshape_log_height = 4; let (pp, _vp) = setup_pcs::(reshape_log_height); let m1 = make_rmm(3, 1); @@ -1071,14 +1070,80 @@ mod tests { .expect("commit should succeed"); assert_eq!(comm.num_polys(), 3); - assert_eq!(comm.poly_heights, vec![3, 5, 5]); - assert_eq!(comm.cumulative_heights, vec![0, 3, 8, 13]); - assert_eq!(comm.total_evaluations(), 13); + assert_eq!(comm.poly_heights, vec![4, 8, 8]); + assert_eq!(comm.cumulative_heights, vec![0, 4, 12, 20]); + assert_eq!(comm.total_evaluations(), 20); + assert_eq!(comm.polys[0].occupied_len(), 4); + assert_eq!(comm.polys[1].occupied_len(), 8); + assert_eq!(comm.polys[2].occupied_len(), 8); } #[test] - fn test_jagged_commit_uses_unpadded_physical_rotation_height() { - // 3 logical rows with 4-way rotation occupy 12 real physical rows, padded to 16. + fn test_jagged_commit_layout_matches_custom_padding() { + let mut rng = thread_rng(); + + let num_rows = 1023usize; + let num_cols = 2usize; + let reshape_log_height = 8; + let (pp, vp) = setup_pcs::(reshape_log_height); + + let values: Vec = (0..num_rows * num_cols) + .map(|i| F::from_canonical_u64(i as u64 + 1)) + .collect(); + let mut rmm = WitnessRowMajorMatrix::new_by_inner_matrix( + RowMajorMatrix::new(values, num_cols), + InstancePaddingStrategy::Custom(Arc::new(|_, _| 99)), + ); + rmm.padding_by_strategy(); + + let committed_height = rmm.height(); + assert_eq!(rmm.occupied_physical_rows(), num_rows); + assert_eq!(committed_height, 1024); + + let col_polys: Vec> = (0..num_cols) + .map(|c| { + (0..committed_height) + .map(|r| rmm.values[r * num_cols + c]) + .collect() + }) + .collect(); + + let mut transcript_p = BasicTranscript::::new(b"jagged_custom_padding_layout"); + let comm = jagged_commit::(&pp, vec![rmm], reshape_log_height).expect("commit"); + Pcs::write_commitment(&comm.to_commitment().inner, &mut transcript_p).unwrap(); + + assert_eq!(comm.poly_heights, vec![committed_height; num_cols]); + assert_eq!(comm.cumulative_heights, vec![0, committed_height, committed_height * 2]); + assert_eq!(comm.polys[0].occupied_len(), committed_height); + assert_eq!(comm.polys[1].occupied_len(), committed_height); + + let point: Vec = (0..ceil_log2(committed_height)) + .map(|_| E::random(&mut rng)) + .collect(); + let evals: Vec = col_polys + .iter() + .map(|col| eval_column_poly_at_point(col, &point)) + .collect(); + + let proof = jagged_batch_open::(&pp, &comm, &point, &evals, &mut transcript_p) + .expect("batch open"); + + let mut transcript_v = BasicTranscript::::new(b"jagged_custom_padding_layout"); + Pcs::write_commitment(&comm.to_commitment().inner, &mut transcript_v).unwrap(); + jagged_batch_verify::( + &vp, + &comm.to_commitment(), + &point, + &evals, + &proof, + &mut transcript_v, + ) + .expect("batch verify"); + } + + #[test] + fn test_jagged_commit_uses_committed_rotation_height() { + // 3 logical rows with 4-way rotation occupy 12 physical rows, padded to 16. let reshape_log_height = 4; let (pp, _vp) = setup_pcs::(reshape_log_height); let rmm = WitnessRowMajorMatrix::new_by_rotation(3, 2, 2, InstancePaddingStrategy::Default); @@ -1087,9 +1152,9 @@ mod tests { .expect("commit should succeed"); assert_eq!(comm.num_polys(), 2); - assert_eq!(comm.poly_heights, vec![12, 12]); - assert_eq!(comm.cumulative_heights, vec![0, 12, 24]); - assert_eq!(comm.total_evaluations(), 24); + assert_eq!(comm.poly_heights, vec![16, 16]); + assert_eq!(comm.cumulative_heights, vec![0, 16, 32]); + assert_eq!(comm.total_evaluations(), 32); } #[test] @@ -1442,13 +1507,13 @@ mod tests { let heights = [1023usize, 777, 513]; let log_heights: Vec = heights.iter().map(|&h| ceil_log2(h)).collect(); let max_s = *log_heights.iter().max().unwrap(); - let total_evals: usize = heights.iter().sum(); let reshape_log_height = 8; let h = 1usize << reshape_log_height; - let w = total_evals.div_ceil(h); let (pp, vp) = setup_pcs::(reshape_log_height); let rmms: Vec<_> = heights.iter().map(|&h| make_rmm(h, 1)).collect(); + let total_evals: usize = rmms.iter().map(|rmm| rmm.height()).sum(); + let w = total_evals.div_ceil(h); let col_polys: Vec> = rmms .iter() From 193131227f40fe5c080f8cb5567616f6780fd246 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 8 Jun 2026 14:38:49 +0800 Subject: [PATCH 2/6] fix(mpcs): compact jagged q over occupied rows --- crates/mpcs/src/jagged/mod.rs | 254 ++++++++++++++++++++++++++-------- 1 file changed, 195 insertions(+), 59 deletions(-) diff --git a/crates/mpcs/src/jagged/mod.rs b/crates/mpcs/src/jagged/mod.rs index 48ae9fa..3fc9508 100644 --- a/crates/mpcs/src/jagged/mod.rs +++ b/crates/mpcs/src/jagged/mod.rs @@ -21,8 +21,9 @@ //! ## Cumulative Heights //! //! Each polynomial `p_i` has `s_i = ceil_log2(h_i)` variables, where `h_i` is the -//! committed padded height of the input matrix column. `q'` stores exactly those -//! `h_i` evaluations so its layout matches the inner PCS column MLEs. +//! occupied physical height of the input matrix column. `q'` stores exactly those +//! `h_i` occupied evaluations; any matrix padding remains outside the jagged q +//! layout. //! //! The cumulative height sequence `t` tracks the starting position of each polynomial in `q'`: //! - `t[0] = 0` @@ -37,8 +38,8 @@ //! //! ## Commit Protocol //! -//! 1. For each input matrix `M_k` (with committed height `h_k` and `w_k` columns), -//! extract each column as a polynomial with `h_k` evaluations. +//! 1. For each input matrix `M_k` (with occupied physical height `h_k` and `w_k` +//! columns), extract each occupied column as a polynomial with `h_k` evaluations. //! 2. Concatenate all column polynomials: `cat = p_0 || p_1 || ...` //! 3. Compute cumulative heights `t[i]`. //! 4. Pad `cat` to the next power of two (required for MLE representation). @@ -53,7 +54,7 @@ //! //! ### Correction factors for different-height polynomials //! -//! In the giga polynomial `q'`, each `p_i` occupies `h_i` committed slots. When +//! In the giga polynomial `q'`, each `p_i` occupies `h_i` occupied slots. When //! `s_i < m`, this is equivalent to //! zero-padding `p_i` to `m` variables: //! @@ -221,10 +222,10 @@ pub fn jagged_commit> .flat_map(|rmm| rmm.to_mles().into_iter().map(Arc::new)) .collect_vec(); - // --- Step 1: Compute cumulative heights from committed matrix heights --- + // --- Step 1: Compute cumulative heights from occupied matrix heights --- let mut poly_heights: Vec = Vec::new(); for rmm in &rmms { - let num_rows = rmm.height(); + let num_rows = rmm.occupied_physical_rows(); let num_cols = rmm.width(); if num_rows == 0 { @@ -272,25 +273,20 @@ pub fn jagged_commit> let mut poly_idx = 0; for rmm in rmms { let n_cols = rmm.width(); - let n_rows = rmm.height(); + let n_rows = rmm.occupied_physical_rows(); let n_cells = n_cols * n_rows; // The start position in `concatenated` for this matrix's block of polynomials. let start = cumulative_heights[poly_idx]; - // Step 3: Transpose — write each column j of `rmm` (= one polynomial) - // into its corresponding contiguous slice in `concatenated`. + // Step 3: Write each committed column prefix into its compact q slice. + // Use the column MLEs as the source so q matches the committed polynomial + // ordering even when the row-major backing has backend-specific layout details. (0..n_cols) .into_par_iter() .zip(concatenated[start..start + n_cells].par_chunks_mut(n_rows)) .for_each(|(j, chunk)| { - rmm.values - .iter() - .take(n_cells) - .skip(j) - .step_by(n_cols) - .zip_eq(chunk.iter_mut()) - .for_each(|(v, out)| *out = *v); + chunk.copy_from_slice(&polys[poly_idx + j].get_base_field_vec()[..n_rows]); }); poly_idx += n_cols; @@ -317,23 +313,43 @@ pub fn jagged_commit> row_major.set_len(h * group_cols) }; - (0..group_cols).into_par_iter().for_each(|local_col| { - let global_col = group_start_col + local_col; - let src_start = global_col * h; - let col_len = total_size.saturating_sub(src_start).min(h); - let dst = unsafe { - &mut *std::ptr::slice_from_raw_parts_mut( - row_major.as_ptr() as *mut E::BaseField, - h * group_cols, - ) - }; - for b in 0..col_len { - dst[b * group_cols + local_col] = concatenated[src_start + b]; - } - for b in col_len..h { - dst[b * group_cols + local_col] = E::BaseField::ZERO; + let group_start = group_start_col * h; + let group_len = total_size.saturating_sub(group_start).min(h * group_cols); + let full_cols = group_len / h; + let tail_rows = group_len % h; + + let write_full_cols = |row: usize, dst: &mut [E::BaseField]| { + for (local_col, value) in dst.iter_mut().take(full_cols).enumerate() { + let global_col = group_start_col + local_col; + *value = concatenated[global_col * h + row]; } - }); + }; + + if tail_rows == 0 { + row_major + .par_chunks_mut(group_cols) + .enumerate() + .for_each(|(row, dst)| write_full_cols(row, dst)); + } else { + let tail_col = full_cols; + let (occupied_tail_rows, padded_tail_rows) = + row_major.split_at_mut(tail_rows * group_cols); + occupied_tail_rows + .par_chunks_mut(group_cols) + .enumerate() + .for_each(|(row, dst)| { + write_full_cols(row, dst); + dst[tail_col] = concatenated[(group_start_col + tail_col) * h + row]; + }); + padded_tail_rows + .par_chunks_mut(group_cols) + .enumerate() + .for_each(|(row_offset, dst)| { + let row = tail_rows + row_offset; + write_full_cols(row, dst); + dst[tail_col] = E::BaseField::ZERO; + }); + } giga_rmms.push(WitnessRowMajorMatrix::::new_by_values( row_major, @@ -361,7 +377,7 @@ fn default_reshape_log_height usize { let max_poly_height = rmms .iter() - .map(|rmm| rmm.height()) + .map(|rmm| rmm.occupied_physical_rows()) .max() .unwrap_or(1); pp.get_max_message_size_log() @@ -1010,6 +1026,7 @@ mod tests { type F = Goldilocks; type E = GoldilocksExt2; type Pcs = Basefold; + type JPcs = Jagged; fn make_rmm(num_rows: usize, num_cols: usize) -> WitnessRowMajorMatrix { let values: Vec = (0..num_rows * num_cols) @@ -1059,8 +1076,9 @@ mod tests { } #[test] - fn test_jagged_commit_uses_committed_heights() { - // 3x1 + 5x2 are padded by RowMajorMatrix to committed heights [4, 8, 8]. + fn test_jagged_commit_uses_occupied_heights() { + // 3x1 + 5x2 are padded by RowMajorMatrix to committed heights [4, 8, 8], + // but jagged q only stores occupied rows [3, 5, 5]. let reshape_log_height = 4; let (pp, _vp) = setup_pcs::(reshape_log_height); let m1 = make_rmm(3, 1); @@ -1070,16 +1088,16 @@ mod tests { .expect("commit should succeed"); assert_eq!(comm.num_polys(), 3); - assert_eq!(comm.poly_heights, vec![4, 8, 8]); - assert_eq!(comm.cumulative_heights, vec![0, 4, 12, 20]); - assert_eq!(comm.total_evaluations(), 20); + assert_eq!(comm.poly_heights, vec![3, 5, 5]); + assert_eq!(comm.cumulative_heights, vec![0, 3, 8, 13]); + assert_eq!(comm.total_evaluations(), 13); assert_eq!(comm.polys[0].occupied_len(), 4); assert_eq!(comm.polys[1].occupied_len(), 8); assert_eq!(comm.polys[2].occupied_len(), 8); } #[test] - fn test_jagged_commit_layout_matches_custom_padding() { + fn test_jagged_commit_uses_zero_rmm_padding_with_compact_q() { let mut rng = thread_rng(); let num_rows = 1023usize; @@ -1087,37 +1105,48 @@ mod tests { let reshape_log_height = 8; let (pp, vp) = setup_pcs::(reshape_log_height); - let values: Vec = (0..num_rows * num_cols) - .map(|i| F::from_canonical_u64(i as u64 + 1)) - .collect(); - let mut rmm = WitnessRowMajorMatrix::new_by_inner_matrix( - RowMajorMatrix::new(values, num_cols), - InstancePaddingStrategy::Custom(Arc::new(|_, _| 99)), - ); - rmm.padding_by_strategy(); + let rmm = make_rmm(num_rows, num_cols); let committed_height = rmm.height(); assert_eq!(rmm.occupied_physical_rows(), num_rows); assert_eq!(committed_height, 1024); + assert!( + rmm.values[num_rows * num_cols..] + .iter() + .all(|value| *value == F::ZERO) + ); let col_polys: Vec> = (0..num_cols) .map(|c| { - (0..committed_height) + (0..num_rows) .map(|r| rmm.values[r * num_cols + c]) .collect() }) .collect(); - let mut transcript_p = BasicTranscript::::new(b"jagged_custom_padding_layout"); + let mut transcript_p = BasicTranscript::::new(b"jagged_compact_q_layout"); let comm = jagged_commit::(&pp, vec![rmm], reshape_log_height).expect("commit"); Pcs::write_commitment(&comm.to_commitment().inner, &mut transcript_p).unwrap(); - assert_eq!(comm.poly_heights, vec![committed_height; num_cols]); - assert_eq!(comm.cumulative_heights, vec![0, committed_height, committed_height * 2]); + assert_eq!(comm.poly_heights, vec![num_rows; num_cols]); + assert_eq!( + comm.cumulative_heights, + vec![0, num_rows, num_rows * num_cols] + ); assert_eq!(comm.polys[0].occupied_len(), committed_height); assert_eq!(comm.polys[1].occupied_len(), committed_height); + assert!( + comm.polys[0].get_base_field_vec()[num_rows..] + .iter() + .all(|value| *value == F::ZERO) + ); + assert!( + comm.polys[1].get_base_field_vec()[num_rows..] + .iter() + .all(|value| *value == F::ZERO) + ); - let point: Vec = (0..ceil_log2(committed_height)) + let point: Vec = (0..ceil_log2(num_rows)) .map(|_| E::random(&mut rng)) .collect(); let evals: Vec = col_polys @@ -1128,7 +1157,7 @@ mod tests { let proof = jagged_batch_open::(&pp, &comm, &point, &evals, &mut transcript_p) .expect("batch open"); - let mut transcript_v = BasicTranscript::::new(b"jagged_custom_padding_layout"); + let mut transcript_v = BasicTranscript::::new(b"jagged_compact_q_layout"); Pcs::write_commitment(&comm.to_commitment().inner, &mut transcript_v).unwrap(); jagged_batch_verify::( &vp, @@ -1142,7 +1171,7 @@ mod tests { } #[test] - fn test_jagged_commit_uses_committed_rotation_height() { + fn test_jagged_commit_uses_occupied_rotation_height() { // 3 logical rows with 4-way rotation occupy 12 physical rows, padded to 16. let reshape_log_height = 4; let (pp, _vp) = setup_pcs::(reshape_log_height); @@ -1152,9 +1181,11 @@ mod tests { .expect("commit should succeed"); assert_eq!(comm.num_polys(), 2); - assert_eq!(comm.poly_heights, vec![16, 16]); - assert_eq!(comm.cumulative_heights, vec![0, 16, 32]); - assert_eq!(comm.total_evaluations(), 32); + assert_eq!(comm.poly_heights, vec![12, 12]); + assert_eq!(comm.cumulative_heights, vec![0, 12, 24]); + assert_eq!(comm.total_evaluations(), 24); + assert_eq!(comm.polys[0].occupied_len(), 16); + assert_eq!(comm.polys[1].occupied_len(), 16); } #[test] @@ -1356,6 +1387,50 @@ mod tests { .expect("batch verify"); } + #[test] + fn test_jagged_trait_batch_open_verify_padded_witness_compact_q() { + let mut rng = thread_rng(); + + let num_rows = 1023usize; + let num_cols = 2usize; + let reshape_log_height = 8; + let (pp, vp) = setup_pcs::(reshape_log_height); + let rmm = make_rmm(num_rows, num_cols); + + let mut transcript_p = BasicTranscript::::new(b"jagged_trait_compact_q"); + let comm = jagged_commit::(&pp, vec![rmm], reshape_log_height).expect("commit"); + JPcs::write_commitment(&comm.to_commitment(), &mut transcript_p).unwrap(); + + let polys = JPcs::get_arc_mle_witness_from_commitment(&comm); + assert_eq!(polys.len(), num_cols); + assert_eq!(polys[0].occupied_len(), num_rows.next_power_of_two()); + assert_eq!(comm.poly_heights, vec![num_rows; num_cols]); + + let point: Vec = (0..polys[0].num_vars()) + .map(|_| E::random(&mut rng)) + .collect(); + let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); + + let proof = JPcs::batch_open( + &pp, + vec![(&comm, vec![(point.clone(), evals.clone())])], + &mut transcript_p, + ) + .expect("batch open"); + + let mut transcript_v = BasicTranscript::::new(b"jagged_trait_compact_q"); + let pure_comm = JPcs::get_pure_commitment(&comm); + JPcs::write_commitment(&pure_comm, &mut transcript_v).unwrap(); + + JPcs::batch_verify( + &vp, + vec![(pure_comm, vec![(point.len(), (point, evals))])], + &proof, + &mut transcript_v, + ) + .expect("batch verify"); + } + #[test] fn test_jagged_batch_open_verify_soundness() { let mut rng = thread_rng(); @@ -1512,7 +1587,7 @@ mod tests { let (pp, vp) = setup_pcs::(reshape_log_height); let rmms: Vec<_> = heights.iter().map(|&h| make_rmm(h, 1)).collect(); - let total_evals: usize = rmms.iter().map(|rmm| rmm.height()).sum(); + let total_evals: usize = rmms.iter().map(|rmm| rmm.occupied_physical_rows()).sum(); let w = total_evals.div_ceil(h); let col_polys: Vec> = rmms @@ -1548,4 +1623,65 @@ mod tests { ) .expect("batch verify"); } + + #[test] + fn test_jagged_trait_batch_open_verify_multiple_rounds_compact_q() { + let mut rng = thread_rng(); + + let reshape_log_height = 8; + let (pp, vp) = setup_pcs::(reshape_log_height); + let comm0 = jagged_commit::(&pp, vec![make_rmm(1023, 2)], reshape_log_height) + .expect("commit 0"); + let comm1 = jagged_commit::(&pp, vec![make_rmm(777, 1)], reshape_log_height) + .expect("commit 1"); + + let mut transcript_p = BasicTranscript::::new(b"jagged_trait_multi_round"); + JPcs::write_commitment(&comm0.to_commitment(), &mut transcript_p).unwrap(); + JPcs::write_commitment(&comm1.to_commitment(), &mut transcript_p).unwrap(); + + let polys0 = JPcs::get_arc_mle_witness_from_commitment(&comm0); + let point0: Vec = (0..polys0[0].num_vars()) + .map(|_| E::random(&mut rng)) + .collect(); + let evals0 = polys0 + .iter() + .map(|poly| poly.evaluate(&point0)) + .collect_vec(); + + let polys1 = JPcs::get_arc_mle_witness_from_commitment(&comm1); + let point1: Vec = (0..polys1[0].num_vars()) + .map(|_| E::random(&mut rng)) + .collect(); + let evals1 = polys1 + .iter() + .map(|poly| poly.evaluate(&point1)) + .collect_vec(); + + let proof = JPcs::batch_open( + &pp, + vec![ + (&comm0, vec![(point0.clone(), evals0.clone())]), + (&comm1, vec![(point1.clone(), evals1.clone())]), + ], + &mut transcript_p, + ) + .expect("batch open"); + + let pure_comm0 = JPcs::get_pure_commitment(&comm0); + let pure_comm1 = JPcs::get_pure_commitment(&comm1); + let mut transcript_v = BasicTranscript::::new(b"jagged_trait_multi_round"); + JPcs::write_commitment(&pure_comm0, &mut transcript_v).unwrap(); + JPcs::write_commitment(&pure_comm1, &mut transcript_v).unwrap(); + + JPcs::batch_verify( + &vp, + vec![ + (pure_comm0, vec![(point0.len(), (point0, evals0))]), + (pure_comm1, vec![(point1.len(), (point1, evals1))]), + ], + &proof, + &mut transcript_v, + ) + .expect("batch verify"); + } } From ce7f94dd70e1620f440abb8e33d6c958c63571ee Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 8 Jun 2026 20:30:38 +0800 Subject: [PATCH 3/6] rollback unnessesary comments --- crates/mpcs/src/jagged/mod.rs | 220 +++------------------------------- 1 file changed, 16 insertions(+), 204 deletions(-) diff --git a/crates/mpcs/src/jagged/mod.rs b/crates/mpcs/src/jagged/mod.rs index 3fc9508..8257563 100644 --- a/crates/mpcs/src/jagged/mod.rs +++ b/crates/mpcs/src/jagged/mod.rs @@ -21,9 +21,9 @@ //! ## Cumulative Heights //! //! Each polynomial `p_i` has `s_i = ceil_log2(h_i)` variables, where `h_i` is the -//! occupied physical height of the input matrix column. `q'` stores exactly those -//! `h_i` occupied evaluations; any matrix padding remains outside the jagged q -//! layout. +//! real number of evaluations from the input matrix column. `q'` stores exactly +//! those `h_i` evaluations; any implicit zero padding to `2^{s_i}` is only an MLE +//! evaluation convention and is not materialized inside the concatenation. //! //! The cumulative height sequence `t` tracks the starting position of each polynomial in `q'`: //! - `t[0] = 0` @@ -38,8 +38,8 @@ //! //! ## Commit Protocol //! -//! 1. For each input matrix `M_k` (with occupied physical height `h_k` and `w_k` -//! columns), extract each occupied column as a polynomial with `h_k` evaluations. +//! 1. For each input matrix `M_k` (with `h_k` rows and `w_k` columns), extract each +//! column as a polynomial with `h_k` evaluations. //! 2. Concatenate all column polynomials: `cat = p_0 || p_1 || ...` //! 3. Compute cumulative heights `t[i]`. //! 4. Pad `cat` to the next power of two (required for MLE representation). @@ -54,8 +54,8 @@ //! //! ### Correction factors for different-height polynomials //! -//! In the giga polynomial `q'`, each `p_i` occupies `h_i` occupied slots. When -//! `s_i < m`, this is equivalent to +//! In the giga polynomial `q'`, each `p_i` occupies only `h_i` slots (padded to +//! `2^{s_i}` for MLE representation). When `s_i < m`, this is equivalent to //! zero-padding `p_i` to `m` variables: //! //! ```text @@ -201,7 +201,7 @@ impl Serialize for Jagged { /// /// # Arguments /// * `pp` — Prover parameters for `InnerPcs`. -/// * `rmms` — Non-empty sequence of row-major matrices. +/// * `rmms` — Non-empty sequence of row-major matrices. This function uses each matrix's height exactly as given. /// /// # Errors /// Returns `Error::InvalidPcsParam` if `rmms` is empty or all matrices are empty. @@ -222,7 +222,7 @@ pub fn jagged_commit> .flat_map(|rmm| rmm.to_mles().into_iter().map(Arc::new)) .collect_vec(); - // --- Step 1: Compute cumulative heights from occupied matrix heights --- + // --- Step 1: Compute cumulative heights from real matrix heights --- let mut poly_heights: Vec = Vec::new(); for rmm in &rmms { let num_rows = rmm.occupied_physical_rows(); @@ -279,9 +279,7 @@ pub fn jagged_commit> // The start position in `concatenated` for this matrix's block of polynomials. let start = cumulative_heights[poly_idx]; - // Step 3: Write each committed column prefix into its compact q slice. - // Use the column MLEs as the source so q matches the committed polynomial - // ordering even when the row-major backing has backend-specific layout details. + // Step 3: write each committed column prefix into its compact q slice. (0..n_cols) .into_par_iter() .zip(concatenated[start..start + n_cells].par_chunks_mut(n_rows)) @@ -1026,7 +1024,6 @@ mod tests { type F = Goldilocks; type E = GoldilocksExt2; type Pcs = Basefold; - type JPcs = Jagged; fn make_rmm(num_rows: usize, num_cols: usize) -> WitnessRowMajorMatrix { let values: Vec = (0..num_rows * num_cols) @@ -1076,9 +1073,8 @@ mod tests { } #[test] - fn test_jagged_commit_uses_occupied_heights() { - // 3x1 + 5x2 are padded by RowMajorMatrix to committed heights [4, 8, 8], - // but jagged q only stores occupied rows [3, 5, 5]. + fn test_jagged_commit_uses_real_heights() { + // 3x1 + 5x2 → heights [3, 5, 5], not [4, 8, 8]. let reshape_log_height = 4; let (pp, _vp) = setup_pcs::(reshape_log_height); let m1 = make_rmm(3, 1); @@ -1091,88 +1087,11 @@ mod tests { assert_eq!(comm.poly_heights, vec![3, 5, 5]); assert_eq!(comm.cumulative_heights, vec![0, 3, 8, 13]); assert_eq!(comm.total_evaluations(), 13); - assert_eq!(comm.polys[0].occupied_len(), 4); - assert_eq!(comm.polys[1].occupied_len(), 8); - assert_eq!(comm.polys[2].occupied_len(), 8); } #[test] - fn test_jagged_commit_uses_zero_rmm_padding_with_compact_q() { - let mut rng = thread_rng(); - - let num_rows = 1023usize; - let num_cols = 2usize; - let reshape_log_height = 8; - let (pp, vp) = setup_pcs::(reshape_log_height); - - let rmm = make_rmm(num_rows, num_cols); - - let committed_height = rmm.height(); - assert_eq!(rmm.occupied_physical_rows(), num_rows); - assert_eq!(committed_height, 1024); - assert!( - rmm.values[num_rows * num_cols..] - .iter() - .all(|value| *value == F::ZERO) - ); - - let col_polys: Vec> = (0..num_cols) - .map(|c| { - (0..num_rows) - .map(|r| rmm.values[r * num_cols + c]) - .collect() - }) - .collect(); - - let mut transcript_p = BasicTranscript::::new(b"jagged_compact_q_layout"); - let comm = jagged_commit::(&pp, vec![rmm], reshape_log_height).expect("commit"); - Pcs::write_commitment(&comm.to_commitment().inner, &mut transcript_p).unwrap(); - - assert_eq!(comm.poly_heights, vec![num_rows; num_cols]); - assert_eq!( - comm.cumulative_heights, - vec![0, num_rows, num_rows * num_cols] - ); - assert_eq!(comm.polys[0].occupied_len(), committed_height); - assert_eq!(comm.polys[1].occupied_len(), committed_height); - assert!( - comm.polys[0].get_base_field_vec()[num_rows..] - .iter() - .all(|value| *value == F::ZERO) - ); - assert!( - comm.polys[1].get_base_field_vec()[num_rows..] - .iter() - .all(|value| *value == F::ZERO) - ); - - let point: Vec = (0..ceil_log2(num_rows)) - .map(|_| E::random(&mut rng)) - .collect(); - let evals: Vec = col_polys - .iter() - .map(|col| eval_column_poly_at_point(col, &point)) - .collect(); - - let proof = jagged_batch_open::(&pp, &comm, &point, &evals, &mut transcript_p) - .expect("batch open"); - - let mut transcript_v = BasicTranscript::::new(b"jagged_compact_q_layout"); - Pcs::write_commitment(&comm.to_commitment().inner, &mut transcript_v).unwrap(); - jagged_batch_verify::( - &vp, - &comm.to_commitment(), - &point, - &evals, - &proof, - &mut transcript_v, - ) - .expect("batch verify"); - } - - #[test] - fn test_jagged_commit_uses_occupied_rotation_height() { - // 3 logical rows with 4-way rotation occupy 12 physical rows, padded to 16. + fn test_jagged_commit_uses_unpadded_physical_rotation_height() { + // 3 logical rows with 4-way rotation occupy 12 real physical rows, padded to 16. let reshape_log_height = 4; let (pp, _vp) = setup_pcs::(reshape_log_height); let rmm = WitnessRowMajorMatrix::new_by_rotation(3, 2, 2, InstancePaddingStrategy::Default); @@ -1184,8 +1103,6 @@ mod tests { assert_eq!(comm.poly_heights, vec![12, 12]); assert_eq!(comm.cumulative_heights, vec![0, 12, 24]); assert_eq!(comm.total_evaluations(), 24); - assert_eq!(comm.polys[0].occupied_len(), 16); - assert_eq!(comm.polys[1].occupied_len(), 16); } #[test] @@ -1387,50 +1304,6 @@ mod tests { .expect("batch verify"); } - #[test] - fn test_jagged_trait_batch_open_verify_padded_witness_compact_q() { - let mut rng = thread_rng(); - - let num_rows = 1023usize; - let num_cols = 2usize; - let reshape_log_height = 8; - let (pp, vp) = setup_pcs::(reshape_log_height); - let rmm = make_rmm(num_rows, num_cols); - - let mut transcript_p = BasicTranscript::::new(b"jagged_trait_compact_q"); - let comm = jagged_commit::(&pp, vec![rmm], reshape_log_height).expect("commit"); - JPcs::write_commitment(&comm.to_commitment(), &mut transcript_p).unwrap(); - - let polys = JPcs::get_arc_mle_witness_from_commitment(&comm); - assert_eq!(polys.len(), num_cols); - assert_eq!(polys[0].occupied_len(), num_rows.next_power_of_two()); - assert_eq!(comm.poly_heights, vec![num_rows; num_cols]); - - let point: Vec = (0..polys[0].num_vars()) - .map(|_| E::random(&mut rng)) - .collect(); - let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); - - let proof = JPcs::batch_open( - &pp, - vec![(&comm, vec![(point.clone(), evals.clone())])], - &mut transcript_p, - ) - .expect("batch open"); - - let mut transcript_v = BasicTranscript::::new(b"jagged_trait_compact_q"); - let pure_comm = JPcs::get_pure_commitment(&comm); - JPcs::write_commitment(&pure_comm, &mut transcript_v).unwrap(); - - JPcs::batch_verify( - &vp, - vec![(pure_comm, vec![(point.len(), (point, evals))])], - &proof, - &mut transcript_v, - ) - .expect("batch verify"); - } - #[test] fn test_jagged_batch_open_verify_soundness() { let mut rng = thread_rng(); @@ -1582,13 +1455,13 @@ mod tests { let heights = [1023usize, 777, 513]; let log_heights: Vec = heights.iter().map(|&h| ceil_log2(h)).collect(); let max_s = *log_heights.iter().max().unwrap(); + let total_evals: usize = heights.iter().sum(); let reshape_log_height = 8; let h = 1usize << reshape_log_height; + let w = total_evals.div_ceil(h); let (pp, vp) = setup_pcs::(reshape_log_height); let rmms: Vec<_> = heights.iter().map(|&h| make_rmm(h, 1)).collect(); - let total_evals: usize = rmms.iter().map(|rmm| rmm.occupied_physical_rows()).sum(); - let w = total_evals.div_ceil(h); let col_polys: Vec> = rmms .iter() @@ -1623,65 +1496,4 @@ mod tests { ) .expect("batch verify"); } - - #[test] - fn test_jagged_trait_batch_open_verify_multiple_rounds_compact_q() { - let mut rng = thread_rng(); - - let reshape_log_height = 8; - let (pp, vp) = setup_pcs::(reshape_log_height); - let comm0 = jagged_commit::(&pp, vec![make_rmm(1023, 2)], reshape_log_height) - .expect("commit 0"); - let comm1 = jagged_commit::(&pp, vec![make_rmm(777, 1)], reshape_log_height) - .expect("commit 1"); - - let mut transcript_p = BasicTranscript::::new(b"jagged_trait_multi_round"); - JPcs::write_commitment(&comm0.to_commitment(), &mut transcript_p).unwrap(); - JPcs::write_commitment(&comm1.to_commitment(), &mut transcript_p).unwrap(); - - let polys0 = JPcs::get_arc_mle_witness_from_commitment(&comm0); - let point0: Vec = (0..polys0[0].num_vars()) - .map(|_| E::random(&mut rng)) - .collect(); - let evals0 = polys0 - .iter() - .map(|poly| poly.evaluate(&point0)) - .collect_vec(); - - let polys1 = JPcs::get_arc_mle_witness_from_commitment(&comm1); - let point1: Vec = (0..polys1[0].num_vars()) - .map(|_| E::random(&mut rng)) - .collect(); - let evals1 = polys1 - .iter() - .map(|poly| poly.evaluate(&point1)) - .collect_vec(); - - let proof = JPcs::batch_open( - &pp, - vec![ - (&comm0, vec![(point0.clone(), evals0.clone())]), - (&comm1, vec![(point1.clone(), evals1.clone())]), - ], - &mut transcript_p, - ) - .expect("batch open"); - - let pure_comm0 = JPcs::get_pure_commitment(&comm0); - let pure_comm1 = JPcs::get_pure_commitment(&comm1); - let mut transcript_v = BasicTranscript::::new(b"jagged_trait_multi_round"); - JPcs::write_commitment(&pure_comm0, &mut transcript_v).unwrap(); - JPcs::write_commitment(&pure_comm1, &mut transcript_v).unwrap(); - - JPcs::batch_verify( - &vp, - vec![ - (pure_comm0, vec![(point0.len(), (point0, evals0))]), - (pure_comm1, vec![(point1.len(), (point1, evals1))]), - ], - &proof, - &mut transcript_v, - ) - .expect("batch verify"); - } } From 2b2c75ecfc587b3b4b421a49eda6f40bbc3e9ec2 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 8 Jun 2026 20:41:02 +0800 Subject: [PATCH 4/6] rollback unnessesary change --- crates/mpcs/src/jagged/mod.rs | 52 +++++++++++------------------------ 1 file changed, 16 insertions(+), 36 deletions(-) diff --git a/crates/mpcs/src/jagged/mod.rs b/crates/mpcs/src/jagged/mod.rs index 8257563..3c1f9b9 100644 --- a/crates/mpcs/src/jagged/mod.rs +++ b/crates/mpcs/src/jagged/mod.rs @@ -311,43 +311,23 @@ pub fn jagged_commit> row_major.set_len(h * group_cols) }; - let group_start = group_start_col * h; - let group_len = total_size.saturating_sub(group_start).min(h * group_cols); - let full_cols = group_len / h; - let tail_rows = group_len % h; - - let write_full_cols = |row: usize, dst: &mut [E::BaseField]| { - for (local_col, value) in dst.iter_mut().take(full_cols).enumerate() { - let global_col = group_start_col + local_col; - *value = concatenated[global_col * h + row]; + (0..group_cols).into_par_iter().for_each(|local_col| { + let global_col = group_start_col + local_col; + let src_start = global_col * h; + let col_len = total_size.saturating_sub(src_start).min(h); + let dst = unsafe { + &mut *std::ptr::slice_from_raw_parts_mut( + row_major.as_ptr() as *mut E::BaseField, + h * group_cols, + ) + }; + for b in 0..col_len { + dst[b * group_cols + local_col] = concatenated[src_start + b]; } - }; - - if tail_rows == 0 { - row_major - .par_chunks_mut(group_cols) - .enumerate() - .for_each(|(row, dst)| write_full_cols(row, dst)); - } else { - let tail_col = full_cols; - let (occupied_tail_rows, padded_tail_rows) = - row_major.split_at_mut(tail_rows * group_cols); - occupied_tail_rows - .par_chunks_mut(group_cols) - .enumerate() - .for_each(|(row, dst)| { - write_full_cols(row, dst); - dst[tail_col] = concatenated[(group_start_col + tail_col) * h + row]; - }); - padded_tail_rows - .par_chunks_mut(group_cols) - .enumerate() - .for_each(|(row_offset, dst)| { - let row = tail_rows + row_offset; - write_full_cols(row, dst); - dst[tail_col] = E::BaseField::ZERO; - }); - } + for b in col_len..h { + dst[b * group_cols + local_col] = E::BaseField::ZERO; + } + }); giga_rmms.push(WitnessRowMajorMatrix::::new_by_values( row_major, From 3b99622c356ed60fed1d6f48968feb619795211c Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 9 Jun 2026 20:14:33 +0800 Subject: [PATCH 5/6] Fuse jagged assist claim computation --- crates/mpcs/src/jagged/assist.rs | 79 ++++++++++++++++++++++++++------ crates/mpcs/src/jagged/mod.rs | 4 +- crates/mpcs/src/lib.rs | 4 +- crates/sumcheck/src/frontload.rs | 11 +++++ crates/sumcheck/src/prover.rs | 13 ++++++ crates/sumcheck/src/structs.rs | 1 + 6 files changed, 95 insertions(+), 17 deletions(-) diff --git a/crates/mpcs/src/jagged/assist.rs b/crates/mpcs/src/jagged/assist.rs index 51c5e3c..539f0e2 100644 --- a/crates/mpcs/src/jagged/assist.rs +++ b/crates/mpcs/src/jagged/assist.rs @@ -35,27 +35,64 @@ pub fn assist_sumcheck_prove( n_robp: usize, transcript: &mut impl Transcript, ) -> (IOPProof, Vec) { + let (_, proof, challenges) = assist_sumcheck_prove_impl( + z_row_padded, + rho_padded, + eq_col, + cumulative_heights, + n_robp, + transcript, + false, + ); + (proof, challenges) +} + +/// Compute the assist claimed sum, append it to the transcript, and run the +/// assist sumcheck prover using the same ROBP precompute. +pub fn assist_sumcheck_prove_and_append_claim( + z_row_padded: &[E], + rho_padded: &[E], + eq_col: &[E], + cumulative_heights: &[usize], + n_robp: usize, + transcript: &mut impl Transcript, +) -> (E, IOPProof, Vec) { + assist_sumcheck_prove_impl( + z_row_padded, + rho_padded, + eq_col, + cumulative_heights, + n_robp, + transcript, + true, + ) +} + +fn assist_sumcheck_prove_impl( + z_row_padded: &[E], + rho_padded: &[E], + eq_col: &[E], + cumulative_heights: &[usize], + n_robp: usize, + transcript: &mut impl Transcript, + append_claim: bool, +) -> (E, IOPProof, Vec) { let num_polys = cumulative_heights.len() - 1; let n_vars = 2 * n_robp; let max_degree: usize = 2; - // Write transcript header (must match verifier). - transcript.append_message(&n_vars.to_le_bytes()); - transcript.append_message(&max_degree.to_le_bytes()); - // Precompute per-step symbol matrices. let step_mats: Vec<[TransitionMatrix; 4]> = (0..n_robp) .map(|i| symbol_transition_matrices(z_row_padded[i], rho_padded[i])) .collect(); - // Extract Boolean bits in step-major layout: c_bits[i][y], d_bits[i][y]. - // c_bits[i][y] = bit_i(t_y), d_bits[i][y] = bit_i(t_{y+1}) - let mut c_bits = vec![vec![0usize; num_polys]; n_robp]; - let mut d_bits = vec![vec![0usize; num_polys]; n_robp]; + // Extract Boolean symbol pairs in step-major layout: + // cd_bits[i][y] = 2 * bit_i(t_y) + bit_i(t_{y+1}). + let mut cd_bits = vec![vec![0u8; num_polys]; n_robp]; for i in 0..n_robp { for y in 0..num_polys { - c_bits[i][y] = (cumulative_heights[y] >> i) & 1; - d_bits[i][y] = (cumulative_heights[y + 1] >> i) & 1; + cd_bits[i][y] = ((((cumulative_heights[y] >> i) & 1) << 1) + | ((cumulative_heights[y + 1] >> i) & 1)) as u8; } } @@ -89,14 +126,27 @@ pub fn assist_sumcheck_prove( let dst = &mut left[i]; let src = &right[0]; dst.into_par_iter().enumerate().for_each(|(y, dst_y)| { - let cd = c_bits[i][y] * 2 + d_bits[i][y]; + let cd = cd_bits[i][y] as usize; *dst_y = mat_vec_mul(&step_mats[i][cd], &src[y]); }); } + let source = source_vec(); + let mut claimed_sum = E::ZERO; + for y in 0..num_polys { + claimed_sum += eq_col[y] * dot4(&source, &bwd[0][y]); + } + if append_claim { + transcript.append_field_element_ext(&claimed_sum); + } + + // Write transcript header (must match verifier). + transcript.append_message(&n_vars.to_le_bytes()); + transcript.append_message(&max_degree.to_le_bytes()); + // Initialize weights and forward vector. let mut weights: Vec = eq_col[..num_polys].to_vec(); - let mut fwd: StateVec = source_vec(); + let mut fwd: StateVec = source; let mut challenges: Vec = Vec::with_capacity(n_vars); let mut proof_messages: Vec> = Vec::with_capacity(n_vars); @@ -138,7 +188,7 @@ pub fn assist_sumcheck_prove( .map(|chunk| { let mut local_bwd_sum = [[E::ZERO; ROBP_WIDTH]; 4]; for &y in chunk { - let cd = c_bits[i][y] * 2 + d_bits[i][y]; + let cd = cd_bits[i][y] as usize; let w = weights[y]; for s in 0..ROBP_WIDTH { local_bwd_sum[cd][s] += w * bwd[i + 1][y][s]; @@ -220,7 +270,7 @@ pub fn assist_sumcheck_prove( let start = chunk_idx * batch_size; for (j, w) in w_chunk.iter_mut().enumerate() { let y = start + j; - let cd = c_bits[i][y] * 2 + d_bits[i][y]; + let cd = cd_bits[i][y] as usize; *w *= eq_cd[cd]; } }); @@ -235,6 +285,7 @@ pub fn assist_sumcheck_prove( } ( + claimed_sum, IOPProof { proofs: proof_messages, }, diff --git a/crates/mpcs/src/jagged/mod.rs b/crates/mpcs/src/jagged/mod.rs index 3c1f9b9..e45dcb3 100644 --- a/crates/mpcs/src/jagged/mod.rs +++ b/crates/mpcs/src/jagged/mod.rs @@ -112,7 +112,9 @@ pub mod evaluator; pub mod sumcheck; mod types; -pub use assist::{assist_sumcheck_prove, compute_q_at_assist_point}; +pub use assist::{ + assist_sumcheck_prove, assist_sumcheck_prove_and_append_claim, compute_q_at_assist_point, +}; pub use evaluator::{evaluate_g, evaluate_g_backward, evaluate_g_forward}; pub use sumcheck::{JaggedSumcheckInput, QPrimeEvaluations, jagged_sumcheck_prove}; pub use types::{JaggedBatchOpenProof, JaggedCommitment, JaggedCommitmentWithWitness, JaggedProof}; diff --git a/crates/mpcs/src/lib.rs b/crates/mpcs/src/lib.rs index 4b7cf2e..3ba273f 100644 --- a/crates/mpcs/src/lib.rs +++ b/crates/mpcs/src/lib.rs @@ -282,8 +282,8 @@ pub mod jagged; pub use jagged::{ JAGGED_RESHAPE_GROUP_WIDTH, Jagged, JaggedBatchOpenProof, JaggedCommitment, JaggedCommitmentWithWitness, JaggedProof, JaggedSumcheckInput, assist_sumcheck_prove, - evaluate_g, evaluate_g_backward, evaluate_g_forward, jagged_batch_open, jagged_batch_verify, - jagged_commit, jagged_sumcheck_prove, + assist_sumcheck_prove_and_append_claim, evaluate_g, evaluate_g_backward, evaluate_g_forward, + jagged_batch_open, jagged_batch_verify, jagged_commit, jagged_sumcheck_prove, }; #[cfg(feature = "whir")] extern crate whir as whir_external; diff --git a/crates/sumcheck/src/frontload.rs b/crates/sumcheck/src/frontload.rs index 8515fec..bb85ba1 100644 --- a/crates/sumcheck/src/frontload.rs +++ b/crates/sumcheck/src/frontload.rs @@ -103,6 +103,7 @@ use crate::{ pub struct FrontloadProverState { pub challenges: Vec>, pub final_evaluations: Vec>, + pub claimed_sum: E, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -181,6 +182,7 @@ pub fn prove_2phase<'a, E: ExtensionField>( } let mut proofs = Vec::with_capacity(global_num_vars); let mut challenge: Option> = None; + let mut claimed_sum = E::ZERO; for round in 0..local_num_vars { workers.par_iter_mut().for_each(|worker| { @@ -208,6 +210,9 @@ pub fn prove_2phase<'a, E: ExtensionField>( }, ) }; + if round == 0 { + claimed_sum = evaluations[0] + evaluations[1]; + } evaluations.remove(0); transcript.append_field_element_exts(&evaluations); proofs.push(IOPProverMessage { evaluations }); @@ -249,6 +254,7 @@ pub fn prove_2phase<'a, E: ExtensionField>( FrontloadProverState { challenges, final_evaluations, + claimed_sum, }, ) } @@ -268,6 +274,7 @@ fn prove_inner<'a, E: ExtensionField>( let mut proof = Vec::with_capacity(num_vars); let mut challenge: Option> = None; + let mut claimed_sum = E::ZERO; for round in 0..num_vars { if let Some(challenge) = challenge.take() { @@ -276,6 +283,9 @@ fn prove_inner<'a, E: ExtensionField>( } let mut evaluations = state.round_evaluations(round); + if round == 0 { + claimed_sum = evaluations[0] + evaluations[1]; + } evaluations.remove(0); transcript.append_field_element_exts(&evaluations); proof.push(IOPProverMessage { evaluations }); @@ -293,6 +303,7 @@ fn prove_inner<'a, E: ExtensionField>( FrontloadProverState { challenges: state.challenges, final_evaluations, + claimed_sum, }, ) } diff --git a/crates/sumcheck/src/prover.rs b/crates/sumcheck/src/prover.rs index d75e3dc..dae2941 100644 --- a/crates/sumcheck/src/prover.rs +++ b/crates/sumcheck/src/prover.rs @@ -192,6 +192,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { max_num_variables, poly_meta: vec![], final_evaluations: Some(state.final_evaluations), + claimed_sum: state.claimed_sum, phase2_numvar: None, } } @@ -449,6 +450,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { poly: polynomial, poly_meta: poly_meta.unwrap_or_else(|| vec![PolyMeta::Normal; num_polys]), final_evaluations: None, + claimed_sum: E::ZERO, phase2_numvar, } } @@ -512,6 +514,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { exit_span!(start); assert!(uni_polys.len() > 1); + if self.round == 1 { + self.claimed_sum = uni_polys[0] + uni_polys[1]; + } // NOTE remove uni_polys.eval(0) from lagrange domain // as verifier can derive via claim - uni_polys.eval(1) uni_polys.remove(0); @@ -843,6 +848,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .collect_vec() } + pub fn claimed_sum(&self) -> E { + self.claimed_sum + } + pub fn expected_numvars_at_round(&self) -> usize { // first round start from 1 let num_vars = self.max_num_variables + 1 - self.round; @@ -1021,6 +1030,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { poly: polynomial, poly_meta, final_evaluations: None, + claimed_sum: E::ZERO, phase2_numvar: None, }; @@ -1130,6 +1140,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { exit_span!(start); assert!(uni_polys.len() > 1); + if self.round == 1 { + self.claimed_sum = uni_polys[0] + uni_polys[1]; + } // NOTE remove uni_polys.eval(0) from lagrange domain // as verifier can derive via claim - uni_polys.eval(1) uni_polys.remove(0); diff --git a/crates/sumcheck/src/structs.rs b/crates/sumcheck/src/structs.rs index ef5dc28..ba207d3 100644 --- a/crates/sumcheck/src/structs.rs +++ b/crates/sumcheck/src/structs.rs @@ -99,6 +99,7 @@ pub struct IOPProverState<'a, E: ExtensionField> { pub(crate) max_num_variables: usize, pub(crate) poly_meta: Vec, pub(crate) final_evaluations: Option>>, + pub(crate) claimed_sum: E, /// phase 1 and phase 2 sumcheck we share similar implementation /// thus this option variable only use for phase 1 sumcheck to mark how many variables belongs to phase 2 pub(crate) phase2_numvar: Option, From 14b672a2f19d10440de8571ad685bdd1c4f73e6e Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 19 Jun 2026 13:24:51 +0800 Subject: [PATCH 6/6] Fix jagged assist clippy warnings --- crates/mpcs/src/jagged/assist.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/mpcs/src/jagged/assist.rs b/crates/mpcs/src/jagged/assist.rs index 539f0e2..3d3ed75 100644 --- a/crates/mpcs/src/jagged/assist.rs +++ b/crates/mpcs/src/jagged/assist.rs @@ -89,9 +89,9 @@ fn assist_sumcheck_prove_impl( // Extract Boolean symbol pairs in step-major layout: // cd_bits[i][y] = 2 * bit_i(t_y) + bit_i(t_{y+1}). let mut cd_bits = vec![vec![0u8; num_polys]; n_robp]; - for i in 0..n_robp { - for y in 0..num_polys { - cd_bits[i][y] = ((((cumulative_heights[y] >> i) & 1) << 1) + for (i, cd_bits_i) in cd_bits.iter_mut().enumerate() { + for (y, cd_bit) in cd_bits_i.iter_mut().enumerate() { + *cd_bit = ((((cumulative_heights[y] >> i) & 1) << 1) | ((cumulative_heights[y + 1] >> i) & 1)) as u8; } } @@ -133,8 +133,8 @@ fn assist_sumcheck_prove_impl( let source = source_vec(); let mut claimed_sum = E::ZERO; - for y in 0..num_polys { - claimed_sum += eq_col[y] * dot4(&source, &bwd[0][y]); + for (y, eq) in eq_col.iter().enumerate().take(num_polys) { + claimed_sum += *eq * dot4(&source, &bwd[0][y]); } if append_claim { transcript.append_field_element_ext(&claimed_sum);