diff --git a/src/algebra/ntt/cooley_tukey.rs b/src/algebra/ntt/cooley_tukey.rs index 74cc5bdb..a9380b3e 100644 --- a/src/algebra/ntt/cooley_tukey.rs +++ b/src/algebra/ntt/cooley_tukey.rs @@ -356,6 +356,166 @@ impl NttEngine { size => self.ntt_recurse(values, roots, size), } } + + /// Output-pruned NTT (Sorensen-Burrus, radix-2 DIT). + /// + /// Computes the size-`size` NTT of `values` (zero-padded to `size` if + /// shorter) and returns the outputs at positions `indices`, in input + /// order. Output `j` equals the full NTT at position `indices[j]`. + /// + /// Walks the butterfly DAG backwards from `indices` to mark only the + /// cone of butterflies that contribute to the queried outputs, then + /// runs only the marked butterflies on the forward pass. Cost is + /// `O(size + indices.len() * log(size))` field operations, vs + /// `O(size * log(size))` for a full NTT. + /// + /// `size` must be a power of two. + #[allow(dead_code)] // public single-shot entry; batched callers use the plan-based path + pub fn ntt_partial(&self, values: &[F], size: usize, indices: &[usize]) -> Vec { + let plan = PartialNttPlan::new(size, indices); + let mut out = vec![F::ZERO; indices.len()]; + self.ntt_partial_with_plan_into(values, &plan, &mut out, 1); + out + } + + /// Run a pruned NTT using a precomputed plan and write outputs into + /// `out` at stride `stride` (so `out[j * stride]` holds the result for + /// `plan.indices[j]`). When `stride == 1`, output is contiguous. + /// + /// Sharing a single plan across many NTTs with the same `(size, indices)` + /// avoids re-running the O(size · log size) mask construction per call. + #[allow(clippy::significant_drop_tightening)] // roots guard intentionally held across DIT stages + pub fn ntt_partial_with_plan_into( + &self, + values: &[F], + plan: &PartialNttPlan, + out: &mut [F], + stride: usize, + ) { + let size = plan.size; + let indices = &plan.indices; + assert!(values.len() <= size, "input longer than NTT size"); + if indices.is_empty() { + return; + } + assert!( + out.len() > (indices.len() - 1) * stride, + "output buffer too small for stride" + ); + if size == 1 { + let v = values.first().copied().unwrap_or(F::ZERO); + for j in 0..indices.len() { + out[j * stride] = v; + } + return; + } + + let log_n = size.trailing_zeros() as usize; + let roots = self.roots_table(size); + + // Load bit-reversed input into work buffer, gated by mask[0]. + let mut work = vec![F::ZERO; size]; + let shift = (usize::BITS as usize) - log_n; + for (j, &c) in values.iter().enumerate() { + let rev = j.reverse_bits() >> shift; + if plan.mask[0][rev] { + work[rev] = c; + } + } + + // Forward DIT, skipping butterflies with no needed outputs. + // The shared roots table may hold roots at a larger order than `size`; + // `roots[k * twiddle_step]` retrieves ω_m^k regardless. + for stage in 1..=log_n { + let m = 1usize << stage; + let half = m >> 1; + let twiddle_step = roots.len() / m; + let cur = &plan.mask[stage]; + let mut base = 0; + while base < size { + for k in 0..half { + let a = base + k; + let b = a + half; + if cur[a] || cur[b] { + let w = roots[k * twiddle_step]; + let t = work[b] * w; + let u = work[a]; + work[a] = u + t; + work[b] = u - t; + } + } + base += m; + } + } + + for (j, &i) in indices.iter().enumerate() { + out[j * stride] = work[i]; + } + } +} + +/// Pruning plan for an output-pruned NTT. +/// +/// Holds the queried output indices and the precomputed per-stage +/// "needed-position" masks used by [`NttEngine::ntt_partial_with_plan_into`]. +/// Construct once per `(size, indices)` and reuse across multiple NTTs of +/// the same shape (e.g. all polynomials in an interleaved batch). +#[derive(Debug, Clone)] +pub struct PartialNttPlan { + size: usize, + indices: Vec, + /// `mask[stage][p]` is true iff position `p` after `stage` DIT stages + /// must be correct for the final outputs. `mask[log_n]` mirrors + /// `indices`; `mask[0]` selects the bit-reversed input positions that + /// must be loaded. + mask: Vec>, +} + +impl PartialNttPlan { + pub fn new(size: usize, indices: &[usize]) -> Self { + assert!(size.is_power_of_two(), "size must be a power of two"); + assert!( + indices.iter().all(|&i| i < size), + "query index out of range" + ); + let log_n = size.trailing_zeros() as usize; + let mut mask: Vec> = vec![vec![false; size]; log_n + 1]; + for &i in indices { + mask[log_n][i] = true; + } + for stage in (1..=log_n).rev() { + let m = 1usize << stage; + let half = m >> 1; + let (lo, hi) = mask.split_at_mut(stage); + let cur = &hi[0]; + let prev = &mut lo[stage - 1]; + let mut base = 0; + while base < size { + for k in 0..half { + let a = base + k; + let b = a + half; + if cur[a] || cur[b] { + prev[a] = true; + prev[b] = true; + } + } + base += m; + } + } + Self { + size, + indices: indices.to_vec(), + mask, + } + } + + pub const fn size(&self) -> usize { + self.size + } + + pub fn indices(&self) -> &[usize] { + &self.indices + } } /// Applies twiddle factors to a slice of field elements in-place. @@ -963,4 +1123,93 @@ mod tests { assert_eq!(values_ntt, expected_values); } + + #[test] + fn test_ntt_partial_matches_full() { + use ark_std::{rand::Rng, UniformRand}; + + let engine = NttEngine::::new_from_fftfield(); + let mut rng = ark_std::test_rng(); + + for &size in &[4usize, 16, 64, 256, 1024, 1 << 15] { + for _ in 0..8 { + // Full NTT reference. + let coeffs: Vec<_> = (0..size).map(|_| Field64::rand(&mut rng)).collect(); + let mut full = coeffs.clone(); + engine.ntt_batch(&mut full, size); + + // Random subset of varying size (cover dense + sparse). + let k = rng.gen_range(1..=size.min(64)); + let mut perm: Vec = (0..size).collect(); + for i in (1..size).rev() { + perm.swap(i, rng.gen_range(0..=i)); + } + let indices: Vec = perm.into_iter().take(k).collect(); + + let partial = engine.ntt_partial(&coeffs, size, &indices); + assert_eq!(partial.len(), indices.len()); + for (j, &idx) in indices.iter().enumerate() { + assert_eq!(partial[j], full[idx], "size={size} idx={idx}"); + } + } + } + } + + #[test] + fn test_ntt_partial_zero_padded_input() { + // M < N: input is zero-padded. Partial NTT must agree with full NTT + // computed over the zero-padded coefficient vector. + use ark_std::UniformRand; + + let engine = NttEngine::::new_from_fftfield(); + let mut rng = ark_std::test_rng(); + + for (m, size) in [(1usize, 4), (4, 16), (256, 1024), (1 << 13, 1 << 15)] { + let coeffs: Vec<_> = (0..m).map(|_| Field64::rand(&mut rng)).collect(); + let mut padded = coeffs.clone(); + padded.resize(size, Field64::ZERO); + engine.ntt_batch(&mut padded, size); + + let stride = (size / 8).max(1); + let indices: Vec = (0..size).step_by(stride).take(8).collect(); + let partial = engine.ntt_partial(&coeffs, size, &indices); + for (j, &idx) in indices.iter().enumerate() { + assert_eq!(partial[j], padded[idx], "m={m} size={size} idx={idx}"); + } + } + } + + #[test] + fn test_ntt_partial_edge_cases() { + use ark_std::UniformRand; + + let engine = NttEngine::::new_from_fftfield(); + let mut rng = ark_std::test_rng(); + + // Empty index set. + let coeffs: Vec<_> = (0..16).map(|_| Field64::rand(&mut rng)).collect(); + let out = engine.ntt_partial(&coeffs, 16, &[]); + assert!(out.is_empty()); + + // Singleton at position 0 and position N-1. + let coeffs: Vec<_> = (0..64).map(|_| Field64::rand(&mut rng)).collect(); + let mut full = coeffs.clone(); + engine.ntt_batch(&mut full, 64); + for idx in [0usize, 1, 31, 32, 63] { + let out = engine.ntt_partial(&coeffs, 64, &[idx]); + assert_eq!(out, vec![full[idx]], "idx={idx}"); + } + + // Repeated indices: each occurrence must yield the matching output. + let indices = vec![5usize, 5, 17, 5, 17]; + let out = engine.ntt_partial(&coeffs, 64, &indices); + for (j, &idx) in indices.iter().enumerate() { + assert_eq!(out[j], full[idx]); + } + + // size = 1: any indices must all return values[0]. + let single = vec![Field64::from(42)]; + let out = engine.ntt_partial(&single, 1, &[0, 0, 0]); + assert_eq!(out, vec![Field64::from(42); 3]); + } } diff --git a/src/algebra/ntt/mod.rs b/src/algebra/ntt/mod.rs index 525a4985..2103a28a 100644 --- a/src/algebra/ntt/mod.rs +++ b/src/algebra/ntt/mod.rs @@ -17,13 +17,15 @@ use std::{ }; use ark_ff::{FftField, Field}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use static_assertions::assert_obj_safe; #[cfg(feature = "tracing")] use tracing::instrument; -use self::matrix::MatrixMut; +use self::{cooley_tukey::NttEngine, matrix::MatrixMut}; pub use self::{ - cooley_tukey::{generator, intt, intt_batch, ntt, ntt_batch}, + cooley_tukey::{generator, intt, intt_batch, ntt, ntt_batch, PartialNttPlan}, transpose::transpose, wavelet::{inverse_wavelet_transform, wavelet_transform}, }; @@ -93,6 +95,76 @@ pub fn interleaved_rs_encode( engine.interleaved_encode(interleaved_coeffs, codeword_length, interleaving_depth) } +/// Partial Reed-Solomon encode that materialises only the rows at `indices`. +/// +/// Equivalent to taking [`interleaved_rs_encode`]'s output (a row-major +/// `(codeword_length, num_polys * interleaving_depth)` matrix) and +/// extracting the rows whose row index is in `indices`. Output layout is +/// row-major `(indices.len(), num_polys * interleaving_depth)`, byte-exact +/// against the full encode. +/// +/// Uses an output-pruned NTT (see [`PartialNttPlan`]) so peak memory and +/// flop count are both proportional to `indices.len()`, not +/// `codeword_length`. The pruning plan is built once for the index set and +/// reused across every polynomial × interleaving slot. +#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(coeffs, indices), fields(size = coeffs.len(), k = indices.len())))] +pub fn partial_interleaved_rs_encode( + coeffs: &[&[F]], + codeword_length: usize, + interleaving_depth: usize, + indices: &[usize], +) -> Vec { + if coeffs.is_empty() || indices.is_empty() { + return Vec::new(); + } + let poly_size = coeffs[0].len(); + for poly in coeffs { + assert_eq!(poly.len(), poly_size); + } + assert!(poly_size.is_multiple_of(interleaving_depth)); + let message_length = poly_size / interleaving_depth; + assert!(codeword_length.is_multiple_of(message_length)); + + let num_polys = coeffs.len(); + let num_cols = num_polys * interleaving_depth; + let k = indices.len(); + + let engine = NttEngine::::new_from_cache(); + let plan = PartialNttPlan::new(codeword_length, indices); + + // Build the submatrix in batch-major layout (`(num_cols, k)`): each + // contiguous k-chunk is one NTT's outputs. Batches are independent, so + // populate in parallel across (poly_idx, slot_idx). Final transpose + // converts to the row-major `(k, num_cols)` layout that + // `irs_commit::open_inner_from_coeffs` expects. + let mut batch_major = vec![F::ZERO; num_cols * k]; + + #[cfg(feature = "parallel")] + { + batch_major + .par_chunks_exact_mut(k) + .enumerate() + .for_each(|(col, dst)| { + let poly_idx = col / interleaving_depth; + let slot_idx = col % interleaving_depth; + let block = &coeffs[poly_idx] + [slot_idx * message_length..(slot_idx + 1) * message_length]; + engine.ntt_partial_with_plan_into(block, &plan, dst, 1); + }); + } + #[cfg(not(feature = "parallel"))] + for (col, dst) in batch_major.chunks_exact_mut(k).enumerate() { + let poly_idx = col / interleaving_depth; + let slot_idx = col % interleaving_depth; + let block = + &coeffs[poly_idx][slot_idx * message_length..(slot_idx + 1) * message_length]; + engine.ntt_partial_with_plan_into(block, &plan, dst, 1); + } + + transpose(&mut batch_major, num_cols, k); + batch_major +} + /// /// RS encode coefficients grouped in `interleaving_depth` contiguous blocks /// at the rate 1/`expansion`, then interleave the evaluations per point. @@ -350,4 +422,58 @@ mod tests { interleaved_rs_encode(&[poly.as_slice()], codeword_length, 1 << folding_factor); assert_eq!(expected, interleaved_ntt); } + + #[test] + fn test_partial_interleaved_rs_encode_matches_full() { + use ark_std::{rand::Rng, UniformRand}; + + let mut rng = ark_std::test_rng(); + + // Span several (num_polys, interleaving_depth, M, N) shapes covering + // the regimes that actually appear in whir_zk (single witness with + // depth 8, multi-witness with depth 1, M = N/4 blowup). + let cases = [ + (1usize, 1usize, 64usize, 256usize), + (1, 8, 16, 64), + (2, 4, 32, 128), + (1, 8, 1 << 10, 1 << 12), + ]; + + for (num_polys, interleaving_depth, message_length, codeword_length) in cases { + let poly_size = message_length * interleaving_depth; + let polys: Vec> = (0..num_polys) + .map(|_| (0..poly_size).map(|_| Field64::rand(&mut rng)).collect()) + .collect(); + let poly_slices: Vec<&[Field64]> = polys.iter().map(Vec::as_slice).collect(); + + let full = interleaved_rs_encode(&poly_slices, codeword_length, interleaving_depth); + let num_cols = num_polys * interleaving_depth; + assert_eq!(full.len(), codeword_length * num_cols); + + // Random subset including 0, last, and a sprinkling in between. + let k = rng.gen_range(1..=codeword_length.min(16)); + let mut perm: Vec = (0..codeword_length).collect(); + for i in (1..codeword_length).rev() { + perm.swap(i, rng.gen_range(0..=i)); + } + let indices: Vec = perm.into_iter().take(k).collect(); + + let partial = partial_interleaved_rs_encode( + &poly_slices, + codeword_length, + interleaving_depth, + &indices, + ); + assert_eq!(partial.len(), k * num_cols); + + for (row, &idx) in indices.iter().enumerate() { + let full_row = &full[idx * num_cols..(idx + 1) * num_cols]; + let partial_row = &partial[row * num_cols..(row + 1) * num_cols]; + assert_eq!( + partial_row, full_row, + "shape=({num_polys},{interleaving_depth},{message_length},{codeword_length}) row idx={idx}" + ); + } + } + } } diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index f7fbdbec..d7c0df49 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -465,6 +465,114 @@ where self.verify_inner(verifier_state, commitments, indices, points) } + /// Opens the commitment without requiring `witness.matrix` to be + /// populated. + /// + /// Functionally identical to [`open`]: same in-domain challenges, same + /// transcript bytes (submatrix hint + Merkle paths), same returned + /// [`Evaluations`]. The difference is that the queried codeword rows + /// are reconstructed from the supplied polynomial coefficients via an + /// output-pruned NTT (see [`ntt::partial_interleaved_rs_encode`]), so + /// the prover never materialises the full `(num_cols × codeword_length)` + /// codeword matrix held in `witness.matrix`. + /// + /// `coeffs_per_witness[i]` must be the same polynomial slice set that + /// would have produced `witnesses[i].matrix` via + /// [`interleaved_rs_encode`]. Mismatch results in verifier rejection. + pub fn open_from_coeffs( + &self, + prover_state: &mut ProverState, + coeffs_per_witness: &[&[&[M::Source]]], + witnesses: &[&Witness], + ) -> Evaluations + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + u8: Decoding<[H::U]>, + Hash: ProverMessage<[H::U]>, + { + assert_eq!(coeffs_per_witness.len(), witnesses.len()); + for witness in witnesses { + assert_eq!(witness.out_of_domain.points.len(), self.out_domain_samples); + assert_eq!( + witness.out_of_domain.matrix.len(), + self.out_domain_samples * self.num_vectors + ); + } + let (indices, points) = self.in_domain_challenges(prover_state); + self.open_inner_from_coeffs( + prover_state, + coeffs_per_witness, + witnesses, + &indices, + points, + ) + } + + /// Like [`open_from_coeffs`] but with caller-provided indices, mirroring + /// [`open_at_indices`]. Used for the Γ consistency check. + pub fn open_at_indices_from_coeffs( + &self, + prover_state: &mut ProverState, + coeffs_per_witness: &[&[&[M::Source]]], + witnesses: &[&Witness], + indices: &[usize], + ) -> Evaluations + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>, + { + assert!( + indices.iter().all(|&i| i < self.codeword_length), + "index out of bounds: all indices must be < codeword_length ({})", + self.codeword_length + ); + assert_eq!(coeffs_per_witness.len(), witnesses.len()); + let generator = self.generator(); + let points: Vec = indices.iter().map(|&i| generator.pow([i as u64])).collect(); + self.open_inner_from_coeffs(prover_state, coeffs_per_witness, witnesses, indices, points) + } + + /// Shared open logic for [`open_from_coeffs`] and [`open_at_indices_from_coeffs`]. + fn open_inner_from_coeffs( + &self, + prover_state: &mut ProverState, + coeffs_per_witness: &[&[&[M::Source]]], + witnesses: &[&Witness], + indices: &[usize], + points: Vec, + ) -> Evaluations + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>, + { + let num_cols = self.num_cols(); + let stride = witnesses.len() * num_cols; + let mut matrix = vec![M::Source::ZERO; indices.len() * stride]; + let mut matrix_col_offset = 0; + for (coeffs, witness) in coeffs_per_witness.iter().zip(witnesses) { + let submatrix = ntt::partial_interleaved_rs_encode( + coeffs, + self.codeword_length, + self.interleaving_depth, + indices, + ); + debug_assert_eq!(submatrix.len(), indices.len() * num_cols); + for (row, src) in submatrix.chunks_exact(num_cols).enumerate() { + let dst = &mut matrix[row * stride + matrix_col_offset + ..row * stride + matrix_col_offset + num_cols]; + dst.copy_from_slice(src); + } + prover_state.prover_hint_ark(&submatrix); + self.matrix_commit + .open(prover_state, &witness.matrix_witness, indices); + matrix_col_offset += num_cols; + } + Evaluations { points, matrix } + } + /// Shared open logic for [`open`] and [`open_at_indices`]. fn open_inner( &self, diff --git a/src/protocols/whir_zk/committer.rs b/src/protocols/whir_zk/committer.rs index 7b431565..12bdcbc9 100644 --- a/src/protocols/whir_zk/committer.rs +++ b/src/protocols/whir_zk/committer.rs @@ -103,7 +103,14 @@ impl Config { // Step 1b: Commit [[f̂]] via first WHIR instance. let f_hat_refs: Vec<&[F]> = f_hat_polys.iter().map(|p| p.as_slice()).collect(); - let f_hat_witness = self.blinded_polynomial.commit(prover_state, &f_hat_refs); + let mut f_hat_witness = self.blinded_polynomial.commit(prover_state, &f_hat_refs); + + // Drop the encoded codeword; will be re-encoded immediately before each + // open in prove_blinded_polynomial (Steps 4 and 6). This keeps the + // ~codeword_length × interleaving_depth field elements out of the + // resident set during the prepare_and_sumcheck rounds where global peak + // hits. + f_hat_witness.matrix = Vec::new(); // Step 1c: Sample ν + 1 random ℓ-variate blinding polynomials ĝ₀..ĝ_ν. let num_blinding_polys = dims.num_g_polys(); @@ -138,10 +145,17 @@ impl Config { } let blinding_refs: Vec<&[F]> = blinding_vectors.iter().map(|v| v.as_slice()).collect(); - let blinding_poly_witness = self + let mut blinding_poly_witness = self .blinding_polynomial .commit(prover_state, &blinding_refs); + // The encoded codeword is only needed when [[M, ĝ]] is opened in + // Step 7. Until then it is dead weight (held resident through all of + // prove_blinded_polynomial, where global peak hits). Drop the matrix + // here; the prover re-encodes from `secrets.blinding_vectors` just + // before calling `blinding_polynomial.prove`. + blinding_poly_witness.matrix = Vec::new(); + Witness { f_hat_witness, blinding_poly_witness, diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs index 8b449c69..ee782854 100644 --- a/src/protocols/whir_zk/mod.rs +++ b/src/protocols/whir_zk/mod.rs @@ -964,10 +964,11 @@ mod tests { .irs_committer .commit(&mut prover_state, &[&f_zk]); round_config.pow.prove(&mut prover_state); + let f_hat_refs: Vec<&[F]> = f_hat_polys.iter().map(Vec::as_slice).collect(); let in_domain = config .blinded_polynomial .initial_committer - .open(&mut prover_state, &[&f_hat_witness]); + .open_from_coeffs(&mut prover_state, &[&f_hat_refs], &[&f_hat_witness]); let mut lambda_z_points: Vec = Vec::new(); let send_blinding = |ps: &mut ProverState<_, _>, z: F| { @@ -989,7 +990,6 @@ mod tests { send_blinding(&mut prover_state, z); lambda_z_points.push(z); } - drop(f_hat_polys); for &z in &in_domain.points { send_blinding(&mut prover_state, z); lambda_z_points.push(z); @@ -1034,11 +1034,16 @@ mod tests { &round0_folding, ); let gamma_points = remaining.first_in_domain_points; - let _ = config.blinded_polynomial.initial_committer.open_at_indices( - &mut prover_state, - &[&f_hat_witness], - &gamma_to_f_hat_indices(&gamma_points, &config), - ); + let _ = config + .blinded_polynomial + .initial_committer + .open_at_indices_from_coeffs( + &mut prover_state, + &[&f_hat_refs], + &[&f_hat_witness], + &gamma_to_f_hat_indices(&gamma_points, &config), + ); + drop(f_hat_polys); for &gamma in &gamma_points { send_blinding(&mut prover_state, gamma); lambda_z_points.push(gamma); @@ -1068,6 +1073,21 @@ mod tests { .iter() .map(|v| Cow::Borrowed(v.as_slice())) .collect(); + // Re-encode blinding_poly_witness.matrix (cleared at commit time); + // mirrors prover.rs::prove_blinded_polynomial before + // `prove_blinding_polynomial`. + let blinding_refs: Vec<&[F]> = secrets + .blinding_vectors + .iter() + .map(|v| v.as_slice()) + .collect(); + let mut blinding_poly_witness = blinding_poly_witness; + blinding_poly_witness.matrix = crate::algebra::ntt::interleaved_rs_encode( + &blinding_refs, + config.blinding_polynomial.initial_committer.codeword_length, + config.blinding_polynomial.initial_committer.interleaving_depth, + ); + drop(blinding_refs); let _ = config.blinding_polynomial.prove( &mut prover_state, blinding_cows, diff --git a/src/protocols/whir_zk/prover.rs b/src/protocols/whir_zk/prover.rs index a20c23fa..57ec8f36 100644 --- a/src/protocols/whir_zk/prover.rs +++ b/src/protocols/whir_zk/prover.rs @@ -22,7 +22,9 @@ use crate::{ embedding::Identity, geometric_sequence, linear_form::{Covector, Evaluate, LinearForm, UnivariateEvaluation}, - multilinear_extend, univariate_evaluate, MultilinearPoint, + multilinear_extend, + ntt::interleaved_rs_encode, + univariate_evaluate, MultilinearPoint, }, hash::Hash, protocols::{ @@ -119,7 +121,7 @@ where &mut self, vectors: Vec>, g_polys: &[Vec], - linear_forms: &[Box>], + linear_forms: Vec>>, evaluations: &[F], ) -> PrepareResult { let num_vectors = self.dims.num_vectors; @@ -159,7 +161,7 @@ where let g_claims: Vec = { let mut buf = vec![F::ZERO; size]; let mut claims = Vec::with_capacity(linear_forms.len()); - for w in linear_forms { + for w in &linear_forms { buf.fill(F::ZERO); w.accumulate(&mut buf, F::ONE); claims.push(dot(&buf, &g_poly)); @@ -247,6 +249,8 @@ where for (coeff, lf) in constraint_rlc_coeffs.iter().zip(linear_forms.iter()) { lf.accumulate(&mut covector, *coeff); } + // Only the combined `covector` is needed past this point. + drop(linear_forms); let mut the_sum: F = constraint_rlc_coeffs .iter() @@ -312,8 +316,10 @@ where /// Step 5: OOD/STIR queries, STIR constraint accumulation, and remaining WHIR rounds. /// - /// Takes ownership of `f_hat_polys` so it can be freed after OOD evaluations, - /// before the memory-intensive WHIR rounds begin. + /// Borrows `f_hat_polys` so it remains available for the f̂ open in + /// Step 6 (`gamma_check`). The [[f̂]] open uses an output-pruned NTT + /// (`open_from_coeffs`) that materialises only the queried codeword + /// rows, so `f_hat_witness.matrix` stays empty throughout. #[allow(clippy::too_many_arguments)] fn ood_stir_and_rounds( &mut self, @@ -322,7 +328,7 @@ where rho: F, folding_randomness: MultilinearPoint, f_hat_witness: &irs_commit::Witness, - f_hat_polys: Vec>, + f_hat_polys: &[Vec], masking_polys: &[Vec], g_polys: &[Vec], ) -> OodStirResult { @@ -334,11 +340,21 @@ where .irs_committer .commit(self.prover_state, &[state.vector.as_slice()]); round_config.pow.prove(self.prover_state); + + // Open [[f̂]] at in-domain indices via output-pruned NTT: only the + // k = in_domain_samples queried codeword rows are materialised, + // skipping the full Reed-Solomon re-encode and its (num_cols × + // codeword_length) allocation. + let f_hat_refs: Vec<&[F]> = f_hat_polys.iter().map(|p| p.as_slice()).collect(); let in_domain = self .config .blinded_polynomial .initial_committer - .open(self.prover_state, &[f_hat_witness]); + .open_from_coeffs( + self.prover_state, + &[&f_hat_refs], + &[f_hat_witness], + ); let r_bar = folding_randomness.0; let eq_weights = compute_eq_weights(&r_bar); @@ -383,9 +399,9 @@ where lambda_z_points.push(z); } - // Release f̂ data before WHIR rounds. + // Release f̂_combined before WHIR rounds. f_hat_polys is borrowed + // from the caller (still needed for re-encoding in gamma_check). drop(f_hat_combined); - drop(f_hat_polys); // --- STIR responses --- for &z in &in_domain.points { @@ -433,10 +449,13 @@ where /// Step 6: Γ consistency check. /// - /// Opens [[f̂]] at Γ indices and sends blinding evaluations for each γ ∈ Γ. + /// Opens [[f̂]] at Γ indices via `open_at_indices_from_coeffs` (output- + /// pruned NTT) and sends blinding evaluations for each γ ∈ Γ. The + /// codeword matrix is never materialised. fn gamma_check( &mut self, f_hat_witness: &irs_commit::Witness, + f_hat_polys: &[Vec], masking_coeffs_all: &[Vec], g_i_coeffs: &[Vec], gamma_points: &[F], @@ -444,14 +463,20 @@ where ) { let gamma_f_hat_indices = gamma_to_f_hat_indices(gamma_points, self.config); - // Writes [[f̂]] openings at Γ indices to the transcript. - // The verifier uses these to reconstruct fold(r̄, [[f̂]])(γ). - // Return value (Evaluations) is unused: the prover already knows the values. + // Open [[f̂]] at Γ indices via output-pruned NTT: the verifier + // reconstructs fold(r̄, [[f̂]])(γ) from these openings. Return value + // is unused; the prover already knows the values. + let f_hat_refs: Vec<&[F]> = f_hat_polys.iter().map(|p| p.as_slice()).collect(); let _f_hat_openings = self .config .blinded_polynomial .initial_committer - .open_at_indices(self.prover_state, &[f_hat_witness], &gamma_f_hat_indices); + .open_at_indices_from_coeffs( + self.prover_state, + &[&f_hat_refs], + &[f_hat_witness], + &gamma_f_hat_indices, + ); for &gamma in gamma_points { send_blinding_evals(self.prover_state, gamma, masking_coeffs_all, g_i_coeffs); @@ -463,19 +488,21 @@ where impl Config { /// Steps 2-6: Prove the blinded polynomial instance. /// - /// `f_hat_polys` is taken by value and freed during OOD evaluations (Step 5), - /// before the memory-intensive WHIR rounds begin. - /// Other witness fields are borrowed; the caller frees them before Step 7. + /// `f_hat_witness.matrix` is empty on entry (cleared at commit time) + /// and stays empty: both [[f̂]] opens (in `ood_stir_and_rounds` and + /// `gamma_check`) use output-pruned encoding, so the full codeword + /// matrix is never materialised. `f_hat_polys` is borrowed because + /// both opens read coefficients from it. #[allow(clippy::too_many_arguments)] fn prove_blinded_polynomial( &self, prover_state: &mut ProverState, vectors: Vec>, f_hat_witness: &irs_commit::Witness, - f_hat_polys: Vec>, + f_hat_polys: &[Vec], masking_polys: &[Vec], g_polys: &[Vec], - linear_forms: &[Box>], + linear_forms: Vec>>, evaluations: &[F], ) -> BlindedProveResult where @@ -548,6 +575,7 @@ impl Config { ctx.gamma_check( f_hat_witness, + f_hat_polys, &masking_coeffs_all, &g_i_coeffs, &gamma_points, @@ -669,26 +697,49 @@ impl Config { { let Witness { f_hat_witness, - blinding_poly_witness, + mut blinding_poly_witness, f_hat_polys, secrets, } = witness; // Steps 2-6: blinded polynomial proof. + // Both `f_hat_witness.matrix` and `blinding_poly_witness.matrix` are + // empty here (cleared at commit time). The blinded prover re-encodes + // f_hat transiently around each of its two opens; blinding_poly is + // re-encoded just before Step 7 below. This keeps both codewords + // (~codeword_length × interleaving_depth field elements each) out of + // the resident set during the prepare_and_sumcheck rounds where global + // peak hits. let blinded = self.prove_blinded_polynomial( prover_state, vectors, &f_hat_witness, - f_hat_polys, + &f_hat_polys, &secrets.masking_polys, &secrets.g_polys, - &linear_forms, + linear_forms, &evaluations, ); // Free fields only needed during Steps 2-6, before Step 7. drop(f_hat_witness); - drop(linear_forms); + drop(f_hat_polys); + + // Re-encode the [[M, ĝ]] codeword, which was dropped at commit time + // to keep the resident set small through Step 6. + let blinding_refs: Vec<&[F]> = secrets + .blinding_vectors + .iter() + .map(|v| v.as_slice()) + .collect(); + blinding_poly_witness.matrix = interleaved_rs_encode( + &blinding_refs, + self.blinding_polynomial.initial_committer.codeword_length, + self.blinding_polynomial + .initial_committer + .interleaving_depth, + ); + drop(blinding_refs); // Step 7: batched blinding polynomial proof. self.prove_blinding_polynomial(