diff --git a/src/algebra/buffer.rs b/src/algebra/buffer.rs new file mode 100644 index 00000000..c6c8fb2b --- /dev/null +++ b/src/algebra/buffer.rs @@ -0,0 +1,442 @@ +use std::{any::Any, mem}; + +use ark_ff::Field; +use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, Rng, RngCore}; +use spongefish::DuplexSpongeInterface; + +use crate::{ + algebra::{ + embedding::{Embedding, Identity}, + linear_form::{Covector, LinearForm, UnivariateEvaluation}, + mixed_dot, mixed_multilinear_extend, mixed_scalar_mul_add, mixed_univariate_evaluate, ntt, + sumcheck::{compute_sumcheck_polynomial, fold, fold_and_compute_polynomial}, + }, + hash::Hash, + protocols::matrix_commit, + transcript::{ProverMessage, ProverState}, + type_info::TypeInfo, + utils::chunks_exact_or_empty, +}; + +pub trait BufferOps: Clone { + type Buffer: BufferOps; + type Matrix: MatrixBufferOps; + + fn from_vec(source: Vec) -> Self; + fn zeros(length: usize) -> Self; + fn len(&self) -> usize; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn random(rng: &mut R, length: usize) -> Self + where + R: RngCore + CryptoRng, + Standard: Distribution; + + fn linear_forms_rlc( + size: usize, + linear_forms: &mut [Box>], + rlc_coeffs: &[F], + ) -> Self; + + fn zero_pad(&mut self); + fn fold(&mut self, weight: F); + fn sumcheck_polynomial(&self, other: &Self) -> (F, F); + fn fold_and_sumcheck_polynomial(&mut self, other: &mut Self, weight: F) -> (F, F); + fn accumulate_univariate_evaluations( + &mut self, + evaluators: &[UnivariateEvaluation], + scalars: &[F], + ); + fn write_to_prover(&self, prover_state: &mut ProverState) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: ProverMessage<[H::U]>; + fn mixed_extend, T: Field>( + &self, + embedding: &M, + point: &[M::Target], + ) -> M::Target; + fn mixed_dot, T: Field>( + &self, + embedding: &M, + other: &Self::Buffer, + ) -> M::Target; + fn mixed_univariate_evaluate>( + &self, + embedding: &M, + point: M::Target, + ) -> M::Target; + fn mixed_linear_combination>( + embedding: &M, + vectors: &[&Self], + coeffs: &[M::Target], + ) -> Self::Buffer; + fn mixed_scalar_mul_add_to>( + &self, + embedding: &M, + accumulator: &mut Self::Buffer, + weight: M::Target, + ); + fn mixed_dot_slice>( + &self, + embedding: &M, + other: &[M::Target], + ) -> M::Target; + fn interleaved_rs_encode( + vectors: &[&Self], + masks: &Self, + message_length: usize, + interleaving_depth: usize, + codeword_length: usize, + ) -> Self::Matrix + where + F: 'static; + fn dot(&self, other: &Self) -> F; +} + +pub trait MatrixBufferOps { + type Witness; + + fn len(&self) -> usize; + fn num_rows(&self) -> usize; + fn num_cols(&self) -> usize; + + fn commit_rows( + &self, + config: &matrix_commit::Config, + prover_state: &mut ProverState, + ) -> Self::Witness + where + F: TypeInfo + matrix_commit::Encodable + Send + Sync, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>; + + fn open_rows( + &self, + config: &matrix_commit::Config, + prover_state: &mut ProverState, + witness: &Self::Witness, + indices: &[usize], + ) where + F: TypeInfo + matrix_commit::Encodable + Send + Sync, + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>; + + fn read_rows(&self, indices: &[usize]) -> Vec; +} + +#[derive( + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Debug, + Default, + serde::Serialize, + serde::Deserialize, +)] +pub struct CpuMatrix { + data: Vec, + num_rows: usize, + num_cols: usize, +} + +impl CpuMatrix { + pub fn from_vec(data: Vec, num_rows: usize, num_cols: usize) -> Self { + assert_eq!(data.len(), num_rows * num_cols); + Self { + data, + num_rows, + num_cols, + } + } + + fn row(&self, row: usize) -> &[F] { + let start = row * self.num_cols; + let end = start + self.num_cols; + &self.data[start..end] + } +} + +impl MatrixBufferOps for CpuMatrix +where + F: Field + TypeInfo + matrix_commit::Encodable + Send + Sync, +{ + type Witness = matrix_commit::Witness; + + fn len(&self) -> usize { + self.data.len() + } + + fn num_rows(&self) -> usize { + self.num_rows + } + + fn num_cols(&self) -> usize { + self.num_cols + } + + fn commit_rows( + &self, + config: &matrix_commit::Config, + prover_state: &mut ProverState, + ) -> Self::Witness + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>, + { + assert_eq!(config.num_rows(), self.num_rows); + assert_eq!(config.num_cols, self.num_cols); + config.commit(prover_state, &self.data) + } + + fn open_rows( + &self, + config: &matrix_commit::Config, + prover_state: &mut ProverState, + witness: &Self::Witness, + indices: &[usize], + ) where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + Hash: ProverMessage<[H::U]>, + { + config.open(prover_state, witness, indices); + } + + fn read_rows(&self, indices: &[usize]) -> Vec { + let mut rows = Vec::with_capacity(indices.len() * self.num_cols); + for &index in indices { + assert!(index < self.num_rows); + rows.extend_from_slice(self.row(index)); + } + rows + } +} +#[derive( + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Debug, + Default, + serde::Serialize, + serde::Deserialize, +)] +pub struct CpuBuffer { + data: Vec, +} + +impl CpuBuffer { + pub fn from_vec(source: Vec) -> Self { + Self { data: source } + } + + pub fn from_slice(source: &[F]) -> Self { + Self { + data: Vec::from(source), + } + } + + pub(crate) fn as_slice(&self) -> &[F] { + self.data.as_slice() + } +} + +impl BufferOps for CpuBuffer { + type Buffer = CpuBuffer; + type Matrix = CpuMatrix; + + fn from_vec(source: Vec) -> Self { + Self::from_vec(source) + } + + fn zeros(length: usize) -> Self { + Self { + data: vec![F::ZERO; length], + } + } + + fn len(&self) -> usize { + self.data.len() + } + + fn random(rng: &mut R, length: usize) -> Self + where + R: RngCore + CryptoRng, + Standard: Distribution, + { + Self { + data: (0..length).map(|_| rng.gen()).collect(), + } + } + + fn linear_forms_rlc( + size: usize, + linear_forms: &mut [Box>], + rlc_coeffs: &[F], + ) -> Self { + assert_eq!(linear_forms.len(), rlc_coeffs.len()); + let mut covector = vec![F::ZERO; size]; + if let Some((first, linear_forms)) = linear_forms.split_first_mut() { + debug_assert_eq!(rlc_coeffs[0], F::ONE); + if let Some(covector_form) = + (first.as_mut() as &mut dyn Any).downcast_mut::>() + { + mem::swap(&mut covector, &mut covector_form.vector); + } else { + first.accumulate(&mut covector, F::ONE); + } + for (rlc_coeff, linear_form) in rlc_coeffs[1..].iter().zip(linear_forms) { + linear_form.accumulate(&mut covector, *rlc_coeff); + } + } + Self { data: covector } + } + + fn zero_pad(&mut self) { + if !self.is_empty() { + self.data.resize(self.len().next_power_of_two(), F::ZERO); + } + } + + fn fold(&mut self, weight: F) { + fold(&mut self.data, weight); + } + + fn sumcheck_polynomial(&self, other: &Self) -> (F, F) { + compute_sumcheck_polynomial(&self.data, other.as_slice()) + } + + fn fold_and_sumcheck_polynomial(&mut self, other: &mut Self, weight: F) -> (F, F) { + fold_and_compute_polynomial(&mut self.data, &mut other.data, weight) + } + + fn accumulate_univariate_evaluations( + &mut self, + evaluators: &[UnivariateEvaluation], + scalars: &[F], + ) { + UnivariateEvaluation::accumulate_many(evaluators, &mut self.data, scalars); + } + + fn write_to_prover(&self, prover_state: &mut ProverState) + where + H: DuplexSpongeInterface, + R: RngCore + CryptoRng, + F: ProverMessage<[H::U]>, + { + prover_state.prover_messages(&self.data); + } + + fn mixed_extend, T: Field>( + &self, + embedding: &M, + point: &[M::Target], + ) -> M::Target { + mixed_multilinear_extend(embedding, &self.data, point) + } + + fn mixed_dot, T: Field>( + &self, + embedding: &M, + other: &Self::Buffer, + ) -> M::Target { + mixed_dot(embedding, other.as_slice(), self.as_slice()) + } + + fn dot(&self, other: &Self) -> F { + self.mixed_dot(&Identity::new(), other) + } + + fn mixed_univariate_evaluate>( + &self, + embedding: &M, + point: M::Target, + ) -> M::Target { + mixed_univariate_evaluate(embedding, &self.data, point) + } + + fn mixed_linear_combination>( + embedding: &M, + vectors: &[&Self], + coeffs: &[M::Target], + ) -> Self::Buffer { + assert_eq!(vectors.len(), coeffs.len()); + let Some((first, vectors)) = vectors.split_first() else { + return CpuBuffer::from_vec(Vec::new()); + }; + debug_assert_eq!(coeffs[0], M::Target::ONE); + let mut accumulator = crate::algebra::lift(embedding, first.as_slice()); + for (coeff, vector) in coeffs[1..].iter().zip(vectors) { + mixed_scalar_mul_add(embedding, &mut accumulator, *coeff, vector.as_slice()); + } + CpuBuffer::from_vec(accumulator) + } + + fn mixed_scalar_mul_add_to>( + &self, + embedding: &M, + accumulator: &mut Self::Buffer, + weight: M::Target, + ) { + mixed_scalar_mul_add(embedding, &mut accumulator.data, weight, self.as_slice()); + } + + fn mixed_dot_slice>( + &self, + embedding: &M, + other: &[M::Target], + ) -> M::Target { + mixed_dot(embedding, other, self.as_slice()) + } + + fn interleaved_rs_encode( + vectors: &[&Self], + masks: &Self, + message_length: usize, + interleaving_depth: usize, + codeword_length: usize, + ) -> Self::Matrix + where + F: 'static, + { + let vectors = vectors.iter().map(|v| v.as_slice()).collect::>(); + interleaved_rs_encode_slices( + &vectors, + masks.as_slice(), + message_length, + interleaving_depth, + codeword_length, + ) + } +} + +fn interleaved_rs_encode_slices( + vectors: &[&[F]], + masks: &[F], + message_length: usize, + interleaving_depth: usize, + codeword_length: usize, +) -> CpuMatrix { + let messages = vectors + .iter() + .flat_map(|v| chunks_exact_or_empty(v, message_length, interleaving_depth)) + .collect::>(); + CpuMatrix::from_vec( + ntt::interleaved_rs_encode(&messages, masks, codeword_length), + codeword_length, + vectors.len() * interleaving_depth, + ) +} diff --git a/src/algebra/mod.rs b/src/algebra/mod.rs index 2c7e4838..4d17dc3b 100644 --- a/src/algebra/mod.rs +++ b/src/algebra/mod.rs @@ -1,3 +1,4 @@ +pub mod buffer; pub mod embedding; pub mod fields; pub mod linear_form; diff --git a/src/bin/benchmark.rs b/src/bin/benchmark.rs index d9e79ce6..42e654b0 100644 --- a/src/bin/benchmark.rs +++ b/src/bin/benchmark.rs @@ -10,6 +10,7 @@ use clap::Parser; use serde::Serialize; use whir::{ algebra::{ + buffer::CpuBuffer, embedding::{Basefield, Embedding, Identity}, fields::{Field128, Field192, Field256, Field64, Field64_2, Field64_3}, linear_form::{Evaluate, LinearForm, MultilinearExtension}, @@ -161,12 +162,13 @@ where HASH_COUNTER.reset(); - let witness = params.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(vector.as_slice())], - vec![Cow::Owned(witness)], + &[&vector_buffer], + vec![&witness], vec![], Cow::Owned(vec![]), ); @@ -238,7 +240,8 @@ where HASH_COUNTER.reset(); let whir_prover_time = Instant::now(); - let witness = params.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let prove_linear_forms: Vec>> = points .iter() @@ -249,8 +252,8 @@ where let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(vector.as_slice())], - vec![Cow::Owned(witness)], + &[&vector_buffer], + vec![&witness], prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); diff --git a/src/bin/main.rs b/src/bin/main.rs index bafe16f5..c63ada92 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -6,6 +6,7 @@ use ark_std::rand::distributions::{Distribution, Standard}; use clap::Parser; use whir::{ algebra::{ + buffer::CpuBuffer, embedding::{Basefield, Embedding, Identity}, fields::{Field128, Field192, Field256, Field64, Field64_2, Field64_3}, linear_form::{Covector, Evaluate, LinearForm, MultilinearExtension}, @@ -148,9 +149,10 @@ where } let vector = (0..num_coeffs).map(M::Source::from).collect::>(); + let vector_buffer = CpuBuffer::from_slice(&vector); let whir_commit_time = Instant::now(); - let witness = params.commit(&mut prover_state, &[&vector]); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let whir_commit_time = whir_commit_time.elapsed(); // Allocate constraints @@ -183,8 +185,8 @@ where let whir_prove_time = Instant::now(); let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(vector.as_slice())], - vec![Cow::Owned(witness)], + &[&vector_buffer], + vec![&witness], prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -314,14 +316,15 @@ where } let whir_commit_time = Instant::now(); - let witness = params.commit(&mut prover_state, &[vector.as_slice()]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let whir_commit_time = whir_commit_time.elapsed(); let whir_prove_time = Instant::now(); let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(&vector)], - witness, + &[&vector_buffer], + vec![&witness], prove_linear_forms, Cow::Borrowed(&evaluations), ); diff --git a/src/protocols/basecase.rs b/src/protocols/basecase.rs index 32763271..d1e7e73b 100644 --- a/src/protocols/basecase.rs +++ b/src/protocols/basecase.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + //! Base Case Linear Opening Protocol //! //! It support honest verifier zero-knowledge (HVZK), but is not succinct. @@ -11,8 +13,8 @@ use spongefish::{Decoding, VerificationResult}; use crate::{ algebra::{ - dot, embedding::Identity, multilinear_extend, random_vector, scalar_mul_add_new, - univariate_evaluate, + buffer::CpuBuffer, dot, embedding::Identity, multilinear_extend, random_vector, + scalar_mul_add_new, univariate_evaluate, }, hash::Hash, protocols::{irs_commit, sumcheck}, @@ -49,9 +51,9 @@ impl Config { pub fn prove( &self, prover_state: &mut ProverState, - mut vector: Vec, + vector: Vec, witness: &irs_commit::Witness, - mut covector: Vec, + covector: Vec, mut sum: F, ) -> Opening where @@ -79,24 +81,33 @@ impl Config { // Even more trivial non-zk protocol: send f and r directly. if !self.masked { prover_state.prover_messages(&vector); - prover_state.prover_messages(&witness.masks); + prover_state.prover_messages(witness.masks.as_slice()); let _ = self.commit.open(prover_state, &[witness]); + let mut vector_buffer = CpuBuffer::from_vec(vector); + let mut covector_buffer = CpuBuffer::from_vec(covector); let point = self .sumcheck - .prove(prover_state, &mut vector, &mut covector, &mut sum, &[]) + .prove( + prover_state, + &mut vector_buffer, + &mut covector_buffer, + &mut sum, + &[], + ) .round_challenges; - assert!(!vector[0].is_zero(), "Proof failed"); + assert!(!vector_buffer.as_slice()[0].is_zero(), "Proof failed"); return Opening { evaluation_points: point, - linear_form_evaluation: covector[0], + linear_form_evaluation: covector_buffer.as_slice()[0], }; } // Create masking vector. let mask = random_vector(prover_state.rng(), vector.len()); + let mask_buffer = CpuBuffer::from_slice(&mask); // Commit to the masking vector. - let mask_witness = self.commit.commit(prover_state, &[&mask]); + let mask_witness = self.commit.commit(prover_state, &[&mask_buffer]); // Compute and send linear form of mask (μ' in paper). let mask_sum = dot(&mask, &covector); @@ -105,11 +116,15 @@ impl Config { // RLC the mask with the vector let mask_rlc = prover_state.verifier_message::(); assert!(!mask_rlc.is_zero(), "Proof failed"); - let mut masked_vector = scalar_mul_add_new(&mask, mask_rlc, &vector); + let masked_vector = scalar_mul_add_new(&mask, mask_rlc, &vector); prover_state.prover_messages(&masked_vector); // Send combined IRS randomness. (r^* in paper) - let masked_masks = scalar_mul_add_new(&mask_witness.masks, mask_rlc, &witness.masks); + let masked_masks = scalar_mul_add_new( + mask_witness.masks.as_slice(), + mask_rlc, + witness.masks.as_slice(), + ); prover_state.prover_messages(&masked_masks); // Open the commitment and mask simultaneously. @@ -117,12 +132,14 @@ impl Config { // Run sumcheck to reduce linear form claim let mut masked_sum = mask_sum + mask_rlc * sum; + let mut masked_vector_buffer = CpuBuffer::from_vec(masked_vector); + let mut covector_buffer = CpuBuffer::from_vec(covector); let point = self .sumcheck .prove( prover_state, - &mut masked_vector, - &mut covector, + &mut masked_vector_buffer, + &mut covector_buffer, &mut masked_sum, &[], ) @@ -132,11 +149,14 @@ impl Config { // Basically the sumcheck equation has degenerated to 0 * l(r) = 0, which provides // no constraints on l(r) that the verifier can return. // This event is cryptographically unlikely as `F` is challenge sized. - assert!(!masked_vector[0].is_zero(), "Proof failed"); + assert!( + !masked_vector_buffer.as_slice()[0].is_zero(), + "Proof failed" + ); Opening { evaluation_points: point, - linear_form_evaluation: covector[0], + linear_form_evaluation: covector_buffer.as_slice()[0], } } @@ -284,7 +304,8 @@ mod tests { // Prover let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = config.commit.commit(&mut prover_state, &[&vector_buffer]); let prover_result = config.prove( &mut prover_state, vector.clone(), diff --git a/src/protocols/code_switch.rs b/src/protocols/code_switch.rs index 62bbb602..aa2a6370 100644 --- a/src/protocols/code_switch.rs +++ b/src/protocols/code_switch.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + //! Code-switching IOR: R_{C, C_zk, sl} → R_{C', C_zk, sl'} //! //! Reduces a proximity claim about oracle f (source code C) to a proximity @@ -13,6 +15,7 @@ use tracing::instrument; use crate::{ algebra::{ + buffer::CpuBuffer, dot, embedding::{Embedding, Identity}, eq_weights, geometric_accumulate, lift, mixed_dot, scalar_mul, univariate_evaluate, @@ -195,7 +198,8 @@ impl Config { }; // Step 1: g := Enc_{C'}(f, r') — Construction 9.7 Step 1, p.55 - let target_witness = self.target.commit(prover_state, &[&message]); + let message_buffer = CpuBuffer::from_slice(&message); + let target_witness = self.target.commit(prover_state, &[&message_buffer]); // Step 2-3: OOD challenge + answers — Construction 9.7 Steps 2-3, p.55 // y := ze_ood(ρ) · [f; r; s] = f(α) + α^ℓ · (r,s)(α) @@ -484,7 +488,7 @@ mod tests { } // Lift ι parallel masks (total length source.mask_length × ι) and fold // chunks of length source.mask_length down to a single chunk. - let raw = lift(config.source.embedding(), &source_witness.masks); + let raw = lift(config.source.embedding(), source_witness.masks.as_slice()); let mut mask = fold_chunks(&raw, config.source.mask_length, folding_randomness); // Append fresh padding s of length message_mask_length - source.mask_length. mask.extend(random_vector::( @@ -521,7 +525,8 @@ mod tests { .session(&format!("Test at {}:{}", file!(), line!())) .instance(&instance); let mut prover_state = ProverState::new_std(&ds); - let source_witness = config.source.commit(&mut prover_state, &[&f_full]); + let f_full_buffer = CpuBuffer::from_slice(&f_full); + let source_witness = config.source.commit(&mut prover_state, &[&f_full_buffer]); // Sample γ for sumcheck folding (length log2(ι)). let folding_randomness = sample_folding_randomness(config, &mut rng); @@ -574,7 +579,8 @@ mod tests { .session(&format!("Test at {}:{}", file!(), line!())) .instance(&instance); let mut prover_state = ProverState::new_std(&ds); - let source_witness = config.source.commit(&mut prover_state, &[&f_full]); + let f_full_buffer = CpuBuffer::from_slice(&f_full); + let source_witness = config.source.commit(&mut prover_state, &[&f_full_buffer]); let folding_randomness = sample_folding_randomness(config, &mut rng); let folded_message = @@ -642,7 +648,8 @@ mod tests { // Commit honest f_full, fold to get the honest post-fold message. let mut prover_state = ProverState::new_std(&ds); - let source_witness = config.source.commit(&mut prover_state, &[&f_full]); + let f_full_buffer = CpuBuffer::from_slice(&f_full); + let source_witness = config.source.commit(&mut prover_state, &[&f_full_buffer]); let folding_randomness = sample_folding_randomness(config, &mut rng); let folded_message = fold_chunks(&f_full, config.source.message_length(), &folding_randomness); diff --git a/src/protocols/irs_commit.rs b/src/protocols/irs_commit.rs index b5716f3a..1d9e20a8 100644 --- a/src/protocols/irs_commit.rs +++ b/src/protocols/irs_commit.rs @@ -20,6 +20,7 @@ use std::{ f64::{self, consts::LOG2_10}, fmt, + marker::PhantomData, ops::Neg, }; @@ -32,8 +33,13 @@ use tracing::instrument; use crate::{ algebra::{ - dot, embedding::Embedding, fields::FieldWithSize, lift, linear_form::UnivariateEvaluation, - mixed_univariate_evaluate, ntt, random_vector, + buffer::{BufferOps, CpuBuffer, CpuMatrix, MatrixBufferOps}, + dot, + embedding::Embedding, + fields::FieldWithSize, + lift, + linear_form::UnivariateEvaluation, + ntt, }, engines::EngineId, hash::Hash, @@ -43,7 +49,7 @@ use crate::{ VerifierMessage, VerifierState, }, type_info::Typed, - utils::{chunks_exact_or_empty, zip_strict}, + utils::zip_strict, verify, }; @@ -92,14 +98,17 @@ pub struct Config { #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] #[must_use] -pub struct Witness +pub struct Witness, Matrix = CpuMatrix> where G: Field, + Masks: BufferOps, + Matrix: MatrixBufferOps, { - pub masks: Vec, - pub matrix: Vec, - pub matrix_witness: matrix_commit::Witness, + pub masks: Masks, + pub matrix: Matrix, + pub matrix_witness: Matrix::Witness, pub out_of_domain: Evaluations, + source_field: PhantomData, } #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Default, Serialize, Deserialize)] @@ -289,12 +298,15 @@ impl Config { /// Commit to one or more vectors. #[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))] - pub fn commit( + pub fn commit( &self, prover_state: &mut ProverState, - vectors: &[&[M::Source]], - ) -> Witness + vectors: &[&B], + ) -> Witness where + B: BufferOps, + >::Witness: + Clone + PartialEq + Eq + PartialOrd + Ord + fmt::Debug + std::hash::Hash + Default, Standard: Distribution, H: DuplexSpongeInterface, R: RngCore + CryptoRng, @@ -310,18 +322,17 @@ impl Config { assert_eq!(vectors.len(), self.num_vectors); assert!(vectors.iter().all(|p| p.len() == self.vector_size)); - // Generate random mask - let masks = random_vector(prover_state.rng(), self.mask_length * self.num_messages()); - - // Interleaved RS Encode the vectors - let messages = vectors - .iter() - .flat_map(|v| chunks_exact_or_empty(v, self.message_length(), self.interleaving_depth)) - .collect::>(); - let matrix = ntt::interleaved_rs_encode(&messages, &masks, self.codeword_length); + let masks = B::random(prover_state.rng(), self.mask_length * self.num_messages()); + let matrix = B::interleaved_rs_encode( + vectors, + &masks, + self.message_length(), + self.interleaving_depth, + self.codeword_length, + ); // Commit to the matrix - let matrix_witness = self.matrix_commit.commit(prover_state, &matrix); + let matrix_witness = matrix.commit_rows(&self.matrix_commit, prover_state); // Handle out-of-domain points and values // TODO : Remove this logic after main whir protocol is updated @@ -332,7 +343,7 @@ impl Config { let mut oods_matrix = Vec::with_capacity(self.out_domain_samples * self.num_vectors); for &point in &oods_points { for &vector in vectors { - let value = mixed_univariate_evaluate(&*self.embedding, vector, point); + let value = vector.mixed_univariate_evaluate(&*self.embedding, point); prover_state.prover_message(&value); oods_matrix.push(value); } @@ -346,6 +357,7 @@ impl Config { points: oods_points, matrix: oods_matrix, }, + source_field: PhantomData, } } @@ -382,12 +394,14 @@ impl Config { /// When there are multiple openings, the evaluation matrices will /// be horizontally concatenated. #[cfg_attr(feature = "tracing", instrument(skip_all, fields(self = %self)))] - pub fn open( + pub fn open( &self, prover_state: &mut ProverState, - witnesses: &[&Witness], + witnesses: &[&Witness], ) -> Evaluations where + Masks: BufferOps, + Matrix: MatrixBufferOps, H: DuplexSpongeInterface, R: RngCore + CryptoRng, u8: Decoding<[H::U]>, @@ -409,22 +423,23 @@ impl Config { // and collect them in the evaluation matrix. let stride = witnesses.len() * self.num_cols(); let mut matrix = vec![M::Source::ZERO; indices.len() * stride]; - let mut submatrix = Vec::with_capacity(indices.len() * self.num_cols()); let mut matrix_col_offset = 0; for witness in witnesses { - submatrix.clear(); - for (point_index, &code_index) in indices.iter().enumerate() { - let row = &witness.matrix - [code_index * self.num_cols()..(code_index + 1) * self.num_cols()]; - submatrix.extend_from_slice(row); - - let matrix_row = &mut matrix[point_index * stride..(point_index + 1) * stride]; - matrix_row[matrix_col_offset..matrix_col_offset + self.num_cols()] - .copy_from_slice(row); + let submatrix = witness.matrix.read_rows(&indices); + if self.num_cols() != 0 { + for (point_index, row) in submatrix.chunks_exact(self.num_cols()).enumerate() { + let matrix_row = &mut matrix[point_index * stride..(point_index + 1) * stride]; + matrix_row[matrix_col_offset..matrix_col_offset + self.num_cols()] + .copy_from_slice(row); + } } prover_state.prover_hint_ark(&submatrix); - self.matrix_commit - .open(prover_state, &witness.matrix_witness, &indices); + witness.matrix.open_rows( + &self.matrix_commit, + prover_state, + &witness.matrix_witness, + &indices, + ); matrix_col_offset += self.num_cols(); } @@ -511,7 +526,13 @@ impl Commitment { } } -impl Witness { +impl Witness +where + F: Field, + G: Field, + Masks: BufferOps, + Matrix: MatrixBufferOps, +{ /// Returns the out-of-domain evaluations. pub const fn out_of_domain(&self) -> &Evaluations { &self.out_of_domain @@ -642,7 +663,7 @@ pub(crate) mod tests { use crate::{ algebra::{ embedding::{Basefield, Compose, Frobenius, Identity}, - fields, + fields, mixed_univariate_evaluate, ntt::NTT, random_vector, univariate_evaluate, }, @@ -728,10 +749,12 @@ pub(crate) mod tests { // Prover let mut prover_state = ProverState::new_std(&ds); - let witness = config.commit( - &mut prover_state, - &vectors.iter().map(|p| p.as_slice()).collect::>(), - ); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let vector_refs = vector_buffers.iter().collect::>(); + let witness = config.commit(&mut prover_state, &vector_refs); assert_eq!( witness.out_of_domain().points.len(), config.out_domain_samples diff --git a/src/protocols/mask_proximity.rs b/src/protocols/mask_proximity.rs index 1d632ca0..7413158c 100644 --- a/src/protocols/mask_proximity.rs +++ b/src/protocols/mask_proximity.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + //! Mask proximity verification via γ-combination. //! //! Implements Construction 7.2 (p.43-44) specialized for zero-constraint mask @@ -45,7 +47,10 @@ use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, R use serde::{Deserialize, Serialize}; use crate::{ - algebra::{embedding::Identity, random_vector, scalar_mul_add_new, univariate_evaluate}, + algebra::{ + buffer::CpuBuffer, embedding::Identity, random_vector, scalar_mul_add_new, + univariate_evaluate, + }, hash::Hash, protocols::irs_commit::{ Commitment as IrsCommitment, Config as IrsConfig, Witness as IrsWitness, @@ -129,11 +134,19 @@ impl Config { .map(|_| random_vector(prover_state.rng(), self.c_zk_commit.vector_size)) .collect(); + let original_buffers = original_msgs + .iter() + .map(|msg| CpuBuffer::from_slice(msg)) + .collect::>(); + let fresh_buffers = fresh_msgs + .iter() + .map(|msg| CpuBuffer::from_slice(msg)) + .collect::>(); + // Tree layout: [originals..., freshes...] - let all_vectors: Vec<&[F]> = original_msgs + let all_vectors: Vec<&CpuBuffer> = original_buffers .iter() - .chain(fresh_msgs.iter()) - .map(|v| v.as_slice()) + .chain(fresh_buffers.iter()) .collect(); let mask_witness = self.c_zk_commit.commit(prover_state, &all_vectors); @@ -180,10 +193,8 @@ impl Config { // Step 2: compute and send combined polynomials + IRS randomness let irs_masks_per_vector = self.c_zk_commit.mask_length * self.c_zk_commit.interleaving_depth; - assert_eq!( - witness.mask_witness.masks.len(), - 2 * self.num_masks * irs_masks_per_vector - ); + let irs_masks = witness.mask_witness.masks.as_slice(); + assert_eq!(irs_masks.len(), 2 * self.num_masks * irs_masks_per_vector); for (i, (orig_msg, fresh_msg)) in original_msgs .iter() .zip(witness.fresh_msgs.iter()) @@ -195,10 +206,8 @@ impl Config { // r*_i = r'_i + γ · r_i if irs_masks_per_vector > 0 { - let orig_r = &witness.mask_witness.masks - [i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; - let fresh_r = &witness.mask_witness.masks[(self.num_masks + i) - * irs_masks_per_vector + let orig_r = &irs_masks[i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; + let fresh_r = &irs_masks[(self.num_masks + i) * irs_masks_per_vector ..(self.num_masks + i + 1) * irs_masks_per_vector]; let combined_r = scalar_mul_add_new(fresh_r, gamma, orig_r); prover_state.prover_messages(&combined_r); @@ -455,6 +464,7 @@ mod tests { let gamma: F = prover_state.verifier_message(); let irs_masks_per_vector = config.c_zk_commit.mask_length * config.c_zk_commit.interleaving_depth; + let irs_masks = witness.mask_witness.masks.as_slice(); for (i, (orig_msg, fresh_msg)) in original_msgs .iter() @@ -468,10 +478,8 @@ mod tests { prover_state.prover_messages(&combined_msg); if irs_masks_per_vector > 0 { - let orig_r = &witness.mask_witness.masks - [i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; - let fresh_r = &witness.mask_witness.masks[(config.num_masks + i) - * irs_masks_per_vector + let orig_r = &irs_masks[i * irs_masks_per_vector..(i + 1) * irs_masks_per_vector]; + let fresh_r = &irs_masks[(config.num_masks + i) * irs_masks_per_vector ..(config.num_masks + i + 1) * irs_masks_per_vector]; let combined_r = scalar_mul_add_new(fresh_r, gamma, orig_r); prover_state.prover_messages(&combined_r); diff --git a/src/protocols/sumcheck.rs b/src/protocols/sumcheck.rs index 5018e22d..13c0c255 100644 --- a/src/protocols/sumcheck.rs +++ b/src/protocols/sumcheck.rs @@ -9,11 +9,7 @@ use serde::{Deserialize, Serialize}; use tracing::instrument; use crate::{ - algebra::{ - dot, - sumcheck::{compute_sumcheck_polynomial, fold, fold_and_compute_polynomial}, - univariate_evaluate, - }, + algebra::{buffer::BufferOps, univariate_evaluate}, protocols::proof_of_work, transcript::{ codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverState, VerificationResult, @@ -66,15 +62,16 @@ impl Config { /// - Applies proof-of-work grinding if required. /// - Returns the sampled folding randomness values used in each reduction step. #[cfg_attr(feature = "tracing", instrument(skip_all))] - pub fn prove( + pub fn prove( &self, prover_state: &mut ProverState, - a: &mut Vec, - b: &mut Vec, + a: &mut B, + b: &mut B, sum: &mut F, masks: &[F], ) -> SumcheckOpening where + B: BufferOps, H: DuplexSpongeInterface, R: CryptoRng + RngCore, F: Codec<[H::U]>, @@ -88,7 +85,7 @@ impl Config { assert!(self.mask_length == 0 || self.mask_length >= 3); assert_eq!(a.len(), self.initial_size); assert_eq!(b.len(), self.initial_size); - debug_assert_eq!(dot(a, b), *sum); + debug_assert_eq!(a.dot(b), *sum); assert_eq!(masks.len(), self.num_rounds * self.mask_length); let half = F::from(2).inverse().unwrap(); @@ -110,9 +107,9 @@ impl Config { { // Fold and compute sumcheck polynomial in one pass. let (c0, c2) = if let Some(w) = prev_round_challenge { - fold_and_compute_polynomial(a, b, w) + a.fold_and_sumcheck_polynomial(b, w) } else { - compute_sumcheck_polynomial(a, b) + a.sumcheck_polynomial(b) }; let c1 = *sum - c0.double() - c2; @@ -150,8 +147,8 @@ impl Config { } if let Some(w) = prev_round_challenge { // Final fold of the inputs (no polynomial computation) - fold(a, w); - fold(b, w); + a.fold(w); + b.fold(w); } *sum = mask_sum + mask_rlc * *sum; @@ -238,23 +235,24 @@ fn eval_01(coefficients: &[F]) -> F { #[cfg(test)] mod tests { - use ark_std::rand::{ - distributions::{Distribution, Standard}, - rngs::StdRng, - SeedableRng, - }; - use proptest::{prelude::Just, prop_oneof, proptest, strategy::Strategy}; - #[cfg(feature = "tracing")] - use tracing::instrument; - use super::*; use crate::{ algebra::{ + buffer::CpuBuffer, + dot, fields::{self, Field64}, multilinear_extend, random_vector, }, transcript::DomainSeparator, }; + use ark_std::rand::{ + distributions::{Distribution, Standard}, + rngs::StdRng, + SeedableRng, + }; + use proptest::{prelude::Just, prop_oneof, proptest, strategy::Strategy}; + #[cfg(feature = "tracing")] + use tracing::instrument; impl Config where @@ -299,8 +297,8 @@ mod tests { let masks = random_vector(&mut rng, config.mask_length * config.num_rounds); // Prover - let mut vector = initial_vector.clone(); - let mut covector = initial_covector.clone(); + let mut vector = CpuBuffer::from_slice(&initial_vector); + let mut covector = CpuBuffer::from_slice(&initial_covector); let mut sum = initial_sum; let mut prover_state = ProverState::new_std(&ds); let SumcheckOpening { @@ -316,8 +314,14 @@ mod tests { assert_eq!(vector.len(), config.final_size()); assert_eq!(covector.len(), config.final_size()); if config.final_size() == 1 { - assert_eq!(multilinear_extend(&initial_vector, &point), vector[0]); - assert_eq!(multilinear_extend(&initial_covector, &point), covector[0]); + assert_eq!( + multilinear_extend(&initial_vector, &point), + vector.as_slice()[0] + ); + assert_eq!( + multilinear_extend(&initial_covector, &point), + covector.as_slice()[0] + ); } else { // TODO: Check correct folding. } @@ -327,7 +331,10 @@ mod tests { .zip(&point) .map(|(m, x)| univariate_evaluate(m, *x)) .sum(); - assert_eq!(sum, expected_mask_sum + mask_rlc * dot(&vector, &covector)); + assert_eq!( + sum, + expected_mask_sum + mask_rlc * dot(vector.as_slice(), covector.as_slice()) + ); let proof = prover_state.proof(); diff --git a/src/protocols/whir/mod.rs b/src/protocols/whir/mod.rs index 805a206e..0f20a14b 100644 --- a/src/protocols/whir/mod.rs +++ b/src/protocols/whir/mod.rs @@ -14,6 +14,7 @@ use tracing::instrument; use crate::{ algebra::{ + buffer::{BufferOps, CpuBuffer}, embedding::{Embedding, Identity}, linear_form::LinearForm, }, @@ -45,7 +46,13 @@ pub struct RoundConfig { pub pow: proof_of_work::Config, } -pub type Witness> = irs_commit::Witness; +pub type Witness, B = CpuBuffer<::Source>> = + irs_commit::Witness< + ::Source, + F, + B, + ::Source>>::Matrix, + >; pub type Commitment = irs_commit::Commitment; #[must_use = "The final claim must be checked if there where any linear forms."] @@ -75,13 +82,19 @@ impl FinalClaim { impl Config { /// Commit to one or more vectors. - #[cfg_attr(feature = "tracing", instrument(skip_all, fields(size = vectors.first().unwrap().len())))] - pub fn commit( + #[cfg_attr( + feature = "tracing", + instrument(skip_all, fields(size = vectors.first().unwrap().len())) + )] + pub fn commit( &self, prover_state: &mut ProverState, - vectors: &[&[M::Source]], - ) -> Witness + vectors: &[&B], + ) -> Witness where + B: BufferOps, + >::Witness: + Clone + PartialEq + Eq + PartialOrd + Ord + Debug + std::hash::Hash + Default, Standard: Distribution, H: DuplexSpongeInterface, R: RngCore + CryptoRng, @@ -128,6 +141,7 @@ mod tests { use super::*; use crate::{ algebra::{ + buffer::CpuBuffer, embedding::Basefield, fields::{Field64, Field64_3}, linear_form::{Covector, Evaluate, LinearForm, MultilinearExtension}, @@ -239,15 +253,16 @@ mod tests { let mut prover_state = ProverState::new_std(&ds); // Commit to the polynomial and generate auxiliary witness data - let witness = params.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let prove_linear_forms = build_prove_forms(&points, num_variables, true); // Generate a proof for the given statement and witness let _ = params.prove( &mut prover_state, - vec![Cow::from(vector)], - vec![Cow::Owned(witness)], + &[&vector_buffer], + vec![&witness], prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -380,7 +395,7 @@ mod tests { vec![F::from((i + 1) as u64); num_coeffs] }) .collect(); - let vec_refs = vectors.iter().map(|v| v.as_slice()).collect::>(); + let vec_refs = vectors.iter().collect::>(); let points: Vec<_> = (0..num_points_per_poly) .map(|_| random_vector(thread_rng(), num_variables)) @@ -401,7 +416,7 @@ mod tests { .flat_map(|linear_form| { vec_refs .iter() - .map(|vec| linear_form.evaluate(params.embedding(), vec)) + .map(|&vec| linear_form.evaluate(params.embedding(), vec)) }) .collect::>(); @@ -412,8 +427,12 @@ mod tests { let mut prover_state = ProverState::new_std(&ds); // Commit to each polynomial and generate witnesses + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); let mut witnesses = Vec::new(); - for &vec in &vec_refs { + for vec in &vector_buffers { let witness = params.commit(&mut prover_state, &[vec]); witnesses.push(witness); } @@ -423,11 +442,8 @@ mod tests { // Batch prove all polynomials together let _ = params.prove( &mut prover_state, - vectors - .iter() - .map(|v| Cow::Borrowed(v.as_slice())) - .collect(), - witnesses.into_iter().map(Cow::Owned).collect(), + &vector_buffers.iter().collect::>(), + witnesses.iter().collect(), prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -569,16 +585,19 @@ mod tests { .instance(&Empty); let mut prover_state = ProverState::new_std(&ds); - let witness1 = params.commit(&mut prover_state, &[&vec1]); - let witness2 = params.commit(&mut prover_state, &[&vec2]); + let vec1_buffer = CpuBuffer::from_slice(&vec1); + let vec2_buffer = CpuBuffer::from_slice(&vec2); + let witness1 = params.commit(&mut prover_state, &[&vec1_buffer]); + let witness2 = params.commit(&mut prover_state, &[&vec2_buffer]); let prove_linear_forms = build_prove_forms(&constraint_points, num_variables, false); // Generate proof with mismatched polynomials + let vec_wrong_buffer = CpuBuffer::from_vec(vec_wrong); let _ = params.prove( &mut prover_state, - vec![Cow::Borrowed(vec1.as_slice()), Cow::from(vec_wrong)], - vec![Cow::Owned(witness1), Cow::Owned(witness2)], + &[&vec1_buffer, &vec_wrong_buffer], + vec![&witness1, &witness2], prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -651,7 +670,7 @@ mod tests { let all_vectors: Vec> = (0..num_witnesses * batch_size) .map(|i| vec![F::from((i + 1) as u64); num_coeffs]) .collect::>(); - let vec_refs = all_vectors.iter().map(|p| p.as_slice()).collect::>(); + let vec_refs = all_vectors.iter().collect::>(); let points: Vec<_> = (0..num_points_per_poly) .map(|_| random_vector(thread_rng(), num_variables)) @@ -672,7 +691,7 @@ mod tests { .flat_map(|linear_form| { vec_refs .iter() - .map(|vec| linear_form.evaluate(params.embedding(), vec)) + .map(|&vec| linear_form.evaluate(params.embedding(), vec)) }) .collect::>(); @@ -683,8 +702,13 @@ mod tests { let mut prover_state = ProverState::new_std(&ds); // Commit using commit_batch (stacks batch_size polynomials per witness) + let vector_buffers = all_vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let buffer_refs = vector_buffers.iter().collect::>(); let mut witnesses = Vec::new(); - for witness_polys in vec_refs.chunks(batch_size) { + for witness_polys in buffer_refs.chunks(batch_size) { let witness = params.commit(&mut prover_state, witness_polys); witnesses.push(witness); } @@ -694,11 +718,8 @@ mod tests { // Batch prove all witnesses together let _ = params.prove( &mut prover_state, - all_vectors - .iter() - .map(|v| Cow::Borrowed(v.as_slice())) - .collect(), - witnesses.into_iter().map(Cow::Owned).collect(), + &vector_buffers.iter().collect::>(), + witnesses.iter().collect(), prove_linear_forms, Cow::Borrowed(evaluations.as_slice()), ); @@ -793,7 +814,7 @@ mod tests { let vectors: Vec> = (0..batch_size) .map(|_| random_vector(thread_rng(), num_coeffs)) .collect(); - let vec_refs = vectors.iter().map(|v| v.as_slice()).collect::>(); + let vec_refs = vectors.iter().collect::>(); // Generate `num_points` random points in the multilinear domain let points: Vec<_> = (0..num_points) @@ -809,7 +830,12 @@ mod tests { let mut prover_state = ProverState::new_std(&ds); // Create a commitment to the polynomial and generate auxiliary witness data - let batched_witness = params.commit(&mut prover_state, &vec_refs); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let buffer_refs = vector_buffers.iter().collect::>(); + let batched_witness = params.commit(&mut prover_state, &buffer_refs); // Create a weights matrix and evaluations for each polynomial let mut linear_forms: Vec>>> = Vec::new(); @@ -839,11 +865,8 @@ mod tests { .collect::>(); let _ = params.prove( &mut prover_state, - vectors - .iter() - .map(|v| Cow::Borrowed(v.as_slice())) - .collect(), - vec![Cow::Owned(batched_witness)], + &buffer_refs, + vec![&batched_witness], prove_linear_forms, Cow::Borrowed(values.as_slice()), ); diff --git a/src/protocols/whir/prover.rs b/src/protocols/whir/prover.rs index 04b4473e..01c87f3c 100644 --- a/src/protocols/whir/prover.rs +++ b/src/protocols/whir/prover.rs @@ -1,4 +1,4 @@ -use std::{any::Any, borrow::Cow, mem}; +use std::{borrow::Cow, fmt::Debug}; use ark_ff::{AdditiveGroup, Field}; use ark_std::rand::{distributions::Standard, prelude::Distribution, CryptoRng, RngCore}; @@ -8,12 +8,11 @@ use tracing::instrument; use super::{Config, Witness}; use crate::{ algebra::{ + buffer::{BufferOps, MatrixBufferOps}, dot, embedding::Embedding, - eq_weights, lift, - linear_form::{Covector, Evaluate, LinearForm, UnivariateEvaluation}, - mixed_scalar_mul_add, - sumcheck::fold, + eq_weights, + linear_form::LinearForm, tensor_product, }, hash::Hash, @@ -25,9 +24,15 @@ use crate::{ utils::zip_strict, }; -enum RoundWitness<'a, F: Field, M: Embedding> { - Initial(Vec>>), - Round(irs_commit::Witness), +enum RoundWitness<'a, F, M, B> +where + F: Field, + M: Embedding, + B: BufferOps, + B::Buffer: BufferOps, +{ + Initial(Vec<&'a irs_commit::Witness>), + Round(irs_commit::Witness, as BufferOps>::Matrix>), } impl Config { @@ -47,15 +52,29 @@ impl Config { /// #[cfg_attr(feature = "tracing", instrument(skip_all))] #[allow(clippy::too_many_lines, clippy::cognitive_complexity)] - pub fn prove<'a, H, R>( + pub fn prove<'a, H, R, B>( &self, prover_state: &mut ProverState, - vectors: Vec>, - witnesses: Vec>>, + vectors: &[&B], + witnesses: Vec<&'a Witness>, linear_forms: Vec>>, evaluations: Cow<'a, [M::Target]>, ) -> FinalClaim where + B: BufferOps, + B::Buffer: BufferOps, + >::Witness: + Clone + PartialEq + Eq + PartialOrd + Ord + Debug + std::hash::Hash + Default, + < as BufferOps>::Matrix as MatrixBufferOps< + M::Target, + >>::Witness: Clone + + PartialEq + + Eq + + PartialOrd + + Ord + + Debug + + std::hash::Hash + + Default, Standard: Distribution + Distribution, H: DuplexSpongeInterface, R: RngCore + CryptoRng, @@ -73,7 +92,7 @@ impl Config { witnesses.len() * self.initial_committer.num_vectors ); assert_eq!(evaluations.len(), num_vectors * linear_forms.len()); - for vector in &vectors { + for vector in vectors { assert_eq!(vector.len(), self.initial_size()); } for linear_form in &linear_forms { @@ -85,8 +104,11 @@ impl Config { { use crate::algebra::linear_form::Covector; let covector = Covector::from(&**linear_form); - for (vector, evaluation) in zip_strict(&vectors, evaluations) { - debug_assert_eq!(covector.evaluate(self.embedding(), vector), *evaluation); + for (vector, evaluation) in zip_strict(vectors, evaluations) { + debug_assert_eq!( + vector.mixed_dot_slice(self.embedding(), &covector.vector), + *evaluation + ); } } if vectors.is_empty() { @@ -110,12 +132,13 @@ impl Config { if j >= vector_offset && j < oods_row.len() + vector_offset { debug_assert_eq!( oods_row[j - vector_offset], - oods_eval.evaluate(self.embedding(), vector) + vector.mixed_univariate_evaluate(self.embedding(), oods_eval.point) ); oods_matrix.push(oods_row[j - vector_offset]); } else { - let eval = oods_eval.evaluate(self.embedding(), vector); + let eval = + vector.mixed_univariate_evaluate(self.embedding(), oods_eval.point); prover_state.prover_message(&eval); oods_matrix.push(eval); } @@ -130,18 +153,9 @@ impl Config { // Random linear combination of the vectors. let mut vector_rlc_coeffs: Vec = geometric_challenge(prover_state, num_vectors); assert_eq!(vector_rlc_coeffs[0], M::Target::ONE); - // Recycle the first input as the accumulator (its coefficient is always ONE). - let mut vectors = vectors.into_iter(); - let first = vectors.next().expect("non-empty"); - let mut vector = match first { - Cow::Borrowed(slice) => lift(self.embedding(), slice), - Cow::Owned(vec) => self.embedding().map_vec(vec), - }; - for (rlc_coeff, input_vector) in zip_strict(&vector_rlc_coeffs[1..], vectors) { - mixed_scalar_mul_add(self.embedding(), &mut vector, *rlc_coeff, &input_vector); - } + let mut vector = B::mixed_linear_combination(self.embedding(), vectors, &vector_rlc_coeffs); - let mut prev_witness: RoundWitness<'a, M::Target, M> = RoundWitness::Initial(witnesses); + let mut prev_witness: RoundWitness<'a, M::Target, M, B> = RoundWitness::Initial(witnesses); // Random linear combination of the constraints. let constraint_rlc_coeffs: Vec = @@ -149,26 +163,16 @@ impl Config { let has_constraints = !constraint_rlc_coeffs.is_empty(); let (initial_forms_rlc_coeffs, oods_rlc_coeffs) = constraint_rlc_coeffs.split_at(linear_forms.len()); - // Try to recycle the first linear form as Covector. - let mut covector = vec![]; let mut linear_forms = linear_forms; - if let Some((first, linear_forms)) = linear_forms.split_first_mut() { - debug_assert_eq!(initial_forms_rlc_coeffs[0], M::Target::ONE); - if let Some(covector_form) = - (first.as_mut() as &mut dyn Any).downcast_mut::>() - { - mem::swap(&mut covector, &mut covector_form.vector); - } else { - covector.resize(self.initial_size(), M::Target::ZERO); - first.accumulate(&mut covector, M::Target::ONE); - } - for (rlc_coeff, linear_form) in zip_strict(&initial_forms_rlc_coeffs[1..], linear_forms) - { - linear_form.accumulate(&mut covector, *rlc_coeff); - } - } else if has_constraints { - covector.resize(self.initial_size(), M::Target::ZERO); - } + let mut covector = if has_constraints { + as BufferOps>::linear_forms_rlc( + self.initial_size(), + &mut linear_forms, + initial_forms_rlc_coeffs, + ) + } else { + as BufferOps>::zeros(0) + }; drop(linear_forms); // Compute "The Sum" @@ -180,17 +184,17 @@ impl Config { .sum(); drop(evaluations); - debug_assert!(!has_constraints || dot(&vector, &covector) == the_sum); + debug_assert!(!has_constraints || vector.dot(&covector) == the_sum); // Add OODS constraints - UnivariateEvaluation::accumulate_many(&oods_evals, &mut covector, oods_rlc_coeffs); + covector.accumulate_univariate_evaluations(&oods_evals, oods_rlc_coeffs); the_sum += zip_strict(oods_rlc_coeffs, oods_matrix.chunks_exact(num_vectors)) .map(|(poly_coeff, row)| *poly_coeff * dot(&vector_rlc_coeffs, row)) .sum::(); drop(oods_evals); drop(oods_matrix); - debug_assert!(!has_constraints || dot(&vector, &covector) == the_sum); + debug_assert!(!has_constraints || vector.dot(&covector) == the_sum); // Run initial sumcheck on batched vectors with combined statement let mut folding_randomness = if has_constraints { @@ -207,15 +211,17 @@ impl Config { self.initial_skip_pow.prove(prover_state); // Fold vector for &f in &folding_randomness { - fold(&mut vector, f); + vector.fold(f); } // Covector must be all zeros. - covector = vec![M::Target::ZERO; self.initial_sumcheck.final_size()]; + covector = as BufferOps>::zeros( + self.initial_sumcheck.final_size(), + ); folding_randomness }; let mut evaluation_point = folding_randomness.clone(); - debug_assert_eq!(dot(&vector, &covector), the_sum); + debug_assert_eq!(vector.dot(&covector), the_sum); // Execute standard WHIR rounds on the batched vectors for (round_index, round_config) in self.round_configs.iter().enumerate() { @@ -227,12 +233,10 @@ impl Config { // Open the previous round's witness. let in_domain = match prev_witness { - RoundWitness::Initial(init_witnesses) => { - let witness_refs: Vec<&_> = init_witnesses.iter().map(|c| &**c).collect(); - self.initial_committer - .open(prover_state, &witness_refs) - .lift(self.embedding()) - } + RoundWitness::Initial(init_witnesses) => self + .initial_committer + .open(prover_state, &init_witnesses) + .lift(self.embedding()), RoundWitness::Round(old_witness) => { let prev_round_config = &self.round_configs[round_index - 1]; prev_round_config @@ -256,13 +260,9 @@ impl Config { ))) .collect::>(); let stir_rlc_coeffs = geometric_challenge(prover_state, stir_challenges.len()); - UnivariateEvaluation::accumulate_many( - &stir_challenges, - &mut covector, - &stir_rlc_coeffs, - ); + covector.accumulate_univariate_evaluations(&stir_challenges, &stir_rlc_coeffs); the_sum += dot(&stir_rlc_coeffs, &stir_evaluations); - debug_assert_eq!(dot(&vector, &covector), the_sum); + debug_assert_eq!(vector.dot(&covector), the_sum); // Run sumcheck for this round folding_randomness = round_config @@ -271,7 +271,7 @@ impl Config { .round_challenges; evaluation_point.extend(folding_randomness.iter().copied()); - debug_assert_eq!(dot(&vector, &covector), the_sum); + debug_assert_eq!(vector.dot(&covector), the_sum); prev_witness = RoundWitness::Round(new_witness); vector_rlc_coeffs = vec![M::Target::ONE]; @@ -279,9 +279,7 @@ impl Config { // Directly send the vector to the verifier. assert_eq!(vector.len(), self.final_sumcheck.initial_size); - for coeff in &vector { - prover_state.prover_message(coeff); - } + vector.write_to_prover(prover_state); // PoW self.final_pow.prove(prover_state); @@ -289,8 +287,7 @@ impl Config { // Open and consume the final previous witness. match prev_witness { RoundWitness::Initial(init_witnesses) => { - let witness_refs: Vec<&_> = init_witnesses.iter().map(|c| &**c).collect(); - let _in_domain = self.initial_committer.open(prover_state, &witness_refs); + let _in_domain = self.initial_committer.open(prover_state, &init_witnesses); } RoundWitness::Round(old_witness) => { let prev_config = self.round_configs.last().unwrap(); diff --git a/src/protocols/whir_zk/committer.rs b/src/protocols/whir_zk/committer.rs index 942b87e1..dbc875ec 100644 --- a/src/protocols/whir_zk/committer.rs +++ b/src/protocols/whir_zk/committer.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + use ark_ff::Field; use ark_std::rand::{distributions::Standard, prelude::Distribution}; #[cfg(feature = "tracing")] @@ -5,6 +7,7 @@ use tracing::instrument; use super::{utils::BlindingPolynomials, Config}; use crate::{ + algebra::buffer::CpuBuffer, hash::Hash, protocols::{irs_commit, whir}, transcript::{ @@ -43,7 +46,7 @@ impl Config { pub fn commit( &self, prover_state: &mut ProverState, - polynomials: &[&[F]], + polynomials: &[&CpuBuffer], ) -> Witness where Standard: Distribution, @@ -63,6 +66,7 @@ impl Config { let num_blinding_variables = self.num_blinding_variables(); let num_witness_variables = self.num_witness_variables(); for &poly in polynomials { + let poly = poly.as_slice(); let blinding = BlindingPolynomials::sample( prover_state.rng(), num_blinding_variables, @@ -86,9 +90,10 @@ impl Config { .zip(mask.iter().cycle()) .map(|(&coeff, &m)| coeff + m) .collect::>(); + let f_hat_buffer = CpuBuffer::from_slice(&f_hat_vec); let witness = self .blinded_commitment - .commit(prover_state, &[f_hat_vec.as_slice()]); + .commit(prover_state, &[&f_hat_buffer]); f_hat_vectors.push(f_hat_vec); f_hat_witnesses.push(witness); blinding_polynomials.push(blinding); @@ -110,10 +115,11 @@ impl Config { ); blinding_vectors.extend(layout); } - let blinding_vector_refs = blinding_vectors + let blinding_buffers = blinding_vectors .iter() - .map(Vec::as_slice) + .map(|v| CpuBuffer::from_slice(v)) .collect::>(); + let blinding_vector_refs = blinding_buffers.iter().collect::>(); let blinding_witness = self .blinding_commitment .commit(prover_state, &blinding_vector_refs); diff --git a/src/protocols/whir_zk/mod.rs b/src/protocols/whir_zk/mod.rs index b5e41c6e..493ea32e 100644 --- a/src/protocols/whir_zk/mod.rs +++ b/src/protocols/whir_zk/mod.rs @@ -1,3 +1,5 @@ +// IGNORE CHANGES TO THIS FILE - NOT FULLY PORTED TO PROPERLY USE BUFFER ABSTRACTION. + #![cfg(feature = "rs_in_order")] // TODO: Support permuted. mod committer; mod prover; @@ -249,6 +251,7 @@ mod tests { use super::*; use crate::{ algebra::{ + buffer::CpuBuffer, fields::Field64, linear_form::{Covector, Evaluate, LinearForm, MultilinearExtension}, random_vector, @@ -350,7 +353,12 @@ mod tests { .session(&tag) .instance(&Empty); let mut prover_state = ProverState::new_std(&ds); - let witness = params.commit(&mut prover_state, vectors); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let vector_refs = vector_buffers.iter().collect::>(); + let witness = params.commit(&mut prover_state, &vector_refs); let _ = params.prove( &mut prover_state, vectors @@ -442,7 +450,12 @@ mod tests { .session(&format!("zk-stage1-negative {}:{}", file!(), line!())) .instance(&Empty); let mut prover_state = ProverState::new_std(&ds); - let witness = params.commit(&mut prover_state, &vectors); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let vector_refs = vector_buffers.iter().collect::>(); + let witness = params.commit(&mut prover_state, &vector_refs); let _ = params.prove( &mut prover_state, vectors @@ -496,7 +509,12 @@ mod tests { .session(&format!("zk-stage1-tamper {}:{}", file!(), line!())) .instance(&Empty); let mut prover_state = ProverState::new_std(&ds); - let witness = params.commit(&mut prover_state, &vectors); + let vector_buffers = vectors + .iter() + .map(|v| CpuBuffer::from_slice(v)) + .collect::>(); + let vector_refs = vector_buffers.iter().collect::>(); + let witness = params.commit(&mut prover_state, &vector_refs); let _ = params.prove( &mut prover_state, vectors @@ -557,7 +575,8 @@ mod tests { let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { let mut prover_state = ProverState::new_std(&ds); - let witness = params.commit(&mut prover_state, &[&vector]); + let vector_buffer = CpuBuffer::from_slice(&vector); + let witness = params.commit(&mut prover_state, &[&vector_buffer]); let _ = params.prove( &mut prover_state, vec![Cow::Borrowed(&vector)],