diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4a72623a..f567db95 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,6 +25,29 @@ jobs: - run: npm ci + # The wasm-pack output (crates/solver/pkg/) is build-only and gitignored, + # so it must be built here before typecheck/test/build consume it. This is + # the single source of truth for the WASM binding surface — building it in + # CI is what prevents the binary from silently drifting from the Rust + # source (the previous, committed-and-stale failure mode). + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable (2026-04-15) + with: + targets: wasm32-unknown-unknown + + - name: Cache Rust dependencies + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + workspaces: 'crates/solver -> target' + + - name: Install wasm-pack + uses: taiki-e/install-action@e49978b799e49ff429d162b7a30601a569ab6538 # v2.81.1 + with: + tool: wasm-pack@0.13.1 + + - name: Build WASM + run: npm run build:wasm + - name: Cache build artifacts uses: actions/cache@v5 with: diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index b61bbe50..22ad9363 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -33,6 +33,25 @@ jobs: - run: npm ci + # pkg/ is build-only/gitignored — build it before typecheck/test consume it. + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable (2026-04-15) + with: + targets: wasm32-unknown-unknown + + - name: Cache Rust dependencies + uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # v2.9.1 + with: + workspaces: 'crates/solver -> target' + + - name: Install wasm-pack + uses: taiki-e/install-action@e49978b799e49ff429d162b7a30601a569ab6538 # v2.81.1 + with: + tool: wasm-pack@0.13.1 + + - name: Build WASM + run: npm run build:wasm + - name: Format check run: npm run format:check diff --git a/.gitignore b/.gitignore index 0b8eea3d..024aaa6e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,8 +5,14 @@ apps/*/dist/ packages/*/dist/ *.tsbuildinfo .vite/ -crates/solver/target/ +# Rust build output and wasm-pack artifacts are build-only — never committed. +# (`crates/*/target` covers solver and any future crate; `pkg/` is the +# wasm-pack output, rebuilt by scripts/ensure-wasm.mjs.) +crates/*/target/ +crates/*/pkg/ test/ +.test_data/ +.playwright-mcp/ .env .env.local .claude/ diff --git a/apps/cadecon/src/lib/algorithm-store.ts b/apps/cadecon/src/lib/algorithm-store.ts index 82ffc4a0..ffb28df1 100644 --- a/apps/cadecon/src/lib/algorithm-store.ts +++ b/apps/cadecon/src/lib/algorithm-store.ts @@ -9,6 +9,17 @@ const [lpFilterEnabled, setLpFilterEnabled] = createSignal(false); const [maxIterations, setMaxIterations] = createSignal(20); const [convergenceTol, setConvergenceTol] = createSignal(0.01); +// Inner-loop solver parameters. These directly affect deconvolution output, so +// they are configurable (not hardcoded) and travel with the run; defaults match +// the previously hardcoded values. Per-trace FISTA: +const [traceFistaMaxIters, setTraceFistaMaxIters] = createSignal(500); +const [traceFistaTol, setTraceFistaTol] = createSignal(1e-4); +// Per-subset free-form kernel estimation FISTA: +const [kernelFistaMaxIters, setKernelFistaMaxIters] = createSignal(200); +const [kernelFistaTol, setKernelFistaTol] = createSignal(1e-4); +// TV-L1 smoothness penalty for kernel estimation (0 = no smoothness): +const [kernelSmoothLambda, setKernelSmoothLambda] = createSignal(0); + // --- Derived --- const upsampleFactor = createMemo(() => { @@ -28,5 +39,15 @@ export { setMaxIterations, convergenceTol, setConvergenceTol, + traceFistaMaxIters, + setTraceFistaMaxIters, + traceFistaTol, + setTraceFistaTol, + kernelFistaMaxIters, + setKernelFistaMaxIters, + kernelFistaTol, + setKernelFistaTol, + kernelSmoothLambda, + setKernelSmoothLambda, upsampleFactor, }; diff --git a/apps/cadecon/src/lib/community/cadecon-service.ts b/apps/cadecon/src/lib/community/cadecon-service.ts index 8b497e2a..60c0ed1d 100644 --- a/apps/cadecon/src/lib/community/cadecon-service.ts +++ b/apps/cadecon/src/lib/community/cadecon-service.ts @@ -3,7 +3,12 @@ import { createSubmissionService } from '@calab/community'; import type { CadeconSubmission } from './types.ts'; -const service = createSubmissionService('cadecon_submissions'); +// Reads go through the PII-free public view (migration 010); writes/deletes +// target the base table under owner-scoped RLS. +const service = createSubmissionService( + 'cadecon_submissions', + 'cadecon_submissions_public', +); export const submitParameters = service.submit; export const fetchSubmissions = service.fetch; diff --git a/apps/cadecon/src/lib/iteration-manager.ts b/apps/cadecon/src/lib/iteration-manager.ts index 6bc1ee78..e3276e29 100644 --- a/apps/cadecon/src/lib/iteration-manager.ts +++ b/apps/cadecon/src/lib/iteration-manager.ts @@ -40,6 +40,11 @@ import { convergenceTol, hpFilterEnabled, lpFilterEnabled, + traceFistaMaxIters, + traceFistaTol, + kernelFistaMaxIters, + kernelFistaTol, + kernelSmoothLambda, } from './algorithm-store.ts'; import { parsedData, @@ -54,15 +59,9 @@ import { dataIndex } from './data-utils.ts'; import { median } from './math-utils.ts'; import { reconvolveAR2 } from './reconvolve.ts'; -/** Per-trace FISTA solver parameters (shared between subset and finalization passes). */ -const TRACE_FISTA_MAX_ITERS = 500; -const TRACE_FISTA_TOL = 1e-4; - -/** Per-subset kernel estimation solver parameters. */ -const KERNEL_FISTA_MAX_ITERS = 200; -const KERNEL_FISTA_TOL = 1e-4; -/** TV-L1 smoothness penalty for free-form kernel estimation. */ -const KERNEL_SMOOTH_LAMBDA = 0; +// Per-trace and per-kernel FISTA solver parameters are configurable via +// algorithm-store (traceFistaMaxIters/Tol, kernelFistaMaxIters/Tol, +// kernelSmoothLambda) so they are overridable and recorded with the run. /** Number of early free-kernel samples to skip in bi-exponential fitting. */ export const BIEXP_FIT_SKIP = 0; @@ -269,10 +268,10 @@ function dispatchKernelJobs( baselines, kernelLength, fs, - maxIters: KERNEL_FISTA_MAX_ITERS, - tol: KERNEL_FISTA_TOL, + maxIters: kernelFistaMaxIters(), + tol: kernelFistaTol(), refine: true, - smoothLambda: KERNEL_SMOOTH_LAMBDA, + smoothLambda: kernelSmoothLambda(), biexpSkip: BIEXP_FIT_SKIP, warmKernel, warmBiexp, @@ -512,8 +511,8 @@ export async function startRun(): Promise { tauD, fs, upFactor, - TRACE_FISTA_MAX_ITERS, - TRACE_FISTA_TOL, + traceFistaMaxIters(), + traceFistaTol(), hpOn, lpOn, sparsityLambda, @@ -793,8 +792,8 @@ export async function startRun(): Promise { tauDecay: tauD, fs, upsampleFactor: upFactor, - maxIters: TRACE_FISTA_MAX_ITERS, - tol: TRACE_FISTA_TOL, + maxIters: traceFistaMaxIters(), + tol: traceFistaTol(), hpEnabled: hpOn, lpEnabled: lpOn, lambda: sparsityLambda, diff --git a/apps/catune/src/lib/community/catune-service.ts b/apps/catune/src/lib/community/catune-service.ts index 3e906658..2005ccfa 100644 --- a/apps/catune/src/lib/community/catune-service.ts +++ b/apps/catune/src/lib/community/catune-service.ts @@ -3,7 +3,12 @@ import { createSubmissionService } from '@calab/community'; import type { CatuneSubmission } from './types.ts'; -const service = createSubmissionService('catune_submissions'); +// Reads go through the PII-free public view (migration 010); writes/deletes +// target the base table under owner-scoped RLS. +const service = createSubmissionService( + 'catune_submissions', + 'catune_submissions_public', +); export const submitParameters = service.submit; export const fetchSubmissions = service.fetch; diff --git a/crates/solver/pkg/README.md b/crates/solver/pkg/README.md deleted file mode 100644 index a362fb95..00000000 --- a/crates/solver/pkg/README.md +++ /dev/null @@ -1,98 +0,0 @@ -# calab-solver - -Rust FISTA deconvolution solver with dual WASM/PyO3 targets. - -## Overview - -This crate implements the FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) solver used by CaTune for calcium trace deconvolution. It is compiled to WebAssembly via `wasm-pack` and runs in Web Workers in the browser. The compiled output in `pkg/` is committed to the repository so that CI and development do not require a Rust toolchain. - -## Algorithm - -The solver minimizes the following objective with a non-negativity constraint: - -``` -minimize (1/2)||y - K·s - b||² + λ·G_dc·||s||₁ subject to s ≥ 0 -``` - -| Symbol | Meaning | -| ------ | ------------------------------------------------------------------------------------------ | -| `y` | Input fluorescence trace | -| `K` | Convolution matrix from double-exponential kernel | -| `s` | Deconvolved activity (output) | -| `b` | Scalar baseline, estimated jointly as `mean(y - K·s)` | -| `λ` | Sparsity penalty (user-adjustable) | -| `G_dc` | Kernel DC gain `Σh`, scales λ so the sparsity slider is effective across all kernel shapes | - -**Kernel:** `h(t) = exp(-t/τ_decay) - exp(-t/τ_rise)`, normalized to peak = 1.0. Length extends until the decay envelope drops below 1e-6 of peak. - -**FISTA iteration:** Standard Beck & Teboulle (2009) with momentum extrapolation. Step size is `1/L` where `L` (Lipschitz constant) = max|H(ω)|² computed via DFT of the kernel. - -**Adaptive restart:** O'Donoghue & Candes (2015) gradient-mapping criterion — resets momentum to `t = 1` when the proximal step undoes the momentum direction. - -**Convergence:** Primal residual criterion `||x_{k+1} - x_k|| / ||x_k|| < 1e-6` after iteration 5. This avoids an expensive forward convolution + objective evaluation per iteration. - -## Modules - -| Module | Description | -| ----------- | ------------------------------------------------------------------------------------------------------------------------ | -| `lib.rs` | `Solver` struct — public wasm-bindgen API, parameter management, state serialization, bandpass filter methods | -| `kernel.rs` | `build_kernel` (double-exponential), `compute_lipschitz` (spectral bound via DFT) | -| `fista.rs` | `step_batch` — FISTA iteration loop with FFT convolutions, adaptive restart, convergence check | -| `fft.rs` | `FftConvolver` — self-contained FFT convolution engine with pre-computed kernel spectrum, forward and adjoint operations | -| `filter.rs` | `BandpassFilter` — FFT-based bandpass filter derived from kernel time constants, cosine-tapered transitions | - -## Public API - -Methods exposed to JavaScript via `wasm-bindgen`: - -| Method | Description | -| -------------------------------------------------- | ------------------------------------------------------------------------------- | -| `new()` | Create solver with default parameters (τ_rise=0.02, τ_decay=0.4, λ=0.01, fs=30) | -| `set_params(tau_rise, tau_decay, lambda, fs)` | Update parameters and rebuild kernel | -| `set_trace(trace)` | Load a trace, grow buffers if needed, reset iteration state | -| `step_batch(n_steps)` | Run N FISTA iterations, return true if converged | -| `get_solution()` | Get deconvolved activity (owned copy) | -| `get_reconvolution()` | Get K·s (lazy-computed, owned copy) | -| `get_reconvolution_with_baseline()` | Get K·s + b (owned copy) | -| `get_baseline()` | Get estimated scalar baseline | -| `get_trace()` | Get current trace (may be filtered) | -| `converged()` | Check convergence flag | -| `iteration_count()` | Get iteration count | -| `reset_momentum()` | Reset FISTA momentum for warm-start after kernel change | -| `export_state()` / `load_state(state)` | Serialize/restore solver state for warm-start cache | -| `set_filter_enabled(enabled)` / `filter_enabled()` | Toggle bandpass filter | -| `apply_filter()` | Apply bandpass filter to loaded trace | -| `get_power_spectrum()` | Get \|FFT\|² of current trace | -| `get_spectrum_frequencies()` | Get frequency axis in Hz | -| `get_filter_cutoffs()` | Get [f_hp, f_lp] cutoff frequencies | - -## Build - -```bash -cd crates/solver -wasm-pack build --target web --release -``` - -Output goes to `pkg/` which is committed to the repository. You only need to rebuild when modifying the solver Rust source. - -From the repo root: - -```bash -npm run build:wasm -``` - -## Performance - -- **Pre-allocated buffers** — grow but never shrink to prevent WASM memory fragmentation -- **f32 precision** — halves memory per worker compared to f64 (Lipschitz constant computed in f64 for step-size accuracy) -- **FFT convolution** — O(n log n) via `realfft`/`rustfft` for both forward and adjoint operations -- **Release profile** — `opt-level = 3`, LTO, single codegen unit, wasm-opt with bulk-memory - -## Dependencies - -| Crate | Purpose | -| -------------------------- | ------------------------------------------ | -| `wasm-bindgen` | JavaScript interop | -| `console_error_panic_hook` | Readable panic messages in browser console | -| `realfft` | Real-valued FFT (wraps rustfft) | -| `rustfft` | FFT computation | diff --git a/crates/solver/pkg/calab_solver.d.ts b/crates/solver/pkg/calab_solver.d.ts deleted file mode 100644 index 436961a4..00000000 --- a/crates/solver/pkg/calab_solver.d.ts +++ /dev/null @@ -1,313 +0,0 @@ -/* tslint:disable */ -/* eslint-disable */ - -/** - * Constraint type for the proximal step. - */ -export enum Constraint { - /** - * Current: max(0, z - threshold) — L1 + non-negativity. - */ - NonNegative = 0, - /** - * InDeCa Eq. 3: clamp(z, 0, 1) — box constraint, no L1 penalty. - */ - Box01 = 1, -} - -/** - * Convolution mode for forward/adjoint operations in FISTA. - */ -export enum ConvMode { - /** - * FFT-based O(T log T) per call — the original implementation. - */ - Fft = 0, - /** - * Banded AR(2) recursion O(T) per call — faster for long traces. - */ - BandedAR2 = 1, -} - -/** - * FISTA solver for calcium deconvolution. - * - * Minimizes (1/2)||y - K*s - b||^2 + lambda*G_dc*||s||_1 subject to s >= 0, - * where K is the convolution matrix derived from a double-exponential kernel, - * b is a scalar baseline estimated jointly, and G_dc = sum(K) scales lambda - * so the sparsity slider is effective across all kernel configurations. - * - * Pre-allocated buffers grow but never shrink to prevent WASM memory fragmentation. - */ -export class Solver { - free(): void; - [Symbol.dispose](): void; - /** - * Apply bandpass filter to the active trace region. Returns true if filtering was applied. - * - * Sets `self.filtered = true` only when HP is active, because HP removes DC and - * baseline estimation should be skipped. LP-only preserves DC, so baseline - * estimation must still run. - */ - apply_filter(): boolean; - /** - * Returns whether the solver has converged. - */ - converged(): boolean; - /** - * Serialize solver state for warm-start cache. - * Format: [active_len (u32)] [t_fista (f64)] [iteration (u32)] [baseline (f64)] [solution f32...] [solution_prev f32...] - */ - export_state(): Uint8Array; - filter_enabled(): boolean; - /** - * Returns the estimated scalar baseline (EMA-smoothed for stable display). - * Lazily computes reconvolution if stale, to ensure the EMA is up to date. - */ - get_baseline(): number; - /** - * Get filter cutoff frequencies as [f_hp, f_lp]. - */ - get_filter_cutoffs(): Float32Array; - /** - * Returns a copy of the kernel. - * - * Returns `Vec` which wasm-bindgen copies into a JS-owned `Float32Array`. - * A WASM memory view would be unsound here: any subsequent WASM allocation - * (e.g. `set_trace`) can grow the memory and invalidate the view. The JS side - * also transfers these buffers via `postMessage`, which requires ownership. - */ - get_kernel(): Float32Array; - /** - * Get the power spectrum of the current trace (N/2+1 bins). - */ - get_power_spectrum(): Float32Array; - /** - * Returns the reconvolution (K * solution) for the active region. - * Computes the reconvolution lazily if it is stale (not computed during iteration). - * - * See `get_kernel` for why this returns an owned copy rather than a memory view. - */ - get_reconvolution(): Float32Array; - /** - * Returns reconvolution with baseline added: K*s + b for the active region. - * Computes the reconvolution lazily if it is stale. - * - * See `get_kernel` for why this returns an owned copy rather than a memory view. - */ - get_reconvolution_with_baseline(): Float32Array; - /** - * Returns the current solution (spike train) for the active region. - * - * See `get_kernel` for why this returns an owned copy rather than a memory view. - */ - get_solution(): Float32Array; - /** - * Get frequency axis in Hz for the spectrum bins. - */ - get_spectrum_frequencies(): Float32Array; - /** - * Returns the current trace for the active region. - * After apply_filter(), this contains the filtered trace. - * - * See `get_kernel` for why this returns an owned copy rather than a memory view. - */ - get_trace(): Float32Array; - /** - * Returns the current iteration count. - */ - iteration_count(): number; - /** - * Load warm-start state. If state is empty or wrong size, performs cold-start (zero solution). - */ - load_state(state: Uint8Array): void; - /** - * Create a new Solver with default parameters. - */ - constructor(); - /** - * Reset FISTA momentum. Used for warm-start after kernel change. - * Sets t_fista = 1.0 and copies solution into solution_prev. - */ - reset_momentum(): void; - /** - * Set the constraint type (NonNegative or Box01). - */ - set_constraint(c: Constraint): void; - /** - * Set the convolution mode (FFT or BandedAR2). - * Recomputes the Lipschitz constant for the selected mode. - * Does NOT reset solution/iteration state — warm-start is preserved. - */ - set_conv_mode(mode: ConvMode): void; - /** - * Convenience: set both HP and LP together (used by CaTune's single toggle). - */ - set_filter_enabled(enabled: boolean): void; - set_hp_filter_enabled(enabled: boolean): void; - set_lp_filter_enabled(enabled: boolean): void; - /** - * Update solver parameters and rebuild kernel. - */ - set_params(tau_rise: number, tau_decay: number, lambda: number, fs: number): void; - /** - * Load a trace for deconvolution. Grows buffers if needed (never shrinks). - * Resets iteration state for a fresh solve. - */ - set_trace(trace: Float32Array): void; - /** - * Run n_steps of FISTA iterations. Returns true if converged. - * - * Uses the standard Beck & Teboulle FISTA with two sequences: - * - x_k (solution): the proximal update point - * - y_k (solution_prev used as extrapolated point): where gradient is evaluated - * - * The algorithm evaluates the gradient at the extrapolated point y_k, takes - * the proximal step to get x_{k+1}, then extrapolates to get y_{k+1}. - * - * Includes adaptive restart (O'Donoghue & Candes 2015): when the gradient-mapping - * criterion detects momentum is hurting progress, reset to avoid oscillation. - * - * Uses FFT-based O(n log n) convolutions instead of time-domain O(n*k), and - * primal residual convergence criterion to eliminate one convolution per iteration. - */ - step_batch(n_steps: number): boolean; - /** - * Subtract a rolling-percentile baseline from the active trace. - * - * Brings the trace floor to ~0, removing slow baseline drift while - * preserving positive-going calcium transients. After subtraction the - * baseline is ~0 so FISTA baseline estimation can be skipped (same - * rationale as when HP removes DC). - */ - subtract_baseline(): void; -} - -/** - * Get all built-in simulation preset names and their configs. - * - * Returns: JsValue containing Vec<(name, SimulationConfig)>. - */ -export function get_simulation_presets(): any; - -/** - * Compute the upsample factor for a given sampling rate and target rate. - */ -export function indeca_compute_upsample_factor(fs: number, target_fs: number): number; - -/** - * Estimate a free-form kernel from multiple traces and their spike trains. - * - * `warm_kernel`: optional kernel from a previous iteration. Pass an empty slice - * for cold-start. - * - * Returns the estimated kernel as Float32Array (via Vec). - */ -export function indeca_estimate_kernel(traces_flat: Float32Array, spikes_flat: Float32Array, trace_lengths: Uint32Array, alphas: Float64Array, baselines: Float64Array, kernel_length: number, max_iters: number, tol: number, warm_kernel: Float32Array, smooth_lambda: number): Float32Array; - -/** - * Fit a bi-exponential model to a free-form kernel. - * - * Warm-start: pass `use_warm=true` and the previous result's fields to add - * the previous result as an additional refined candidate alongside the cold - * grid search. This gives faster convergence when the kernel evolves smoothly. - * Pass `use_warm=false` (and any values for warm_* fields) for cold-start only. - * - * Returns a JsValue containing the serialized BiexpResult: - * { tau_rise, tau_decay, beta, residual, tau_rise_fast, tau_decay_fast, beta_fast } - */ -export function indeca_fit_biexponential(h_free: Float32Array, fs: number, refine: boolean, skip: number, warm_tau_rise: number, warm_tau_decay: number, warm_tau_rise_fast: number, warm_tau_decay_fast: number, warm_beta: number, warm_beta_fast: number, warm_residual: number, use_warm: boolean): any; - -/** - * Solve a single trace using the InDeCa pipeline. - * - * `warm_counts`: optional spike counts from a previous iteration at the original - * sampling rate. Pass an empty slice for cold-start. - * - * Returns a JsValue containing the serialized InDecaResult: - * { s_counts, alpha, baseline, threshold, pve, iterations, converged } - */ -export function indeca_solve_trace(trace: Float32Array, tau_r: number, tau_d: number, fs: number, upsample_factor: number, max_iters: number, tol: number, hp_enabled: boolean, lp_enabled: boolean, warm_counts: Float32Array, lambda: number): any; - -/** - * Run peak-seeded spike detection on a single trace. - * - * Returns a JsValue containing the serialized SeedTraceResult: - * { s_counts, alpha, baseline } - */ -export function seed_trace(trace: Float32Array, fs: number): any; - -/** - * Generate synthetic calcium traces from a config object. - * - * Accepts: JsValue containing a SimulationConfig-shaped object. - * Returns: JsValue containing a SimulationResult-shaped object. - */ -export function simulate_traces(config_js: any): any; - -export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module; - -export interface InitOutput { - readonly memory: WebAssembly.Memory; - readonly __wbg_solver_free: (a: number, b: number) => void; - readonly get_simulation_presets: () => number; - readonly indeca_compute_upsample_factor: (a: number, b: number) => number; - readonly indeca_estimate_kernel: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number, p: number, q: number) => void; - readonly indeca_fit_biexponential: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number) => number; - readonly indeca_solve_trace: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number) => number; - readonly seed_trace: (a: number, b: number, c: number) => number; - readonly simulate_traces: (a: number) => number; - readonly solver_apply_filter: (a: number) => number; - readonly solver_converged: (a: number) => number; - readonly solver_export_state: (a: number, b: number) => void; - readonly solver_filter_enabled: (a: number) => number; - readonly solver_get_baseline: (a: number) => number; - readonly solver_get_filter_cutoffs: (a: number, b: number) => void; - readonly solver_get_kernel: (a: number, b: number) => void; - readonly solver_get_power_spectrum: (a: number, b: number) => void; - readonly solver_get_reconvolution: (a: number, b: number) => void; - readonly solver_get_reconvolution_with_baseline: (a: number, b: number) => void; - readonly solver_get_solution: (a: number, b: number) => void; - readonly solver_get_spectrum_frequencies: (a: number, b: number) => void; - readonly solver_get_trace: (a: number, b: number) => void; - readonly solver_iteration_count: (a: number) => number; - readonly solver_load_state: (a: number, b: number, c: number) => void; - readonly solver_new: () => number; - readonly solver_reset_momentum: (a: number) => void; - readonly solver_set_constraint: (a: number, b: number) => void; - readonly solver_set_conv_mode: (a: number, b: number) => void; - readonly solver_set_filter_enabled: (a: number, b: number) => void; - readonly solver_set_hp_filter_enabled: (a: number, b: number) => void; - readonly solver_set_lp_filter_enabled: (a: number, b: number) => void; - readonly solver_set_params: (a: number, b: number, c: number, d: number, e: number) => void; - readonly solver_set_trace: (a: number, b: number, c: number) => void; - readonly solver_step_batch: (a: number, b: number) => number; - readonly solver_subtract_baseline: (a: number) => void; - readonly __wbindgen_export: (a: number, b: number) => number; - readonly __wbindgen_export2: (a: number, b: number, c: number, d: number) => number; - readonly __wbindgen_export3: (a: number) => void; - readonly __wbindgen_export4: (a: number, b: number, c: number) => void; - readonly __wbindgen_add_to_stack_pointer: (a: number) => number; -} - -export type SyncInitInput = BufferSource | WebAssembly.Module; - -/** - * Instantiates the given `module`, which can either be bytes or - * a precompiled `WebAssembly.Module`. - * - * @param {{ module: SyncInitInput }} module - Passing `SyncInitInput` directly is deprecated. - * - * @returns {InitOutput} - */ -export function initSync(module: { module: SyncInitInput } | SyncInitInput): InitOutput; - -/** - * If `module_or_path` is {RequestInfo} or {URL}, makes a request and - * for everything else, calls `WebAssembly.instantiate` directly. - * - * @param {{ module_or_path: InitInput | Promise }} module_or_path - Passing `InitInput` directly is deprecated. - * - * @returns {Promise} - */ -export default function __wbg_init (module_or_path?: { module_or_path: InitInput | Promise } | InitInput | Promise): Promise; diff --git a/crates/solver/pkg/calab_solver.js b/crates/solver/pkg/calab_solver.js deleted file mode 100644 index 9e78bb07..00000000 --- a/crates/solver/pkg/calab_solver.js +++ /dev/null @@ -1,1130 +0,0 @@ -/* @ts-self-types="./calab_solver.d.ts" */ - -/** - * Constraint type for the proximal step. - * @enum {0 | 1} - */ -export const Constraint = Object.freeze({ - /** - * Current: max(0, z - threshold) — L1 + non-negativity. - */ - NonNegative: 0, "0": "NonNegative", - /** - * InDeCa Eq. 3: clamp(z, 0, 1) — box constraint, no L1 penalty. - */ - Box01: 1, "1": "Box01", -}); - -/** - * Convolution mode for forward/adjoint operations in FISTA. - * @enum {0 | 1} - */ -export const ConvMode = Object.freeze({ - /** - * FFT-based O(T log T) per call — the original implementation. - */ - Fft: 0, "0": "Fft", - /** - * Banded AR(2) recursion O(T) per call — faster for long traces. - */ - BandedAR2: 1, "1": "BandedAR2", -}); - -/** - * FISTA solver for calcium deconvolution. - * - * Minimizes (1/2)||y - K*s - b||^2 + lambda*G_dc*||s||_1 subject to s >= 0, - * where K is the convolution matrix derived from a double-exponential kernel, - * b is a scalar baseline estimated jointly, and G_dc = sum(K) scales lambda - * so the sparsity slider is effective across all kernel configurations. - * - * Pre-allocated buffers grow but never shrink to prevent WASM memory fragmentation. - */ -export class Solver { - __destroy_into_raw() { - const ptr = this.__wbg_ptr; - this.__wbg_ptr = 0; - SolverFinalization.unregister(this); - return ptr; - } - free() { - const ptr = this.__destroy_into_raw(); - wasm.__wbg_solver_free(ptr, 0); - } - /** - * Apply bandpass filter to the active trace region. Returns true if filtering was applied. - * - * Sets `self.filtered = true` only when HP is active, because HP removes DC and - * baseline estimation should be skipped. LP-only preserves DC, so baseline - * estimation must still run. - * @returns {boolean} - */ - apply_filter() { - const ret = wasm.solver_apply_filter(this.__wbg_ptr); - return ret !== 0; - } - /** - * Returns whether the solver has converged. - * @returns {boolean} - */ - converged() { - const ret = wasm.solver_converged(this.__wbg_ptr); - return ret !== 0; - } - /** - * Serialize solver state for warm-start cache. - * Format: [active_len (u32)] [t_fista (f64)] [iteration (u32)] [baseline (f64)] [solution f32...] [solution_prev f32...] - * @returns {Uint8Array} - */ - export_state() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_export_state(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayU8FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 1, 1); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * @returns {boolean} - */ - filter_enabled() { - const ret = wasm.solver_filter_enabled(this.__wbg_ptr); - return ret !== 0; - } - /** - * Returns the estimated scalar baseline (EMA-smoothed for stable display). - * Lazily computes reconvolution if stale, to ensure the EMA is up to date. - * @returns {number} - */ - get_baseline() { - const ret = wasm.solver_get_baseline(this.__wbg_ptr); - return ret; - } - /** - * Get filter cutoff frequencies as [f_hp, f_lp]. - * @returns {Float32Array} - */ - get_filter_cutoffs() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_get_filter_cutoffs(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * Returns a copy of the kernel. - * - * Returns `Vec` which wasm-bindgen copies into a JS-owned `Float32Array`. - * A WASM memory view would be unsound here: any subsequent WASM allocation - * (e.g. `set_trace`) can grow the memory and invalidate the view. The JS side - * also transfers these buffers via `postMessage`, which requires ownership. - * @returns {Float32Array} - */ - get_kernel() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_get_kernel(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * Get the power spectrum of the current trace (N/2+1 bins). - * @returns {Float32Array} - */ - get_power_spectrum() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_get_power_spectrum(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * Returns the reconvolution (K * solution) for the active region. - * Computes the reconvolution lazily if it is stale (not computed during iteration). - * - * See `get_kernel` for why this returns an owned copy rather than a memory view. - * @returns {Float32Array} - */ - get_reconvolution() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_get_reconvolution(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * Returns reconvolution with baseline added: K*s + b for the active region. - * Computes the reconvolution lazily if it is stale. - * - * See `get_kernel` for why this returns an owned copy rather than a memory view. - * @returns {Float32Array} - */ - get_reconvolution_with_baseline() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_get_reconvolution_with_baseline(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * Returns the current solution (spike train) for the active region. - * - * See `get_kernel` for why this returns an owned copy rather than a memory view. - * @returns {Float32Array} - */ - get_solution() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_get_solution(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * Get frequency axis in Hz for the spectrum bins. - * @returns {Float32Array} - */ - get_spectrum_frequencies() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_get_spectrum_frequencies(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * Returns the current trace for the active region. - * After apply_filter(), this contains the filtered trace. - * - * See `get_kernel` for why this returns an owned copy rather than a memory view. - * @returns {Float32Array} - */ - get_trace() { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - wasm.solver_get_trace(retptr, this.__wbg_ptr); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v1 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v1; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } - } - /** - * Returns the current iteration count. - * @returns {number} - */ - iteration_count() { - const ret = wasm.solver_iteration_count(this.__wbg_ptr); - return ret >>> 0; - } - /** - * Load warm-start state. If state is empty or wrong size, performs cold-start (zero solution). - * @param {Uint8Array} state - */ - load_state(state) { - const ptr0 = passArray8ToWasm0(state, wasm.__wbindgen_export); - const len0 = WASM_VECTOR_LEN; - wasm.solver_load_state(this.__wbg_ptr, ptr0, len0); - } - /** - * Create a new Solver with default parameters. - */ - constructor() { - const ret = wasm.solver_new(); - this.__wbg_ptr = ret >>> 0; - SolverFinalization.register(this, this.__wbg_ptr, this); - return this; - } - /** - * Reset FISTA momentum. Used for warm-start after kernel change. - * Sets t_fista = 1.0 and copies solution into solution_prev. - */ - reset_momentum() { - wasm.solver_reset_momentum(this.__wbg_ptr); - } - /** - * Set the constraint type (NonNegative or Box01). - * @param {Constraint} c - */ - set_constraint(c) { - wasm.solver_set_constraint(this.__wbg_ptr, c); - } - /** - * Set the convolution mode (FFT or BandedAR2). - * Recomputes the Lipschitz constant for the selected mode. - * Does NOT reset solution/iteration state — warm-start is preserved. - * @param {ConvMode} mode - */ - set_conv_mode(mode) { - wasm.solver_set_conv_mode(this.__wbg_ptr, mode); - } - /** - * Convenience: set both HP and LP together (used by CaTune's single toggle). - * @param {boolean} enabled - */ - set_filter_enabled(enabled) { - wasm.solver_set_filter_enabled(this.__wbg_ptr, enabled); - } - /** - * @param {boolean} enabled - */ - set_hp_filter_enabled(enabled) { - wasm.solver_set_hp_filter_enabled(this.__wbg_ptr, enabled); - } - /** - * @param {boolean} enabled - */ - set_lp_filter_enabled(enabled) { - wasm.solver_set_lp_filter_enabled(this.__wbg_ptr, enabled); - } - /** - * Update solver parameters and rebuild kernel. - * @param {number} tau_rise - * @param {number} tau_decay - * @param {number} lambda - * @param {number} fs - */ - set_params(tau_rise, tau_decay, lambda, fs) { - wasm.solver_set_params(this.__wbg_ptr, tau_rise, tau_decay, lambda, fs); - } - /** - * Load a trace for deconvolution. Grows buffers if needed (never shrinks). - * Resets iteration state for a fresh solve. - * @param {Float32Array} trace - */ - set_trace(trace) { - const ptr0 = passArrayF32ToWasm0(trace, wasm.__wbindgen_export); - const len0 = WASM_VECTOR_LEN; - wasm.solver_set_trace(this.__wbg_ptr, ptr0, len0); - } - /** - * Run n_steps of FISTA iterations. Returns true if converged. - * - * Uses the standard Beck & Teboulle FISTA with two sequences: - * - x_k (solution): the proximal update point - * - y_k (solution_prev used as extrapolated point): where gradient is evaluated - * - * The algorithm evaluates the gradient at the extrapolated point y_k, takes - * the proximal step to get x_{k+1}, then extrapolates to get y_{k+1}. - * - * Includes adaptive restart (O'Donoghue & Candes 2015): when the gradient-mapping - * criterion detects momentum is hurting progress, reset to avoid oscillation. - * - * Uses FFT-based O(n log n) convolutions instead of time-domain O(n*k), and - * primal residual convergence criterion to eliminate one convolution per iteration. - * @param {number} n_steps - * @returns {boolean} - */ - step_batch(n_steps) { - const ret = wasm.solver_step_batch(this.__wbg_ptr, n_steps); - return ret !== 0; - } - /** - * Subtract a rolling-percentile baseline from the active trace. - * - * Brings the trace floor to ~0, removing slow baseline drift while - * preserving positive-going calcium transients. After subtraction the - * baseline is ~0 so FISTA baseline estimation can be skipped (same - * rationale as when HP removes DC). - */ - subtract_baseline() { - wasm.solver_subtract_baseline(this.__wbg_ptr); - } -} -if (Symbol.dispose) Solver.prototype[Symbol.dispose] = Solver.prototype.free; - -/** - * Get all built-in simulation preset names and their configs. - * - * Returns: JsValue containing Vec<(name, SimulationConfig)>. - * @returns {any} - */ -export function get_simulation_presets() { - const ret = wasm.get_simulation_presets(); - return takeObject(ret); -} - -/** - * Compute the upsample factor for a given sampling rate and target rate. - * @param {number} fs - * @param {number} target_fs - * @returns {number} - */ -export function indeca_compute_upsample_factor(fs, target_fs) { - const ret = wasm.indeca_compute_upsample_factor(fs, target_fs); - return ret >>> 0; -} - -/** - * Estimate a free-form kernel from multiple traces and their spike trains. - * - * `warm_kernel`: optional kernel from a previous iteration. Pass an empty slice - * for cold-start. - * - * Returns the estimated kernel as Float32Array (via Vec). - * @param {Float32Array} traces_flat - * @param {Float32Array} spikes_flat - * @param {Uint32Array} trace_lengths - * @param {Float64Array} alphas - * @param {Float64Array} baselines - * @param {number} kernel_length - * @param {number} max_iters - * @param {number} tol - * @param {Float32Array} warm_kernel - * @param {number} smooth_lambda - * @returns {Float32Array} - */ -export function indeca_estimate_kernel(traces_flat, spikes_flat, trace_lengths, alphas, baselines, kernel_length, max_iters, tol, warm_kernel, smooth_lambda) { - try { - const retptr = wasm.__wbindgen_add_to_stack_pointer(-16); - const ptr0 = passArrayF32ToWasm0(traces_flat, wasm.__wbindgen_export); - const len0 = WASM_VECTOR_LEN; - const ptr1 = passArrayF32ToWasm0(spikes_flat, wasm.__wbindgen_export); - const len1 = WASM_VECTOR_LEN; - const ptr2 = passArray32ToWasm0(trace_lengths, wasm.__wbindgen_export); - const len2 = WASM_VECTOR_LEN; - const ptr3 = passArrayF64ToWasm0(alphas, wasm.__wbindgen_export); - const len3 = WASM_VECTOR_LEN; - const ptr4 = passArrayF64ToWasm0(baselines, wasm.__wbindgen_export); - const len4 = WASM_VECTOR_LEN; - const ptr5 = passArrayF32ToWasm0(warm_kernel, wasm.__wbindgen_export); - const len5 = WASM_VECTOR_LEN; - wasm.indeca_estimate_kernel(retptr, ptr0, len0, ptr1, len1, ptr2, len2, ptr3, len3, ptr4, len4, kernel_length, max_iters, tol, ptr5, len5, smooth_lambda); - var r0 = getDataViewMemory0().getInt32(retptr + 4 * 0, true); - var r1 = getDataViewMemory0().getInt32(retptr + 4 * 1, true); - var v7 = getArrayF32FromWasm0(r0, r1).slice(); - wasm.__wbindgen_export4(r0, r1 * 4, 4); - return v7; - } finally { - wasm.__wbindgen_add_to_stack_pointer(16); - } -} - -/** - * Fit a bi-exponential model to a free-form kernel. - * - * Warm-start: pass `use_warm=true` and the previous result's fields to add - * the previous result as an additional refined candidate alongside the cold - * grid search. This gives faster convergence when the kernel evolves smoothly. - * Pass `use_warm=false` (and any values for warm_* fields) for cold-start only. - * - * Returns a JsValue containing the serialized BiexpResult: - * { tau_rise, tau_decay, beta, residual, tau_rise_fast, tau_decay_fast, beta_fast } - * @param {Float32Array} h_free - * @param {number} fs - * @param {boolean} refine - * @param {number} skip - * @param {number} warm_tau_rise - * @param {number} warm_tau_decay - * @param {number} warm_tau_rise_fast - * @param {number} warm_tau_decay_fast - * @param {number} warm_beta - * @param {number} warm_beta_fast - * @param {number} warm_residual - * @param {boolean} use_warm - * @returns {any} - */ -export function indeca_fit_biexponential(h_free, fs, refine, skip, warm_tau_rise, warm_tau_decay, warm_tau_rise_fast, warm_tau_decay_fast, warm_beta, warm_beta_fast, warm_residual, use_warm) { - const ptr0 = passArrayF32ToWasm0(h_free, wasm.__wbindgen_export); - const len0 = WASM_VECTOR_LEN; - const ret = wasm.indeca_fit_biexponential(ptr0, len0, fs, refine, skip, warm_tau_rise, warm_tau_decay, warm_tau_rise_fast, warm_tau_decay_fast, warm_beta, warm_beta_fast, warm_residual, use_warm); - return takeObject(ret); -} - -/** - * Solve a single trace using the InDeCa pipeline. - * - * `warm_counts`: optional spike counts from a previous iteration at the original - * sampling rate. Pass an empty slice for cold-start. - * - * Returns a JsValue containing the serialized InDecaResult: - * { s_counts, alpha, baseline, threshold, pve, iterations, converged } - * @param {Float32Array} trace - * @param {number} tau_r - * @param {number} tau_d - * @param {number} fs - * @param {number} upsample_factor - * @param {number} max_iters - * @param {number} tol - * @param {boolean} hp_enabled - * @param {boolean} lp_enabled - * @param {Float32Array} warm_counts - * @param {number} lambda - * @returns {any} - */ -export function indeca_solve_trace(trace, tau_r, tau_d, fs, upsample_factor, max_iters, tol, hp_enabled, lp_enabled, warm_counts, lambda) { - const ptr0 = passArrayF32ToWasm0(trace, wasm.__wbindgen_export); - const len0 = WASM_VECTOR_LEN; - const ptr1 = passArrayF32ToWasm0(warm_counts, wasm.__wbindgen_export); - const len1 = WASM_VECTOR_LEN; - const ret = wasm.indeca_solve_trace(ptr0, len0, tau_r, tau_d, fs, upsample_factor, max_iters, tol, hp_enabled, lp_enabled, ptr1, len1, lambda); - return takeObject(ret); -} - -/** - * Run peak-seeded spike detection on a single trace. - * - * Returns a JsValue containing the serialized SeedTraceResult: - * { s_counts, alpha, baseline } - * @param {Float32Array} trace - * @param {number} fs - * @returns {any} - */ -export function seed_trace(trace, fs) { - const ptr0 = passArrayF32ToWasm0(trace, wasm.__wbindgen_export); - const len0 = WASM_VECTOR_LEN; - const ret = wasm.seed_trace(ptr0, len0, fs); - return takeObject(ret); -} - -/** - * Generate synthetic calcium traces from a config object. - * - * Accepts: JsValue containing a SimulationConfig-shaped object. - * Returns: JsValue containing a SimulationResult-shaped object. - * @param {any} config_js - * @returns {any} - */ -export function simulate_traces(config_js) { - const ret = wasm.simulate_traces(addHeapObject(config_js)); - return takeObject(ret); -} - -function __wbg_get_imports() { - const import0 = { - __proto__: null, - __wbg_Error_8c4e43fe74559d73: function(arg0, arg1) { - const ret = Error(getStringFromWasm0(arg0, arg1)); - return addHeapObject(ret); - }, - __wbg_Number_04624de7d0e8332d: function(arg0) { - const ret = Number(getObject(arg0)); - return ret; - }, - __wbg___wbindgen_bigint_get_as_i64_8fcf4ce7f1ca72a2: function(arg0, arg1) { - const v = getObject(arg1); - const ret = typeof(v) === 'bigint' ? v : undefined; - getDataViewMemory0().setBigInt64(arg0 + 8 * 1, isLikeNone(ret) ? BigInt(0) : ret, true); - getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true); - }, - __wbg___wbindgen_boolean_get_bbbb1c18aa2f5e25: function(arg0) { - const v = getObject(arg0); - const ret = typeof(v) === 'boolean' ? v : undefined; - return isLikeNone(ret) ? 0xFFFFFF : ret ? 1 : 0; - }, - __wbg___wbindgen_debug_string_0bc8482c6e3508ae: function(arg0, arg1) { - const ret = debugString(getObject(arg1)); - const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_export, wasm.__wbindgen_export2); - const len1 = WASM_VECTOR_LEN; - getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); - getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); - }, - __wbg___wbindgen_in_47fa6863be6f2f25: function(arg0, arg1) { - const ret = getObject(arg0) in getObject(arg1); - return ret; - }, - __wbg___wbindgen_is_bigint_31b12575b56f32fc: function(arg0) { - const ret = typeof(getObject(arg0)) === 'bigint'; - return ret; - }, - __wbg___wbindgen_is_function_0095a73b8b156f76: function(arg0) { - const ret = typeof(getObject(arg0)) === 'function'; - return ret; - }, - __wbg___wbindgen_is_object_5ae8e5880f2c1fbd: function(arg0) { - const val = getObject(arg0); - const ret = typeof(val) === 'object' && val !== null; - return ret; - }, - __wbg___wbindgen_is_undefined_9e4d92534c42d778: function(arg0) { - const ret = getObject(arg0) === undefined; - return ret; - }, - __wbg___wbindgen_jsval_eq_11888390b0186270: function(arg0, arg1) { - const ret = getObject(arg0) === getObject(arg1); - return ret; - }, - __wbg___wbindgen_jsval_loose_eq_9dd77d8cd6671811: function(arg0, arg1) { - const ret = getObject(arg0) == getObject(arg1); - return ret; - }, - __wbg___wbindgen_number_get_8ff4255516ccad3e: function(arg0, arg1) { - const obj = getObject(arg1); - const ret = typeof(obj) === 'number' ? obj : undefined; - getDataViewMemory0().setFloat64(arg0 + 8 * 1, isLikeNone(ret) ? 0 : ret, true); - getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true); - }, - __wbg___wbindgen_string_get_72fb696202c56729: function(arg0, arg1) { - const obj = getObject(arg1); - const ret = typeof(obj) === 'string' ? obj : undefined; - var ptr1 = isLikeNone(ret) ? 0 : passStringToWasm0(ret, wasm.__wbindgen_export, wasm.__wbindgen_export2); - var len1 = WASM_VECTOR_LEN; - getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); - getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); - }, - __wbg___wbindgen_throw_be289d5034ed271b: function(arg0, arg1) { - throw new Error(getStringFromWasm0(arg0, arg1)); - }, - __wbg_call_389efe28435a9388: function() { return handleError(function (arg0, arg1) { - const ret = getObject(arg0).call(getObject(arg1)); - return addHeapObject(ret); - }, arguments); }, - __wbg_done_57b39ecd9addfe81: function(arg0) { - const ret = getObject(arg0).done; - return ret; - }, - __wbg_entries_58c7934c745daac7: function(arg0) { - const ret = Object.entries(getObject(arg0)); - return addHeapObject(ret); - }, - __wbg_error_7534b8e9a36f1ab4: function(arg0, arg1) { - let deferred0_0; - let deferred0_1; - try { - deferred0_0 = arg0; - deferred0_1 = arg1; - console.error(getStringFromWasm0(arg0, arg1)); - } finally { - wasm.__wbindgen_export4(deferred0_0, deferred0_1, 1); - } - }, - __wbg_get_9b94d73e6221f75c: function(arg0, arg1) { - const ret = getObject(arg0)[arg1 >>> 0]; - return addHeapObject(ret); - }, - __wbg_get_b3ed3ad4be2bc8ac: function() { return handleError(function (arg0, arg1) { - const ret = Reflect.get(getObject(arg0), getObject(arg1)); - return addHeapObject(ret); - }, arguments); }, - __wbg_get_with_ref_key_1dc361bd10053bfe: function(arg0, arg1) { - const ret = getObject(arg0)[getObject(arg1)]; - return addHeapObject(ret); - }, - __wbg_instanceof_ArrayBuffer_c367199e2fa2aa04: function(arg0) { - let result; - try { - result = getObject(arg0) instanceof ArrayBuffer; - } catch (_) { - result = false; - } - const ret = result; - return ret; - }, - __wbg_instanceof_Map_53af74335dec57f4: function(arg0) { - let result; - try { - result = getObject(arg0) instanceof Map; - } catch (_) { - result = false; - } - const ret = result; - return ret; - }, - __wbg_instanceof_Uint8Array_9b9075935c74707c: function(arg0) { - let result; - try { - result = getObject(arg0) instanceof Uint8Array; - } catch (_) { - result = false; - } - const ret = result; - return ret; - }, - __wbg_isArray_d314bb98fcf08331: function(arg0) { - const ret = Array.isArray(getObject(arg0)); - return ret; - }, - __wbg_isSafeInteger_bfbc7332a9768d2a: function(arg0) { - const ret = Number.isSafeInteger(getObject(arg0)); - return ret; - }, - __wbg_iterator_6ff6560ca1568e55: function() { - const ret = Symbol.iterator; - return addHeapObject(ret); - }, - __wbg_length_32ed9a279acd054c: function(arg0) { - const ret = getObject(arg0).length; - return ret; - }, - __wbg_length_35a7bace40f36eac: function(arg0) { - const ret = getObject(arg0).length; - return ret; - }, - __wbg_new_361308b2356cecd0: function() { - const ret = new Object(); - return addHeapObject(ret); - }, - __wbg_new_3eb36ae241fe6f44: function() { - const ret = new Array(); - return addHeapObject(ret); - }, - __wbg_new_8a6f238a6ece86ea: function() { - const ret = new Error(); - return addHeapObject(ret); - }, - __wbg_new_dd2b680c8bf6ae29: function(arg0) { - const ret = new Uint8Array(getObject(arg0)); - return addHeapObject(ret); - }, - __wbg_next_3482f54c49e8af19: function() { return handleError(function (arg0) { - const ret = getObject(arg0).next(); - return addHeapObject(ret); - }, arguments); }, - __wbg_next_418f80d8f5303233: function(arg0) { - const ret = getObject(arg0).next; - return addHeapObject(ret); - }, - __wbg_prototypesetcall_bdcdcc5842e4d77d: function(arg0, arg1, arg2) { - Uint8Array.prototype.set.call(getArrayU8FromWasm0(arg0, arg1), getObject(arg2)); - }, - __wbg_set_3f1d0b984ed272ed: function(arg0, arg1, arg2) { - getObject(arg0)[takeObject(arg1)] = takeObject(arg2); - }, - __wbg_set_f43e577aea94465b: function(arg0, arg1, arg2) { - getObject(arg0)[arg1 >>> 0] = takeObject(arg2); - }, - __wbg_stack_0ed75d68575b0f3c: function(arg0, arg1) { - const ret = getObject(arg1).stack; - const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_export, wasm.__wbindgen_export2); - const len1 = WASM_VECTOR_LEN; - getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); - getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); - }, - __wbg_value_0546255b415e96c1: function(arg0) { - const ret = getObject(arg0).value; - return addHeapObject(ret); - }, - __wbindgen_cast_0000000000000001: function(arg0) { - // Cast intrinsic for `F64 -> Externref`. - const ret = arg0; - return addHeapObject(ret); - }, - __wbindgen_cast_0000000000000002: function(arg0) { - // Cast intrinsic for `I64 -> Externref`. - const ret = arg0; - return addHeapObject(ret); - }, - __wbindgen_cast_0000000000000003: function(arg0, arg1) { - // Cast intrinsic for `Ref(String) -> Externref`. - const ret = getStringFromWasm0(arg0, arg1); - return addHeapObject(ret); - }, - __wbindgen_cast_0000000000000004: function(arg0) { - // Cast intrinsic for `U64 -> Externref`. - const ret = BigInt.asUintN(64, arg0); - return addHeapObject(ret); - }, - __wbindgen_object_clone_ref: function(arg0) { - const ret = getObject(arg0); - return addHeapObject(ret); - }, - __wbindgen_object_drop_ref: function(arg0) { - takeObject(arg0); - }, - }; - return { - __proto__: null, - "./calab_solver_bg.js": import0, - }; -} - -const SolverFinalization = (typeof FinalizationRegistry === 'undefined') - ? { register: () => {}, unregister: () => {} } - : new FinalizationRegistry(ptr => wasm.__wbg_solver_free(ptr >>> 0, 1)); - -function addHeapObject(obj) { - if (heap_next === heap.length) heap.push(heap.length + 1); - const idx = heap_next; - heap_next = heap[idx]; - - heap[idx] = obj; - return idx; -} - -function debugString(val) { - // primitive types - const type = typeof val; - if (type == 'number' || type == 'boolean' || val == null) { - return `${val}`; - } - if (type == 'string') { - return `"${val}"`; - } - if (type == 'symbol') { - const description = val.description; - if (description == null) { - return 'Symbol'; - } else { - return `Symbol(${description})`; - } - } - if (type == 'function') { - const name = val.name; - if (typeof name == 'string' && name.length > 0) { - return `Function(${name})`; - } else { - return 'Function'; - } - } - // objects - if (Array.isArray(val)) { - const length = val.length; - let debug = '['; - if (length > 0) { - debug += debugString(val[0]); - } - for(let i = 1; i < length; i++) { - debug += ', ' + debugString(val[i]); - } - debug += ']'; - return debug; - } - // Test for built-in - const builtInMatches = /\[object ([^\]]+)\]/.exec(toString.call(val)); - let className; - if (builtInMatches && builtInMatches.length > 1) { - className = builtInMatches[1]; - } else { - // Failed to match the standard '[object ClassName]' - return toString.call(val); - } - if (className == 'Object') { - // we're a user defined class or Object - // JSON.stringify avoids problems with cycles, and is generally much - // easier than looping through ownProperties of `val`. - try { - return 'Object(' + JSON.stringify(val) + ')'; - } catch (_) { - return 'Object'; - } - } - // errors - if (val instanceof Error) { - return `${val.name}: ${val.message}\n${val.stack}`; - } - // TODO we could test for more things here, like `Set`s and `Map`s. - return className; -} - -function dropObject(idx) { - if (idx < 132) return; - heap[idx] = heap_next; - heap_next = idx; -} - -function getArrayF32FromWasm0(ptr, len) { - ptr = ptr >>> 0; - return getFloat32ArrayMemory0().subarray(ptr / 4, ptr / 4 + len); -} - -function getArrayU8FromWasm0(ptr, len) { - ptr = ptr >>> 0; - return getUint8ArrayMemory0().subarray(ptr / 1, ptr / 1 + len); -} - -let cachedDataViewMemory0 = null; -function getDataViewMemory0() { - if (cachedDataViewMemory0 === null || cachedDataViewMemory0.buffer.detached === true || (cachedDataViewMemory0.buffer.detached === undefined && cachedDataViewMemory0.buffer !== wasm.memory.buffer)) { - cachedDataViewMemory0 = new DataView(wasm.memory.buffer); - } - return cachedDataViewMemory0; -} - -let cachedFloat32ArrayMemory0 = null; -function getFloat32ArrayMemory0() { - if (cachedFloat32ArrayMemory0 === null || cachedFloat32ArrayMemory0.byteLength === 0) { - cachedFloat32ArrayMemory0 = new Float32Array(wasm.memory.buffer); - } - return cachedFloat32ArrayMemory0; -} - -let cachedFloat64ArrayMemory0 = null; -function getFloat64ArrayMemory0() { - if (cachedFloat64ArrayMemory0 === null || cachedFloat64ArrayMemory0.byteLength === 0) { - cachedFloat64ArrayMemory0 = new Float64Array(wasm.memory.buffer); - } - return cachedFloat64ArrayMemory0; -} - -function getStringFromWasm0(ptr, len) { - ptr = ptr >>> 0; - return decodeText(ptr, len); -} - -let cachedUint32ArrayMemory0 = null; -function getUint32ArrayMemory0() { - if (cachedUint32ArrayMemory0 === null || cachedUint32ArrayMemory0.byteLength === 0) { - cachedUint32ArrayMemory0 = new Uint32Array(wasm.memory.buffer); - } - return cachedUint32ArrayMemory0; -} - -let cachedUint8ArrayMemory0 = null; -function getUint8ArrayMemory0() { - if (cachedUint8ArrayMemory0 === null || cachedUint8ArrayMemory0.byteLength === 0) { - cachedUint8ArrayMemory0 = new Uint8Array(wasm.memory.buffer); - } - return cachedUint8ArrayMemory0; -} - -function getObject(idx) { return heap[idx]; } - -function handleError(f, args) { - try { - return f.apply(this, args); - } catch (e) { - wasm.__wbindgen_export3(addHeapObject(e)); - } -} - -let heap = new Array(128).fill(undefined); -heap.push(undefined, null, true, false); - -let heap_next = heap.length; - -function isLikeNone(x) { - return x === undefined || x === null; -} - -function passArray32ToWasm0(arg, malloc) { - const ptr = malloc(arg.length * 4, 4) >>> 0; - getUint32ArrayMemory0().set(arg, ptr / 4); - WASM_VECTOR_LEN = arg.length; - return ptr; -} - -function passArray8ToWasm0(arg, malloc) { - const ptr = malloc(arg.length * 1, 1) >>> 0; - getUint8ArrayMemory0().set(arg, ptr / 1); - WASM_VECTOR_LEN = arg.length; - return ptr; -} - -function passArrayF32ToWasm0(arg, malloc) { - const ptr = malloc(arg.length * 4, 4) >>> 0; - getFloat32ArrayMemory0().set(arg, ptr / 4); - WASM_VECTOR_LEN = arg.length; - return ptr; -} - -function passArrayF64ToWasm0(arg, malloc) { - const ptr = malloc(arg.length * 8, 8) >>> 0; - getFloat64ArrayMemory0().set(arg, ptr / 8); - WASM_VECTOR_LEN = arg.length; - return ptr; -} - -function passStringToWasm0(arg, malloc, realloc) { - if (realloc === undefined) { - const buf = cachedTextEncoder.encode(arg); - const ptr = malloc(buf.length, 1) >>> 0; - getUint8ArrayMemory0().subarray(ptr, ptr + buf.length).set(buf); - WASM_VECTOR_LEN = buf.length; - return ptr; - } - - let len = arg.length; - let ptr = malloc(len, 1) >>> 0; - - const mem = getUint8ArrayMemory0(); - - let offset = 0; - - for (; offset < len; offset++) { - const code = arg.charCodeAt(offset); - if (code > 0x7F) break; - mem[ptr + offset] = code; - } - if (offset !== len) { - if (offset !== 0) { - arg = arg.slice(offset); - } - ptr = realloc(ptr, len, len = offset + arg.length * 3, 1) >>> 0; - const view = getUint8ArrayMemory0().subarray(ptr + offset, ptr + len); - const ret = cachedTextEncoder.encodeInto(arg, view); - - offset += ret.written; - ptr = realloc(ptr, len, offset, 1) >>> 0; - } - - WASM_VECTOR_LEN = offset; - return ptr; -} - -function takeObject(idx) { - const ret = getObject(idx); - dropObject(idx); - return ret; -} - -let cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }); -cachedTextDecoder.decode(); -const MAX_SAFARI_DECODE_BYTES = 2146435072; -let numBytesDecoded = 0; -function decodeText(ptr, len) { - numBytesDecoded += len; - if (numBytesDecoded >= MAX_SAFARI_DECODE_BYTES) { - cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }); - cachedTextDecoder.decode(); - numBytesDecoded = len; - } - return cachedTextDecoder.decode(getUint8ArrayMemory0().subarray(ptr, ptr + len)); -} - -const cachedTextEncoder = new TextEncoder(); - -if (!('encodeInto' in cachedTextEncoder)) { - cachedTextEncoder.encodeInto = function (arg, view) { - const buf = cachedTextEncoder.encode(arg); - view.set(buf); - return { - read: arg.length, - written: buf.length - }; - }; -} - -let WASM_VECTOR_LEN = 0; - -let wasmModule, wasm; -function __wbg_finalize_init(instance, module) { - wasm = instance.exports; - wasmModule = module; - cachedDataViewMemory0 = null; - cachedFloat32ArrayMemory0 = null; - cachedFloat64ArrayMemory0 = null; - cachedUint32ArrayMemory0 = null; - cachedUint8ArrayMemory0 = null; - return wasm; -} - -async function __wbg_load(module, imports) { - if (typeof Response === 'function' && module instanceof Response) { - if (typeof WebAssembly.instantiateStreaming === 'function') { - try { - return await WebAssembly.instantiateStreaming(module, imports); - } catch (e) { - const validResponse = module.ok && expectedResponseType(module.type); - - if (validResponse && module.headers.get('Content-Type') !== 'application/wasm') { - console.warn("`WebAssembly.instantiateStreaming` failed because your server does not serve Wasm with `application/wasm` MIME type. Falling back to `WebAssembly.instantiate` which is slower. Original error:\n", e); - - } else { throw e; } - } - } - - const bytes = await module.arrayBuffer(); - return await WebAssembly.instantiate(bytes, imports); - } else { - const instance = await WebAssembly.instantiate(module, imports); - - if (instance instanceof WebAssembly.Instance) { - return { instance, module }; - } else { - return instance; - } - } - - function expectedResponseType(type) { - switch (type) { - case 'basic': case 'cors': case 'default': return true; - } - return false; - } -} - -function initSync(module) { - if (wasm !== undefined) return wasm; - - - if (module !== undefined) { - if (Object.getPrototypeOf(module) === Object.prototype) { - ({module} = module) - } else { - console.warn('using deprecated parameters for `initSync()`; pass a single object instead') - } - } - - const imports = __wbg_get_imports(); - if (!(module instanceof WebAssembly.Module)) { - module = new WebAssembly.Module(module); - } - const instance = new WebAssembly.Instance(module, imports); - return __wbg_finalize_init(instance, module); -} - -async function __wbg_init(module_or_path) { - if (wasm !== undefined) return wasm; - - - if (module_or_path !== undefined) { - if (Object.getPrototypeOf(module_or_path) === Object.prototype) { - ({module_or_path} = module_or_path) - } else { - console.warn('using deprecated parameters for the initialization function; pass a single object instead') - } - } - - if (module_or_path === undefined) { - module_or_path = new URL('calab_solver_bg.wasm', import.meta.url); - } - const imports = __wbg_get_imports(); - - if (typeof module_or_path === 'string' || (typeof Request === 'function' && module_or_path instanceof Request) || (typeof URL === 'function' && module_or_path instanceof URL)) { - module_or_path = fetch(module_or_path); - } - - const { instance, module } = await __wbg_load(await module_or_path, imports); - - return __wbg_finalize_init(instance, module); -} - -export { initSync, __wbg_init as default }; diff --git a/crates/solver/pkg/calab_solver_bg.wasm b/crates/solver/pkg/calab_solver_bg.wasm deleted file mode 100644 index fe782d72..00000000 Binary files a/crates/solver/pkg/calab_solver_bg.wasm and /dev/null differ diff --git a/crates/solver/pkg/calab_solver_bg.wasm.d.ts b/crates/solver/pkg/calab_solver_bg.wasm.d.ts deleted file mode 100644 index 8033ef6a..00000000 --- a/crates/solver/pkg/calab_solver_bg.wasm.d.ts +++ /dev/null @@ -1,42 +0,0 @@ -/* tslint:disable */ -/* eslint-disable */ -export const memory: WebAssembly.Memory; -export const __wbg_solver_free: (a: number, b: number) => void; -export const get_simulation_presets: () => number; -export const indeca_compute_upsample_factor: (a: number, b: number) => number; -export const indeca_estimate_kernel: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number, p: number, q: number) => void; -export const indeca_fit_biexponential: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number) => number; -export const indeca_solve_trace: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number) => number; -export const seed_trace: (a: number, b: number, c: number) => number; -export const simulate_traces: (a: number) => number; -export const solver_apply_filter: (a: number) => number; -export const solver_converged: (a: number) => number; -export const solver_export_state: (a: number, b: number) => void; -export const solver_filter_enabled: (a: number) => number; -export const solver_get_baseline: (a: number) => number; -export const solver_get_filter_cutoffs: (a: number, b: number) => void; -export const solver_get_kernel: (a: number, b: number) => void; -export const solver_get_power_spectrum: (a: number, b: number) => void; -export const solver_get_reconvolution: (a: number, b: number) => void; -export const solver_get_reconvolution_with_baseline: (a: number, b: number) => void; -export const solver_get_solution: (a: number, b: number) => void; -export const solver_get_spectrum_frequencies: (a: number, b: number) => void; -export const solver_get_trace: (a: number, b: number) => void; -export const solver_iteration_count: (a: number) => number; -export const solver_load_state: (a: number, b: number, c: number) => void; -export const solver_new: () => number; -export const solver_reset_momentum: (a: number) => void; -export const solver_set_constraint: (a: number, b: number) => void; -export const solver_set_conv_mode: (a: number, b: number) => void; -export const solver_set_filter_enabled: (a: number, b: number) => void; -export const solver_set_hp_filter_enabled: (a: number, b: number) => void; -export const solver_set_lp_filter_enabled: (a: number, b: number) => void; -export const solver_set_params: (a: number, b: number, c: number, d: number, e: number) => void; -export const solver_set_trace: (a: number, b: number, c: number) => void; -export const solver_step_batch: (a: number, b: number) => number; -export const solver_subtract_baseline: (a: number) => void; -export const __wbindgen_export: (a: number, b: number) => number; -export const __wbindgen_export2: (a: number, b: number, c: number, d: number) => number; -export const __wbindgen_export3: (a: number) => void; -export const __wbindgen_export4: (a: number, b: number, c: number) => void; -export const __wbindgen_add_to_stack_pointer: (a: number) => number; diff --git a/crates/solver/pkg/package.json b/crates/solver/pkg/package.json deleted file mode 100644 index b6d5fb85..00000000 --- a/crates/solver/pkg/package.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "name": "calab-solver", - "type": "module", - "version": "0.1.0", - "files": [ - "calab_solver_bg.wasm", - "calab_solver.js", - "calab_solver.d.ts" - ], - "main": "calab_solver.js", - "types": "calab_solver.d.ts", - "sideEffects": [ - "./snippets/*" - ] -} \ No newline at end of file diff --git a/crates/solver/src/js_indeca.rs b/crates/solver/src/js_indeca.rs index 9f25bdd3..ff8fefe9 100644 --- a/crates/solver/src/js_indeca.rs +++ b/crates/solver/src/js_indeca.rs @@ -58,6 +58,11 @@ pub fn indeca_solve_trace( /// for cold-start. /// /// Returns the estimated kernel as Float32Array (via Vec). +/// +/// Throws a JS error (rather than aborting the WASM module) if the input array +/// lengths are inconsistent — `alphas`/`baselines` must have one entry per +/// trace, and `traces_flat`/`spikes_flat` must each be the sum of +/// `trace_lengths`. #[wasm_bindgen] pub fn indeca_estimate_kernel( traces_flat: &[f32], @@ -70,14 +75,25 @@ pub fn indeca_estimate_kernel( tol: f64, warm_kernel: &[f32], smooth_lambda: f64, -) -> Vec { +) -> Result, JsError> { let lengths: Vec = trace_lengths.iter().map(|&v| v as usize).collect(); + let total_len: usize = lengths.iter().sum(); + if alphas.len() != lengths.len() || baselines.len() != lengths.len() { + return Err(JsError::new( + "indeca_estimate_kernel: alphas and baselines must have one entry per trace", + )); + } + if traces_flat.len() != total_len || spikes_flat.len() != total_len { + return Err(JsError::new( + "indeca_estimate_kernel: traces_flat and spikes_flat length must equal sum(trace_lengths)", + )); + } let warm = if warm_kernel.is_empty() { None } else { Some(warm_kernel) }; - kernel_est::estimate_free_kernel( + Ok(kernel_est::estimate_free_kernel( traces_flat, spikes_flat, alphas, @@ -88,7 +104,7 @@ pub fn indeca_estimate_kernel( tol, warm, smooth_lambda, - ) + )) } /// Fit a bi-exponential model to a free-form kernel. diff --git a/crates/solver/src/kernel_est.rs b/crates/solver/src/kernel_est.rs index 6aa3a98e..00e796e3 100644 --- a/crates/solver/src/kernel_est.rs +++ b/crates/solver/src/kernel_est.rs @@ -122,14 +122,25 @@ pub fn estimate_free_kernel( smooth_lambda: f64, ) -> Vec { let n_traces = trace_lengths.len(); - assert_eq!(alphas.len(), n_traces); - assert_eq!(baselines.len(), n_traces); - let total_len: usize = trace_lengths.iter().sum(); - assert_eq!(traces.len(), total_len); - assert_eq!(spike_trains.len(), total_len); - if kernel_length == 0 || total_len == 0 { + // Length invariants. These are guaranteed by the FFI wrappers, which + // validate and return a typed error (see js_indeca / py_api). The + // debug_assert keeps the contract loud for internal callers and tests; + // the release-mode guard degrades to an empty kernel rather than panicking + // across the WASM/PyO3 boundary (a panic there aborts the module). + debug_assert_eq!(alphas.len(), n_traces); + debug_assert_eq!(baselines.len(), n_traces); + debug_assert_eq!(traces.len(), total_len); + debug_assert_eq!(spike_trains.len(), total_len); + + if alphas.len() != n_traces + || baselines.len() != n_traces + || traces.len() != total_len + || spike_trains.len() != total_len + || kernel_length == 0 + || total_len == 0 + { return vec![0.0; kernel_length]; } diff --git a/crates/solver/src/peak_seed.rs b/crates/solver/src/peak_seed.rs index 583911da..9ef39543 100644 --- a/crates/solver/src/peak_seed.rs +++ b/crates/solver/src/peak_seed.rs @@ -201,7 +201,6 @@ pub fn seed_kernel_estimate( fs: f64, ) -> SeedKernelResult { let total_len: usize = trace_lengths.iter().sum(); - assert_eq!(traces_flat.len(), total_len); let min_peak_distance_s = 5.0; @@ -209,6 +208,24 @@ pub fn seed_kernel_estimate( let kernel_length = (1.5 * fs).ceil() as usize; let kernel_length = kernel_length.clamp(10, 200); + // Length invariant guaranteed by the FFI wrapper (see py_api). Loud in + // debug/tests; in release degrade to an empty result rather than panicking + // across the PyO3 boundary. + debug_assert_eq!(traces_flat.len(), total_len); + if traces_flat.len() != total_len { + return SeedKernelResult { + free_kernel: vec![0.0; kernel_length], + tau_rise: 0.02, + tau_decay: 0.4, + beta: 0.0, + residual: f64::INFINITY, + tau_rise_fast: 0.0, + tau_decay_fast: 0.0, + beta_fast: 0.0, + n_seed_spikes: 0, + }; + } + let mut spike_trains = vec![0.0_f32; total_len]; let mut alphas = Vec::with_capacity(trace_lengths.len()); let mut baselines = Vec::with_capacity(trace_lengths.len()); diff --git a/crates/solver/src/py_api.rs b/crates/solver/src/py_api.rs index 57236348..1f77ae51 100644 --- a/crates/solver/src/py_api.rs +++ b/crates/solver/src/py_api.rs @@ -486,6 +486,20 @@ fn py_indeca_estimate_kernel<'py>( let warm = optional_to_f32_vec(warm_kernel)?; + // Validate array-length consistency before handing off, so a caller mistake + // surfaces as a clear ValueError instead of a Rust panic across the FFI. + let total_len: usize = lengths.iter().sum(); + if alphas_slice.len() != lengths.len() || baselines_slice.len() != lengths.len() { + return Err(pyo3::exceptions::PyValueError::new_err( + "alphas and baselines must have one entry per trace (len == trace_lengths.len())", + )); + } + if traces_f32.len() != total_len || spikes_f32.len() != total_len { + return Err(pyo3::exceptions::PyValueError::new_err( + "traces_flat and spikes_flat length must equal sum(trace_lengths)", + )); + } + let result = kernel_est::estimate_free_kernel( &traces_f32, &spikes_f32, diff --git a/crates/solver/src/simulate.rs b/crates/solver/src/simulate.rs index d91abb7b..79f496e9 100644 --- a/crates/solver/src/simulate.rs +++ b/crates/solver/src/simulate.rs @@ -436,7 +436,11 @@ pub fn simulate(config: &SimulationConfig) -> SimulationResult { let mut traces = Vec::with_capacity(n_cells * n_tp); let mut ground_truth = Vec::with_capacity(n_cells); - let bins_per_frame = (config.spike_sim_hz / config.fs_hz).round() as usize; + // Clamp to >= 1: `spike_sim_hz` and `fs_hz` come from a user-supplied + // config, and a spike rate below half the frame rate would otherwise round + // to 0, producing a zero-length high-resolution buffer and a degenerate + // (empty) simulation rather than a usable result. + let bins_per_frame = ((config.spike_sim_hz / config.fs_hz).round() as usize).max(1); let num_high_res = n_tp * bins_per_frame; let shared_kernel = if !has_kernel_variation { @@ -965,6 +969,27 @@ mod tests { } } + #[test] + fn low_spike_sim_hz_does_not_degenerate() { + // spike_sim_hz below fs_hz/2 rounds bins_per_frame to 0; the clamp + // keeps the high-resolution buffer non-empty so the simulation still + // produces full-length, finite traces instead of a degenerate result. + let cfg = SimulationConfig { + fs_hz: 30.0, + spike_sim_hz: 10.0, + num_timepoints: 300, + num_cells: 2, + alpha_cv: 0.0, + ..Default::default() + }; + let r = simulate(&cfg); + assert_eq!(r.traces.len(), 2 * 300); + assert!(r.traces.iter().all(|v| v.is_finite())); + for gt in &r.ground_truth { + assert_eq!(gt.spikes.len(), 300); + } + } + #[test] fn spikes_non_negative() { for gt in &simulate(&small_config()).ground_truth { diff --git a/package.json b/package.json index ca94ffb5..fb3903bd 100644 --- a/package.json +++ b/package.json @@ -8,16 +8,22 @@ "packages/*" ], "scripts": { + "ensure-wasm": "node scripts/ensure-wasm.mjs", + "predev": "npm run ensure-wasm", "dev": "npm run dev -w apps/catune", + "predev:carank": "npm run ensure-wasm", "dev:carank": "npm run dev -w apps/carank", "build": "npm run build:wasm && npm run build:apps", + "prebuild:apps": "npm run ensure-wasm", "build:apps": "node scripts/build-apps.mjs", "build:pages": "CALAB_PAGES=1 npm run build && node scripts/combine-dist.mjs", "build:wasm": "cd crates/solver && wasm-pack build --target web --release", + "pretest": "npm run ensure-wasm", "test": "npm run test --workspaces --if-present", "test:watch": "npm run test:watch -w apps/catune", "lint": "eslint apps/ packages/ scripts/", "lint:fix": "eslint --fix apps/ packages/ scripts/", + "pretypecheck": "npm run ensure-wasm", "typecheck": "tsc -b apps/catune apps/carank apps/admin apps/cadecon apps/_template", "format": "prettier --write .", "format:check": "prettier --check ." diff --git a/packages/community/src/submission-service.ts b/packages/community/src/submission-service.ts index a87a8960..84bf9771 100644 --- a/packages/community/src/submission-service.ts +++ b/packages/community/src/submission-service.ts @@ -22,9 +22,15 @@ async function requireClient(): Promise { /** * Create a typed CRUD service for a Supabase submission table. * Handles auth user injection, base filter application, and RLS-guarded delete. + * + * `readSource` is the relation used for `fetch` (community browsing). It + * defaults to the base table but should be a PII-free public view: base-table + * SELECT is restricted to owner+admin (migration 010), so anonymous browsing + * must go through the view. Writes and deletes always target `tableName`. */ export function createSubmissionService( tableName: string, + readSource: string = tableName, ): SubmissionService { return { async submit(payload) { @@ -48,7 +54,7 @@ export function createSubmissionService( async fetch(filters?) { const client = await requireClient(); - let query = client.from(tableName).select('*'); + let query = client.from(readSource).select('*'); if (filters?.indicator) { query = query.eq('indicator', filters.indicator); diff --git a/packages/compute/src/__tests__/worker-pool.test.ts b/packages/compute/src/__tests__/worker-pool.test.ts new file mode 100644 index 00000000..5ce35336 --- /dev/null +++ b/packages/compute/src/__tests__/worker-pool.test.ts @@ -0,0 +1,199 @@ +import { describe, it, expect, beforeEach } from 'vitest'; +import { createWorkerPool, type BaseJob, type MessageRouter } from '@calab/compute'; + +// ── Test doubles ──────────────────────────────────────────────────────────── + +type TestMsg = { type: 'ready' } | { type: 'result'; jobId: number }; + +/** Minimal stand-in for the DOM Worker the pool drives. */ +class FakeWorker { + static instances: FakeWorker[] = []; + onmessage: ((e: { data: TestMsg }) => void) | null = null; + posted: unknown[] = []; + terminated = false; + + constructor() { + FakeWorker.instances.push(this); + } + + postMessage(payload: unknown): void { + this.posted.push(payload); + } + + terminate(): void { + this.terminated = true; + } + + /** Simulate the worker emitting a message back to the pool. */ + emit(msg: TestMsg): void { + this.onmessage?.({ data: msg }); + } +} + +class TestJob implements BaseJob { + cancelled = false; + errored: string | null = null; + done = false; + + constructor( + public jobId: number, + private priority?: number, + ) {} + + onCancelled(): void { + this.cancelled = true; + } + + onError(msg: string): void { + this.errored = msg; + } + + getPriority(): number { + return this.priority ?? 1; + } +} + +const router: MessageRouter = { + isReady: (msg) => msg.type === 'ready', + getJobId: (msg) => (msg.type === 'result' ? msg.jobId : undefined), + routeMessage: (job, _msg, finish) => { + job.done = true; + finish(); + }, + buildDispatch: (job) => [{ jobId: job.jobId }, []], +}; + +function makePool(poolSize: number) { + FakeWorker.instances = []; + const pool = createWorkerPool( + () => new FakeWorker() as unknown as Worker, + router, + poolSize, + ); + return { pool, workers: FakeWorker.instances }; +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +describe('createWorkerPool', () => { + beforeEach(() => { + FakeWorker.instances = []; + }); + + it('creates the requested number of workers', () => { + const { pool, workers } = makePool(3); + expect(pool.size).toBe(3); + expect(workers).toHaveLength(3); + }); + + it('queues jobs until a worker reports ready, then dispatches', () => { + const { pool, workers } = makePool(2); + const job = new TestJob(1); + + pool.dispatch(job); + // Workers start in `init`; nothing dispatched yet. + expect(workers.every((w) => w.posted.length === 0)).toBe(true); + + workers[0].emit({ type: 'ready' }); + expect(workers[0].posted).toEqual([{ jobId: 1 }]); + }); + + it('queues a second job while busy and drains it when the first finishes', () => { + const { pool, workers } = makePool(1); + const j1 = new TestJob(1); + const j2 = new TestJob(2); + + workers[0].emit({ type: 'ready' }); + pool.dispatch(j1); + pool.dispatch(j2); + expect(workers[0].posted).toEqual([{ jobId: 1 }]); + + workers[0].emit({ type: 'result', jobId: 1 }); + expect(j1.done).toBe(true); + expect(workers[0].posted).toEqual([{ jobId: 1 }, { jobId: 2 }]); + }); + + it('dispatches queued jobs in priority order (lower first)', () => { + const { pool, workers } = makePool(1); + const busy = new TestJob(1); + const low = new TestJob(2, 10); + const high = new TestJob(3, 1); + + workers[0].emit({ type: 'ready' }); + pool.dispatch(busy); // occupies the only worker + pool.dispatch(low); // queued + pool.dispatch(high); // queued + + workers[0].emit({ type: 'result', jobId: 1 }); // frees worker → drains by priority + expect(workers[0].posted).toEqual([{ jobId: 1 }, { jobId: 3 }]); + + workers[0].emit({ type: 'result', jobId: 3 }); + expect(workers[0].posted).toEqual([{ jobId: 1 }, { jobId: 3 }, { jobId: 2 }]); + }); + + it('cancel() removes a queued job and notifies it, without dispatching it', () => { + const { pool, workers } = makePool(1); + const busy = new TestJob(1); + const queued = new TestJob(2); + + workers[0].emit({ type: 'ready' }); + pool.dispatch(busy); + pool.dispatch(queued); + + pool.cancel(queued.jobId); + expect(queued.cancelled).toBe(true); + + // Finishing the busy job must not dispatch the cancelled one. + workers[0].emit({ type: 'result', jobId: 1 }); + expect(workers[0].posted).toEqual([{ jobId: 1 }]); + }); + + it('cancel() signals a cancel message to the worker for an in-flight job', () => { + const { pool, workers } = makePool(1); + const job = new TestJob(1); + + workers[0].emit({ type: 'ready' }); + pool.dispatch(job); + pool.cancel(job.jobId); + + expect(workers[0].posted).toContainEqual({ type: 'cancel' }); + expect(job.cancelled).toBe(false); // in-flight cancel is acknowledged by the worker, not here + }); + + it('cancelAll() cancels queued jobs and signals busy workers', () => { + const { pool, workers } = makePool(1); + const busy = new TestJob(1); + const queued = new TestJob(2); + + workers[0].emit({ type: 'ready' }); + pool.dispatch(busy); + pool.dispatch(queued); + + pool.cancelAll(); + expect(queued.cancelled).toBe(true); + expect(workers[0].posted).toContainEqual({ type: 'cancel' }); + }); + + it('ignores result messages for unknown / already-finished jobs', () => { + const { pool, workers } = makePool(1); + workers[0].emit({ type: 'ready' }); + + // No job in flight with id 999 — must not throw and must leave worker idle. + expect(() => workers[0].emit({ type: 'result', jobId: 999 })).not.toThrow(); + + const job = new TestJob(1); + pool.dispatch(job); + expect(workers[0].posted).toEqual([{ jobId: 1 }]); + }); + + it('dispose() terminates workers and blocks further dispatch', () => { + const { pool, workers } = makePool(2); + workers.forEach((w) => w.emit({ type: 'ready' })); + + pool.dispose(); + expect(workers.every((w) => w.terminated)).toBe(true); + + pool.dispatch(new TestJob(1)); + expect(workers.every((w) => w.posted.length === 0)).toBe(true); + }); +}); diff --git a/python/src/calab/_bridge/_apps.py b/python/src/calab/_bridge/_apps.py index 0bf60e54..11672837 100644 --- a/python/src/calab/_bridge/_apps.py +++ b/python/src/calab/_bridge/_apps.py @@ -20,6 +20,11 @@ HEARTBEAT_TIMEOUT = 10 # seconds without heartbeat = browser disconnected +# Kernel waveforms are truncated to this many decay time-constants +# (kernel_length = KERNEL_LENGTH_DECAY_MULTIPLES * tau_decay * fs). Five decay +# constants capture >99% of a bi-exponential's mass. +KERNEL_LENGTH_DECAY_MULTIPLES = 5.0 + # Default app URLs (GitHub Pages deployment) _DEFAULT_CATUNE_URL = "https://miniscope.github.io/CaLab/CaTune/" _DEFAULT_CADECON_URL = "https://miniscope.github.io/CaLab/CaDecon/" @@ -300,14 +305,14 @@ def decon( tau_rise = results.get("tau_rise", 0.2) tau_decay = results.get("tau_decay", 1.0) beta = results.get("beta", 1.0) - kernel_length = int(5.0 * tau_decay * result_fs) + kernel_length = int(KERNEL_LENGTH_DECAY_MULTIPLES * tau_decay * result_fs) kernel_slow = _build_biexp_waveform(tau_rise, tau_decay, beta, result_fs, kernel_length) tau_rise_fast = results.get("tau_rise_fast", 0.0) tau_decay_fast = results.get("tau_decay_fast", 0.0) beta_fast = results.get("beta_fast", 0.0) if tau_decay_fast > 0 and beta_fast != 0: - kernel_length_fast = int(5.0 * tau_decay_fast * result_fs) + kernel_length_fast = int(KERNEL_LENGTH_DECAY_MULTIPLES * tau_decay_fast * result_fs) kernel_fast = _build_biexp_waveform( tau_rise_fast, tau_decay_fast, beta_fast, result_fs, kernel_length_fast, ) diff --git a/scripts/ensure-wasm.mjs b/scripts/ensure-wasm.mjs new file mode 100644 index 00000000..740e2622 --- /dev/null +++ b/scripts/ensure-wasm.mjs @@ -0,0 +1,67 @@ +#!/usr/bin/env node +/** + * Ensure the wasm-pack output (crates/solver/pkg/) exists and is not stale. + * + * The pkg/ directory is build-only and gitignored — it is NOT committed (the + * binary previously was, and went silently stale because gitignored rebuilds + * never show up in `git status`). This guard runs as a pre-hook for the JS + * entry points (dev/typecheck/test/build:apps) so consumers always see a fresh + * binding surface and binary, without paying for a rebuild when nothing changed. + * + * Rebuilds only when: + * - pkg/calab_solver_bg.wasm is missing, OR + * - any tracked solver source (src/**, Cargo.toml) is newer than the binary. + */ +import { execFileSync } from 'node:child_process'; +import { existsSync, statSync, readdirSync } from 'node:fs'; +import { fileURLToPath } from 'node:url'; +import { dirname, join } from 'node:path'; + +const repoRoot = join(dirname(fileURLToPath(import.meta.url)), '..'); +const solverDir = join(repoRoot, 'crates', 'solver'); +const wasmFile = join(solverDir, 'pkg', 'calab_solver_bg.wasm'); + +/** Latest mtime (ms) across a directory tree, recursively. */ +function newestMtime(dir) { + let newest = 0; + for (const entry of readdirSync(dir, { withFileTypes: true })) { + const full = join(dir, entry.name); + if (entry.isDirectory()) { + newest = Math.max(newest, newestMtime(full)); + } else { + newest = Math.max(newest, statSync(full).mtimeMs); + } + } + return newest; +} + +function needsRebuild() { + if (!existsSync(wasmFile)) return 'pkg/ missing'; + const wasmMtime = statSync(wasmFile).mtimeMs; + const srcMtime = Math.max( + newestMtime(join(solverDir, 'src')), + statSync(join(solverDir, 'Cargo.toml')).mtimeMs, + ); + return srcMtime > wasmMtime ? 'solver source changed' : null; +} + +const reason = needsRebuild(); +if (!reason) { + console.log('[ensure-wasm] pkg/ is up to date — skipping rebuild.'); + process.exit(0); +} + +console.log(`[ensure-wasm] Rebuilding WASM (${reason})...`); +try { + // Invoke wasm-pack directly (no shell) — mirrors the `build:wasm` npm script. + execFileSync('wasm-pack', ['build', '--target', 'web', '--release'], { + cwd: solverDir, + stdio: 'inherit', + }); +} catch { + console.error( + '[ensure-wasm] WASM build failed. Install the Rust toolchain + wasm-pack ' + + '(see rust-toolchain.toml), or run `npm run build:wasm` manually.', + ); + process.exit(1); +} diff --git a/supabase/README.md b/supabase/README.md index cb86f65c..a24d62d8 100644 --- a/supabase/README.md +++ b/supabase/README.md @@ -8,21 +8,28 @@ usage analytics, and admin moderation. Each CaLab app has its own submissions table (e.g., `catune_submissions`). All tables share a common set of base columns defined in `000_base_template.sql`. -| Migration | Purpose | -| ---------------------------- | ------------------------------------------------------------------------------- | -| `000_base_template.sql` | **Template only** (not executed). Copy and extend for new apps. | -| `001_catune_submissions.sql` | CaTune submissions table with deconvolution-specific columns. | -| `002_field_options.sql` | Shared canonical field options lookup table. | -| `003_analytics.sql` | Analytics tables (`analytics_sessions`, `analytics_events`) for usage tracking. | -| `004_admin_role.sql` | `is_admin()` helper function and admin moderation policies. | +| Migration | Purpose | +| --------------------------------- | -------------------------------------------------------------------------------- | +| `000_base_template.sql` | **Template only** (not executed). Copy and extend for new apps. | +| `001_catune_submissions.sql` | CaTune submissions table with deconvolution-specific columns. | +| `002_field_options.sql` | Shared canonical field options lookup table. | +| `003_analytics.sql` | Analytics tables (`analytics_sessions`, `analytics_events`) for usage tracking. | +| `004_admin_role.sql` | `is_admin()` helper function and admin moderation policies. | +| `005`–`009` | CaDecon table, bridge data source, analytics hardening, and tighter constraints. | +| `010_restrict_submission_pii.sql` | Locks submission reads to owner+admin; adds PII-free `*_public` browsing views. | ## Applying migrations 1. Open your Supabase project dashboard. 2. Navigate to **SQL Editor**. -3. Run each numbered file in order (001 through 004). +3. Run each numbered file in order (001 through the highest-numbered). Existing + projects only need to run migrations they have not applied yet. 4. Run `supabase/seed/field_options_seed.sql` to populate the indicator, species, and brain region lookup values. +> **Note:** migration `010` must be applied **before** deploying the app build +> that reads the `*_public` views, or community browsing will fail until the +> views exist. + ## Edge Functions CaLab uses one Edge Function for server-side GeoIP resolution during analytics session creation. diff --git a/supabase/migrations/010_restrict_submission_pii.sql b/supabase/migrations/010_restrict_submission_pii.sql new file mode 100644 index 00000000..9465360f --- /dev/null +++ b/supabase/migrations/010_restrict_submission_pii.sql @@ -0,0 +1,77 @@ +-- Restrict read access to submission tables and expose PII-free public views. +-- +-- Problem (audit H1): migrations 001/006 granted submission SELECT to +-- `anon, authenticated USING (true)`, so any unauthenticated visitor with the +-- public anon key could read every submission's free-text PII — `orcid` (a +-- globally unique researcher ID), `lab_name`, and `notes` — and deanonymize +-- who submitted what. This contradicts the project's "don't leak data between +-- users" goal. +-- +-- Fix: base-table SELECT is now owner-or-admin only. Community browsing reads +-- a dedicated view that exposes every column EXCEPT those three PII fields. +-- No app reads orcid/lab_name/notes (they are write-only submission metadata), +-- so this is non-breaking for the CaTune/CaDecon community browsers and the +-- admin dashboard. +-- +-- The *_public views intentionally run with the view owner's privileges +-- (security_invoker = false, the default) so community browsing still sees +-- every contributor's submission — only the PII columns are dropped, not the +-- rows. Writes and deletes continue against the base tables under the existing +-- owner-scoped policies. + +BEGIN; + +-- ── catune_submissions ────────────────────────────────────────────────────── + +DROP POLICY IF EXISTS "Public read access" ON catune_submissions; + +CREATE POLICY "Owner and admin read access" +ON catune_submissions FOR SELECT +TO anon, authenticated +USING ((select auth.uid()) = user_id OR public.is_admin()); + +CREATE VIEW catune_submissions_public +WITH (security_invoker = false) AS +SELECT + id, created_at, user_id, + tau_rise, tau_decay, t_peak, fwhm, lambda, sampling_rate, + ar2_g1, ar2_g2, + indicator, species, brain_region, + filter_enabled, + virus_construct, time_since_injection_days, + num_cells, recording_length_s, fps, + dataset_hash, quality_score, app_version, data_source, + microscope_type, imaging_depth_um, cell_type, + extra_metadata +FROM catune_submissions; + +GRANT SELECT ON catune_submissions_public TO anon, authenticated; + +-- ── cadecon_submissions ───────────────────────────────────────────────────── + +DROP POLICY IF EXISTS "Public read" ON cadecon_submissions; + +CREATE POLICY "Owner and admin read access" +ON cadecon_submissions FOR SELECT +TO anon, authenticated +USING ((select auth.uid()) = user_id OR public.is_admin()); + +CREATE VIEW cadecon_submissions_public +WITH (security_invoker = false) AS +SELECT + id, created_at, user_id, + tau_rise, tau_decay, t_peak, fwhm, beta, upsample_factor, sampling_rate, + num_subsets, target_coverage, max_iterations, convergence_tol, + weighting_enabled, hp_filter_enabled, lp_filter_enabled, + median_alpha, median_pve, mean_event_rate, num_iterations, converged, + indicator, species, brain_region, + virus_construct, time_since_injection_days, + microscope_type, imaging_depth_um, cell_type, + num_cells, recording_length_s, fps, + dataset_hash, app_version, data_source, + extra_metadata +FROM cadecon_submissions; + +GRANT SELECT ON cadecon_submissions_public TO anon, authenticated; + +COMMIT; diff --git a/supabase/tests/rls/test.sql b/supabase/tests/rls/test.sql index 14a4f477..896c1f34 100644 --- a/supabase/tests/rls/test.sql +++ b/supabase/tests/rls/test.sql @@ -152,6 +152,10 @@ SET LOCAL "request.jwt.claims" = '{"sub":"22222222-2222-2222-2222-222222222222", -- so RLS filters the candidate rows to zero → query succeeds with 0 rows -- affected. Assert no rows were deleted. DELETE FROM catune_submissions WHERE dataset_hash = 'hash-alice'; +-- Verify under a privileged identity: as of migration 010, bob can no longer +-- SELECT alice's row (reads are owner-or-admin only), so the existence check +-- must run with RLS bypassed or it would read 0 for the wrong reason. +RESET ROLE; SELECT assert_row_count( $sql$SELECT COUNT(*)::int FROM catune_submissions WHERE dataset_hash = 'hash-alice'$sql$, 1, @@ -191,6 +195,8 @@ BEGIN; SET LOCAL ROLE authenticated; SET LOCAL "request.jwt.claims" = '{"sub":"22222222-2222-2222-2222-222222222222","role":"authenticated"}'; DELETE FROM cadecon_submissions WHERE dataset_hash = 'hash-alice'; +-- See note above: bob cannot read alice's row post-010, so verify privileged. +RESET ROLE; SELECT assert_row_count( $sql$SELECT COUNT(*)::int FROM cadecon_submissions WHERE dataset_hash = 'hash-alice'$sql$, 1, @@ -198,6 +204,78 @@ SELECT assert_row_count( ); ROLLBACK; +-- ── submission PII lockdown (migration 010) ──────────────────────────────── + +-- anon cannot read the base submission tables at all (no rows leak). +BEGIN; +SET LOCAL ROLE anon; +SELECT assert_row_count( + $sql$SELECT COUNT(*)::int FROM catune_submissions$sql$, + 0, + 'catune base table: anon SELECT returns 0 rows' +); +SELECT assert_row_count( + $sql$SELECT COUNT(*)::int FROM cadecon_submissions$sql$, + 0, + 'cadecon base table: anon SELECT returns 0 rows' +); +ROLLBACK; + +-- A non-owner authenticated user cannot read another user's base-table rows. +BEGIN; +SET LOCAL ROLE authenticated; +SET LOCAL "request.jwt.claims" = '{"sub":"22222222-2222-2222-2222-222222222222","role":"authenticated"}'; +SELECT assert_row_count( + $sql$SELECT COUNT(*)::int FROM catune_submissions WHERE dataset_hash = 'hash-alice'$sql$, + 0, + 'catune base table: non-owner SELECT cannot see foreign row' +); +ROLLBACK; + +-- The owner can still read their own base-table row (needed for insert-return). +BEGIN; +SET LOCAL ROLE authenticated; +SET LOCAL "request.jwt.claims" = '{"sub":"11111111-1111-1111-1111-111111111111","role":"authenticated"}'; +SELECT assert_row_count( + $sql$SELECT COUNT(*)::int FROM catune_submissions WHERE dataset_hash = 'hash-alice'$sql$, + 1, + 'catune base table: owner SELECT sees own row' +); +ROLLBACK; + +-- Community browsing still works: anon reads all rows through the public view. +BEGIN; +SET LOCAL ROLE anon; +SELECT assert_row_count( + $sql$SELECT COUNT(*)::int FROM catune_submissions_public WHERE dataset_hash = 'hash-alice'$sql$, + 1, + 'catune public view: anon can browse submissions' +); +SELECT assert_row_count( + $sql$SELECT COUNT(*)::int FROM cadecon_submissions_public WHERE dataset_hash = 'hash-alice'$sql$, + 1, + 'cadecon public view: anon can browse submissions' +); +ROLLBACK; + +-- The public view must NOT expose the PII columns. Selecting them errors with +-- undefined_column (42703), which assert_denied accepts. +BEGIN; +SET LOCAL ROLE anon; +SELECT assert_denied( + $sql$SELECT orcid FROM catune_submissions_public LIMIT 1$sql$, + 'catune public view omits orcid' +); +SELECT assert_denied( + $sql$SELECT lab_name FROM catune_submissions_public LIMIT 1$sql$, + 'catune public view omits lab_name' +); +SELECT assert_denied( + $sql$SELECT notes FROM cadecon_submissions_public LIMIT 1$sql$, + 'cadecon public view omits notes' +); +ROLLBACK; + -- ── analytics_sessions: anon INSERT denied ───────────────────────────────── BEGIN;