Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions src/algebra/ntt/cooley_tukey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,166 @@ impl<F: Field> NttEngine<F> {
size => self.ntt_recurse(values, roots, size),
}
}

/// Output-pruned NTT (Sorensen-Burrus, radix-2 DIT).
///
/// Computes the size-`size` NTT of `values` (zero-padded to `size` if
/// shorter) and returns the outputs at positions `indices`, in input
/// order. Output `j` equals the full NTT at position `indices[j]`.
///
/// Walks the butterfly DAG backwards from `indices` to mark only the
/// cone of butterflies that contribute to the queried outputs, then
/// runs only the marked butterflies on the forward pass. Cost is
/// `O(size + indices.len() * log(size))` field operations, vs
/// `O(size * log(size))` for a full NTT.
///
/// `size` must be a power of two.
#[allow(dead_code)] // public single-shot entry; batched callers use the plan-based path
pub fn ntt_partial(&self, values: &[F], size: usize, indices: &[usize]) -> Vec<F> {
let plan = PartialNttPlan::new(size, indices);
let mut out = vec![F::ZERO; indices.len()];
self.ntt_partial_with_plan_into(values, &plan, &mut out, 1);
out
}

/// Run a pruned NTT using a precomputed plan and write outputs into
/// `out` at stride `stride` (so `out[j * stride]` holds the result for
/// `plan.indices[j]`). When `stride == 1`, output is contiguous.
///
/// Sharing a single plan across many NTTs with the same `(size, indices)`
/// avoids re-running the O(size · log size) mask construction per call.
#[allow(clippy::significant_drop_tightening)] // roots guard intentionally held across DIT stages
pub fn ntt_partial_with_plan_into(
&self,
values: &[F],
plan: &PartialNttPlan,
out: &mut [F],
stride: usize,
) {
let size = plan.size;
let indices = &plan.indices;
assert!(values.len() <= size, "input longer than NTT size");
if indices.is_empty() {
return;
}
assert!(
out.len() > (indices.len() - 1) * stride,
"output buffer too small for stride"
);
if size == 1 {
let v = values.first().copied().unwrap_or(F::ZERO);
for j in 0..indices.len() {
out[j * stride] = v;
}
return;
}

let log_n = size.trailing_zeros() as usize;
let roots = self.roots_table(size);

// Load bit-reversed input into work buffer, gated by mask[0].
let mut work = vec![F::ZERO; size];
let shift = (usize::BITS as usize) - log_n;
for (j, &c) in values.iter().enumerate() {
let rev = j.reverse_bits() >> shift;
if plan.mask[0][rev] {
work[rev] = c;
}
}

// Forward DIT, skipping butterflies with no needed outputs.
// The shared roots table may hold roots at a larger order than `size`;
// `roots[k * twiddle_step]` retrieves ω_m^k regardless.
for stage in 1..=log_n {
let m = 1usize << stage;
let half = m >> 1;
let twiddle_step = roots.len() / m;
let cur = &plan.mask[stage];
let mut base = 0;
while base < size {
for k in 0..half {
let a = base + k;
let b = a + half;
if cur[a] || cur[b] {
let w = roots[k * twiddle_step];
let t = work[b] * w;
let u = work[a];
work[a] = u + t;
work[b] = u - t;
}
}
base += m;
}
}

for (j, &i) in indices.iter().enumerate() {
out[j * stride] = work[i];
}
}
}

/// Pruning plan for an output-pruned NTT.
///
/// Holds the queried output indices and the precomputed per-stage
/// "needed-position" masks used by [`NttEngine::ntt_partial_with_plan_into`].
/// Construct once per `(size, indices)` and reuse across multiple NTTs of
/// the same shape (e.g. all polynomials in an interleaved batch).
#[derive(Debug, Clone)]
pub struct PartialNttPlan {
size: usize,
indices: Vec<usize>,
/// `mask[stage][p]` is true iff position `p` after `stage` DIT stages
/// must be correct for the final outputs. `mask[log_n]` mirrors
/// `indices`; `mask[0]` selects the bit-reversed input positions that
/// must be loaded.
mask: Vec<Vec<bool>>,
}

impl PartialNttPlan {
pub fn new(size: usize, indices: &[usize]) -> Self {
assert!(size.is_power_of_two(), "size must be a power of two");
assert!(
indices.iter().all(|&i| i < size),
"query index out of range"
);
let log_n = size.trailing_zeros() as usize;
let mut mask: Vec<Vec<bool>> = vec![vec![false; size]; log_n + 1];
for &i in indices {
mask[log_n][i] = true;
}
for stage in (1..=log_n).rev() {
let m = 1usize << stage;
let half = m >> 1;
let (lo, hi) = mask.split_at_mut(stage);
let cur = &hi[0];
let prev = &mut lo[stage - 1];
let mut base = 0;
while base < size {
for k in 0..half {
let a = base + k;
let b = a + half;
if cur[a] || cur[b] {
prev[a] = true;
prev[b] = true;
}
}
base += m;
}
}
Self {
size,
indices: indices.to_vec(),
mask,
}
}

pub const fn size(&self) -> usize {
self.size
}

pub fn indices(&self) -> &[usize] {
&self.indices
}
}

/// Applies twiddle factors to a slice of field elements in-place.
Expand Down Expand Up @@ -963,4 +1123,93 @@ mod tests {

assert_eq!(values_ntt, expected_values);
}

#[test]
fn test_ntt_partial_matches_full() {
use ark_std::{rand::Rng, UniformRand};

let engine = NttEngine::<Field64>::new_from_fftfield();
let mut rng = ark_std::test_rng();

for &size in &[4usize, 16, 64, 256, 1024, 1 << 15] {
for _ in 0..8 {
// Full NTT reference.
let coeffs: Vec<_> = (0..size).map(|_| Field64::rand(&mut rng)).collect();
let mut full = coeffs.clone();
engine.ntt_batch(&mut full, size);

// Random subset of varying size (cover dense + sparse).
let k = rng.gen_range(1..=size.min(64));
let mut perm: Vec<usize> = (0..size).collect();
for i in (1..size).rev() {
perm.swap(i, rng.gen_range(0..=i));
}
let indices: Vec<usize> = perm.into_iter().take(k).collect();

let partial = engine.ntt_partial(&coeffs, size, &indices);
assert_eq!(partial.len(), indices.len());
for (j, &idx) in indices.iter().enumerate() {
assert_eq!(partial[j], full[idx], "size={size} idx={idx}");
}
}
}
}

#[test]
fn test_ntt_partial_zero_padded_input() {
// M < N: input is zero-padded. Partial NTT must agree with full NTT
// computed over the zero-padded coefficient vector.
use ark_std::UniformRand;

let engine = NttEngine::<Field64>::new_from_fftfield();
let mut rng = ark_std::test_rng();

for (m, size) in [(1usize, 4), (4, 16), (256, 1024), (1 << 13, 1 << 15)] {
let coeffs: Vec<_> = (0..m).map(|_| Field64::rand(&mut rng)).collect();
let mut padded = coeffs.clone();
padded.resize(size, Field64::ZERO);
engine.ntt_batch(&mut padded, size);

let stride = (size / 8).max(1);
let indices: Vec<usize> = (0..size).step_by(stride).take(8).collect();
let partial = engine.ntt_partial(&coeffs, size, &indices);
for (j, &idx) in indices.iter().enumerate() {
assert_eq!(partial[j], padded[idx], "m={m} size={size} idx={idx}");
}
}
}

#[test]
fn test_ntt_partial_edge_cases() {
use ark_std::UniformRand;

let engine = NttEngine::<Field64>::new_from_fftfield();
let mut rng = ark_std::test_rng();

// Empty index set.
let coeffs: Vec<_> = (0..16).map(|_| Field64::rand(&mut rng)).collect();
let out = engine.ntt_partial(&coeffs, 16, &[]);
assert!(out.is_empty());

// Singleton at position 0 and position N-1.
let coeffs: Vec<_> = (0..64).map(|_| Field64::rand(&mut rng)).collect();
let mut full = coeffs.clone();
engine.ntt_batch(&mut full, 64);
for idx in [0usize, 1, 31, 32, 63] {
let out = engine.ntt_partial(&coeffs, 64, &[idx]);
assert_eq!(out, vec![full[idx]], "idx={idx}");
}

// Repeated indices: each occurrence must yield the matching output.
let indices = vec![5usize, 5, 17, 5, 17];
let out = engine.ntt_partial(&coeffs, 64, &indices);
for (j, &idx) in indices.iter().enumerate() {
assert_eq!(out[j], full[idx]);
}

// size = 1: any indices must all return values[0].
let single = vec![Field64::from(42)];
let out = engine.ntt_partial(&single, 1, &[0, 0, 0]);
assert_eq!(out, vec![Field64::from(42); 3]);
}
}
Loading
Loading