From e6f9c3eca8659439c718d31958562716a9014441 Mon Sep 17 00:00:00 2001 From: zkfriendly Date: Tue, 2 Jun 2026 13:59:01 +0200 Subject: [PATCH 1/5] wip: introduce buffer ops trait and cpu buffer --- src/algebra/buffer.rs | 460 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 460 insertions(+) create mode 100644 src/algebra/buffer.rs diff --git a/src/algebra/buffer.rs b/src/algebra/buffer.rs new file mode 100644 index 00000000..89cc4807 --- /dev/null +++ b/src/algebra/buffer.rs @@ -0,0 +1,460 @@ +use std::cmp::max; + +use ark_ff::{AdditiveGroup, Field}; + +use crate::algebra::embedding::{Embedding, Identity}; +#[cfg(feature = "parallel")] +use crate::utils::workload_size; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +pub trait BufferOps: Clone { + fn len(&self) -> usize; + fn is_empty(&self) -> bool; + // zero pad to 2**log_m + fn zero_pad(&mut self, log_m: usize); + fn mixed_extend, T: Field>( + &self, + embedding: &M, + point: &[M::Target], + ) -> M::Target; + fn mixed_dot, T: Field>( + &self, + embedding: &M, + other: &impl BufferOps, + ) -> M::Target; + fn dot(&self, other: &Self) -> F; + fn as_slice(&self) -> &[F]; + fn as_ro_buffer(&self, size: usize) -> impl BufferOps; +} + +// read-only buffer ops +pub trait ROBufferOps { + fn split_at(&self, mid: usize) -> (&impl BufferOps, &impl BufferOps); +} + +#[derive(Clone)] +pub struct SliceCpuBuffer<'a, F: Field> { + data: &'a [F], +} + +#[derive(Clone)] +pub struct CpuBuffer { + data: Vec, + len: usize, +} + +impl CpuBuffer { + pub fn from_vec(source: Vec) -> Self { + let len = source.len(); + Self { data: source, len } + } + + pub fn from_slice(source: &[F]) -> Self { + Self { + data: Vec::from(source), + len: source.len(), + } + } +} + +impl BufferOps for CpuBuffer { + fn len(&self) -> usize { + self.len + } + + fn is_empty(&self) -> bool { + self.len == 0 + } + + fn zero_pad(&mut self, log_m: usize) { + if !self.is_empty() { + self.data.resize(1 << log_m, F::ZERO); + } + } + + fn mixed_extend, T: Field>( + &self, + embedding: &M, + point: &[M::Target], + ) -> M::Target { + #[inline] + fn eval_exact( + embedding: &M, + evals: &[M::Source], + point: &[M::Target], + ) -> M::Target { + debug_assert_eq!(evals.len(), 1 << point.len()); + + // Helper to compute (a + (b - a) * c) efficiently with a, b in source field. + let mixed = |a, b, c| embedding.mixed_add(embedding.mixed_mul(c, b - a), a); + + match point { + [] => embedding.map(evals[0]), + [x] => mixed(evals[0], evals[1], *x), + [x0, x1] => { + let a0 = mixed(evals[0], evals[1], *x1); + let a1 = mixed(evals[2], evals[3], *x1); + a0 + (a1 - a0) * *x0 + } + [x0, x1, x2] => { + let a00 = mixed(evals[0], evals[1], *x2); + let a01 = mixed(evals[2], evals[3], *x2); + let a10 = mixed(evals[4], evals[5], *x2); + let a11 = mixed(evals[6], evals[7], *x2); + let a0 = a00 + (a01 - a00) * *x1; + let a1 = a10 + (a11 - a10) * *x1; + a0 + (a1 - a0) * *x0 + } + [x0, x1, x2, x3] => { + let a000 = mixed(evals[0], evals[1], *x3); + let a001 = mixed(evals[2], evals[3], *x3); + let a010 = mixed(evals[4], evals[5], *x3); + let a011 = mixed(evals[6], evals[7], *x3); + let a100 = mixed(evals[8], evals[9], *x3); + let a101 = mixed(evals[10], evals[11], *x3); + let a110 = mixed(evals[12], evals[13], *x3); + let a111 = mixed(evals[14], evals[15], *x3); + let a00 = a000 + (a001 - a000) * *x2; + let a01 = a010 + (a011 - a010) * *x2; + let a10 = a100 + (a101 - a100) * *x2; + let a11 = a110 + (a111 - a110) * *x2; + let a0 = a00 + (a01 - a00) * *x1; + let a1 = a10 + (a11 - a10) * *x1; + a0 + (a1 - a0) * *x0 + } + [x, tail @ ..] => { + let (f0, f1) = evals.split_at(evals.len() / 2); + #[cfg(not(feature = "parallel"))] + let (f0, f1) = ( + eval_exact(embedding, f0, tail), + eval_exact(embedding, f1, tail), + ); + + #[cfg(feature = "parallel")] + let (f0, f1) = { + use crate::utils::workload_size; + if evals.len() > workload_size::() { + rayon::join( + || eval_exact(embedding, f0, tail), + || eval_exact(embedding, f1, tail), + ) + } else { + ( + eval_exact(embedding, f0, tail), + eval_exact(embedding, f1, tail), + ) + } + }; + + f0 + (f1 - f0) * *x + } + } + } + + #[inline] + fn eval_partial( + embedding: &M, + evals: &[M::Source], + point: &[M::Target], + ) -> M::Target { + let size = 1 << point.len(); + debug_assert!(evals.len() <= size); + if evals.is_empty() { + return M::Target::ZERO; + } + if evals.len() == size { + return eval_exact(embedding, evals, point); + } + + match point { + [] => embedding.map(evals[0]), + [x, tail @ ..] => { + let half = size / 2; + + // Only low half has data; high half is all implicit zeros. + if evals.len() <= half { + let f0 = eval_partial(embedding, evals, tail); + return f0 * (M::Target::ONE - *x); + } + + // Low subtree is exact/full, high subtree is partial. + let (low, high) = evals.split_at(half); + + #[cfg(not(feature = "parallel"))] + let (f0, f1) = ( + eval_exact(embedding, low, tail), + eval_partial(embedding, high, tail), + ); + + #[cfg(feature = "parallel")] + let (f0, f1) = { + use crate::utils::workload_size; + if evals.len() > workload_size::() { + rayon::join( + || eval_exact(embedding, low, tail), + || eval_partial(embedding, high, tail), + ) + } else { + ( + eval_exact(embedding, low, tail), + eval_partial(embedding, high, tail), + ) + } + }; + + f0 + (f1 - f0) * *x + } + } + } + + eval_partial(embedding, &self.data, point) + } + + fn mixed_dot, T: Field>( + &self, + embedding: &M, + other: &impl BufferOps, + ) -> M::Target { + assert_eq!(self.len(), other.len()); + + let a = other.as_slice(); + let b = self.as_slice(); + + #[cfg(feature = "parallel")] + if a.len() > workload_size::() { + return a + .par_iter() + .zip(b) + .map(|(a, b)| embedding.mixed_mul(*a, *b)) + .sum(); + } + + a.iter() + .zip(b) + .map(|(a, b)| embedding.mixed_mul(*a, *b)) + .sum() + } + + fn dot(&self, other: &Self) -> F { + self.mixed_dot(&Identity::new(), other) + } + + fn as_slice(&self) -> &[F] { + &self.data[..self.len] + } + + fn as_ro_buffer(&self, size: usize) -> impl BufferOps { + SliceCpuBuffer::from_buffer_with_size(self, size) + } +} + +impl<'a, F: Field> SliceCpuBuffer<'a, F> { + pub fn from_buffer(buffer: &'a CpuBuffer) -> Self { + Self { + data: buffer.as_slice(), + } + } + + pub fn from_buffer_with_size(buffer: &'a CpuBuffer, size: usize) -> Self { + Self { + data: &buffer.data[..max(buffer.len, size)], + } + } + + pub fn from_slice_with_size(slice: &'a &[F], size: usize) -> Self { + assert!(size <= slice.len()); + Self { + data: &slice[..size], + } + } +} + +impl<'a, F: Field> BufferOps for SliceCpuBuffer<'a, F> { + fn len(&self) -> usize { + self.data.len() + } + + fn is_empty(&self) -> bool { + self.data.is_empty() + } + + fn zero_pad(&mut self, log_m: usize) { + panic!("read only") + } + + fn mixed_extend, T: Field>( + &self, + embedding: &M, + point: &[M::Target], + ) -> M::Target { + #[inline] + fn eval_exact( + embedding: &M, + evals: &[M::Source], + point: &[M::Target], + ) -> M::Target { + debug_assert_eq!(evals.len(), 1 << point.len()); + + // Helper to compute (a + (b - a) * c) efficiently with a, b in source field. + let mixed = |a, b, c| embedding.mixed_add(embedding.mixed_mul(c, b - a), a); + + match point { + [] => embedding.map(evals[0]), + [x] => mixed(evals[0], evals[1], *x), + [x0, x1] => { + let a0 = mixed(evals[0], evals[1], *x1); + let a1 = mixed(evals[2], evals[3], *x1); + a0 + (a1 - a0) * *x0 + } + [x0, x1, x2] => { + let a00 = mixed(evals[0], evals[1], *x2); + let a01 = mixed(evals[2], evals[3], *x2); + let a10 = mixed(evals[4], evals[5], *x2); + let a11 = mixed(evals[6], evals[7], *x2); + let a0 = a00 + (a01 - a00) * *x1; + let a1 = a10 + (a11 - a10) * *x1; + a0 + (a1 - a0) * *x0 + } + [x0, x1, x2, x3] => { + let a000 = mixed(evals[0], evals[1], *x3); + let a001 = mixed(evals[2], evals[3], *x3); + let a010 = mixed(evals[4], evals[5], *x3); + let a011 = mixed(evals[6], evals[7], *x3); + let a100 = mixed(evals[8], evals[9], *x3); + let a101 = mixed(evals[10], evals[11], *x3); + let a110 = mixed(evals[12], evals[13], *x3); + let a111 = mixed(evals[14], evals[15], *x3); + let a00 = a000 + (a001 - a000) * *x2; + let a01 = a010 + (a011 - a010) * *x2; + let a10 = a100 + (a101 - a100) * *x2; + let a11 = a110 + (a111 - a110) * *x2; + let a0 = a00 + (a01 - a00) * *x1; + let a1 = a10 + (a11 - a10) * *x1; + a0 + (a1 - a0) * *x0 + } + [x, tail @ ..] => { + let (f0, f1) = evals.split_at(evals.len() / 2); + #[cfg(not(feature = "parallel"))] + let (f0, f1) = ( + eval_exact(embedding, f0, tail), + eval_exact(embedding, f1, tail), + ); + + #[cfg(feature = "parallel")] + let (f0, f1) = { + use crate::utils::workload_size; + if evals.len() > workload_size::() { + rayon::join( + || eval_exact(embedding, f0, tail), + || eval_exact(embedding, f1, tail), + ) + } else { + ( + eval_exact(embedding, f0, tail), + eval_exact(embedding, f1, tail), + ) + } + }; + + f0 + (f1 - f0) * *x + } + } + } + + #[inline] + fn eval_partial( + embedding: &M, + evals: &[M::Source], + point: &[M::Target], + ) -> M::Target { + let size = 1 << point.len(); + debug_assert!(evals.len() <= size); + if evals.is_empty() { + return M::Target::ZERO; + } + if evals.len() == size { + return eval_exact(embedding, evals, point); + } + + match point { + [] => embedding.map(evals[0]), + [x, tail @ ..] => { + let half = size / 2; + + // Only low half has data; high half is all implicit zeros. + if evals.len() <= half { + let f0 = eval_partial(embedding, evals, tail); + return f0 * (M::Target::ONE - *x); + } + + // Low subtree is exact/full, high subtree is partial. + let (low, high) = evals.split_at(half); + + #[cfg(not(feature = "parallel"))] + let (f0, f1) = ( + eval_exact(embedding, low, tail), + eval_partial(embedding, high, tail), + ); + + #[cfg(feature = "parallel")] + let (f0, f1) = { + use crate::utils::workload_size; + if evals.len() > workload_size::() { + rayon::join( + || eval_exact(embedding, low, tail), + || eval_partial(embedding, high, tail), + ) + } else { + ( + eval_exact(embedding, low, tail), + eval_partial(embedding, high, tail), + ) + } + }; + + f0 + (f1 - f0) * *x + } + } + } + + eval_partial(embedding, &self.data, point) + } + + fn mixed_dot, T: Field>( + &self, + embedding: &M, + other: &impl BufferOps, + ) -> M::Target { + assert_eq!(self.len(), other.len()); + + let a = other.as_slice(); + let b = self.as_slice(); + + #[cfg(feature = "parallel")] + if a.len() > workload_size::() { + return a + .par_iter() + .zip(b) + .map(|(a, b)| embedding.mixed_mul(*a, *b)) + .sum(); + } + + a.iter() + .zip(b) + .map(|(a, b)| embedding.mixed_mul(*a, *b)) + .sum() + } + + fn dot(&self, other: &Self) -> F { + self.mixed_dot(&Identity::new(), other) + } + + fn as_slice(&self) -> &[F] { + &self.data + } + + fn as_ro_buffer(&self, size: usize) -> impl BufferOps { + SliceCpuBuffer::from_slice_with_size(&self.data, size) + } +} From 3d284a980909eb67d00cf188eec631cb2da06a36 Mon Sep 17 00:00:00 2001 From: zkfriendly Date: Wed, 3 Jun 2026 12:11:06 +0200 Subject: [PATCH 2/5] feat: buffer abstraction with basic interface and cpu buffer implementation --- src/algebra/buffer.rs | 607 +++++++++++++++++------------------------- 1 file changed, 241 insertions(+), 366 deletions(-) diff --git a/src/algebra/buffer.rs b/src/algebra/buffer.rs index 89cc4807..a38a95a4 100644 --- a/src/algebra/buffer.rs +++ b/src/algebra/buffer.rs @@ -1,18 +1,35 @@ -use std::cmp::max; - -use ark_ff::{AdditiveGroup, Field}; - -use crate::algebra::embedding::{Embedding, Identity}; -#[cfg(feature = "parallel")] -use crate::utils::workload_size; -#[cfg(feature = "parallel")] -use rayon::prelude::*; +use ark_ff::Field; +use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, Rng, RngCore}; +use spongefish::DuplexSpongeInterface; + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + mixed_dot, mixed_multilinear_extend, mixed_univariate_evaluate, ntt, + }, + hash::Hash, + protocols::matrix_commit, + transcript::{ProverMessage, ProverState}, + type_info::TypeInfo, + utils::chunks_exact_or_empty, +}; pub trait BufferOps: Clone { + type Buffer: BufferOps; + type Matrix: MatrixBufferOps; + fn len(&self) -> usize; - fn is_empty(&self) -> bool; - // zero pad to 2**log_m - fn zero_pad(&mut self, log_m: usize); + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn random(rng: &mut R, length: usize) -> Self + where + R: RngCore + CryptoRng, + Standard: Distribution; + + fn zero_pad(&mut self); fn mixed_extend, T: Field>( &self, embedding: &M, @@ -21,55 +38,179 @@ pub trait BufferOps: Clone { fn mixed_dot, T: Field>( &self, embedding: &M, - other: &impl BufferOps, + other: &Self::Buffer, + ) -> M::Target; + fn mixed_univariate_evaluate>( + &self, + embedding: &M, + point: M::Target, ) -> M::Target; + fn interleaved_rs_encode( + vectors: &[&Self], + masks: &Self, + message_length: usize, + interleaving_depth: usize, + codeword_length: usize, + ) -> Self::Matrix + where + F: 'static; fn dot(&self, other: &Self) -> F; - fn as_slice(&self) -> &[F]; - fn as_ro_buffer(&self, size: usize) -> impl BufferOps; } -// read-only buffer ops -pub trait ROBufferOps { - fn split_at(&self, mid: usize) -> (&impl BufferOps, &impl BufferOps); +fn interleaved_rs_encode_slices( + vectors: &[&[F]], + masks: &[F], + message_length: usize, + interleaving_depth: usize, + codeword_length: usize, +) -> CpuMatrix { + let messages = vectors + .iter() + .flat_map(|v| chunks_exact_or_empty(v, message_length, interleaving_depth)) + .collect::>(); + CpuMatrix::from_vec( + ntt::interleaved_rs_encode(&messages, masks, codeword_length), + codeword_length, + vectors.len() * interleaving_depth, + ) } -#[derive(Clone)] -pub struct SliceCpuBuffer<'a, F: Field> { - data: &'a [F], +pub trait MatrixBufferOps { + fn len(&self) -> usize; + fn num_rows(&self) -> usize; + fn num_cols(&self) -> usize; + + fn commit_rows( + &self, + config: &matrix_commit::Config, + prover_state: &mut ProverState, + ) -> matrix_commit::Witness + where + F: TypeInfo + matrix_commit::Encodable + Send + Sync, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>; + + fn read_rows(&self, indices: &[usize]) -> Vec; +} + +#[derive( + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Debug, + Default, + serde::Serialize, + serde::Deserialize, +)] +pub struct CpuMatrix { + data: Vec, + num_rows: usize, + num_cols: usize, +} + +impl CpuMatrix { + pub fn from_vec(data: Vec, num_rows: usize, num_cols: usize) -> Self { + assert_eq!(data.len(), num_rows * num_cols); + Self { + data, + num_rows, + num_cols, + } + } + + fn row(&self, row: usize) -> &[F] { + let start = row * self.num_cols; + let end = start + self.num_cols; + &self.data[start..end] + } } +impl MatrixBufferOps for CpuMatrix +where + F: Field + TypeInfo + matrix_commit::Encodable + Send + Sync, +{ + fn len(&self) -> usize { + self.data.len() + } + + fn num_rows(&self) -> usize { + self.num_rows + } + + fn num_cols(&self) -> usize { + self.num_cols + } + + fn commit_rows( + &self, + config: &matrix_commit::Config, + prover_state: &mut ProverState, + ) -> matrix_commit::Witness + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>, + { + assert_eq!(config.num_rows(), self.num_rows); + assert_eq!(config.num_cols, self.num_cols); + config.commit(prover_state, &self.data) + } + + fn read_rows(&self, indices: &[usize]) -> Vec { + let mut rows = Vec::with_capacity(indices.len() * self.num_cols); + for &index in indices { + assert!(index < self.num_rows); + rows.extend_from_slice(self.row(index)); + } + rows + } +} #[derive(Clone)] pub struct CpuBuffer { data: Vec, - len: usize, } impl CpuBuffer { pub fn from_vec(source: Vec) -> Self { - let len = source.len(); - Self { data: source, len } + Self { data: source } } pub fn from_slice(source: &[F]) -> Self { Self { data: Vec::from(source), - len: source.len(), } } + + fn as_slice(&self) -> &[F] { + self.data.as_slice() + } } impl BufferOps for CpuBuffer { + type Buffer = CpuBuffer; + type Matrix = CpuMatrix; + fn len(&self) -> usize { - self.len + self.data.len() } - fn is_empty(&self) -> bool { - self.len == 0 + fn random(rng: &mut R, length: usize) -> Self + where + R: RngCore + CryptoRng, + Standard: Distribution, + { + Self { + data: (0..length).map(|_| rng.gen()).collect(), + } } - fn zero_pad(&mut self, log_m: usize) { + fn zero_pad(&mut self) { if !self.is_empty() { - self.data.resize(1 << log_m, F::ZERO); + self.data.resize(self.len().next_power_of_two(), F::ZERO); } } @@ -78,209 +219,70 @@ impl BufferOps for CpuBuffer { embedding: &M, point: &[M::Target], ) -> M::Target { - #[inline] - fn eval_exact( - embedding: &M, - evals: &[M::Source], - point: &[M::Target], - ) -> M::Target { - debug_assert_eq!(evals.len(), 1 << point.len()); - - // Helper to compute (a + (b - a) * c) efficiently with a, b in source field. - let mixed = |a, b, c| embedding.mixed_add(embedding.mixed_mul(c, b - a), a); - - match point { - [] => embedding.map(evals[0]), - [x] => mixed(evals[0], evals[1], *x), - [x0, x1] => { - let a0 = mixed(evals[0], evals[1], *x1); - let a1 = mixed(evals[2], evals[3], *x1); - a0 + (a1 - a0) * *x0 - } - [x0, x1, x2] => { - let a00 = mixed(evals[0], evals[1], *x2); - let a01 = mixed(evals[2], evals[3], *x2); - let a10 = mixed(evals[4], evals[5], *x2); - let a11 = mixed(evals[6], evals[7], *x2); - let a0 = a00 + (a01 - a00) * *x1; - let a1 = a10 + (a11 - a10) * *x1; - a0 + (a1 - a0) * *x0 - } - [x0, x1, x2, x3] => { - let a000 = mixed(evals[0], evals[1], *x3); - let a001 = mixed(evals[2], evals[3], *x3); - let a010 = mixed(evals[4], evals[5], *x3); - let a011 = mixed(evals[6], evals[7], *x3); - let a100 = mixed(evals[8], evals[9], *x3); - let a101 = mixed(evals[10], evals[11], *x3); - let a110 = mixed(evals[12], evals[13], *x3); - let a111 = mixed(evals[14], evals[15], *x3); - let a00 = a000 + (a001 - a000) * *x2; - let a01 = a010 + (a011 - a010) * *x2; - let a10 = a100 + (a101 - a100) * *x2; - let a11 = a110 + (a111 - a110) * *x2; - let a0 = a00 + (a01 - a00) * *x1; - let a1 = a10 + (a11 - a10) * *x1; - a0 + (a1 - a0) * *x0 - } - [x, tail @ ..] => { - let (f0, f1) = evals.split_at(evals.len() / 2); - #[cfg(not(feature = "parallel"))] - let (f0, f1) = ( - eval_exact(embedding, f0, tail), - eval_exact(embedding, f1, tail), - ); - - #[cfg(feature = "parallel")] - let (f0, f1) = { - use crate::utils::workload_size; - if evals.len() > workload_size::() { - rayon::join( - || eval_exact(embedding, f0, tail), - || eval_exact(embedding, f1, tail), - ) - } else { - ( - eval_exact(embedding, f0, tail), - eval_exact(embedding, f1, tail), - ) - } - }; - - f0 + (f1 - f0) * *x - } - } - } - - #[inline] - fn eval_partial( - embedding: &M, - evals: &[M::Source], - point: &[M::Target], - ) -> M::Target { - let size = 1 << point.len(); - debug_assert!(evals.len() <= size); - if evals.is_empty() { - return M::Target::ZERO; - } - if evals.len() == size { - return eval_exact(embedding, evals, point); - } - - match point { - [] => embedding.map(evals[0]), - [x, tail @ ..] => { - let half = size / 2; - - // Only low half has data; high half is all implicit zeros. - if evals.len() <= half { - let f0 = eval_partial(embedding, evals, tail); - return f0 * (M::Target::ONE - *x); - } - - // Low subtree is exact/full, high subtree is partial. - let (low, high) = evals.split_at(half); - - #[cfg(not(feature = "parallel"))] - let (f0, f1) = ( - eval_exact(embedding, low, tail), - eval_partial(embedding, high, tail), - ); - - #[cfg(feature = "parallel")] - let (f0, f1) = { - use crate::utils::workload_size; - if evals.len() > workload_size::() { - rayon::join( - || eval_exact(embedding, low, tail), - || eval_partial(embedding, high, tail), - ) - } else { - ( - eval_exact(embedding, low, tail), - eval_partial(embedding, high, tail), - ) - } - }; - - f0 + (f1 - f0) * *x - } - } - } - - eval_partial(embedding, &self.data, point) + mixed_multilinear_extend(embedding, &self.data, point) } fn mixed_dot, T: Field>( &self, embedding: &M, - other: &impl BufferOps, + other: &Self::Buffer, ) -> M::Target { - assert_eq!(self.len(), other.len()); - - let a = other.as_slice(); - let b = self.as_slice(); - - #[cfg(feature = "parallel")] - if a.len() > workload_size::() { - return a - .par_iter() - .zip(b) - .map(|(a, b)| embedding.mixed_mul(*a, *b)) - .sum(); - } - - a.iter() - .zip(b) - .map(|(a, b)| embedding.mixed_mul(*a, *b)) - .sum() + mixed_dot(embedding, other.as_slice(), self.as_slice()) } fn dot(&self, other: &Self) -> F { self.mixed_dot(&Identity::new(), other) } - fn as_slice(&self) -> &[F] { - &self.data[..self.len] + fn mixed_univariate_evaluate>( + &self, + embedding: &M, + point: M::Target, + ) -> M::Target { + mixed_univariate_evaluate(embedding, &self.data, point) } - fn as_ro_buffer(&self, size: usize) -> impl BufferOps { - SliceCpuBuffer::from_buffer_with_size(self, size) + fn interleaved_rs_encode( + vectors: &[&Self], + masks: &Self, + message_length: usize, + interleaving_depth: usize, + codeword_length: usize, + ) -> Self::Matrix + where + F: 'static, + { + let vectors = vectors.iter().map(|v| v.as_slice()).collect::>(); + interleaved_rs_encode_slices( + &vectors, + masks.as_slice(), + message_length, + interleaving_depth, + codeword_length, + ) } } -impl<'a, F: Field> SliceCpuBuffer<'a, F> { - pub fn from_buffer(buffer: &'a CpuBuffer) -> Self { - Self { - data: buffer.as_slice(), - } - } +impl BufferOps for Vec { + type Buffer = Vec; + type Matrix = CpuMatrix; - pub fn from_buffer_with_size(buffer: &'a CpuBuffer, size: usize) -> Self { - Self { - data: &buffer.data[..max(buffer.len, size)], - } - } - - pub fn from_slice_with_size(slice: &'a &[F], size: usize) -> Self { - assert!(size <= slice.len()); - Self { - data: &slice[..size], - } - } -} - -impl<'a, F: Field> BufferOps for SliceCpuBuffer<'a, F> { fn len(&self) -> usize { - self.data.len() + self.len() } - fn is_empty(&self) -> bool { - self.data.is_empty() + fn random(rng: &mut R, length: usize) -> Self + where + R: RngCore + CryptoRng, + Standard: Distribution, + { + (0..length).map(|_| rng.gen()).collect() } - fn zero_pad(&mut self, log_m: usize) { - panic!("read only") + fn zero_pad(&mut self) { + if !self.is_empty() { + self.resize(self.len().next_power_of_two(), F::ZERO); + } } fn mixed_extend, T: Field>( @@ -288,173 +290,46 @@ impl<'a, F: Field> BufferOps for SliceCpuBuffer<'a, F> { embedding: &M, point: &[M::Target], ) -> M::Target { - #[inline] - fn eval_exact( - embedding: &M, - evals: &[M::Source], - point: &[M::Target], - ) -> M::Target { - debug_assert_eq!(evals.len(), 1 << point.len()); - - // Helper to compute (a + (b - a) * c) efficiently with a, b in source field. - let mixed = |a, b, c| embedding.mixed_add(embedding.mixed_mul(c, b - a), a); - - match point { - [] => embedding.map(evals[0]), - [x] => mixed(evals[0], evals[1], *x), - [x0, x1] => { - let a0 = mixed(evals[0], evals[1], *x1); - let a1 = mixed(evals[2], evals[3], *x1); - a0 + (a1 - a0) * *x0 - } - [x0, x1, x2] => { - let a00 = mixed(evals[0], evals[1], *x2); - let a01 = mixed(evals[2], evals[3], *x2); - let a10 = mixed(evals[4], evals[5], *x2); - let a11 = mixed(evals[6], evals[7], *x2); - let a0 = a00 + (a01 - a00) * *x1; - let a1 = a10 + (a11 - a10) * *x1; - a0 + (a1 - a0) * *x0 - } - [x0, x1, x2, x3] => { - let a000 = mixed(evals[0], evals[1], *x3); - let a001 = mixed(evals[2], evals[3], *x3); - let a010 = mixed(evals[4], evals[5], *x3); - let a011 = mixed(evals[6], evals[7], *x3); - let a100 = mixed(evals[8], evals[9], *x3); - let a101 = mixed(evals[10], evals[11], *x3); - let a110 = mixed(evals[12], evals[13], *x3); - let a111 = mixed(evals[14], evals[15], *x3); - let a00 = a000 + (a001 - a000) * *x2; - let a01 = a010 + (a011 - a010) * *x2; - let a10 = a100 + (a101 - a100) * *x2; - let a11 = a110 + (a111 - a110) * *x2; - let a0 = a00 + (a01 - a00) * *x1; - let a1 = a10 + (a11 - a10) * *x1; - a0 + (a1 - a0) * *x0 - } - [x, tail @ ..] => { - let (f0, f1) = evals.split_at(evals.len() / 2); - #[cfg(not(feature = "parallel"))] - let (f0, f1) = ( - eval_exact(embedding, f0, tail), - eval_exact(embedding, f1, tail), - ); - - #[cfg(feature = "parallel")] - let (f0, f1) = { - use crate::utils::workload_size; - if evals.len() > workload_size::() { - rayon::join( - || eval_exact(embedding, f0, tail), - || eval_exact(embedding, f1, tail), - ) - } else { - ( - eval_exact(embedding, f0, tail), - eval_exact(embedding, f1, tail), - ) - } - }; - - f0 + (f1 - f0) * *x - } - } - } - - #[inline] - fn eval_partial( - embedding: &M, - evals: &[M::Source], - point: &[M::Target], - ) -> M::Target { - let size = 1 << point.len(); - debug_assert!(evals.len() <= size); - if evals.is_empty() { - return M::Target::ZERO; - } - if evals.len() == size { - return eval_exact(embedding, evals, point); - } - - match point { - [] => embedding.map(evals[0]), - [x, tail @ ..] => { - let half = size / 2; - - // Only low half has data; high half is all implicit zeros. - if evals.len() <= half { - let f0 = eval_partial(embedding, evals, tail); - return f0 * (M::Target::ONE - *x); - } - - // Low subtree is exact/full, high subtree is partial. - let (low, high) = evals.split_at(half); - - #[cfg(not(feature = "parallel"))] - let (f0, f1) = ( - eval_exact(embedding, low, tail), - eval_partial(embedding, high, tail), - ); - - #[cfg(feature = "parallel")] - let (f0, f1) = { - use crate::utils::workload_size; - if evals.len() > workload_size::() { - rayon::join( - || eval_exact(embedding, low, tail), - || eval_partial(embedding, high, tail), - ) - } else { - ( - eval_exact(embedding, low, tail), - eval_partial(embedding, high, tail), - ) - } - }; - - f0 + (f1 - f0) * *x - } - } - } - - eval_partial(embedding, &self.data, point) + mixed_multilinear_extend(embedding, self, point) } fn mixed_dot, T: Field>( &self, embedding: &M, - other: &impl BufferOps, + other: &Self::Buffer, ) -> M::Target { - assert_eq!(self.len(), other.len()); - - let a = other.as_slice(); - let b = self.as_slice(); - - #[cfg(feature = "parallel")] - if a.len() > workload_size::() { - return a - .par_iter() - .zip(b) - .map(|(a, b)| embedding.mixed_mul(*a, *b)) - .sum(); - } - - a.iter() - .zip(b) - .map(|(a, b)| embedding.mixed_mul(*a, *b)) - .sum() + mixed_dot(embedding, other, self) } - fn dot(&self, other: &Self) -> F { - self.mixed_dot(&Identity::new(), other) + fn mixed_univariate_evaluate>( + &self, + embedding: &M, + point: M::Target, + ) -> M::Target { + mixed_univariate_evaluate(embedding, self, point) } - fn as_slice(&self) -> &[F] { - &self.data + fn interleaved_rs_encode( + vectors: &[&Self], + masks: &Self, + message_length: usize, + interleaving_depth: usize, + codeword_length: usize, + ) -> Self::Matrix + where + F: 'static, + { + let vectors = vectors.iter().map(|v| v.as_slice()).collect::>(); + interleaved_rs_encode_slices( + &vectors, + masks, + message_length, + interleaving_depth, + codeword_length, + ) } - fn as_ro_buffer(&self, size: usize) -> impl BufferOps { - SliceCpuBuffer::from_slice_with_size(&self.data, size) + fn dot(&self, other: &Self) -> F { + self.mixed_dot(&Identity::new(), other) } } From 7f4dd420095534285a4c3a7b409b6ca695dd7eb2 Mon Sep 17 00:00:00 2001 From: zkfriendly Date: Wed, 3 Jun 2026 12:38:49 +0200 Subject: [PATCH 3/5] feat: wire irs commit path to buffer --- src/algebra/buffer.rs | 152 ++++++++++++----------------- src/algebra/mod.rs | 1 + src/bin/benchmark.rs | 7 +- src/bin/main.rs | 7 +- src/protocols/basecase.rs | 18 ++-- src/protocols/code_switch.rs | 15 ++- src/protocols/irs_commit.rs | 105 ++++++++++++-------- src/protocols/mask_proximity.rs | 38 +++++--- src/protocols/whir/mod.rs | 43 +++++--- src/protocols/whir/prover.rs | 6 +- src/protocols/whir_zk/committer.rs | 12 ++- src/protocols/whir_zk/mod.rs | 25 ++++- 12 files changed, 245 insertions(+), 184 deletions(-) diff --git a/src/algebra/buffer.rs b/src/algebra/buffer.rs index a38a95a4..e282707a 100644 --- a/src/algebra/buffer.rs +++ b/src/algebra/buffer.rs @@ -57,25 +57,9 @@ pub trait BufferOps: Clone { fn dot(&self, other: &Self) -> F; } -fn interleaved_rs_encode_slices( - vectors: &[&[F]], - masks: &[F], - message_length: usize, - interleaving_depth: usize, - codeword_length: usize, -) -> CpuMatrix { - let messages = vectors - .iter() - .flat_map(|v| chunks_exact_or_empty(v, message_length, interleaving_depth)) - .collect::>(); - CpuMatrix::from_vec( - ntt::interleaved_rs_encode(&messages, masks, codeword_length), - codeword_length, - vectors.len() * interleaving_depth, - ) -} - pub trait MatrixBufferOps { + type Witness; + fn len(&self) -> usize; fn num_rows(&self) -> usize; fn num_cols(&self) -> usize; @@ -84,13 +68,25 @@ pub trait MatrixBufferOps { &self, config: &matrix_commit::Config, prover_state: &mut ProverState, - ) -> matrix_commit::Witness + ) -> Self::Witness where F: TypeInfo + matrix_commit::Encodable + Send + Sync, H: DuplexSpongeInterface, R: RngCore + CryptoRng, Hash: ProverMessage<[H::U]>; + fn open_rows( + &self, + config: &matrix_commit::Config, + prover_state: &mut ProverState, + witness: &Self::Witness, + indices: &[usize], + ) where + F: TypeInfo + matrix_commit::Encodable + Send + Sync, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>; + fn read_rows(&self, indices: &[usize]) -> Vec; } @@ -133,6 +129,8 @@ impl MatrixBufferOps for CpuMatrix where F: Field + TypeInfo + matrix_commit::Encodable + Send + Sync, { + type Witness = matrix_commit::Witness; + fn len(&self) -> usize { self.data.len() } @@ -149,7 +147,7 @@ where &self, config: &matrix_commit::Config, prover_state: &mut ProverState, - ) -> matrix_commit::Witness + ) -> Self::Witness where H: DuplexSpongeInterface, R: RngCore + CryptoRng, @@ -160,6 +158,20 @@ where config.commit(prover_state, &self.data) } + fn open_rows( + &self, + config: &matrix_commit::Config, + prover_state: &mut ProverState, + witness: &Self::Witness, + indices: &[usize], + ) where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>, + { + config.open(prover_state, witness, indices); + } + fn read_rows(&self, indices: &[usize]) -> Vec { let mut rows = Vec::with_capacity(indices.len() * self.num_cols); for &index in indices { @@ -169,7 +181,18 @@ where rows } } -#[derive(Clone)] +#[derive( + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Debug, + Default, + serde::Serialize, + serde::Deserialize, +)] pub struct CpuBuffer { data: Vec, } @@ -185,7 +208,7 @@ impl CpuBuffer { } } - fn as_slice(&self) -> &[F] { + pub(crate) fn as_slice(&self) -> &[F] { self.data.as_slice() } } @@ -263,73 +286,20 @@ impl BufferOps for CpuBuffer { } } -impl BufferOps for Vec { - type Buffer = Vec; - type Matrix = CpuMatrix; - - fn len(&self) -> usize { - self.len() - } - - fn random(rng: &mut R, length: usize) -> Self - where - R: RngCore + CryptoRng, - Standard: Distribution, - { - (0..length).map(|_| rng.gen()).collect() - } - - fn zero_pad(&mut self) { - if !self.is_empty() { - self.resize(self.len().next_power_of_two(), F::ZERO); - } - } - - fn mixed_extend, T: Field>( - &self, - embedding: &M, - point: &[M::Target], - ) -> M::Target { - mixed_multilinear_extend(embedding, self, point) - } - - fn mixed_dot, T: Field>( - &self, - embedding: &M, - other: &Self::Buffer, - ) -> M::Target { - mixed_dot(embedding, other, self) - } - - fn mixed_univariate_evaluate>( - &self, - embedding: &M, - point: M::Target, - ) -> M::Target { - mixed_univariate_evaluate(embedding, self, point) - } - - fn interleaved_rs_encode( - vectors: &[&Self], - masks: &Self, - message_length: usize, - interleaving_depth: usize, - codeword_length: usize, - ) -> Self::Matrix - where - F: 'static, - { - let vectors = vectors.iter().map(|v| v.as_slice()).collect::>(); - interleaved_rs_encode_slices( - &vectors, - masks, - message_length, - interleaving_depth, - codeword_length, - ) - } - - fn dot(&self, other: &Self) -> F { - self.mixed_dot(&Identity::new(), other) - } +fn interleaved_rs_encode_slices( + vectors: &[&[F]], + masks: &[F], + message_length: usize, + interleaving_depth: usize, + codeword_length: usize, +) -> CpuMatrix { + let messages = vectors + .iter() + .flat_map(|v| chunks_exact_or_empty(v, message_length, interleaving_depth)) + .collect::>(); + CpuMatrix::from_vec( + ntt::interleaved_rs_encode(&messages, masks, codeword_length), + codeword_length, + vectors.len() * interleaving_depth, + ) } diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs index 2c7e4838..4d17dc3b 100644 --- a/src/algebra/mod.rs +++ b/src/algebra/mod.rs @@ -1,3 +1,4 @@ +pub mod buffer; pub mod embedding; pub mod fields; pub mod linear_form; diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index d9e79ce6..b8d4eeab 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -10,6 +10,7 @@ use clap::Parser; use serde::Serialize; use whir::{ algebra::{ + buffer::CpuBuffer, embedding::{Basefield, Embedding, Identity}, fields::{Field128, Field192, Field256, Field64, Field64_2, Field64_3}, linear_form::{Evaluate, LinearForm, MultilinearExtension}, @@ -161,7 +162,8 @@ where HASH_COUNTER.reset(); - let witness = params.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let _ = params.prove( &mut prover_state, @@ -238,7 +240,8 @@ where HASH_COUNTER.reset(); let whir_prover_time = Instant::now(); - let witness = params.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let prove_linear_forms: Vec>> = points .iter() diff --git a/src/bin/main.rs b/src/bin/main.rs index bafe16f5..a69ba1a0 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -6,6 +6,7 @@ use ark_std::rand::distributions::{Distribution, Standard}; use clap::Parser; use whir::{ algebra::{ + buffer::CpuBuffer, embedding::{Basefield, Embedding, Identity}, fields::{Field128, Field192, Field256, Field64, Field64_2, Field64_3}, linear_form::{Covector, Evaluate, LinearForm, MultilinearExtension}, @@ -148,9 +149,10 @@ where } let vector = (0..num_coeffs).map(M::Source::from).collect::>(); + let vector_buffer = CpuBuffer::from_slice(&vector); let whir_commit_time = Instant::now(); - let witness = params.commit(&mut prover_state, &[&vector]); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let whir_commit_time = whir_commit_time.elapsed(); // Allocate constraints @@ -314,7 +316,8 @@ where } let whir_commit_time = Instant::now(); - let witness = params.commit(&mut prover_state, &[vector.as_slice()]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let whir_commit_time = whir_commit_time.elapsed(); let whir_prove_time = Instant::now(); diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 32763271..eac97802 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -11,8 +11,8 @@ use spongefish::{Decoding, VerificationResult}; use crate::{ algebra::{ - dot, embedding::Identity, multilinear_extend, random_vector, scalar_mul_add_new, - univariate_evaluate, + buffer::CpuBuffer, dot, embedding::Identity, multilinear_extend, random_vector, + scalar_mul_add_new, univariate_evaluate, }, hash::Hash, protocols::{irs_commit, sumcheck}, @@ -79,7 +79,7 @@ impl Config { // Even more trivial non-zk protocol: send f and r directly. if !self.masked { prover_state.prover_messages(&vector); - prover_state.prover_messages(&witness.masks); + prover_state.prover_messages(witness.masks.as_slice()); let _ = self.commit.open(prover_state, &[witness]); let point = self .sumcheck @@ -94,9 +94,10 @@ impl Config { // Create masking vector. let mask = random_vector(prover_state.rng(), vector.len()); + let mask_buffer = CpuBuffer::from_slice(&mask); // Commit to the masking vector. - let mask_witness = self.commit.commit(prover_state, &[&mask]); + let mask_witness = self.commit.commit(prover_state, &[&mask_buffer]); // Compute and send linear form of mask (μ' in paper). let mask_sum = dot(&mask, &covector); @@ -109,7 +110,11 @@ impl Config { prover_state.prover_messages(&masked_vector); // Send combined IRS randomness. (r^* in paper) - let masked_masks = scalar_mul_add_new(&mask_witness.masks, mask_rlc, &witness.masks); + let masked_masks = scalar_mul_add_new( + mask_witness.masks.as_slice(), + mask_rlc, + witness.masks.as_slice(), + ); prover_state.prover_messages(&masked_masks); // Open the commitment and mask simultaneously. @@ -284,7 +289,8 @@ mod tests { // Prover let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = config.commit.commit(&mut prover_state, &[&vector_buffer]); let prover_result = config.prove( &mut prover_state, vector.clone(), diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 62bbb602..9da26424 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -13,6 +13,7 @@ use tracing::instrument; use crate::{ algebra::{ + buffer::CpuBuffer, dot, embedding::{Embedding, Identity}, eq_weights, geometric_accumulate, lift, mixed_dot, scalar_mul, univariate_evaluate, @@ -195,7 +196,8 @@ impl Config { }; // Step 1: g := Enc_{C'}(f, r') — Construction 9.7 Step 1, p.55 - let target_witness = self.target.commit(prover_state, &[&message]); + let message_buffer = CpuBuffer::from_slice(&message); + let target_witness = self.target.commit(prover_state, &[&message_buffer]); // Step 2-3: OOD challenge + answers — Construction 9.7 Steps 2-3, p.55 // y := ze_ood(ρ) · [f; r; s] = f(α) + α^ℓ · (r,s)(α) @@ -484,7 +486,7 @@ mod tests { } // Lift ι parallel masks (total length source.mask_length × ι) and fold // chunks of length source.mask_length down to a single chunk. - let raw = lift(config.source.embedding(), &source_witness.masks); + let raw = lift(config.source.embedding(), source_witness.masks.as_slice()); let mut mask = fold_chunks(&raw, config.source.mask_length, folding_randomness); // Append fresh padding s of length message_mask_length - source.mask_length. mask.extend(random_vector::( @@ -521,7 +523,8 @@ mod tests { .session(&format!("Test at {}:{}", file!(), line!())) .instance(&instance); let mut prover_state = ProverState::new_std(&ds); - let source_witness = config.source.commit(&mut prover_state, &[&f_full]); + let f_full_buffer = CpuBuffer::from_slice(&f_full); + let source_witness = config.source.commit(&mut prover_state, &[&f_full_buffer]); // Sample γ for sumcheck folding (length log2(ι)). let folding_randomness = sample_folding_randomness(config, &mut rng); @@ -574,7 +577,8 @@ mod tests { .session(&format!("Test at {}:{}", file!(), line!())) .instance(&instance); let mut prover_state = ProverState::new_std(&ds); - let source_witness = config.source.commit(&mut prover_state, &[&f_full]); + let f_full_buffer = CpuBuffer::from_slice(&f_full); + let source_witness = config.source.commit(&mut prover_state, &[&f_full_buffer]); let folding_randomness = sample_folding_randomness(config, &mut rng); let folded_message = @@ -642,7 +646,8 @@ mod tests { // Commit honest f_full, fold to get the honest post-fold message. let mut prover_state = ProverState::new_std(&ds); - let source_witness = config.source.commit(&mut prover_state, &[&f_full]); + let f_full_buffer = CpuBuffer::from_slice(&f_full); + let source_witness = config.source.commit(&mut prover_state, &[&f_full_buffer]); let folding_randomness = sample_folding_randomness(config, &mut rng); let folded_message = fold_chunks(&f_full, config.source.message_length(), &folding_randomness); diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index b5716f3a..1d9e20a8 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -20,6 +20,7 @@ use std::{ f64::{self, consts::LOG2_10}, fmt, + marker::PhantomData, ops::Neg, }; @@ -32,8 +33,13 @@ use tracing::instrument; use crate::{ algebra::{ - dot, embedding::Embedding, fields::FieldWithSize, lift, linear_form::UnivariateEvaluation, - mixed_univariate_evaluate, ntt, random_vector, + buffer::{BufferOps, CpuBuffer, CpuMatrix, MatrixBufferOps}, + dot, + embedding::Embedding, + fields::FieldWithSize, + lift, + linear_form::UnivariateEvaluation, + ntt, }, engines::EngineId, hash::Hash, @@ -43,7 +49,7 @@ use crate::{ VerifierMessage, VerifierState, }, type_info::Typed, - utils::{chunks_exact_or_empty, zip_strict}, + utils::zip_strict, verify, }; @@ -92,14 +98,17 @@ pub struct Config { #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] #[must_use] -pub struct Witness +pub struct Witness, Matrix = CpuMatrix> where G: Field, + Masks: BufferOps, + Matrix: MatrixBufferOps, { - pub masks: Vec, - pub matrix: Vec, - pub matrix_witness: matrix_commit::Witness, + pub masks: Masks, + pub matrix: Matrix, + pub matrix_witness: Matrix::Witness, pub out_of_domain: Evaluations, + source_field: PhantomData, } #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] @@ -289,12 +298,15 @@ impl Config { /// Commit to one or more vectors. #[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))] - pub fn commit( + pub fn commit( &self, prover_state: &mut ProverState, - vectors: &[&[M::Source]], - ) -> Witness + vectors: &[&B], + ) -> Witness where + B: BufferOps, + >::Witness: + Clone + PartialEq + Eq + PartialOrd + Ord + fmt::Debug + std::hash::Hash + Default, Standard: Distribution, H: DuplexSpongeInterface, R: RngCore + CryptoRng, @@ -310,18 +322,17 @@ impl Config { assert_eq!(vectors.len(), self.num_vectors); assert!(vectors.iter().all(|p| p.len() == self.vector_size)); - // Generate random mask - let masks = random_vector(prover_state.rng(), self.mask_length * self.num_messages()); - - // Interleaved RS Encode the vectors - let messages = vectors - .iter() - .flat_map(|v| chunks_exact_or_empty(v, self.message_length(), self.interleaving_depth)) - .collect::>(); - let matrix = ntt::interleaved_rs_encode(&messages, &masks, self.codeword_length); + let masks = B::random(prover_state.rng(), self.mask_length * self.num_messages()); + let matrix = B::interleaved_rs_encode( + vectors, + &masks, + self.message_length(), + self.interleaving_depth, + self.codeword_length, + ); // Commit to the matrix - let matrix_witness = self.matrix_commit.commit(prover_state, &matrix); + let matrix_witness = matrix.commit_rows(&self.matrix_commit, prover_state); // Handle out-of-domain points and values // TODO : Remove this logic after main whir protocol is updated @@ -332,7 +343,7 @@ impl Config { let mut oods_matrix = Vec::with_capacity(self.out_domain_samples * self.num_vectors); for &point in &oods_points { for &vector in vectors { - let value = mixed_univariate_evaluate(&*self.embedding, vector, point); + let value = vector.mixed_univariate_evaluate(&*self.embedding, point); prover_state.prover_message(&value); oods_matrix.push(value); } @@ -346,6 +357,7 @@ impl Config { points: oods_points, matrix: oods_matrix, }, + source_field: PhantomData, } } @@ -382,12 +394,14 @@ impl Config { /// When there are multiple openings, the evaluation matrices will /// be horizontally concatenated. #[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))] - pub fn open( + pub fn open( &self, prover_state: &mut ProverState, - witnesses: &[&Witness], + witnesses: &[&Witness], ) -> Evaluations where + Masks: BufferOps, + Matrix: MatrixBufferOps, H: DuplexSpongeInterface, R: RngCore + CryptoRng, u8: Decoding<[H::U]>, @@ -409,22 +423,23 @@ impl Config { // and collect them in the evaluation matrix. let stride = witnesses.len() * self.num_cols(); let mut matrix = vec![M::Source::ZERO; indices.len() * stride]; - let mut submatrix = Vec::with_capacity(indices.len() * self.num_cols()); let mut matrix_col_offset = 0; for witness in witnesses { - submatrix.clear(); - for (point_index, &code_index) in indices.iter().enumerate() { - let row = &witness.matrix - [code_index * self.num_cols()..(code_index + 1) * self.num_cols()]; - submatrix.extend_from_slice(row); - - let matrix_row = &mut matrix[point_index * stride..(point_index + 1) * stride]; - matrix_row[matrix_col_offset..matrix_col_offset + self.num_cols()] - .copy_from_slice(row); + let submatrix = witness.matrix.read_rows(&indices); + if self.num_cols() != 0 { + for (point_index, row) in submatrix.chunks_exact(self.num_cols()).enumerate() { + let matrix_row = &mut matrix[point_index * stride..(point_index + 1) * stride]; + matrix_row[matrix_col_offset..matrix_col_offset + self.num_cols()] + .copy_from_slice(row); + } } prover_state.prover_hint_ark(&submatrix); - self.matrix_commit - .open(prover_state, &witness.matrix_witness, &indices); + witness.matrix.open_rows( + &self.matrix_commit, + prover_state, + &witness.matrix_witness, + &indices, + ); matrix_col_offset += self.num_cols(); } @@ -511,7 +526,13 @@ impl Commitment { } } -impl Witness { +impl Witness +where + F: Field, + G: Field, + Masks: BufferOps, + Matrix: MatrixBufferOps, +{ /// Returns the out-of-domain evaluations. pub const fn out_of_domain(&self) -> &Evaluations { &self.out_of_domain @@ -642,7 +663,7 @@ pub(crate) mod tests { use crate::{ algebra::{ embedding::{Basefield, Compose, Frobenius, Identity}, - fields, + fields, mixed_univariate_evaluate, ntt::NTT, random_vector, univariate_evaluate, }, @@ -728,10 +749,12 @@ pub(crate) mod tests { // Prover let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit( - &mut prover_state, - &vectors.iter().map(|p| p.as_slice()).collect::>(), - ); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let vector_refs = vector_buffers.iter().collect::>(); + let witness = config.commit(&mut prover_state, &vector_refs); assert_eq!( witness.out_of_domain().points.len(), config.out_domain_samples diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 1d632ca0..deb49de5 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -45,7 +45,10 @@ use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, R use serde::{Deserialize, Serialize}; use crate::{ - algebra::{embedding::Identity, random_vector, scalar_mul_add_new, univariate_evaluate}, + algebra::{ + buffer::CpuBuffer, embedding::Identity, random_vector, scalar_mul_add_new, + univariate_evaluate, + }, hash::Hash, protocols::irs_commit::{ Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness, @@ -129,11 +132,19 @@ impl Config { .map(|_| random_vector(prover_state.rng(), self.c_zk_commit.vector_size)) .collect(); + let original_buffers = original_msgs + .iter() + .map(|msg| CpuBuffer::from_slice(msg)) + .collect::>(); + let fresh_buffers = fresh_msgs + .iter() + .map(|msg| CpuBuffer::from_slice(msg)) + .collect::>(); + // Tree layout: [originals..., freshes...] - let all_vectors: Vec<&[F]> = original_msgs + let all_vectors: Vec<&CpuBuffer> = original_buffers .iter() - .chain(fresh_msgs.iter()) - .map(|v| v.as_slice()) + .chain(fresh_buffers.iter()) .collect(); let mask_witness = self.c_zk_commit.commit(prover_state, &all_vectors); @@ -180,10 +191,8 @@ impl Config { // Step 2: compute and send combined polynomials + IRS randomness let irs_masks_per_vector = self.c_zk_commit.mask_length * self.c_zk_commit.interleaving_depth; - assert_eq!( - witness.mask_witness.masks.len(), - 2 * self.num_masks * irs_masks_per_vector - ); + let irs_masks = witness.mask_witness.masks.as_slice(); + assert_eq!(irs_masks.len(), 2 * self.num_masks * irs_masks_per_vector); for (i, (orig_msg, fresh_msg)) in original_msgs .iter() .zip(witness.fresh_msgs.iter()) @@ -195,10 +204,8 @@ impl Config { // r*_i = r'_i + γ · r_i if irs_masks_per_vector > 0 { - let orig_r = &witness.mask_witness.masks - [i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; - let fresh_r = &witness.mask_witness.masks[(self.num_masks + i) - * irs_masks_per_vector + let orig_r = &irs_masks[i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; + let fresh_r = &irs_masks[(self.num_masks + i) * irs_masks_per_vector ..(self.num_masks + i + 1) * irs_masks_per_vector]; let combined_r = scalar_mul_add_new(fresh_r, gamma, orig_r); prover_state.prover_messages(&combined_r); @@ -455,6 +462,7 @@ mod tests { let gamma: F = prover_state.verifier_message(); let irs_masks_per_vector = config.c_zk_commit.mask_length * config.c_zk_commit.interleaving_depth; + let irs_masks = witness.mask_witness.masks.as_slice(); for (i, (orig_msg, fresh_msg)) in original_msgs .iter() @@ -468,10 +476,8 @@ mod tests { prover_state.prover_messages(&combined_msg); if irs_masks_per_vector > 0 { - let orig_r = &witness.mask_witness.masks - [i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; - let fresh_r = &witness.mask_witness.masks[(config.num_masks + i) - * irs_masks_per_vector + let orig_r = &irs_masks[i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; + let fresh_r = &irs_masks[(config.num_masks + i) * irs_masks_per_vector ..(config.num_masks + i + 1) * irs_masks_per_vector]; let combined_r = scalar_mul_add_new(fresh_r, gamma, orig_r); prover_state.prover_messages(&combined_r); diff --git a/src/protocols/whir/mod.rs b/src/protocols/whir/mod.rs index 805a206e..9c6a31bb 100644 --- a/src/protocols/whir/mod.rs +++ b/src/protocols/whir/mod.rs @@ -14,6 +14,7 @@ use tracing::instrument; use crate::{ algebra::{ + buffer::CpuBuffer, embedding::{Embedding, Identity}, linear_form::LinearForm, }, @@ -79,7 +80,7 @@ impl Config { pub fn commit( &self, prover_state: &mut ProverState, - vectors: &[&[M::Source]], + vectors: &[&CpuBuffer], ) -> Witness where Standard: Distribution, @@ -128,6 +129,7 @@ mod tests { use super::*; use crate::{ algebra::{ + buffer::CpuBuffer, embedding::Basefield, fields::{Field64, Field64_3}, linear_form::{Covector, Evaluate, LinearForm, MultilinearExtension}, @@ -239,7 +241,8 @@ mod tests { let mut prover_state = ProverState::new_std(&ds); // Commit to the polynomial and generate auxiliary witness data - let witness = params.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let prove_linear_forms = build_prove_forms(&points, num_variables, true); @@ -380,7 +383,7 @@ mod tests { vec![F::from((i + 1) as u64); num_coeffs] }) .collect(); - let vec_refs = vectors.iter().map(|v| v.as_slice()).collect::>(); + let vec_refs = vectors.iter().collect::>(); let points: Vec<_> = (0..num_points_per_poly) .map(|_| random_vector(thread_rng(), num_variables)) @@ -401,7 +404,7 @@ mod tests { .flat_map(|linear_form| { vec_refs .iter() - .map(|vec| linear_form.evaluate(params.embedding(), vec)) + .map(|&vec| linear_form.evaluate(params.embedding(), vec)) }) .collect::>(); @@ -412,8 +415,12 @@ mod tests { let mut prover_state = ProverState::new_std(&ds); // Commit to each polynomial and generate witnesses + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); let mut witnesses = Vec::new(); - for &vec in &vec_refs { + for vec in &vector_buffers { let witness = params.commit(&mut prover_state, &[vec]); witnesses.push(witness); } @@ -569,8 +576,10 @@ mod tests { .instance(&Empty); let mut prover_state = ProverState::new_std(&ds); - let witness1 = params.commit(&mut prover_state, &[&vec1]); - let witness2 = params.commit(&mut prover_state, &[&vec2]); + let vec1_buffer = CpuBuffer::from_slice(&vec1); + let vec2_buffer = CpuBuffer::from_slice(&vec2); + let witness1 = params.commit(&mut prover_state, &[&vec1_buffer]); + let witness2 = params.commit(&mut prover_state, &[&vec2_buffer]); let prove_linear_forms = build_prove_forms(&constraint_points, num_variables, false); @@ -651,7 +660,7 @@ mod tests { let all_vectors: Vec> = (0..num_witnesses * batch_size) .map(|i| vec![F::from((i + 1) as u64); num_coeffs]) .collect::>(); - let vec_refs = all_vectors.iter().map(|p| p.as_slice()).collect::>(); + let vec_refs = all_vectors.iter().collect::>(); let points: Vec<_> = (0..num_points_per_poly) .map(|_| random_vector(thread_rng(), num_variables)) @@ -672,7 +681,7 @@ mod tests { .flat_map(|linear_form| { vec_refs .iter() - .map(|vec| linear_form.evaluate(params.embedding(), vec)) + .map(|&vec| linear_form.evaluate(params.embedding(), vec)) }) .collect::>(); @@ -683,8 +692,13 @@ mod tests { let mut prover_state = ProverState::new_std(&ds); // Commit using commit_batch (stacks batch_size polynomials per witness) + let vector_buffers = all_vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let buffer_refs = vector_buffers.iter().collect::>(); let mut witnesses = Vec::new(); - for witness_polys in vec_refs.chunks(batch_size) { + for witness_polys in buffer_refs.chunks(batch_size) { let witness = params.commit(&mut prover_state, witness_polys); witnesses.push(witness); } @@ -793,7 +807,7 @@ mod tests { let vectors: Vec> = (0..batch_size) .map(|_| random_vector(thread_rng(), num_coeffs)) .collect(); - let vec_refs = vectors.iter().map(|v| v.as_slice()).collect::>(); + let vec_refs = vectors.iter().collect::>(); // Generate `num_points` random points in the multilinear domain let points: Vec<_> = (0..num_points) @@ -809,7 +823,12 @@ mod tests { let mut prover_state = ProverState::new_std(&ds); // Create a commitment to the polynomial and generate auxiliary witness data - let batched_witness = params.commit(&mut prover_state, &vec_refs); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let buffer_refs = vector_buffers.iter().collect::>(); + let batched_witness = params.commit(&mut prover_state, &buffer_refs); // Create a weights matrix and evaluations for each polynomial let mut linear_forms: Vec>>> = Vec::new(); diff --git a/src/protocols/whir/prover.rs b/src/protocols/whir/prover.rs index 04b4473e..e3fc2043 100644 --- a/src/protocols/whir/prover.rs +++ b/src/protocols/whir/prover.rs @@ -8,6 +8,7 @@ use tracing::instrument; use super::{Config, Witness}; use crate::{ algebra::{ + buffer::CpuBuffer, dot, embedding::Embedding, eq_weights, lift, @@ -220,7 +221,10 @@ impl Config { // Execute standard WHIR rounds on the batched vectors for (round_index, round_config) in self.round_configs.iter().enumerate() { // Commit to the vector, this generates out-of-domain evaluations. - let new_witness = round_config.irs_committer.commit(prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let new_witness = round_config + .irs_committer + .commit(prover_state, &[&vector_buffer]); // Proof of work before in-domain challenges round_config.pow.prove(prover_state); diff --git a/src/protocols/whir_zk/committer.rs b/src/protocols/whir_zk/committer.rs index 942b87e1..01b7a863 100644 --- a/src/protocols/whir_zk/committer.rs +++ b/src/protocols/whir_zk/committer.rs @@ -5,6 +5,7 @@ use tracing::instrument; use super::{utils::BlindingPolynomials, Config}; use crate::{ + algebra::buffer::CpuBuffer, hash::Hash, protocols::{irs_commit, whir}, transcript::{ @@ -43,7 +44,7 @@ impl Config { pub fn commit( &self, prover_state: &mut ProverState, - polynomials: &[&[F]], + polynomials: &[&CpuBuffer], ) -> Witness where Standard: Distribution, @@ -63,6 +64,7 @@ impl Config { let num_blinding_variables = self.num_blinding_variables(); let num_witness_variables = self.num_witness_variables(); for &poly in polynomials { + let poly = poly.as_slice(); let blinding = BlindingPolynomials::sample( prover_state.rng(), num_blinding_variables, @@ -86,9 +88,10 @@ impl Config { .zip(mask.iter().cycle()) .map(|(&coeff, &m)| coeff + m) .collect::>(); + let f_hat_buffer = CpuBuffer::from_slice(&f_hat_vec); let witness = self .blinded_commitment - .commit(prover_state, &[f_hat_vec.as_slice()]); + .commit(prover_state, &[&f_hat_buffer]); f_hat_vectors.push(f_hat_vec); f_hat_witnesses.push(witness); blinding_polynomials.push(blinding); @@ -110,10 +113,11 @@ impl Config { ); blinding_vectors.extend(layout); } - let blinding_vector_refs = blinding_vectors + let blinding_buffers = blinding_vectors .iter() - .map(Vec::as_slice) + .map(|v| CpuBuffer::from_slice(v)) .collect::>(); + let blinding_vector_refs = blinding_buffers.iter().collect::>(); let blinding_witness = self .blinding_commitment .commit(prover_state, &blinding_vector_refs); diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs index b5e41c6e..c27ca37b 100644 --- a/src/protocols/whir_zk/mod.rs +++ b/src/protocols/whir_zk/mod.rs @@ -249,6 +249,7 @@ mod tests { use super::*; use crate::{ algebra::{ + buffer::CpuBuffer, fields::Field64, linear_form::{Covector, Evaluate, LinearForm, MultilinearExtension}, random_vector, @@ -350,7 +351,12 @@ mod tests { .session(&tag) .instance(&Empty); let mut prover_state = ProverState::new_std(&ds); - let witness = params.commit(&mut prover_state, vectors); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let vector_refs = vector_buffers.iter().collect::>(); + let witness = params.commit(&mut prover_state, &vector_refs); let _ = params.prove( &mut prover_state, vectors @@ -442,7 +448,12 @@ mod tests { .session(&format!("zk-stage1-negative {}:{}", file!(), line!())) .instance(&Empty); let mut prover_state = ProverState::new_std(&ds); - let witness = params.commit(&mut prover_state, &vectors); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let vector_refs = vector_buffers.iter().collect::>(); + let witness = params.commit(&mut prover_state, &vector_refs); let _ = params.prove( &mut prover_state, vectors @@ -496,7 +507,12 @@ mod tests { .session(&format!("zk-stage1-tamper {}:{}", file!(), line!())) .instance(&Empty); let mut prover_state = ProverState::new_std(&ds); - let witness = params.commit(&mut prover_state, &vectors); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let vector_refs = vector_buffers.iter().collect::>(); + let witness = params.commit(&mut prover_state, &vector_refs); let _ = params.prove( &mut prover_state, vectors @@ -557,7 +573,8 @@ mod tests { let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { let mut prover_state = ProverState::new_std(&ds); - let witness = params.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let _ = params.prove( &mut prover_state, vec![Cow::Borrowed(&vector)], From ee1b207b7e0301f9a2a03048abeef1d78db690d9 Mon Sep 17 00:00:00 2001 From: zkfriendly Date: Wed, 3 Jun 2026 17:13:47 +0200 Subject: [PATCH 4/5] wip: integrate buffer abstraction in prove and sumcheck --- src/algebra/buffer.rs | 139 +++++++++++++++++++++++++++++++++- src/bin/benchmark.rs | 8 +- src/bin/main.rs | 8 +- src/protocols/basecase.rs | 33 +++++--- src/protocols/sumcheck.rs | 61 ++++++++------- src/protocols/whir/mod.rs | 54 +++++++------ src/protocols/whir/prover.rs | 143 +++++++++++++++++------------------ 7 files changed, 300 insertions(+), 146 deletions(-) diff --git a/src/algebra/buffer.rs b/src/algebra/buffer.rs index e282707a..c6c8fb2b 100644 --- a/src/algebra/buffer.rs +++ b/src/algebra/buffer.rs @@ -1,3 +1,5 @@ +use std::{any::Any, mem}; + use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, Rng, RngCore}; use spongefish::DuplexSpongeInterface; @@ -5,7 +7,9 @@ use spongefish::DuplexSpongeInterface; use crate::{ algebra::{ embedding::{Embedding, Identity}, - mixed_dot, mixed_multilinear_extend, mixed_univariate_evaluate, ntt, + linear_form::{Covector, LinearForm, UnivariateEvaluation}, + mixed_dot, mixed_multilinear_extend, mixed_scalar_mul_add, mixed_univariate_evaluate, ntt, + sumcheck::{compute_sumcheck_polynomial, fold, fold_and_compute_polynomial}, }, hash::Hash, protocols::matrix_commit, @@ -18,6 +22,8 @@ pub trait BufferOps: Clone { type Buffer: BufferOps; type Matrix: MatrixBufferOps; + fn from_vec(source: Vec) -> Self; + fn zeros(length: usize) -> Self; fn len(&self) -> usize; fn is_empty(&self) -> bool { @@ -29,7 +35,26 @@ pub trait BufferOps: Clone { R: RngCore + CryptoRng, Standard: Distribution; + fn linear_forms_rlc( + size: usize, + linear_forms: &mut [Box>], + rlc_coeffs: &[F], + ) -> Self; + fn zero_pad(&mut self); + fn fold(&mut self, weight: F); + fn sumcheck_polynomial(&self, other: &Self) -> (F, F); + fn fold_and_sumcheck_polynomial(&mut self, other: &mut Self, weight: F) -> (F, F); + fn accumulate_univariate_evaluations( + &mut self, + evaluators: &[UnivariateEvaluation], + scalars: &[F], + ); + fn write_to_prover(&self, prover_state: &mut ProverState) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: ProverMessage<[H::U]>; fn mixed_extend, T: Field>( &self, embedding: &M, @@ -45,6 +70,22 @@ pub trait BufferOps: Clone { embedding: &M, point: M::Target, ) -> M::Target; + fn mixed_linear_combination>( + embedding: &M, + vectors: &[&Self], + coeffs: &[M::Target], + ) -> Self::Buffer; + fn mixed_scalar_mul_add_to>( + &self, + embedding: &M, + accumulator: &mut Self::Buffer, + weight: M::Target, + ); + fn mixed_dot_slice>( + &self, + embedding: &M, + other: &[M::Target], + ) -> M::Target; fn interleaved_rs_encode( vectors: &[&Self], masks: &Self, @@ -217,6 +258,16 @@ impl BufferOps for CpuBuffer { type Buffer = CpuBuffer; type Matrix = CpuMatrix; + fn from_vec(source: Vec) -> Self { + Self::from_vec(source) + } + + fn zeros(length: usize) -> Self { + Self { + data: vec![F::ZERO; length], + } + } + fn len(&self) -> usize { self.data.len() } @@ -231,12 +282,64 @@ impl BufferOps for CpuBuffer { } } + fn linear_forms_rlc( + size: usize, + linear_forms: &mut [Box>], + rlc_coeffs: &[F], + ) -> Self { + assert_eq!(linear_forms.len(), rlc_coeffs.len()); + let mut covector = vec![F::ZERO; size]; + if let Some((first, linear_forms)) = linear_forms.split_first_mut() { + debug_assert_eq!(rlc_coeffs[0], F::ONE); + if let Some(covector_form) = + (first.as_mut() as &mut dyn Any).downcast_mut::>() + { + mem::swap(&mut covector, &mut covector_form.vector); + } else { + first.accumulate(&mut covector, F::ONE); + } + for (rlc_coeff, linear_form) in rlc_coeffs[1..].iter().zip(linear_forms) { + linear_form.accumulate(&mut covector, *rlc_coeff); + } + } + Self { data: covector } + } + fn zero_pad(&mut self) { if !self.is_empty() { self.data.resize(self.len().next_power_of_two(), F::ZERO); } } + fn fold(&mut self, weight: F) { + fold(&mut self.data, weight); + } + + fn sumcheck_polynomial(&self, other: &Self) -> (F, F) { + compute_sumcheck_polynomial(&self.data, other.as_slice()) + } + + fn fold_and_sumcheck_polynomial(&mut self, other: &mut Self, weight: F) -> (F, F) { + fold_and_compute_polynomial(&mut self.data, &mut other.data, weight) + } + + fn accumulate_univariate_evaluations( + &mut self, + evaluators: &[UnivariateEvaluation], + scalars: &[F], + ) { + UnivariateEvaluation::accumulate_many(evaluators, &mut self.data, scalars); + } + + fn write_to_prover(&self, prover_state: &mut ProverState) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: ProverMessage<[H::U]>, + { + prover_state.prover_messages(&self.data); + } + fn mixed_extend, T: Field>( &self, embedding: &M, @@ -265,6 +368,40 @@ impl BufferOps for CpuBuffer { mixed_univariate_evaluate(embedding, &self.data, point) } + fn mixed_linear_combination>( + embedding: &M, + vectors: &[&Self], + coeffs: &[M::Target], + ) -> Self::Buffer { + assert_eq!(vectors.len(), coeffs.len()); + let Some((first, vectors)) = vectors.split_first() else { + return CpuBuffer::from_vec(Vec::new()); + }; + debug_assert_eq!(coeffs[0], M::Target::ONE); + let mut accumulator = crate::algebra::lift(embedding, first.as_slice()); + for (coeff, vector) in coeffs[1..].iter().zip(vectors) { + mixed_scalar_mul_add(embedding, &mut accumulator, *coeff, vector.as_slice()); + } + CpuBuffer::from_vec(accumulator) + } + + fn mixed_scalar_mul_add_to>( + &self, + embedding: &M, + accumulator: &mut Self::Buffer, + weight: M::Target, + ) { + mixed_scalar_mul_add(embedding, &mut accumulator.data, weight, self.as_slice()); + } + + fn mixed_dot_slice>( + &self, + embedding: &M, + other: &[M::Target], + ) -> M::Target { + mixed_dot(embedding, other, self.as_slice()) + } + fn interleaved_rs_encode( vectors: &[&Self], masks: &Self, diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index b8d4eeab..42e654b0 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -167,8 +167,8 @@ where let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(vector.as_slice())], - vec![Cow::Owned(witness)], + &[&vector_buffer], + vec![&witness], vec![], Cow::Owned(vec![]), ); @@ -252,8 +252,8 @@ where let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(vector.as_slice())], - vec![Cow::Owned(witness)], + &[&vector_buffer], + vec![&witness], prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); diff --git a/src/bin/main.rs b/src/bin/main.rs index a69ba1a0..c63ada92 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -185,8 +185,8 @@ where let whir_prove_time = Instant::now(); let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(vector.as_slice())], - vec![Cow::Owned(witness)], + &[&vector_buffer], + vec![&witness], prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -323,8 +323,8 @@ where let whir_prove_time = Instant::now(); let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(&vector)], - witness, + &[&vector_buffer], + vec![&witness], prove_linear_forms, Cow::Borrowed(&evaluations), ); diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index eac97802..e6faa026 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -49,9 +49,9 @@ impl Config { pub fn prove( &self, prover_state: &mut ProverState, - mut vector: Vec, + vector: Vec, witness: &irs_commit::Witness, - mut covector: Vec, + covector: Vec, mut sum: F, ) -> Opening where @@ -81,14 +81,22 @@ impl Config { prover_state.prover_messages(&vector); prover_state.prover_messages(witness.masks.as_slice()); let _ = self.commit.open(prover_state, &[witness]); + let mut vector_buffer = CpuBuffer::from_vec(vector); + let mut covector_buffer = CpuBuffer::from_vec(covector); let point = self .sumcheck - .prove(prover_state, &mut vector, &mut covector, &mut sum, &[]) + .prove( + prover_state, + &mut vector_buffer, + &mut covector_buffer, + &mut sum, + &[], + ) .round_challenges; - assert!(!vector[0].is_zero(), "Proof failed"); + assert!(!vector_buffer.as_slice()[0].is_zero(), "Proof failed"); return Opening { evaluation_points: point, - linear_form_evaluation: covector[0], + linear_form_evaluation: covector_buffer.as_slice()[0], }; } @@ -106,7 +114,7 @@ impl Config { // RLC the mask with the vector let mask_rlc = prover_state.verifier_message::(); assert!(!mask_rlc.is_zero(), "Proof failed"); - let mut masked_vector = scalar_mul_add_new(&mask, mask_rlc, &vector); + let masked_vector = scalar_mul_add_new(&mask, mask_rlc, &vector); prover_state.prover_messages(&masked_vector); // Send combined IRS randomness. (r^* in paper) @@ -122,12 +130,14 @@ impl Config { // Run sumcheck to reduce linear form claim let mut masked_sum = mask_sum + mask_rlc * sum; + let mut masked_vector_buffer = CpuBuffer::from_vec(masked_vector); + let mut covector_buffer = CpuBuffer::from_vec(covector); let point = self .sumcheck .prove( prover_state, - &mut masked_vector, - &mut covector, + &mut masked_vector_buffer, + &mut covector_buffer, &mut masked_sum, &[], ) @@ -137,11 +147,14 @@ impl Config { // Basically the sumcheck equation has degenerated to 0 * l(r) = 0, which provides // no constraints on l(r) that the verifier can return. // This event is cryptographically unlikely as `F` is challenge sized. - assert!(!masked_vector[0].is_zero(), "Proof failed"); + assert!( + !masked_vector_buffer.as_slice()[0].is_zero(), + "Proof failed" + ); Opening { evaluation_points: point, - linear_form_evaluation: covector[0], + linear_form_evaluation: covector_buffer.as_slice()[0], } } diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index 5018e22d..13c0c255 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -9,11 +9,7 @@ use serde::{Deserialize, Serialize}; use tracing::instrument; use crate::{ - algebra::{ - dot, - sumcheck::{compute_sumcheck_polynomial, fold, fold_and_compute_polynomial}, - univariate_evaluate, - }, + algebra::{buffer::BufferOps, univariate_evaluate}, protocols::proof_of_work, transcript::{ codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverState, VerificationResult, @@ -66,15 +62,16 @@ impl Config { /// - Applies proof-of-work grinding if required. /// - Returns the sampled folding randomness values used in each reduction step. #[cfg_attr(feature = "tracing", instrument(skip_all))] - pub fn prove( + pub fn prove( &self, prover_state: &mut ProverState, - a: &mut Vec, - b: &mut Vec, + a: &mut B, + b: &mut B, sum: &mut F, masks: &[F], ) -> SumcheckOpening where + B: BufferOps, H: DuplexSpongeInterface, R: CryptoRng + RngCore, F: Codec<[H::U]>, @@ -88,7 +85,7 @@ impl Config { assert!(self.mask_length == 0 || self.mask_length >= 3); assert_eq!(a.len(), self.initial_size); assert_eq!(b.len(), self.initial_size); - debug_assert_eq!(dot(a, b), *sum); + debug_assert_eq!(a.dot(b), *sum); assert_eq!(masks.len(), self.num_rounds * self.mask_length); let half = F::from(2).inverse().unwrap(); @@ -110,9 +107,9 @@ impl Config { { // Fold and compute sumcheck polynomial in one pass. let (c0, c2) = if let Some(w) = prev_round_challenge { - fold_and_compute_polynomial(a, b, w) + a.fold_and_sumcheck_polynomial(b, w) } else { - compute_sumcheck_polynomial(a, b) + a.sumcheck_polynomial(b) }; let c1 = *sum - c0.double() - c2; @@ -150,8 +147,8 @@ impl Config { } if let Some(w) = prev_round_challenge { // Final fold of the inputs (no polynomial computation) - fold(a, w); - fold(b, w); + a.fold(w); + b.fold(w); } *sum = mask_sum + mask_rlc * *sum; @@ -238,23 +235,24 @@ fn eval_01(coefficients: &[F]) -> F { #[cfg(test)] mod tests { - use ark_std::rand::{ - distributions::{Distribution, Standard}, - rngs::StdRng, - SeedableRng, - }; - use proptest::{prelude::Just, prop_oneof, proptest, strategy::Strategy}; - #[cfg(feature = "tracing")] - use tracing::instrument; - use super::*; use crate::{ algebra::{ + buffer::CpuBuffer, + dot, fields::{self, Field64}, multilinear_extend, random_vector, }, transcript::DomainSeparator, }; + use ark_std::rand::{ + distributions::{Distribution, Standard}, + rngs::StdRng, + SeedableRng, + }; + use proptest::{prelude::Just, prop_oneof, proptest, strategy::Strategy}; + #[cfg(feature = "tracing")] + use tracing::instrument; impl Config where @@ -299,8 +297,8 @@ mod tests { let masks = random_vector(&mut rng, config.mask_length * config.num_rounds); // Prover - let mut vector = initial_vector.clone(); - let mut covector = initial_covector.clone(); + let mut vector = CpuBuffer::from_slice(&initial_vector); + let mut covector = CpuBuffer::from_slice(&initial_covector); let mut sum = initial_sum; let mut prover_state = ProverState::new_std(&ds); let SumcheckOpening { @@ -316,8 +314,14 @@ mod tests { assert_eq!(vector.len(), config.final_size()); assert_eq!(covector.len(), config.final_size()); if config.final_size() == 1 { - assert_eq!(multilinear_extend(&initial_vector, &point), vector[0]); - assert_eq!(multilinear_extend(&initial_covector, &point), covector[0]); + assert_eq!( + multilinear_extend(&initial_vector, &point), + vector.as_slice()[0] + ); + assert_eq!( + multilinear_extend(&initial_covector, &point), + covector.as_slice()[0] + ); } else { // TODO: Check correct folding. } @@ -327,7 +331,10 @@ mod tests { .zip(&point) .map(|(m, x)| univariate_evaluate(m, *x)) .sum(); - assert_eq!(sum, expected_mask_sum + mask_rlc * dot(&vector, &covector)); + assert_eq!( + sum, + expected_mask_sum + mask_rlc * dot(vector.as_slice(), covector.as_slice()) + ); let proof = prover_state.proof(); diff --git a/src/protocols/whir/mod.rs b/src/protocols/whir/mod.rs index 9c6a31bb..0f20a14b 100644 --- a/src/protocols/whir/mod.rs +++ b/src/protocols/whir/mod.rs @@ -14,7 +14,7 @@ use tracing::instrument; use crate::{ algebra::{ - buffer::CpuBuffer, + buffer::{BufferOps, CpuBuffer}, embedding::{Embedding, Identity}, linear_form::LinearForm, }, @@ -46,7 +46,13 @@ pub struct RoundConfig { pub pow: proof_of_work::Config, } -pub type Witness> = irs_commit::Witness; +pub type Witness, B = CpuBuffer<::Source>> = + irs_commit::Witness< + ::Source, + F, + B, + ::Source>>::Matrix, + >; pub type Commitment = irs_commit::Commitment; #[must_use = "The final claim must be checked if there where any linear forms."] @@ -76,13 +82,19 @@ impl FinalClaim { impl Config { /// Commit to one or more vectors. - #[cfg_attr(feature = "tracing", instrument(skip_all, fields(size = vectors.first().unwrap().len())))] - pub fn commit( + #[cfg_attr( + feature = "tracing", + instrument(skip_all, fields(size = vectors.first().unwrap().len())) + )] + pub fn commit( &self, prover_state: &mut ProverState, - vectors: &[&CpuBuffer], - ) -> Witness + vectors: &[&B], + ) -> Witness where + B: BufferOps, + >::Witness: + Clone + PartialEq + Eq + PartialOrd + Ord + Debug + std::hash::Hash + Default, Standard: Distribution, H: DuplexSpongeInterface, R: RngCore + CryptoRng, @@ -249,8 +261,8 @@ mod tests { // Generate a proof for the given statement and witness let _ = params.prove( &mut prover_state, - vec![Cow::from(vector)], - vec![Cow::Owned(witness)], + &[&vector_buffer], + vec![&witness], prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -430,11 +442,8 @@ mod tests { // Batch prove all polynomials together let _ = params.prove( &mut prover_state, - vectors - .iter() - .map(|v| Cow::Borrowed(v.as_slice())) - .collect(), - witnesses.into_iter().map(Cow::Owned).collect(), + &vector_buffers.iter().collect::>(), + witnesses.iter().collect(), prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -584,10 +593,11 @@ mod tests { let prove_linear_forms = build_prove_forms(&constraint_points, num_variables, false); // Generate proof with mismatched polynomials + let vec_wrong_buffer = CpuBuffer::from_vec(vec_wrong); let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(vec1.as_slice()), Cow::from(vec_wrong)], - vec![Cow::Owned(witness1), Cow::Owned(witness2)], + &[&vec1_buffer, &vec_wrong_buffer], + vec![&witness1, &witness2], prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -708,11 +718,8 @@ mod tests { // Batch prove all witnesses together let _ = params.prove( &mut prover_state, - all_vectors - .iter() - .map(|v| Cow::Borrowed(v.as_slice())) - .collect(), - witnesses.into_iter().map(Cow::Owned).collect(), + &vector_buffers.iter().collect::>(), + witnesses.iter().collect(), prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -858,11 +865,8 @@ mod tests { .collect::>(); let _ = params.prove( &mut prover_state, - vectors - .iter() - .map(|v| Cow::Borrowed(v.as_slice())) - .collect(), - vec![Cow::Owned(batched_witness)], + &buffer_refs, + vec![&batched_witness], prove_linear_forms, Cow::Borrowed(values.as_slice()), ); diff --git a/src/protocols/whir/prover.rs b/src/protocols/whir/prover.rs index e3fc2043..01c87f3c 100644 --- a/src/protocols/whir/prover.rs +++ b/src/protocols/whir/prover.rs @@ -1,4 +1,4 @@ -use std::{any::Any, borrow::Cow, mem}; +use std::{borrow::Cow, fmt::Debug}; use ark_ff::{AdditiveGroup, Field}; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; @@ -8,13 +8,11 @@ use tracing::instrument; use super::{Config, Witness}; use crate::{ algebra::{ - buffer::CpuBuffer, + buffer::{BufferOps, MatrixBufferOps}, dot, embedding::Embedding, - eq_weights, lift, - linear_form::{Covector, Evaluate, LinearForm, UnivariateEvaluation}, - mixed_scalar_mul_add, - sumcheck::fold, + eq_weights, + linear_form::LinearForm, tensor_product, }, hash::Hash, @@ -26,9 +24,15 @@ use crate::{ utils::zip_strict, }; -enum RoundWitness<'a, F: Field, M: Embedding> { - Initial(Vec>>), - Round(irs_commit::Witness), +enum RoundWitness<'a, F, M, B> +where + F: Field, + M: Embedding, + B: BufferOps, + B::Buffer: BufferOps, +{ + Initial(Vec<&'a irs_commit::Witness>), + Round(irs_commit::Witness, as BufferOps>::Matrix>), } impl Config { @@ -48,15 +52,29 @@ impl Config { /// #[cfg_attr(feature = "tracing", instrument(skip_all))] #[allow(clippy::too_many_lines, clippy::cognitive_complexity)] - pub fn prove<'a, H, R>( + pub fn prove<'a, H, R, B>( &self, prover_state: &mut ProverState, - vectors: Vec>, - witnesses: Vec>>, + vectors: &[&B], + witnesses: Vec<&'a Witness>, linear_forms: Vec>>, evaluations: Cow<'a, [M::Target]>, ) -> FinalClaim where + B: BufferOps, + B::Buffer: BufferOps, + >::Witness: + Clone + PartialEq + Eq + PartialOrd + Ord + Debug + std::hash::Hash + Default, + < as BufferOps>::Matrix as MatrixBufferOps< + M::Target, + >>::Witness: Clone + + PartialEq + + Eq + + PartialOrd + + Ord + + Debug + + std::hash::Hash + + Default, Standard: Distribution + Distribution, H: DuplexSpongeInterface, R: RngCore + CryptoRng, @@ -74,7 +92,7 @@ impl Config { witnesses.len() * self.initial_committer.num_vectors ); assert_eq!(evaluations.len(), num_vectors * linear_forms.len()); - for vector in &vectors { + for vector in vectors { assert_eq!(vector.len(), self.initial_size()); } for linear_form in &linear_forms { @@ -86,8 +104,11 @@ impl Config { { use crate::algebra::linear_form::Covector; let covector = Covector::from(&**linear_form); - for (vector, evaluation) in zip_strict(&vectors, evaluations) { - debug_assert_eq!(covector.evaluate(self.embedding(), vector), *evaluation); + for (vector, evaluation) in zip_strict(vectors, evaluations) { + debug_assert_eq!( + vector.mixed_dot_slice(self.embedding(), &covector.vector), + *evaluation + ); } } if vectors.is_empty() { @@ -111,12 +132,13 @@ impl Config { if j >= vector_offset && j < oods_row.len() + vector_offset { debug_assert_eq!( oods_row[j - vector_offset], - oods_eval.evaluate(self.embedding(), vector) + vector.mixed_univariate_evaluate(self.embedding(), oods_eval.point) ); oods_matrix.push(oods_row[j - vector_offset]); } else { - let eval = oods_eval.evaluate(self.embedding(), vector); + let eval = + vector.mixed_univariate_evaluate(self.embedding(), oods_eval.point); prover_state.prover_message(&eval); oods_matrix.push(eval); } @@ -131,18 +153,9 @@ impl Config { // Random linear combination of the vectors. let mut vector_rlc_coeffs: Vec = geometric_challenge(prover_state, num_vectors); assert_eq!(vector_rlc_coeffs[0], M::Target::ONE); - // Recycle the first input as the accumulator (its coefficient is always ONE). - let mut vectors = vectors.into_iter(); - let first = vectors.next().expect("non-empty"); - let mut vector = match first { - Cow::Borrowed(slice) => lift(self.embedding(), slice), - Cow::Owned(vec) => self.embedding().map_vec(vec), - }; - for (rlc_coeff, input_vector) in zip_strict(&vector_rlc_coeffs[1..], vectors) { - mixed_scalar_mul_add(self.embedding(), &mut vector, *rlc_coeff, &input_vector); - } + let mut vector = B::mixed_linear_combination(self.embedding(), vectors, &vector_rlc_coeffs); - let mut prev_witness: RoundWitness<'a, M::Target, M> = RoundWitness::Initial(witnesses); + let mut prev_witness: RoundWitness<'a, M::Target, M, B> = RoundWitness::Initial(witnesses); // Random linear combination of the constraints. let constraint_rlc_coeffs: Vec = @@ -150,26 +163,16 @@ impl Config { let has_constraints = !constraint_rlc_coeffs.is_empty(); let (initial_forms_rlc_coeffs, oods_rlc_coeffs) = constraint_rlc_coeffs.split_at(linear_forms.len()); - // Try to recycle the first linear form as Covector. - let mut covector = vec![]; let mut linear_forms = linear_forms; - if let Some((first, linear_forms)) = linear_forms.split_first_mut() { - debug_assert_eq!(initial_forms_rlc_coeffs[0], M::Target::ONE); - if let Some(covector_form) = - (first.as_mut() as &mut dyn Any).downcast_mut::>() - { - mem::swap(&mut covector, &mut covector_form.vector); - } else { - covector.resize(self.initial_size(), M::Target::ZERO); - first.accumulate(&mut covector, M::Target::ONE); - } - for (rlc_coeff, linear_form) in zip_strict(&initial_forms_rlc_coeffs[1..], linear_forms) - { - linear_form.accumulate(&mut covector, *rlc_coeff); - } - } else if has_constraints { - covector.resize(self.initial_size(), M::Target::ZERO); - } + let mut covector = if has_constraints { + as BufferOps>::linear_forms_rlc( + self.initial_size(), + &mut linear_forms, + initial_forms_rlc_coeffs, + ) + } else { + as BufferOps>::zeros(0) + }; drop(linear_forms); // Compute "The Sum" @@ -181,17 +184,17 @@ impl Config { .sum(); drop(evaluations); - debug_assert!(!has_constraints || dot(&vector, &covector) == the_sum); + debug_assert!(!has_constraints || vector.dot(&covector) == the_sum); // Add OODS constraints - UnivariateEvaluation::accumulate_many(&oods_evals, &mut covector, oods_rlc_coeffs); + covector.accumulate_univariate_evaluations(&oods_evals, oods_rlc_coeffs); the_sum += zip_strict(oods_rlc_coeffs, oods_matrix.chunks_exact(num_vectors)) .map(|(poly_coeff, row)| *poly_coeff * dot(&vector_rlc_coeffs, row)) .sum::(); drop(oods_evals); drop(oods_matrix); - debug_assert!(!has_constraints || dot(&vector, &covector) == the_sum); + debug_assert!(!has_constraints || vector.dot(&covector) == the_sum); // Run initial sumcheck on batched vectors with combined statement let mut folding_randomness = if has_constraints { @@ -208,35 +211,32 @@ impl Config { self.initial_skip_pow.prove(prover_state); // Fold vector for &f in &folding_randomness { - fold(&mut vector, f); + vector.fold(f); } // Covector must be all zeros. - covector = vec![M::Target::ZERO; self.initial_sumcheck.final_size()]; + covector = as BufferOps>::zeros( + self.initial_sumcheck.final_size(), + ); folding_randomness }; let mut evaluation_point = folding_randomness.clone(); - debug_assert_eq!(dot(&vector, &covector), the_sum); + debug_assert_eq!(vector.dot(&covector), the_sum); // Execute standard WHIR rounds on the batched vectors for (round_index, round_config) in self.round_configs.iter().enumerate() { // Commit to the vector, this generates out-of-domain evaluations. - let vector_buffer = CpuBuffer::from_slice(&vector); - let new_witness = round_config - .irs_committer - .commit(prover_state, &[&vector_buffer]); + let new_witness = round_config.irs_committer.commit(prover_state, &[&vector]); // Proof of work before in-domain challenges round_config.pow.prove(prover_state); // Open the previous round's witness. let in_domain = match prev_witness { - RoundWitness::Initial(init_witnesses) => { - let witness_refs: Vec<&_> = init_witnesses.iter().map(|c| &**c).collect(); - self.initial_committer - .open(prover_state, &witness_refs) - .lift(self.embedding()) - } + RoundWitness::Initial(init_witnesses) => self + .initial_committer + .open(prover_state, &init_witnesses) + .lift(self.embedding()), RoundWitness::Round(old_witness) => { let prev_round_config = &self.round_configs[round_index - 1]; prev_round_config @@ -260,13 +260,9 @@ impl Config { ))) .collect::>(); let stir_rlc_coeffs = geometric_challenge(prover_state, stir_challenges.len()); - UnivariateEvaluation::accumulate_many( - &stir_challenges, - &mut covector, - &stir_rlc_coeffs, - ); + covector.accumulate_univariate_evaluations(&stir_challenges, &stir_rlc_coeffs); the_sum += dot(&stir_rlc_coeffs, &stir_evaluations); - debug_assert_eq!(dot(&vector, &covector), the_sum); + debug_assert_eq!(vector.dot(&covector), the_sum); // Run sumcheck for this round folding_randomness = round_config @@ -275,7 +271,7 @@ impl Config { .round_challenges; evaluation_point.extend(folding_randomness.iter().copied()); - debug_assert_eq!(dot(&vector, &covector), the_sum); + debug_assert_eq!(vector.dot(&covector), the_sum); prev_witness = RoundWitness::Round(new_witness); vector_rlc_coeffs = vec![M::Target::ONE]; @@ -283,9 +279,7 @@ impl Config { // Directly send the vector to the verifier. assert_eq!(vector.len(), self.final_sumcheck.initial_size); - for coeff in &vector { - prover_state.prover_message(coeff); - } + vector.write_to_prover(prover_state); // PoW self.final_pow.prove(prover_state); @@ -293,8 +287,7 @@ impl Config { // Open and consume the final previous witness. match prev_witness { RoundWitness::Initial(init_witnesses) => { - let witness_refs: Vec<&_> = init_witnesses.iter().map(|c| &**c).collect(); - let _in_domain = self.initial_committer.open(prover_state, &witness_refs); + let _in_domain = self.initial_committer.open(prover_state, &init_witnesses); } RoundWitness::Round(old_witness) => { let prev_config = self.round_configs.last().unwrap(); From 595b19d4be6928fc6ad2e229b1b2977a78bc347e Mon Sep 17 00:00:00 2001 From: zkfriendly Date: Wed, 3 Jun 2026 17:30:35 +0200 Subject: [PATCH 5/5] chore: add temp ignore notes --- src/protocols/basecase.rs | 2 ++ src/protocols/code_switch.rs | 2 ++ src/protocols/mask_proximity.rs | 2 ++ src/protocols/whir_zk/committer.rs | 2 ++ src/protocols/whir_zk/mod.rs | 2 ++ 5 files changed, 10 insertions(+) diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index e6faa026..d1e7e73b 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + //! Base Case Linear Opening Protocol //! //! It support honest verifier zero-knowledge (HVZK), but is not succinct. diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 9da26424..aa2a6370 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + //! Code-switching IOR: R_{C, C_zk, sl} → R_{C', C_zk, sl'} //! //! Reduces a proximity claim about oracle f (source code C) to a proximity diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index deb49de5..7413158c 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + //! Mask proximity verification via γ-combination. //! //! Implements Construction 7.2 (p.43-44) specialized for zero-constraint mask diff --git a/src/protocols/whir_zk/committer.rs b/src/protocols/whir_zk/committer.rs index 01b7a863..dbc875ec 100644 --- a/src/protocols/whir_zk/committer.rs +++ b/src/protocols/whir_zk/committer.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution}; #[cfg(feature = "tracing")] diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs index c27ca37b..493ea32e 100644 --- a/src/protocols/whir_zk/mod.rs +++ b/src/protocols/whir_zk/mod.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + #![cfg(feature = "rs_in_order")] // TODO: Support permuted. mod committer; mod prover;