diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/__init__.py index 51bd5ff777..aaa053ef6f 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/__init__.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/__init__.py @@ -48,7 +48,7 @@ save_cluster_labels, save_feature_atlas, ) -from .architectures import MoESAE, ReLUSAE, SparseAutoencoder, TopKSAE +from .architectures import MoESAE, ReLUSAE, ShardedTopKSAE, SparseAutoencoder, TopKSAE from .autointerp import ( DEFAULT_PROMPT_TEMPLATE, TOKEN_PROMPT_TEMPLATE, @@ -80,8 +80,10 @@ evaluate_sae, evaluate_sparsity, ) +from .kernels import HAS_TRITON, TritonDecoderAutograd from .perf_logger import PerfLogger from .process_group_manager import ProcessGroupManager +from .streaming import StreamingActivationDataset, StreamingConfig, make_streaming_dataloader from .training import ParallelConfig, Trainer, TrainingConfig, WandbConfig from .utils import get_device, set_seed @@ -90,6 +92,7 @@ __all__ = [ "DEFAULT_PROMPT_TEMPLATE", + "HAS_TRITON", "TOKEN_PROMPT_TEMPLATE", "ActivationStore", "ActivationStoreConfig", @@ -117,14 +120,18 @@ "PerfLogger", "ProcessGroupManager", "ReLUSAE", + "ShardedTopKSAE", "SparseAutoencoder", "SparsityMetrics", + "StreamingActivationDataset", + "StreamingConfig", "TokenActivationCollector", "TokenExample", "TopExample", "TopKSAE", "Trainer", "TrainingConfig", + "TritonDecoderAutograd", "WandbConfig", "build_cluster_label_prompt", "compute_cluster_centroids", @@ -140,6 +147,7 @@ "get_device", "launch_dashboard", "load_activations", + "make_streaming_dataloader", "save_activations", "save_cluster_labels", "save_feature_atlas", diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/__init__.py index 978a4c297d..b3e9333843 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/__init__.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/__init__.py @@ -19,11 +19,13 @@ from .moe import MoESAE from .relu_l1 import ReLUSAE from .topk import TopKSAE +from .topk_tp import ShardedTopKSAE __all__ = [ "MoESAE", "ReLUSAE", + "ShardedTopKSAE", "SparseAutoencoder", "TopKSAE", ] diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py index 3be46cdf7f..34c68f0dd0 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk.py @@ -74,8 +74,16 @@ def __init__( dead_tokens_threshold: int = 10_000_000, init_encoder_from_decoder: bool = True, init_pre_bias: bool = True, + decoder_impl: str = "dense", ): - """Initialize the Top-K SAE with encoder, decoder, and optional auxiliary loss.""" + """Initialize the Top-K SAE with encoder, decoder, and optional auxiliary loss. + + ``decoder_impl`` selects the decode path: "dense" (default) builds the dense + [batch, hidden_dim] code tensor and runs a full decoder matmul; "triton" + decodes directly from the top-k (indices, values) via a sparse kernel + (O(batch*k*d), no dense code tensor), enabling much larger hidden_dim. Weights + are identical, so checkpoints are interchangeable between the two. + """ super().__init__(input_dim, hidden_dim) self.top_k = top_k self.init_pre_bias = init_pre_bias @@ -83,6 +91,9 @@ def __init__( self.auxk = auxk self.auxk_coef = auxk_coef self.dead_tokens_threshold = dead_tokens_threshold + if decoder_impl not in ("dense", "triton"): + raise ValueError(f"decoder_impl must be 'dense' or 'triton', got {decoder_impl!r}") + self.decoder_impl = decoder_impl # Pre-bias (subtracted from normalized input, added to output before denorm) self.pre_bias = nn.Parameter(torch.zeros(input_dim)) @@ -208,9 +219,40 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: top_k_vals, top_k_indices = torch.topk(codes_relu, self.top_k, dim=-1) codes = torch.zeros_like(codes_relu).scatter(-1, top_k_indices, top_k_vals) - recon = self.decode(codes, info) + if self.decoder_impl == "triton": + recon = self._decode_topk_triton(top_k_vals, top_k_indices, info) + else: + recon = self.decode(codes, info) return recon, codes + def _decode_topk_triton( + self, + top_k_vals: torch.Tensor, + top_k_indices: torch.Tensor, + info: Optional[Dict[str, torch.Tensor]] = None, + denormalize: bool = True, + ) -> torch.Tensor: + """Decode from top-k (values, indices) via the sparse Triton kernel. + + Returns reconstruction with pre_bias added; denormalized to input scale when + ``denormalize`` (set False to get the normalized-space recon for aux loss). + """ + from ..kernels import TritonDecoderAutograd + + recon = TritonDecoderAutograd.apply(top_k_indices.contiguous(), top_k_vals.contiguous(), self.decoder.weight) + recon = recon + self.pre_bias + if denormalize and self.normalize_input and info is not None: + recon = self._denormalize(recon, info) + return recon + + def _update_dead_latent_stats_from_indices(self, top_k_indices: torch.Tensor, n_tokens: int) -> None: + """Update stats_last_nonzero from top-k indices (no dense [batch, hidden] tensor).""" + active_mask = torch.zeros_like(self.stats_last_nonzero, dtype=torch.bool) + active_mask[top_k_indices.reshape(-1)] = True + self.stats_last_nonzero = torch.where( + active_mask, torch.zeros_like(self.stats_last_nonzero), self.stats_last_nonzero + n_tokens + ) + def forward_with_aux(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """Forward pass with auxiliary info for auxk loss computation. @@ -257,8 +299,9 @@ def _compute_auxk_loss( x: torch.Tensor, recon: torch.Tensor, pre_act: torch.Tensor, - codes: torch.Tensor, + codes: Optional[torch.Tensor], norm_info: Optional[Dict[str, torch.Tensor]] = None, + recon_norm: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Compute auxiliary loss for dead latents. @@ -293,8 +336,10 @@ def _compute_auxk_loss( if self.normalize_input and norm_info is not None: # Normalize x to match the space where encoding happened x_norm = (x - norm_info["mu"]) / norm_info["std"] - # Reuse codes from forward pass instead of re-encoding - recon_norm = self.decoder(codes) + self.pre_bias + # Reuse codes from forward pass instead of re-encoding (or a precomputed + # normalized recon, e.g. from the sparse/triton decode path). + if recon_norm is None: + recon_norm = self.decoder(codes) + self.pre_bias residual = x_norm - recon_norm.detach() else: residual = x - recon.detach() + self.pre_bias.detach() @@ -375,6 +420,9 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: - aux (if auxk enabled): auxiliary loss value - dead_pct (if auxk enabled): percentage of dead latents """ + if self.decoder_impl == "triton": + return self._loss_triton(x) + # Forward pass with auxiliary info info = self.forward_with_aux(x) recon = info["recon"] @@ -422,3 +470,53 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: result["aux"] = aux_loss return result + + def _loss_triton(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + """loss() using the sparse Triton decoder. + + Numerically equivalent to the dense loss() but never materializes the dense + [batch, hidden_dim] code tensor or runs the full decoder matmul: it decodes + from the top-k (values, indices) and derives dead-latent stats / L0 from the + indices. This is what lets hidden_dim scale to ~1M+. + """ + pre_act, info = self.encode_pre_act(x) + codes_relu = torch.relu(pre_act) + top_k_vals, top_k_indices = torch.topk(codes_relu, self.top_k, dim=-1) + + # Sparse decode in normalized space (pre_bias added); denormalize for the main loss. + recon_norm = self._decode_topk_triton(top_k_vals, top_k_indices, info, denormalize=False) + recon = self._denormalize(recon_norm, info) if (self.normalize_input and info) else recon_norm + + # Dead-latent stats from indices (no dense codes tensor). + self._update_dead_latent_stats_from_indices(top_k_indices, x.shape[0]) + + # Primary reconstruction loss (FVU), centered by pre_bias -- matches dense loss(). + mse = (recon - x).pow(2).mean(dim=-1) + x_var = (x - self.pre_bias).pow(2).mean(dim=-1) + recon_loss = (mse / (x_var + 1e-8)).mean() + + # For TopK, L0 == count of nonzero top-k values. + l0 = (top_k_vals != 0).float().sum(dim=-1).mean() + + with torch.no_grad(): + raw_mse = (recon - x).pow(2).mean() + total_var = torch.var(x, dim=0).sum() + residual_var = torch.var(recon - x, dim=0).sum() + var_explained = 1.0 - (residual_var / (total_var + 1e-8)) + + result = { + "total": recon_loss, + "fvu": 1.0 - var_explained, + "sparsity": l0, + "mse": raw_mse, + "variance_explained": var_explained, + } + dead_pct = (self.stats_last_nonzero > self.dead_tokens_threshold).float().mean() * 100 + result["dead_pct"] = dead_pct + + if self.auxk is not None: + aux_loss = self._compute_auxk_loss(x, recon, pre_act, codes=None, norm_info=info, recon_norm=recon_norm) + result["total"] = recon_loss + self.auxk_coef * aux_loss + result["aux"] = aux_loss + + return result diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk_tp.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk_tp.py new file mode 100644 index 0000000000..e6710efad7 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/architectures/topk_tp.py @@ -0,0 +1,294 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tensor-parallel (latent-sharded) TopK SAE. + +Each rank owns a contiguous block of ``hidden_dim // world_size`` latents: its slice +of the encoder rows, latent bias, and decoder columns. ``pre_bias`` is replicated. +Forward: each rank computes local pre-activations, a sharded global top-k selects the +true global top-k, each rank decodes the selections it owns, and the partial +reconstructions are summed across ranks (all-reduce). + +Numerically equivalent to the dense ``TopKSAE`` (verified by parity tests). Kept as a +separate class so the dense ``TopKSAE`` is untouched; small helpers (_normalize) are +duplicated rather than refactored out of it. + +Replicated ``pre_bias`` gradient note: ``pre_bias`` contributes via both the encoder +(``x - pre_bias``, sharded) and the decoder (added once). We add ``pre_bias / +world_size`` inside the all-reduced decode path so that, after an all-reduce(SUM) of +the ``pre_bias`` gradient across ranks (done by the TP trainer), the total gradient +equals the dense one exactly: sharded encoder parts sum, and the decoder part (1/P per +rank) sums back to a single full contribution. +""" + +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..kernels import HAS_TRITON +from ..parallel import all_reduce_sum, autograd_all_reduce_sum, global_topk + + +class ShardedTopKSAE(nn.Module): + """Latent-sharded TopK SAE (tensor parallel across `world_size` ranks).""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + top_k: int, + rank: int, + world_size: int, + normalize_input: bool = True, + auxk: Optional[int] = None, + auxk_coef: float = 1 / 32, + dead_tokens_threshold: int = 10_000_000, + decoder_impl: str = "dense", + group=None, + ): + """Args mirror TopKSAE; `hidden_dim` is the GLOBAL latent count (divisible by world_size).""" + super().__init__() + if hidden_dim % world_size != 0: + raise ValueError(f"hidden_dim {hidden_dim} must be divisible by world_size {world_size}") + if decoder_impl not in ("dense", "triton"): + raise ValueError(f"decoder_impl must be 'dense' or 'triton', got {decoder_impl!r}") + + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.top_k = top_k + self.rank = rank + self.world_size = world_size + self.normalize_input = normalize_input + self.auxk = auxk + self.auxk_coef = auxk_coef + self.dead_tokens_threshold = dead_tokens_threshold + self.decoder_impl = decoder_impl + self.group = group + self.latents_per_rank = hidden_dim // world_size + + L = self.latents_per_rank + self.W_enc_local = nn.Parameter(torch.empty(L, input_dim)) + self.latent_bias_local = nn.Parameter(torch.zeros(L)) + self.W_dec_local = nn.Parameter(torch.empty(input_dim, L)) + self.pre_bias = nn.Parameter(torch.zeros(input_dim)) # replicated + nn.init.kaiming_uniform_(self.W_dec_local, a=5**0.5) + with torch.no_grad(): + self.W_enc_local.copy_(self.W_dec_local.t()) # encoder = decoder.T init + + self.register_buffer("stats_last_nonzero", torch.zeros(L, dtype=torch.long)) + + @torch.no_grad() + def normalize_decoder(self) -> None: + """Unit-norm each decoder ROW over all latents (matches dense normalize_decoder). + + Dense does F.normalize(weight[d, n], dim=1). Each rank holds [d, L], so the + per-row norm over the full n latents needs an all-reduce of per-row + sum-of-squares. Load-bearing for TopK training stability. + """ + sumsq = (self.W_dec_local**2).sum(dim=1) # [d], local sum over this rank's L latents + all_reduce_sum(sumsq, self.group) # -> per-row sum-of-squares over all n latents + norm = sumsq.clamp_min(1e-12).sqrt().unsqueeze(1) # [d, 1] + self.W_dec_local.data = self.W_dec_local.data / norm + + def post_step(self) -> None: + """Called by the TP trainer after optimizer.step().""" + self.normalize_decoder() + + def reduce_replicated_grads(self) -> None: + """All-reduce(SUM) gradients of replicated params (pre_bias) across the TP group. + + Sharded params are distinct per rank and need no sync. pre_bias is replicated: + summing its grad combines the per-rank encoder contributions and the 1/P decode + parts into the exact dense gradient (see class docstring). Call after backward, + before optimizer.step(). + """ + if self.pre_bias.grad is not None: + all_reduce_sum(self.pre_bias.grad, self.group) + + @torch.no_grad() + def load_shard_from_dense(self, dense) -> None: + """Copy this rank's slice of a dense TopKSAE's weights (for tests / merge).""" + lo = self.rank * self.latents_per_rank + hi = lo + self.latents_per_rank + self.W_enc_local.copy_(dense.encoder.weight[lo:hi, :]) + self.latent_bias_local.copy_(dense.latent_bias[lo:hi]) + self.W_dec_local.copy_(dense.decoder.weight[:, lo:hi]) + self.pre_bias.copy_(dense.pre_bias) + + def _normalize(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + mu = x.mean(dim=-1, keepdim=True) + std = x.std(dim=-1, keepdim=True) + 1e-8 + return (x - mu) / std, {"mu": mu, "std": std} + + def _denormalize(self, x: torch.Tensor, info: Dict[str, torch.Tensor]) -> torch.Tensor: + return x * info["std"] + info["mu"] + + @torch.no_grad() + def init_pre_bias_from_data(self, data: torch.Tensor, max_iter: int = 100, eps: float = 1e-6) -> None: + """Initialize the (replicated) pre_bias to the geometric median of the data. + + Identical computation to dense TopKSAE.init_pre_bias_from_data; since pre_bias + is replicated and every rank sees the same data sample, all ranks compute the + same value (no communication needed). + """ + data = data.float().cpu() + if self.normalize_input: + mu = data.mean(dim=-1, keepdim=True) + std = data.std(dim=-1, keepdim=True) + 1e-8 + data = (data - mu) / std + median = data.mean(dim=0) + for _ in range(max_iter): + diffs = data - median.unsqueeze(0) + distances = diffs.norm(dim=1, keepdim=True).clamp(min=1e-8) + weights = 1.0 / distances + new_median = (data * weights).sum(dim=0) / weights.sum() + if (new_median - median).norm() < eps: + break + median = new_median + self.pre_bias.data = median.to(self.pre_bias.device) + + def encode_pre_act(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """Local pre-activations [batch, L]: optional normalize, subtract pre_bias, encoder.""" + info: Dict[str, torch.Tensor] = {} + if self.normalize_input: + x, info = self._normalize(x) + x_centered = x - self.pre_bias + pre_act_local = F.linear(x_centered, self.W_enc_local, self.latent_bias_local) # [B, L] + return pre_act_local, info + + def _decode_local(self, vals: torch.Tensor, local_indices: torch.Tensor) -> torch.Tensor: + """Decode this rank's owned selections -> partial reconstruction [B, d] (normalized space).""" + if self.decoder_impl == "triton" and HAS_TRITON and vals.is_cuda: + from ..kernels import TritonDecoderAutograd + + return TritonDecoderAutograd.apply(local_indices.contiguous(), vals.contiguous(), self.W_dec_local) + # Dense gather-sum (no scatter -> safe against the index-0 padding of unowned slots): + # partial[b] = sum_j vals[b,j] * W_dec_local[:, local_indices[b,j]] + gathered = self.W_dec_local.t()[local_indices] # [B, k, d] + return (gathered * vals.unsqueeze(-1)).sum(dim=1) + + def forward(self, x: torch.Tensor): + """Return (reconstruction [B, d], GlobalTopK). recon is replicated across ranks.""" + pre_act_local, info = self.encode_pre_act(x) + acts_local = torch.relu(pre_act_local) + + gtk = global_topk(acts_local, self.top_k, self.rank, self.latents_per_rank, self.group) + # Differentiable values from the local activations (grad flows to the encoder); + # zero out selections this rank does not own. + vals = acts_local.gather(1, gtk.local_indices) * gtk.owned_mask.to(acts_local.dtype) + + partial = self._decode_local(vals, gtk.local_indices) + # See class docstring: pre_bias/world_size makes the replicated-grad sum exact. + partial = partial + self.pre_bias / self.world_size + recon = autograd_all_reduce_sum(partial, self.group) # sum partials -> full recon (normalized space) + + if self.normalize_input and info: + recon = self._denormalize(recon, info) + return recon, gtk + + def _update_dead_latent_stats_local(self, gtk, vals: torch.Tensor, n_tokens: int) -> None: + """Mark local latents that fired (owned selection with value > 1e-3); else age them.""" + active = torch.zeros_like(self.stats_last_nonzero, dtype=torch.bool) + fired = vals > 1e-3 # [B, k]; vals already zeroed on unowned slots + active[gtk.local_indices[fired]] = True + self.stats_last_nonzero = torch.where( + active, torch.zeros_like(self.stats_last_nonzero), self.stats_last_nonzero + n_tokens + ) + + def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + """Sharded loss with the same keys/values as the dense TopKSAE.loss(). + + recon is replicated across ranks, so recon-derived metrics match dense on + every rank; dead_pct is reduced globally and auxk uses a sharded global + top-k over dead latents. + """ + pre_act_local, info = self.encode_pre_act(x) + acts_local = torch.relu(pre_act_local) + gtk = global_topk(acts_local, self.top_k, self.rank, self.latents_per_rank, self.group) + vals = acts_local.gather(1, gtk.local_indices) * gtk.owned_mask.to(acts_local.dtype) + + partial = self._decode_local(vals, gtk.local_indices) + self.pre_bias / self.world_size + recon_norm = autograd_all_reduce_sum(partial, self.group) + recon = self._denormalize(recon_norm, info) if (self.normalize_input and info) else recon_norm + + self._update_dead_latent_stats_local(gtk, vals, x.shape[0]) + + mse = (recon - x).pow(2).mean(dim=-1) + # x_var uses pre_bias in a *replicated* (non-sharded) way, so its gradient would + # be over-counted x world_size by the all-reduce(SUM) of pre_bias grads. Scale the + # grad by 1/world_size (value unchanged) so the sum recovers the dense gradient -- + # same principle as the pre_bias/world_size decode term. + pb = self.pre_bias / self.world_size + (self.pre_bias - self.pre_bias / self.world_size).detach() + x_var = (x - pb).pow(2).mean(dim=-1) + recon_loss = (mse / (x_var + 1e-8)).mean() + l0 = (gtk.global_values != 0).float().sum(dim=-1).mean() + + with torch.no_grad(): + raw_mse = (recon - x).pow(2).mean() + total_var = torch.var(x, dim=0).sum() + residual_var = torch.var(recon - x, dim=0).sum() + var_explained = 1.0 - (residual_var / (total_var + 1e-8)) + + result = { + "total": recon_loss, + "fvu": 1.0 - var_explained, + "sparsity": l0, + "mse": raw_mse, + "variance_explained": var_explained, + } + + # Global dead fraction across all shards. + with torch.no_grad(): + local_dead = (self.stats_last_nonzero > self.dead_tokens_threshold).sum().float() + total_dead = all_reduce_sum(local_dead.clone(), self.group) + result["dead_pct"] = total_dead / self.hidden_dim * 100 + + if self.auxk is not None: + aux_loss = self._compute_auxk_loss(x, recon, recon_norm, pre_act_local, info) + result["total"] = recon_loss + self.auxk_coef * aux_loss + result["aux"] = aux_loss + + return result + + def _compute_auxk_loss(self, x, recon, recon_norm, pre_act_local, info) -> torch.Tensor: + """Auxiliary dead-latent loss: a sharded global top-auxk over dead latents. + + Mirrors dense TopKSAE._compute_auxk_loss (top-auxk among dead by relu value, + decode, fit the primary residual in normalized space). + """ + dead_mask_local = self.stats_last_nonzero > self.dead_tokens_threshold # [L] + total_dead = int(all_reduce_sum(dead_mask_local.sum().float().clone(), self.group).item()) + if total_dead == 0: + return torch.zeros((), device=x.device, dtype=x.dtype) + + k_aux = min(self.auxk, total_dead) + acts_local = torch.relu(pre_act_local) + # Only dead latents are selectable; -inf so the global top-k never picks live ones. + masked = acts_local.masked_fill(~dead_mask_local, float("-inf")) + gtk_aux = global_topk(masked, k_aux, self.rank, self.latents_per_rank, self.group) + vals_aux = acts_local.gather(1, gtk_aux.local_indices) * gtk_aux.owned_mask.to(acts_local.dtype) + recon_aux = autograd_all_reduce_sum(self._decode_local(vals_aux, gtk_aux.local_indices), self.group) + + if self.normalize_input and info: + x_norm = (x - info["mu"]) / info["std"] + residual = x_norm - recon_norm.detach() + else: + residual = x - recon.detach() + self.pre_bias.detach() + + mse = (recon_aux - residual).pow(2).mean(dim=-1) + target_var = residual.pow(2).mean(dim=-1) + return (mse / (target_var + 1e-8)).mean() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/benchmarks/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/benchmarks/__init__.py new file mode 100644 index 0000000000..1dd47a63cf --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/benchmarks/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/benchmarks/bench_decoder.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/benchmarks/bench_decoder.py new file mode 100644 index 0000000000..2d38abee84 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/benchmarks/bench_decoder.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark: dense vs Triton sparse TopK decoder. + +Measures forward+backward latency and peak GPU memory across a sweep of latent +counts, isolating the win the sparse kernel gives as ``n_latents`` grows (the dense +path OOMs first — that OOM is itself the headline result). Two modes: + + --mode kernel : the decode op alone (TritonDecoderAutograd vs dense reference) + --mode sae : a full TopKSAE.loss() fwd+bwd (decoder_impl dense vs triton) + +Usage (on a GPU box): + python -m sae.benchmarks.bench_decoder --impl all --mode kernel + python -m sae.benchmarks.bench_decoder --impl all --mode sae --batch 4096 --d 2688 + python -m sae.benchmarks.bench_decoder --json out.json +""" + +import argparse +import json + +import torch + +from sae.kernels import HAS_TRITON, TritonDecoderAutograd, reference_decode + + +# (label, expansion) for d_model=2688 -> n_latents; expansion is informational. +DEFAULT_NS = [21_504, 86_016, 344_064, 688_128, 1_048_576] + + +def _unique_topk(a, n, k, d, dtype, device): + scores = torch.rand(a, n, device=device) + idx = scores.argsort(dim=-1)[:, :k].contiguous().to(torch.int64) + vals = torch.rand(a, k, device=device, dtype=dtype).contiguous() + w = torch.randn(d, n, device=device, dtype=dtype) + return idx, vals, w + + +def _time_fwd_bwd(fn, iters, warmup): + """Return (fwd_ms, bwd_ms) medians, or raise the underlying error (e.g. OOM).""" + fwd, bwd = [], [] + for i in range(warmup + iters): + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + m = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + out, bwd_fn = fn() + m.record() + bwd_fn(out) + e.record() + torch.cuda.synchronize() + if i >= warmup: + fwd.append(s.elapsed_time(m)) + bwd.append(m.elapsed_time(e)) + fwd.sort() + bwd.sort() + return fwd[len(fwd) // 2], bwd[len(bwd) // 2] + + +def bench_cell(impl, mode, a, n, k, d, dtype, device, iters, warmup): + """Benchmark one (impl, n) cell. Returns dict with timings + peak mem, or OOM.""" + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + try: + if mode == "kernel": + idx, vals, w = _unique_topk(a, n, k, d, dtype, device) + + def fn(): + v = vals.clone().requires_grad_(True) + ww = w.clone().requires_grad_(True) + gseed = torch.randn(a, d, device=device, dtype=dtype) + if impl == "triton": + out = TritonDecoderAutograd.apply(idx, v, ww) + else: + out = reference_decode(idx, v, ww) + return (out * gseed).sum(), lambda loss: loss.backward() + else: # sae + from sae.architectures import TopKSAE + + sae = TopKSAE(input_dim=d, hidden_dim=n, top_k=k, normalize_input=True, decoder_impl=impl).to(device) + x = torch.randn(a, d, device=device, dtype=dtype) + + def fn(): + sae.zero_grad(set_to_none=True) + out = sae.loss(x)["total"] + return out, lambda loss: loss.backward() + + fwd_ms, bwd_ms = _time_fwd_bwd(fn, iters, warmup) + peak_gb = torch.cuda.max_memory_allocated() / 1e9 + return {"fwd_ms": round(fwd_ms, 3), "bwd_ms": round(bwd_ms, 3), "peak_gb": round(peak_gb, 2)} + except RuntimeError as exc: + if "out of memory" in str(exc).lower(): + torch.cuda.empty_cache() + return {"status": "OOM"} + raise + + +def main(): + """Parse args and run the dense-vs-Triton decoder benchmark sweep.""" + p = argparse.ArgumentParser(description="Benchmark dense vs Triton sparse TopK decoder") + p.add_argument("--impl", choices=["dense", "triton", "all"], default="all") + p.add_argument("--mode", choices=["kernel", "sae"], default="kernel") + p.add_argument("--batch", type=int, default=4096) + p.add_argument("--d", type=int, default=2688) + p.add_argument("--k", type=int, default=32) + p.add_argument("--ns", type=int, nargs="+", default=DEFAULT_NS) + p.add_argument("--dtype", choices=["fp32", "bf16", "fp16"], default="bf16") + p.add_argument("--iters", type=int, default=20) + p.add_argument("--warmup", type=int, default=5) + p.add_argument("--json", type=str, default=None) + args = p.parse_args() + + if not torch.cuda.is_available(): + raise SystemExit("CUDA required for this benchmark.") + if args.impl in ("triton", "all") and not HAS_TRITON: + raise SystemExit("Triton not available; install triton or use --impl dense.") + + dtype = {"fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}[args.dtype] + impls = ["dense", "triton"] if args.impl == "all" else [args.impl] + gpu = torch.cuda.get_device_name(0) + print(f"GPU: {gpu} | mode={args.mode} | batch={args.batch} d={args.d} k={args.k} dtype={args.dtype}\n") + + results = [] + header = f"{'n_latents':>10} {'exp':>5} " + " ".join(f"{i:>22}" for i in impls) + print(header) + print("-" * len(header)) + for n in args.ns: + row = {"n_latents": n, "expansion": round(n / args.d, 1)} + cells = [] + for impl in impls: + r = bench_cell(impl, args.mode, args.batch, n, args.k, args.d, dtype, "cuda", args.iters, args.warmup) + row[impl] = r + if "status" in r: + cells.append(f"{'OOM':>22}") + else: + cells.append(f"{r['fwd_ms']:>6.2f}/{r['bwd_ms']:>6.2f}ms {r['peak_gb']:>5.1f}GB") + results.append(row) + print(f"{n:>10} {row['expansion']:>5}x " + " ".join(cells)) + + print("\n(cells: fwd/bwd ms, peak GB; 'OOM' = ran out of memory)") + if args.json: + with open(args.json, "w") as f: + json.dump({"gpu": gpu, "args": vars(args), "results": results}, f, indent=2) + print(f"Wrote {args.json}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/__init__.py new file mode 100644 index 0000000000..4979d6118f --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/__init__.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse SAE kernels (Triton now; CUDA later) for scaling TopK SAEs. + +Public surface: + HAS_TRITON -- whether Triton imported successfully + TritonDecoderAutograd -- sparse TopK decode autograd Function + reference_decode -- dense oracle used by tests +""" + +from .reference import reference_decode, reference_sparse_dense_matmul +from .triton_decoder import HAS_TRITON, TritonDecoderAutograd + + +__all__ = [ + "HAS_TRITON", + "TritonDecoderAutograd", + "reference_decode", + "reference_sparse_dense_matmul", +] diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/reference.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/reference.py new file mode 100644 index 0000000000..a32fa69aac --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/reference.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure-PyTorch dense reference implementations of the sparse decoder ops. + +These are the correctness oracle for the Triton kernels: simple, obviously-correct, +fully autograd-differentiable, and device/dtype agnostic. Tests compare the Triton +kernels against these (and against autograd through reference_decode for gradients). +""" + +import torch + + +def reference_sparse_dense_matmul( + sparse_indices: torch.Tensor, + sparse_values: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + """Dense equivalent of ``triton_sparse_dense_matmul``: ``sparse @ dense``. + + sparse_indices/sparse_values are (A, k); dense is (N, B); output is (A, B). + """ + # Gather the active rows of `dense` and weight by values: out[a] = sum_k vals[a,k]*dense[idx[a,k]]. + gathered = dense[sparse_indices] # (A, k, B) + return (gathered * sparse_values.unsqueeze(-1)).sum(dim=1) # (A, B) + + +def reference_decode( + sparse_indices: torch.Tensor, + sparse_values: torch.Tensor, + decoder_weight: torch.Tensor, +) -> torch.Tensor: + """Dense, differentiable reference for ``TritonDecoderAutograd``. + + Builds the dense code tensor and runs the standard decoder matmul, so autograd + through it yields reference gradients for ``sparse_values`` and ``decoder_weight``. + ``decoder_weight`` is (d, n) (``nn.Linear(n, d).weight``); output is (A, d). + """ + a, _ = sparse_indices.shape + n = decoder_weight.shape[1] + codes = torch.zeros(a, n, device=sparse_values.device, dtype=sparse_values.dtype) + codes = codes.scatter(-1, sparse_indices, sparse_values) + return codes @ decoder_weight.T # (A, d) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/triton_decoder.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/triton_decoder.py new file mode 100644 index 0000000000..b026132f7d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/kernels/triton_decoder.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton kernels for the sparse TopK SAE decoder. + +Adapted from OpenAI's sparse autoencoder kernels +(https://github.com/openai/sparse_autoencoder, MIT License, Copyright (c) OpenAI). + +The decoder of a TopK SAE only touches ``k`` of ``n_latents`` columns per token, so +materializing the dense ``[batch, n_latents]`` code tensor and running a full +``[batch, n] @ [n, d]`` matmul is wasteful. These kernels operate directly on the +top-k ``(indices, values)`` so the decode is ``O(batch * k * d)`` instead of +``O(batch * n * d)`` and never allocates the dense code tensor -- which is what +lets the latent count scale to ~1M+. + +``TritonDecoderAutograd.apply(indices, values, decoder_weight)`` computes the +reconstruction (pre-bias) and its gradients. ``decoder_weight`` is ``[d, n]`` +(i.e. ``nn.Linear(n, d, bias=False).weight``); the kernel uses its transpose. + +Triton is imported lazily so this module is importable on CPU-only machines; the +kernels still require a CUDA device at call time. +""" + +import torch + + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except Exception: # pragma: no cover - depends on environment + HAS_TRITON = False + + +if HAS_TRITON: + + def triton_sparse_dense_matmul( + sparse_indices: torch.Tensor, + sparse_values: torch.Tensor, + dense: torch.Tensor, + ) -> torch.Tensor: + """Compute ``sparse @ dense`` (reduce along the uncollated dim of sparse). + + sparse_indices/sparse_values are (A, k); dense is (N, B); output is (A, B). + ``dense`` must be contiguous along dim 0 (i.e. ``dense.T`` is contiguous). + """ + N = dense.shape[0] + assert sparse_indices.shape == sparse_values.shape + assert sparse_indices.is_contiguous() + assert sparse_values.is_contiguous() + # NOTE: OpenAI asserts dense.is_contiguous(). Our decoder weight is a + # standard nn.Linear [d, n], so dense = weight.T is a strided [n, d] view. + # The kernel reads via stride_dn/stride_db so it is correct on the strided + # view (loads along B are uncoalesced -> a perf, not correctness, follow-up). + + A = sparse_indices.shape[0] + K = sparse_indices.shape[1] + B = dense.shape[1] + + out = torch.zeros(A, B, device=dense.device, dtype=sparse_values.dtype) + + triton_sparse_dense_matmul_kernel[(A,)]( + sparse_indices, + sparse_values, + dense, + out, + stride_dn=dense.stride(0), + stride_db=dense.stride(1), + A=A, + B=B, + N=N, + K=K, + BLOCK_SIZE_K=triton.next_power_of_2(K), + BLOCK_SIZE_B=triton.next_power_of_2(B), + ) + return out + + @triton.jit + def triton_sparse_dense_matmul_kernel( # noqa: D103 - low-level Triton kernel (ported) + sparse_indices_ptr, + sparse_values_ptr, + dense_ptr, + out_ptr, + stride_dn, + stride_db, + A, + B, + N, + K, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_B: tl.constexpr, + ): + pid = tl.program_id(0) + + offsets_k = tl.arange(0, BLOCK_SIZE_K) + sparse_indices = tl.load(sparse_indices_ptr + pid * K + offsets_k, mask=offsets_k < K) # (K,) + sparse_values = tl.load(sparse_values_ptr + pid * K + offsets_k, mask=offsets_k < K) # (K,) + + accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32) + offsets_b = tl.arange(0, BLOCK_SIZE_B) + + for k in range(K): + # workaround to do sparse_indices[k] + i = tl.sum( + tl.where( + tl.arange(0, BLOCK_SIZE_K) == k, + sparse_indices, + tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64), + ) + ) + # workaround to do sparse_values[k] + v = tl.sum( + tl.where( + tl.arange(0, BLOCK_SIZE_K) == k, + sparse_values, + tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32), + ) + ) + + tl.device_assert(i < N) + if v != 0: + accum += v * tl.load(dense_ptr + i * stride_dn + offsets_b * stride_db, mask=offsets_b < B) + + tl.store(out_ptr + pid * B + offsets_b, accum.to(sparse_values.dtype), mask=offsets_b < B) + + def triton_sparse_transpose_dense_matmul( + sparse_indices: torch.Tensor, + sparse_values: torch.Tensor, + dense: torch.Tensor, + N: int, + BLOCK_SIZE_AK=128, + ) -> torch.Tensor: + """Compute ``sparse.T @ dense`` (reduce along the collated dim of sparse). + + sparse_indices/sparse_values are (A, k); dense is (A, B); output is (N, B). + """ + assert sparse_indices.shape == sparse_values.shape + assert sparse_indices.is_contiguous() + assert sparse_values.is_contiguous() + assert dense.is_contiguous() # contiguous along B + + K = sparse_indices.shape[1] + A = dense.shape[0] + assert sparse_indices.shape[0] == A + + # COO-format and sorted (by latent index) so equal latents are contiguous. + sorted_indices = sparse_indices.view(-1).sort() + coo_indices = torch.stack( + [ + torch.arange(A, device=sparse_indices.device).repeat_interleave(K)[sorted_indices.indices], + sorted_indices.values, + ] + ) # (2, A * K) + coo_values = sparse_values.view(-1)[sorted_indices.indices] # (A * K,) + return triton_coo_sparse_dense_matmul(coo_indices, coo_values, dense, N, BLOCK_SIZE_AK) + + def triton_coo_sparse_dense_matmul( # noqa: D103 - low-level Triton kernel wrapper (ported) + coo_indices: torch.Tensor, + coo_values: torch.Tensor, + dense: torch.Tensor, + N: int, + BLOCK_SIZE_AK=128, + ) -> torch.Tensor: + AK = coo_indices.shape[1] + B = dense.shape[1] + + out = torch.zeros(N, B, device=dense.device, dtype=coo_values.dtype) + + grid = lambda META: (triton.cdiv(AK, META["BLOCK_SIZE_AK"]), 1) # noqa: E731 + triton_sparse_transpose_dense_matmul_kernel[grid]( + coo_indices, + coo_values, + dense, + out, + stride_da=dense.stride(0), + stride_db=dense.stride(1), + B=B, + N=N, + AK=AK, + BLOCK_SIZE_AK=BLOCK_SIZE_AK, + BLOCK_SIZE_B=triton.next_power_of_2(B), + ) + return out + + @triton.jit + def triton_sparse_transpose_dense_matmul_kernel( # noqa: D103 - low-level Triton kernel (ported) + coo_indices_ptr, + coo_values_ptr, + dense_ptr, + out_ptr, + stride_da, + stride_db, + B, + N, + AK, + BLOCK_SIZE_AK: tl.constexpr, + BLOCK_SIZE_B: tl.constexpr, + ): + pid_ak = tl.program_id(0) + pid_b = tl.program_id(1) + + coo_offsets = tl.arange(0, BLOCK_SIZE_AK) + b_offsets = tl.arange(0, BLOCK_SIZE_B) + + A_coords = tl.load( + coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets, + mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, + ) + K_coords = tl.load( + coo_indices_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets + AK, + mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, + ) + values = tl.load( + coo_values_ptr + pid_ak * BLOCK_SIZE_AK + coo_offsets, + mask=pid_ak * BLOCK_SIZE_AK + coo_offsets < AK, + ) + + last_k = tl.min(K_coords) + accum = tl.zeros((BLOCK_SIZE_B,), dtype=tl.float32) + + for ind in range(BLOCK_SIZE_AK): + if ind + pid_ak * BLOCK_SIZE_AK < AK: + # workaround to do A_coords[ind] + a = tl.sum( + tl.where(tl.arange(0, BLOCK_SIZE_AK) == ind, A_coords, tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64)) + ) + k = tl.sum( + tl.where(tl.arange(0, BLOCK_SIZE_AK) == ind, K_coords, tl.zeros((BLOCK_SIZE_AK,), dtype=tl.int64)) + ) + v = tl.sum( + tl.where(tl.arange(0, BLOCK_SIZE_AK) == ind, values, tl.zeros((BLOCK_SIZE_AK,), dtype=tl.float32)) + ) + + tl.device_assert(k < N) + + if k != last_k: + tl.atomic_add( + out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets, + accum, + mask=BLOCK_SIZE_B * pid_b + b_offsets < B, + ) + accum *= 0 + last_k = k + + if v != 0: + accum += v * tl.load(dense_ptr + a * stride_da + b_offsets, mask=b_offsets < B) + + tl.atomic_add( + out_ptr + last_k * B + BLOCK_SIZE_B * pid_b + b_offsets, + accum, + mask=BLOCK_SIZE_B * pid_b + b_offsets < B, + ) + + def triton_dense_dense_sparseout_matmul( + dense1: torch.Tensor, + dense2: torch.Tensor, + at_indices: torch.Tensor, + ) -> torch.Tensor: + """Equivalent to ``(dense1 @ dense2).gather(1, at_indices)``. + + dense1 is (A, B); dense2 is (B, N); at_indices is (A, K); output is (A, K). + """ + A, B = dense1.shape + N = dense2.shape[1] + assert dense2.shape[0] == B + assert at_indices.shape[0] == A + K = at_indices.shape[1] + assert at_indices.is_contiguous() + assert dense1.stride(1) == 1, "dense1 must be contiguous along B" + # dense2 ([d, n] decoder weight) is read via stride_d2b/stride_d2n, so a + # row-major [d, n] (stride(0)=n) is fine even though OpenAI assumes stride(0)==1. + + if K > 512: + # naive is more efficient for large K + return (dense1 @ dense2).gather(1, at_indices) + + out = torch.zeros(A, K, device=dense1.device, dtype=dense1.dtype) + + triton_dense_dense_sparseout_matmul_kernel[(A,)]( + dense1, + dense2, + at_indices, + out, + stride_d1a=dense1.stride(0), + stride_d1b=dense1.stride(1), + stride_d2b=dense2.stride(0), + stride_d2n=dense2.stride(1), + A=A, + B=B, + N=N, + K=K, + BLOCK_SIZE_B=triton.next_power_of_2(B), + BLOCK_SIZE_N=triton.next_power_of_2(N), + BLOCK_SIZE_K=triton.next_power_of_2(K), + ) + return out + + @triton.jit + def triton_dense_dense_sparseout_matmul_kernel( # noqa: D103 - low-level Triton kernel (ported) + dense1_ptr, + dense2_ptr, + at_indices_ptr, + out_ptr, + stride_d1a, + stride_d1b, + stride_d2b, + stride_d2n, + A, + B, + N, + K, + BLOCK_SIZE_B: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + ): + pid = tl.program_id(0) + + offsets_k = tl.arange(0, BLOCK_SIZE_K) + at_indices = tl.load(at_indices_ptr + pid * K + offsets_k, mask=offsets_k < K) # (K,) + + offsets_b = tl.arange(0, BLOCK_SIZE_B) + dense1 = tl.load(dense1_ptr + pid * stride_d1a + offsets_b * stride_d1b, mask=offsets_b < B) # (B,) + + accum = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32) + + for k in range(K): + # workaround to do at_indices[k] + i = tl.sum( + tl.where(tl.arange(0, BLOCK_SIZE_K) == k, at_indices, tl.zeros((BLOCK_SIZE_K,), dtype=tl.int64)) + ) + tl.device_assert(i < N) + + dense2col = tl.load(dense2_ptr + offsets_b * stride_d2b + i * stride_d2n, mask=offsets_b < B) # (B,) + # NOTE: fixed vs upstream OpenAI, which used tl.int64 zeros here for a + # float32 accumulator -- a type bug that truncates value gradients on + # current Triton. The else-branch must be float32 to match `accum`. + accum += tl.where( + tl.arange(0, BLOCK_SIZE_K) == k, + tl.sum(dense1 * dense2col), + tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32), + ) + + tl.store(out_ptr + pid * K + offsets_k, accum, mask=offsets_k < K) + + class TritonDecoderAutograd(torch.autograd.Function): + """Sparse TopK decode with custom forward/backward (mirrors OpenAI).""" + + @staticmethod + def forward(ctx, sparse_indices, sparse_values, decoder_weight): + """Reconstruction = sparse(top-k) @ decoder_weight.T (no dense codes).""" + ctx.save_for_backward(sparse_indices, sparse_values, decoder_weight) + return triton_sparse_dense_matmul(sparse_indices, sparse_values, decoder_weight.T) + + @staticmethod + def backward(ctx, grad_output): + """Gradients for sparse_values (gathered) and decoder_weight (sparse-transpose).""" + sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors + + # The transpose/sparseout kernels require a contiguous grad_output. + grad_output = grad_output.contiguous() + + decoder_grad = triton_sparse_transpose_dense_matmul( + sparse_indices, sparse_values, grad_output, N=decoder_weight.shape[1] + ).T + + return ( + None, + triton_dense_dense_sparseout_matmul(grad_output, decoder_weight, sparse_indices), + # decoder is contiguous when transposed so this is a matching layout + decoder_grad, + None, + ) + +else: # pragma: no cover - exercised only when triton is unavailable + + class TritonDecoderAutograd: + """Placeholder that errors clearly when Triton is unavailable.""" + + @staticmethod + def apply(*args, **kwargs): + """Error: Triton unavailable, so the sparse decoder cannot run.""" + raise RuntimeError( + "Triton is not available, so decoder_impl='triton' cannot run. " + "Install triton (ships with recent PyTorch) or use decoder_impl='dense'." + ) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/__init__.py new file mode 100644 index 0000000000..6fded16eea --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/__init__.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tensor-parallel building blocks for latent-sharded SAEs.""" + +from .checkpoint import load_and_merge, save_sharded +from .comms import all_gather_cat, all_reduce_sum, autograd_all_reduce_sum +from .topk import GlobalTopK, dense_topk_reference, global_topk +from .training import train_tp_loop + + +__all__ = [ + "GlobalTopK", + "all_gather_cat", + "all_reduce_sum", + "autograd_all_reduce_sum", + "dense_topk_reference", + "global_topk", + "load_and_merge", + "save_sharded", + "train_tp_loop", +] diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/checkpoint.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/checkpoint.py new file mode 100644 index 0000000000..b7ac00de5c --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/checkpoint.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sharded checkpointing for the tensor-parallel TopK SAE. + +Each rank saves its own latent slice (`save_sharded`); `load_and_merge` reassembles +the shards into a single dense `TopKSAE` (on CPU) so the existing dense eval / loss- +recovered path can be reused without any TP machinery. +""" + +import json +import os + +import torch + + +def save_sharded(model, out_dir: str, rank=None) -> None: + """Save this rank's shard to `out_dir/shard_{rank}.pt` (+ meta.json on rank 0).""" + rank = model.rank if rank is None else rank + os.makedirs(out_dir, exist_ok=True) + torch.save( + { + "W_enc_local": model.W_enc_local.detach().cpu(), + "latent_bias_local": model.latent_bias_local.detach().cpu(), + "W_dec_local": model.W_dec_local.detach().cpu(), + "pre_bias": model.pre_bias.detach().cpu(), + "rank": rank, + }, + os.path.join(out_dir, f"shard_{rank:03d}.pt"), + ) + if rank == 0: + meta = { + "world_size": model.world_size, + "input_dim": model.input_dim, + "hidden_dim": model.hidden_dim, + "top_k": model.top_k, + "normalize_input": model.normalize_input, + "auxk": model.auxk, + "auxk_coef": model.auxk_coef, + "dead_tokens_threshold": model.dead_tokens_threshold, + } + with open(os.path.join(out_dir, "meta.json"), "w") as f: + json.dump(meta, f, indent=2) + + +def load_and_merge(out_dir: str): + """Reassemble sharded files into a single dense TopKSAE (CPU) for eval.""" + from ..architectures.topk import TopKSAE # lazy import to avoid any import cycle + + with open(os.path.join(out_dir, "meta.json")) as f: + meta = json.load(f) + world_size = meta["world_size"] + + shards = [torch.load(os.path.join(out_dir, f"shard_{r:03d}.pt"), map_location="cpu") for r in range(world_size)] + w_enc = torch.cat([s["W_enc_local"] for s in shards], dim=0) # [n, d] + w_dec = torch.cat([s["W_dec_local"] for s in shards], dim=1) # [d, n] + latent_bias = torch.cat([s["latent_bias_local"] for s in shards], dim=0) # [n] + pre_bias = shards[0]["pre_bias"] # replicated + + sae = TopKSAE( + input_dim=meta["input_dim"], + hidden_dim=meta["hidden_dim"], + top_k=meta["top_k"], + normalize_input=meta["normalize_input"], + auxk=meta["auxk"], + auxk_coef=meta["auxk_coef"], + dead_tokens_threshold=meta["dead_tokens_threshold"], + init_encoder_from_decoder=False, + init_pre_bias=False, + ) + with torch.no_grad(): + sae.encoder.weight.copy_(w_enc) + sae.decoder.weight.copy_(w_dec) + sae.latent_bias.copy_(latent_bias) + sae.pre_bias.copy_(pre_bias) + return sae diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/comms.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/comms.py new file mode 100644 index 0000000000..3089b2f7a8 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/comms.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collective-communication helpers for tensor-parallel SAEs. + +Thin wrappers over torch.distributed used by the latent-sharded TopK SAE: +- ``all_gather_cat``: gather a per-rank tensor and concatenate (used to collect + each shard's top-k candidates before the global top-k). Non-differentiable. +- ``all_reduce_sum`` / ``autograd_all_reduce_sum``: sum a tensor across ranks. The + autograd variant is used to combine per-rank partial reconstructions: the summed + output is replicated, so the same loss-gradient lands on every rank and passes + straight back to each rank's partial (identity backward). + +For pure tensor parallelism (the Phase A / 1M case) the TP group is the default +process group, so ``group=None`` everywhere. +""" + +import torch +import torch.distributed as dist + + +def all_gather_cat(tensor: torch.Tensor, group=None, dim: int = 0) -> torch.Tensor: + """All-gather `tensor` from every rank and concatenate along `dim`.""" + world_size = dist.get_world_size(group) + parts = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(parts, tensor.contiguous(), group=group) + return torch.cat(parts, dim=dim) + + +def all_reduce_sum(tensor: torch.Tensor, group=None) -> torch.Tensor: + """In-place sum-reduce across ranks (for non-differentiable tensors / grads).""" + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) + return tensor + + +class _AllReduceSum(torch.autograd.Function): + @staticmethod + def forward(ctx, x, group): + y = x.clone() + dist.all_reduce(y, op=dist.ReduceOp.SUM, group=group) + return y + + @staticmethod + def backward(ctx, grad_output): + # The summed output is replicated and the downstream loss is identical on + # every rank, so grad_output is the same on all ranks and d(sum)/d(x_r)=I. + # Each rank therefore keeps the incoming gradient unchanged. + return grad_output, None + + +def autograd_all_reduce_sum(x: torch.Tensor, group=None) -> torch.Tensor: + """Differentiable sum-all-reduce (combine per-rank partial reconstructions).""" + return _AllReduceSum.apply(x, group) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/topk.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/topk.py new file mode 100644 index 0000000000..28ba589352 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/topk.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Global TopK across latent shards for tensor-parallel SAEs. + +Each rank owns a contiguous block of `latents_per_rank` latents. To pick the global +top-k over all `world_size * latents_per_rank` latents without materializing them on +one rank, each rank takes its local top-k, all-gathers the candidates, and takes a +final top-k over the gathered `world_size * k` candidates. + +This is exact: the global top-k is always a subset of the union of the per-rank +top-ks (a globally-selected latent on rank r is among r's `k` largest, since it is +larger than every unselected latent), so gathering `k` per rank loses nothing. +""" + +from typing import NamedTuple + +import torch + +from .comms import all_gather_cat + + +class GlobalTopK(NamedTuple): + """Result of a sharded global top-k. + + Attributes (shapes [batch, k]): + global_values: top-k activation values (replicated across ranks). + global_indices: top-k *global* latent indices (replicated). + local_indices: per-rank local indices of the selections this rank owns + (0 where not owned -- pair with `owned_mask`). + owned_mask: True where the selection belongs to this rank's shard. + """ + + global_values: torch.Tensor + global_indices: torch.Tensor + local_indices: torch.Tensor + owned_mask: torch.Tensor + + +def global_topk( + pre_act_local: torch.Tensor, + k: int, + rank: int, + latents_per_rank: int, + group=None, +) -> GlobalTopK: + """Top-k over latents sharded across ranks. + + Args: + pre_act_local: [batch, latents_per_rank] local pre-activations on this rank. + k: number of global latents to keep per token. + rank: this rank's index within the TP group. + latents_per_rank: latents per shard (used to offset/de-offset indices). + group: TP process group (None = default group). + + Returns: + GlobalTopK (see its docstring). + """ + _, local_dim = pre_act_local.shape + local_k = min(k, local_dim) + + local_vals, local_idx = torch.topk(pre_act_local, local_k, dim=-1) # [batch, local_k] + global_idx_cand = local_idx + rank * latents_per_rank + + cand_vals = all_gather_cat(local_vals, group=group, dim=1) # [batch, world*local_k] + cand_gidx = all_gather_cat(global_idx_cand, group=group, dim=1) + + global_values, pos = torch.topk(cand_vals, k, dim=-1) # [batch, k] + global_indices = cand_gidx.gather(1, pos) # [batch, k] global latent indices + + lo = rank * latents_per_rank + owned_mask = (global_indices >= lo) & (global_indices < lo + latents_per_rank) + local_indices = torch.where(owned_mask, global_indices - lo, torch.zeros_like(global_indices)) + + return GlobalTopK(global_values, global_indices, local_indices, owned_mask) + + +def dense_topk_reference(pre_act_full: torch.Tensor, k: int) -> "tuple[torch.Tensor, torch.Tensor]": + """Single-tensor oracle: top-k over the full (unsharded) latent dim.""" + vals, idx = torch.topk(pre_act_full, k, dim=-1) + return vals, idx diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/training.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/training.py new file mode 100644 index 0000000000..bef5b07205 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/parallel/training.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal tensor-parallel training loop for ShardedTopKSAE. + +Kept separate from the DDP `Trainer` (which is untouched): TP shards the model (not +the batch), so every rank trains on the *same* data, runs no DDP wrap, and only +all-reduces the replicated `pre_bias` gradient. Per-step correctness is covered by +the B1 parity test; this loop is the orchestration around it. +""" + +import time + +import torch +import torch.distributed as dist + + +def train_tp_loop( + sae, + dataloader, + *, + lr: float, + max_steps: int, + device: str, + log_interval: int = 100, + max_grad_norm=None, + checkpoint_dir=None, + group=None, + perf_logger=None, +) -> float: + """Train a ShardedTopKSAE for `max_steps` optimizer steps. Returns final loss. + + If `perf_logger` is provided (rank 0 only), per-step metrics are logged through it + (same metrics/W&B path as the dense recipe); otherwise rank 0 prints periodically. + """ + rank = dist.get_rank(group) if dist.is_initialized() else 0 + sae = sae.to(device) + sae.train() + optimizer = torch.optim.Adam(sae.parameters(), lr=lr) + + # Tensor parallelism replicates the batch: every rank MUST train on the same data, + # else the all-reduced reconstruction combines different inputs and diverges. Rank 0 + # drives the dataloader and broadcasts each batch to the TP group. + distributed = dist.is_initialized() and dist.get_world_size(group) > 1 + data_iter = iter(dataloader) if rank == 0 else None + + final_loss = float("nan") + t0 = time.time() + for step in range(max_steps): + if rank == 0: + try: + batch = next(data_iter) + except StopIteration: + data_iter = iter(dataloader) + batch = next(data_iter) + x = (batch[0] if isinstance(batch, (tuple, list)) else batch).to(device).contiguous() + meta = torch.tensor([x.shape[0], x.shape[1]], dtype=torch.long, device=device) + else: + meta = torch.empty(2, dtype=torch.long, device=device) + if distributed: + dist.broadcast(meta, src=0, group=group) + if rank != 0: + x = torch.empty(int(meta[0]), int(meta[1]), device=device) + dist.broadcast(x, src=0, group=group) + + optimizer.zero_grad(set_to_none=True) + out = sae.loss(x) + loss = out["total"] + loss.backward() + sae.reduce_replicated_grads() # all-reduce replicated (pre_bias) grad + grad_norm = None + if max_grad_norm is not None: + grad_norm = float(torch.nn.utils.clip_grad_norm_(sae.parameters(), max_grad_norm)) + optimizer.step() + if hasattr(sae, "post_step"): + sae.post_step() # e.g. decoder normalization (keeps TopK training stable) + + final_loss = float(loss.detach()) + if rank == 0 and perf_logger is not None: + perf_logger.log_step( + step=step, + batch=x, + loss_dict={k: float(v) for k, v in out.items()}, + grad_norm=grad_norm, + lr=lr, + extra_metrics={"dead_pct": float(out["dead_pct"])}, + ) + elif rank == 0 and (step % log_interval == 0 or step == max_steps - 1): + rate = (step + 1) / (time.time() - t0) + mem = torch.cuda.max_memory_allocated(device) / 1e9 if torch.cuda.is_available() else 0.0 + print( + f"step {step:>7} | loss {final_loss:.6f} | fvu {float(out['fvu']):.4f} " + f"| dead% {float(out['dead_pct']):.2f} | {rate:.1f} steps/s | peak {mem:.1f}GB", + flush=True, + ) + + if rank == 0 and perf_logger is not None: + perf_logger.finish() + if checkpoint_dir is not None: + from .checkpoint import save_sharded + + save_sharded(sae, checkpoint_dir, rank=rank) + if rank == 0: + print(f"Saved sharded checkpoint ({sae.world_size} shards) to {checkpoint_dir}", flush=True) + return final_loss diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/streaming.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/streaming.py new file mode 100644 index 0000000000..a41a562748 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/streaming.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Producer-consumer streaming of activations for on-the-fly SAE training. + +Instead of extracting all activations to disk and then training on them, a +background *producer* thread runs a caller-supplied activation source and pushes +activation chunks onto a bounded queue, while the SAE ``Trainer`` *consumes* +re-batched activations off the queue. This avoids persisting activations to disk +(beyond checkpoints) and bounds host memory via the queue size: when the queue +is full the producer blocks (backpressure), so at most ``queue_size`` chunks are +in flight at once. + +The ``sae`` package is model-agnostic, so the activation source is supplied by +the caller as a *factory* -- a zero-argument callable that returns a fresh +iterator of activation chunks each time it is called. A factory (rather than a +bare iterator) is required so that each training epoch re-runs the source. Each +chunk is a tensor of shape ``[n_tokens, hidden_dim]`` (already flattened and +masked by the caller). + +Streaming is OFF by default. Enable it with ``StreamingConfig(enabled=True)``; +the flag is consulted by callers (e.g. a recipe's training script) to decide +whether to build a streaming dataloader instead of reading a cached store. + +Example: + >>> def producer_factory(): + ... for batch_of_texts in batches: + ... yield model.activations(batch_of_texts) # [n_tokens, hidden_dim] + >>> cfg = StreamingConfig(enabled=True, queue_size=8) + >>> dataloader = make_streaming_dataloader(producer_factory, batch_size=4096, config=cfg) + >>> trainer.fit(dataloader) # Trainer consumes batches as they are produced +""" + +import queue +import threading +from dataclasses import dataclass +from typing import Callable, Iterable, Iterator, Optional + +import torch +from torch.utils.data import DataLoader, IterableDataset + + +# A factory returning a fresh iterator of activation chunks, each of shape +# [n_tokens, hidden_dim]. A factory (not a bare iterator) lets each epoch +# re-run the source. +ActivationProducer = Callable[[], Iterable[torch.Tensor]] + + +@dataclass +class StreamingConfig: + """Configuration for producer-consumer activation streaming. + + Attributes: + enabled: Master flag. Off by default; callers check this to decide + whether to stream instead of reading a cached activation store. + queue_size: Maximum number of activation chunks buffered between the + producer thread and the consumer. Bounds host memory and provides + backpressure (the producer blocks when the queue is full). + shuffle_buffer_size: If > 0, shuffle incoming tokens within a buffer of + this many rows before emitting batches (approximate shuffle). 0 + preserves producer order. + seed: Seed for the shuffle buffer (ignored when shuffle_buffer_size == 0). + drop_last: If True, drop the final partial batch (keeps batch sizes + uniform, which matters for DDP). + """ + + enabled: bool = False + queue_size: int = 8 + shuffle_buffer_size: int = 0 + seed: Optional[int] = None + drop_last: bool = False + + +# Sentinel signalling the producer has finished. A unique object so it can never +# collide with a yielded activation chunk. +_DONE = object() + + +def _normalize_chunk(chunk: torch.Tensor) -> torch.Tensor: + """Coerce a producer chunk to a 2D float32 CPU tensor.""" + if not torch.is_tensor(chunk): + chunk = torch.as_tensor(chunk) + chunk = chunk.detach().to(device="cpu", dtype=torch.float32) + if chunk.ndim != 2: + raise ValueError(f"Activation chunks must be 2D [n_tokens, hidden_dim], got shape {tuple(chunk.shape)}") + return chunk + + +class StreamingActivationDataset(IterableDataset): + """IterableDataset that streams activations from one or more producers. + + One daemon producer thread per factory iterates its activation source and + puts chunks onto a single shared bounded queue; ``__iter__`` pulls chunks + from all producers, optionally shuffles within a buffer, and yields + pre-formed ``[batch_size, hidden_dim]`` batches. Wrap it in + ``DataLoader(..., batch_size=None)`` (or use ``make_streaming_dataloader``) + so the loader passes batches through as-is. + + Multiple producers enable parallel multi-GPU extraction: give each factory a + model replica pinned to its own GPU (and a disjoint slice of the data), and + the threads run their forward passes concurrently (PyTorch releases the GIL + during CUDA execution) while the consumer trains on a separate device. + + Exceptions raised by any producer are propagated to the consumer. When the + consumer stops early (e.g. ``GeneratorExit``), producers are signalled to + stop and the queue is drained so the threads can exit. + """ + + def __init__( + self, + producer_factory, + batch_size: int, + config: Optional[StreamingConfig] = None, + ): + """Initialize the streaming dataset. + + Args: + producer_factory: A zero-arg callable returning a fresh iterator of + activation chunks ([n_tokens, hidden_dim]), OR a list of such + callables (one background thread is spawned per factory, all + feeding the same queue). + batch_size: Number of token activations per emitted batch. + config: Streaming configuration (uses defaults if None). + """ + if batch_size <= 0: + raise ValueError(f"batch_size must be positive, got {batch_size}") + # Normalize to a list of factories (single factory -> one producer). + if callable(producer_factory): + self.producer_factories = [producer_factory] + else: + self.producer_factories = list(producer_factory) + if not self.producer_factories: + raise ValueError("producer_factory must be a callable or a non-empty list of callables") + self.batch_size = batch_size + self.config = config or StreamingConfig() + + def __iter__(self) -> Iterator[torch.Tensor]: + """Yield ``[batch_size, hidden_dim]`` batches as the producers fill the queue.""" + cfg = self.config + batch_size = self.batch_size + emit_threshold = max(cfg.shuffle_buffer_size, batch_size) if cfg.shuffle_buffer_size > 0 else batch_size + + q: "queue.Queue" = queue.Queue(maxsize=max(1, cfg.queue_size)) + stop_event = threading.Event() + n_producers = len(self.producer_factories) + + def _produce(factory) -> None: + try: + for chunk in factory(): + if stop_event.is_set(): + break + # put() blocks when the queue is full -> backpressure. + while not stop_event.is_set(): + try: + q.put(chunk, timeout=0.1) + break + except queue.Full: + continue + q.put(_DONE) # one DONE marker per producer + except Exception as exc: # surface producer failures to the consumer + q.put(exc) + + threads = [ + threading.Thread(target=_produce, args=(f,), name=f"sae-activation-producer-{i}", daemon=True) + for i, f in enumerate(self.producer_factories) + ] + for t in threads: + t.start() + + generator = None + if cfg.shuffle_buffer_size > 0 and cfg.seed is not None: + generator = torch.Generator().manual_seed(cfg.seed) + + buffer: Optional[torch.Tensor] = None + + def _shuffle(buf: torch.Tensor) -> torch.Tensor: + if cfg.shuffle_buffer_size <= 0: + return buf + perm = torch.randperm(buf.shape[0], generator=generator) + return buf[perm] + + try: + done_count = 0 + while True: + item = q.get() + if item is _DONE: + done_count += 1 + if done_count == n_producers: # all producers finished + break + continue + if isinstance(item, BaseException): + raise item + + chunk = _normalize_chunk(item) + if chunk.shape[0] == 0: + continue + buffer = chunk if buffer is None else torch.cat([buffer, chunk], dim=0) + + if buffer.shape[0] >= emit_threshold: + buffer = _shuffle(buffer) + n_full = (buffer.shape[0] // batch_size) * batch_size + for start in range(0, n_full, batch_size): + yield buffer[start : start + batch_size] + buffer = buffer[n_full:] + + # Flush whatever remains after the producer finished. + if buffer is not None and buffer.shape[0] > 0: + buffer = _shuffle(buffer) + n_full = (buffer.shape[0] // batch_size) * batch_size + for start in range(0, n_full, batch_size): + yield buffer[start : start + batch_size] + remainder = buffer[n_full:] + if remainder.shape[0] > 0 and not cfg.drop_last: + yield remainder + finally: + # Signal producers and drain the queue so any blocked put() returns. + stop_event.set() + try: + while True: + q.get_nowait() + except queue.Empty: + pass + for t in threads: + t.join(timeout=5.0) + + +def make_streaming_dataloader( + producer_factory, + batch_size: int, + config: Optional[StreamingConfig] = None, +) -> DataLoader: + """Build a DataLoader that streams activations from one or more producers. + + The returned DataLoader yields pre-formed ``[batch_size, hidden_dim]`` + tensors and can be passed directly to ``Trainer.fit``. ``num_workers`` is + fixed to 0 because producers run as threads in the main process (where the + base model replicas live on the GPUs); worker processes would need their own + model copies. + + Args: + producer_factory: A zero-arg callable returning a fresh iterator of + activation chunks ([n_tokens, hidden_dim]), OR a list of such + callables for parallel multi-GPU extraction (one thread each, all + feeding the same queue). + batch_size: Number of token activations per emitted batch. + config: Streaming configuration (uses defaults if None). + + Returns: + A DataLoader yielding ``[batch_size, hidden_dim]`` float32 tensors. + """ + dataset = StreamingActivationDataset(producer_factory, batch_size, config) + # batch_size=None: the dataset already yields pre-formed batches. + return DataLoader(dataset, batch_size=None, num_workers=0) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/training.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/training.py index d8e7c8f960..58c7e85707 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/training.py +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/training.py @@ -20,6 +20,7 @@ """ import contextlib +import itertools import math import os from dataclasses import dataclass, field @@ -69,6 +70,9 @@ class TrainingConfig: lr_schedule: LR schedule after warmup ('constant', 'cosine', 'linear') lr_min: Minimum LR for decay schedules lr_decay_steps: Total steps for LR decay (None = use full training duration) + max_steps: Stop after this many optimizer steps (None = run all n_epochs). + When set, epochs loop until the step budget is reached, so it controls + duration directly (useful for streaming, which has no fixed length). """ lr: float = 3e-4 @@ -89,6 +93,7 @@ class TrainingConfig: lr_schedule: str = "constant" lr_min: float = 0.0 lr_decay_steps: Optional[int] = None + max_steps: Optional[int] = None @dataclass @@ -560,25 +565,32 @@ def fit( if self.config.lr_decay_steps is not None: self._total_decay_steps = self.config.lr_decay_steps elif self.config.lr_schedule != "constant": - # Estimate total optimizer steps from dataloader length - try: - batches_per_epoch = len(self.dataloader) - steps_per_epoch = batches_per_epoch // accum_steps - total_steps = steps_per_epoch * self.config.n_epochs - self._total_decay_steps = max(0, total_steps - self.config.warmup_steps) - except TypeError: - self._total_decay_steps = 0 - self._print_rank0( - "WARNING: Cannot compute decay steps for streaming dataloader. " - "Set lr_decay_steps explicitly or use lr_schedule='constant'." - ) + if self.config.max_steps is not None: + # max_steps gives an exact budget (works for streaming too) + self._total_decay_steps = max(0, self.config.max_steps - self.config.warmup_steps) + else: + # Estimate total optimizer steps from dataloader length + try: + batches_per_epoch = len(self.dataloader) + steps_per_epoch = batches_per_epoch // accum_steps + total_steps = steps_per_epoch * self.config.n_epochs + self._total_decay_steps = max(0, total_steps - self.config.warmup_steps) + except TypeError: + self._total_decay_steps = 0 + self._print_rank0( + "WARNING: Cannot compute decay steps for streaming dataloader. " + "Set lr_decay_steps or max_steps explicitly, or use lr_schedule='constant'." + ) else: self._total_decay_steps = 0 remaining_info = "" if resume_from is not None: remaining_info = f" (resuming from epoch {self.current_epoch})" - self._print_rank0(f"\nTraining SAE for {self.config.n_epochs} epochs{remaining_info}...") + if self.config.max_steps is not None: + self._print_rank0(f"\nTraining SAE for up to {self.config.max_steps:,} steps{remaining_info}...") + else: + self._print_rank0(f"\nTraining SAE for {self.config.n_epochs} epochs{remaining_info}...") try: self._print_rank0(f"Batches per epoch: ~{len(self.dataloader)}") except TypeError: @@ -606,8 +618,17 @@ def fit( self._print_rank0(f"Resuming from epoch {start_epoch}, step {self.global_step}") epoch_losses = [] + max_steps = self.config.max_steps + # When max_steps is set, loop epochs indefinitely until the step budget is + # reached (so duration is controlled by steps, which also works for streaming). + epoch_iter = ( + itertools.count(start_epoch) if max_steps is not None else range(start_epoch, self.config.n_epochs) + ) + reached_max_steps = False - for epoch in range(start_epoch, self.config.n_epochs): + for epoch in epoch_iter: + if reached_max_steps: + break self.current_epoch = epoch batch_losses = [] @@ -722,8 +743,12 @@ def fit( self.global_step += 1 + if max_steps is not None and self.global_step >= max_steps: + reached_max_steps = True + break + # Epoch complete - avg_loss = np.mean(batch_losses) + avg_loss = np.mean(batch_losses) if batch_losses else float("nan") epoch_losses.append(avg_loss) # Print progress (only if no perf_logger, as it handles printing) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/__init__.py new file mode 100644 index 0000000000..1dd47a63cf --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/_dist_utils.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/_dist_utils.py new file mode 100644 index 0000000000..3f5c86fe9f --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/_dist_utils.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-process test harness for tensor-parallel logic (CPU / gloo). + +`run_distributed(fn, world_size)` launches `world_size` processes, each running +`fn(rank, world_size, *args)` inside an initialized gloo process group, and returns +`{rank: return_value}`. Exceptions in a worker are re-raised in the parent with the +full traceback. + +We use the *spawn* start method: the pytest session also runs GPU/autograd tests, +and fork-after-autograd/CUDA is unsafe ("Unable to handle autograd's threading in +combination with fork-based multiprocessing"). Spawn starts clean interpreters, so +we extend PYTHONPATH for the children to import both the `sae` package and this +`tests` package (which holds the worker functions). +""" + +import os +import socket +import traceback + +import torch.distributed as dist +import torch.multiprocessing as mp + + +_TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) # .../sae/tests +_SAE_DIR = os.path.dirname(_TESTS_DIR) # .../sae (root of the `tests` package) +_SRC_DIR = os.path.join(_SAE_DIR, "src") # .../sae/src (root of the `sae` package) + + +def _free_port() -> int: + s = socket.socket() + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + +def _worker(rank, world_size, backend, fn, args, ret): + try: + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + ret[rank] = fn(rank, world_size, *args) + except Exception: + ret[f"error_{rank}"] = traceback.format_exc() + raise + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def run_distributed(fn, world_size, backend="gloo", args=()): + """Spawn `world_size` gloo workers running fn; return {rank: result}.""" + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = str(_free_port()) + # Ensure spawned children can import `sae` and the `tests` package. + os.environ["PYTHONPATH"] = os.pathsep.join([_SRC_DIR, _SAE_DIR, os.environ.get("PYTHONPATH", "")]) + + ctx = mp.get_context("spawn") + ret = ctx.Manager().dict() + procs = [ + ctx.Process(target=_worker, args=(rank, world_size, backend, fn, args, ret)) for rank in range(world_size) + ] + for p in procs: + p.start() + for p in procs: + p.join() + + errors = {k: v for k, v in ret.items() if isinstance(k, str) and k.startswith("error_")} + if errors: + raise AssertionError("distributed worker failed:\n" + "\n".join(errors.values())) + for p in procs: + if p.exitcode != 0: + raise RuntimeError(f"worker exited with code {p.exitcode}") + return {k: v for k, v in ret.items() if isinstance(k, int)} diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_kernels.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_kernels.py new file mode 100644 index 0000000000..388bb1b1a2 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_kernels.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Correctness tests for the Triton sparse decoder kernels. + +GPU-gated: these skip cleanly on CPU-only machines (Triton kernels need CUDA), so +the suite stays green everywhere and is validated on a GPU box. + +Oracle = the dense, autograd-differentiable reference in sae.kernels.reference. +The decoder-weight gradient kernel uses atomic adds (nondeterministic FP +accumulation), so gradient comparisons use tolerances, not exact equality. +""" + +import pytest +import torch +from sae.kernels import HAS_TRITON, TritonDecoderAutograd, reference_decode + + +pytestmark = pytest.mark.skipif( + not (HAS_TRITON and torch.cuda.is_available()), + reason="Triton sparse decoder kernels require CUDA + Triton", +) + +DEVICE = "cuda" + +# The Triton kernels accumulate in true fp32, but cuBLAS fp32 matmuls use TF32 by +# default on Ampere+/Hopper (~1e-2 error), which would make the dense reference the +# *less* accurate side. Disable TF32 so the reference is exact fp32 for comparison. +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + +def _random_topk(a, n, k, d, dtype, seed=0): + """Build random (indices, values, decoder_weight) with unique indices per row.""" + g = torch.Generator(device=DEVICE).manual_seed(seed) + # Unique top-k indices per row (argsort of random scores, take first k). + scores = torch.rand(a, n, generator=g, device=DEVICE) + indices = scores.argsort(dim=-1)[:, :k].contiguous().to(torch.int64) + values = torch.rand(a, k, generator=g, device=DEVICE, dtype=torch.float32).to(dtype).contiguous() + weight = torch.randn(d, n, generator=g, device=DEVICE, dtype=torch.float32).to(dtype) + return indices, values, weight + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("k", [1, 32]) +def test_forward_matches_reference(dtype, k): + idx, vals, w = _random_topk(a=64, n=4096, k=k, d=256, dtype=dtype) + out = TritonDecoderAutograd.apply(idx, vals, w) + ref = reference_decode(idx, vals, w) + atol, rtol = (1e-3, 1e-3) if dtype == torch.float32 else (5e-2, 5e-2) + torch.testing.assert_close(out.float(), ref.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_backward_matches_reference(dtype): + idx, vals0, w0 = _random_topk(a=128, n=8192, k=32, d=256, dtype=dtype) + grad_seed = torch.randn(128, 256, device=DEVICE, dtype=dtype).contiguous() + + # Triton path + vals_t = vals0.clone().requires_grad_(True) + w_t = w0.clone().requires_grad_(True) + (TritonDecoderAutograd.apply(idx, vals_t, w_t) * grad_seed).sum().backward() + + # Dense reference path + vals_r = vals0.clone().requires_grad_(True) + w_r = w0.clone().requires_grad_(True) + (reference_decode(idx, vals_r, w_r) * grad_seed).sum().backward() + + # fp32 is the strict correctness gate. bf16 grads are inherently coarse: a + # magnitude-~20 dot product has bf16 ulp ~0.06-0.12, so a few-percent relative + # tolerance with a matching atol is the right bar (the kernel itself accumulates + # in fp32 and matches the fp64 truth — verified separately). + atol, rtol = (2e-3, 2e-3) if dtype == torch.float32 else (3e-1, 3e-2) + torch.testing.assert_close(vals_t.grad.float(), vals_r.grad.float(), atol=atol, rtol=rtol) + torch.testing.assert_close(w_t.grad.float(), w_r.grad.float(), atol=atol, rtol=rtol) + + +def test_topksae_dense_vs_triton_parity(): + """End-to-end: dense and triton TopKSAE give matching loss + param grads.""" + from sae.architectures import TopKSAE + + torch.manual_seed(0) + x = torch.randn(256, 128, device=DEVICE) + + def build(impl): + torch.manual_seed(123) + sae = TopKSAE(input_dim=128, hidden_dim=1024, top_k=16, normalize_input=True, decoder_impl=impl) + return sae.to(DEVICE) + + dense = build("dense") + triton = build("triton") + triton.load_state_dict(dense.state_dict()) # identical weights + + ld = dense.loss(x) + lt = triton.loss(x) + torch.testing.assert_close(lt["total"], ld["total"], atol=1e-3, rtol=1e-3) + torch.testing.assert_close(lt["mse"], ld["mse"], atol=1e-3, rtol=1e-3) + + ld["total"].backward() + lt["total"].backward() + for (n_d, p_d), (n_t, p_t) in zip(dense.named_parameters(), triton.named_parameters()): + assert n_d == n_t + if p_d.grad is None and p_t.grad is None: + continue + torch.testing.assert_close(p_t.grad, p_d.grad, atol=2e-3, rtol=2e-3, msg=f"grad mismatch: {n_d}") + + +def test_topksae_parity_with_auxk(): + """Parity including the auxk dead-latent path (codes=None path in triton).""" + from sae.architectures import TopKSAE + + torch.manual_seed(1) + x = torch.randn(256, 64, device=DEVICE) + + def build(impl): + torch.manual_seed(7) + sae = TopKSAE( + input_dim=64, + hidden_dim=512, + top_k=8, + normalize_input=True, + auxk=32, + dead_tokens_threshold=0, + decoder_impl=impl, # threshold 0 -> many "dead" -> exercises auxk + ) + return sae.to(DEVICE) + + dense, triton = build("dense"), build("triton") + triton.load_state_dict(dense.state_dict()) + # Prime dead-latent stats identically with one step. + dense.loss(x)["total"].backward() + triton.loss(x)["total"].backward() + + ld, lt = dense.loss(x), triton.loss(x) + torch.testing.assert_close(lt["total"], ld["total"], atol=2e-3, rtol=2e-3) + assert "aux" in lt and "aux" in ld diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_streaming.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_streaming.py new file mode 100644 index 0000000000..fae0fc1e9c --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_streaming.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for producer-consumer activation streaming (sae.streaming). + +All tests are model-agnostic and CPU-only: the "producer" is a plain Python +generator of tensors, so nothing here loads a base model or touches a GPU. +""" + +import gc +import math + +import pytest +import torch +from sae.streaming import ( + StreamingActivationDataset, + StreamingConfig, + make_streaming_dataloader, +) + + +HIDDEN_DIM = 4 + + +def make_factory(chunk_sizes, hidden_dim=HIDDEN_DIM, on_produce=None): + """Build a producer factory yielding chunks whose rows carry a unique id. + + Row global index ``g`` becomes the tensor row ``[g, g, ...]`` so order and + membership can be checked exactly. ``on_produce(chunk_index)`` is called as + each chunk is produced (for instrumentation). + """ + + def factory(): + g = 0 + for ci, n in enumerate(chunk_sizes): + if on_produce is not None: + on_produce(ci) + rows = torch.arange(g, g + n, dtype=torch.float32).unsqueeze(1).repeat(1, hidden_dim) + g += n + yield rows + + return factory + + +def collect(dataset): + """Drain a streaming dataset into a single concatenated tensor of batches.""" + batches = list(iter(dataset)) + return batches + + +def test_all_tokens_consumed_in_order(): + chunk_sizes = [5, 7, 4] # 16 rows total + total = sum(chunk_sizes) + ds = StreamingActivationDataset(make_factory(chunk_sizes), batch_size=4, config=StreamingConfig()) + + batches = collect(ds) + out = torch.cat(batches, dim=0) + + assert out.shape == (total, HIDDEN_DIM) + # No shuffle -> exact order preserved, every token exactly once. + expected = torch.arange(total, dtype=torch.float32).unsqueeze(1).repeat(1, HIDDEN_DIM) + assert torch.equal(out, expected) + + +def test_batch_shape_and_dtype(): + ds = StreamingActivationDataset(make_factory([10, 10]), batch_size=4, config=StreamingConfig()) + batches = collect(ds) + + # 20 rows, batch 4 -> 5 full batches, no remainder. + assert all(b.shape == (4, HIDDEN_DIM) for b in batches) + assert all(b.dtype == torch.float32 for b in batches) + assert sum(b.shape[0] for b in batches) == 20 + + +def test_partial_last_batch_kept_by_default(): + ds = StreamingActivationDataset(make_factory([10]), batch_size=4, config=StreamingConfig()) + batches = collect(ds) + sizes = [b.shape[0] for b in batches] + assert sizes == [4, 4, 2] # final partial batch retained + + +def test_drop_last(): + ds = StreamingActivationDataset(make_factory([10]), batch_size=4, config=StreamingConfig(drop_last=True)) + batches = collect(ds) + sizes = [b.shape[0] for b in batches] + assert sizes == [4, 4] # 2-row remainder dropped + assert sum(sizes) == 8 + + +def test_producer_exception_propagates(): + def factory(): + yield torch.zeros(4, HIDDEN_DIM) + raise RuntimeError("boom") + + ds = StreamingActivationDataset(factory, batch_size=2, config=StreamingConfig(queue_size=1)) + with pytest.raises(RuntimeError, match="boom"): + list(iter(ds)) # must raise, not hang + + +def test_multi_epoch_refreshes_producer(): + calls = [] + factory = make_factory([8], on_produce=lambda ci: calls.append(ci)) + ds = StreamingActivationDataset(factory, batch_size=4, config=StreamingConfig()) + + first = torch.cat(collect(ds), dim=0) + second = torch.cat(collect(ds), dim=0) + + # Producer factory was invoked once per pass (one chunk each). + assert len(calls) == 2 + assert torch.equal(first, second) + + +def test_backpressure_bounds_in_flight(): + produced = [] + queue_size = 2 + # Each chunk is one full batch; a fast producer would otherwise race ahead. + factory = make_factory([4] * 1000, on_produce=lambda ci: produced.append(ci)) + + ds = StreamingActivationDataset(factory, batch_size=4, config=StreamingConfig(queue_size=queue_size)) + it = iter(ds) + next(it) # consume a single batch, then stop early + it.close() # triggers cleanup (stop event + drain + join) + gc.collect() + + # With backpressure the producer cannot run away: at most queue_size buffered + # + one blocked put + a little slack. Far below the 1000 available chunks. + assert len(produced) <= queue_size + 3 + + +def test_shuffle_buffer_is_seeded_permutation(): + chunk_sizes = [16, 16, 16, 16] # 64 rows + total = sum(chunk_sizes) + cfg = StreamingConfig(shuffle_buffer_size=32, seed=123) + + ds_a = StreamingActivationDataset(make_factory(chunk_sizes), batch_size=8, config=cfg) + ds_b = StreamingActivationDataset( + make_factory(chunk_sizes), batch_size=8, config=StreamingConfig(shuffle_buffer_size=32, seed=123) + ) + + out_a = torch.cat(collect(ds_a), dim=0) + out_b = torch.cat(collect(ds_b), dim=0) + + # Same seed -> identical (deterministic) output order. + assert torch.equal(out_a, out_b) + # Multiset preserved: every original token present exactly once. + ids = out_a[:, 0].sort().values + expected = torch.arange(total, dtype=torch.float32) + assert torch.equal(ids, expected) + # And it actually shuffled (not identity order). + assert not torch.equal(out_a[:, 0], expected) + + +def test_trainer_fit_streaming_smoke(): + from sae.architectures import TopKSAE + from sae.training import Trainer, TrainingConfig + + torch.manual_seed(0) + input_dim, hidden_dim, batch_size = 8, 16, 32 + + def factory(): + for _ in range(8): # 8 chunks * 32 rows = 256 tokens + yield torch.randn(batch_size, input_dim) + + dataloader = make_streaming_dataloader( + factory, batch_size=batch_size, config=StreamingConfig(enabled=True, queue_size=4) + ) + + sae = TopKSAE(input_dim=input_dim, hidden_dim=hidden_dim, top_k=4) + trainer = Trainer(sae, TrainingConfig(n_epochs=1, batch_size=batch_size, device="cpu", log_interval=1000)) + + final_loss = trainer.fit(dataloader) + assert isinstance(final_loss, float) + assert math.isfinite(final_loss) + + +def test_multiple_producers_all_consumed(): + """Multiple producer factories feed one queue; every token is consumed once.""" + + def make(ids): + def factory(): + for g in ids: + yield torch.full((1, HIDDEN_DIM), float(g)) + + return factory + + a = list(range(0, 10)) + b = list(range(100, 117)) + ds = StreamingActivationDataset([make(a), make(b)], batch_size=4, config=StreamingConfig(queue_size=2)) + + out = torch.cat(collect(ds), dim=0) + assert out.shape[0] == len(a) + len(b) + got = sorted(int(v) for v in out[:, 0].tolist()) + assert got == sorted(a + b) # union of both producers, each row exactly once + + +def test_max_steps_stops_across_streaming_epochs(): + """max_steps stops at an exact step budget, looping the producer across epochs.""" + from sae.architectures import TopKSAE + from sae.training import Trainer, TrainingConfig + + torch.manual_seed(0) + input_dim, batch_size = 8, 4 + + # 5 batches per producer pass; max_steps=12 needs ~3 passes (producer re-runs). + def factory(): + for _ in range(5): + yield torch.randn(batch_size, input_dim) + + dl = make_streaming_dataloader(factory, batch_size=batch_size, config=StreamingConfig(enabled=True)) + sae = TopKSAE(input_dim=input_dim, hidden_dim=16, top_k=4) + trainer = Trainer( + sae, + TrainingConfig(n_epochs=1, batch_size=batch_size, device="cpu", max_steps=12, log_interval=10_000), + ) + trainer.fit(dl) + assert trainer.global_step == 12 diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_checkpoint.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_checkpoint.py new file mode 100644 index 0000000000..ac96314297 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_checkpoint.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""B0: sharded checkpoint save -> merge round-trips to the dense model (single process).""" + +import pytest +import torch +from sae.architectures import ShardedTopKSAE, TopKSAE +from sae.parallel import load_and_merge, save_sharded + + +D, N, K, B = 16, 64, 8, 32 + + +@pytest.mark.parametrize("world", [2, 4]) +def test_shard_save_merge_roundtrip(tmp_path, world): + torch.manual_seed(0) + dense = TopKSAE(input_dim=D, hidden_dim=N, top_k=K, normalize_input=True, auxk=16) + out = str(tmp_path) + + for r in range(world): + sh = ShardedTopKSAE(D, N, K, r, world, normalize_input=True, auxk=16) + sh.load_shard_from_dense(dense) + save_sharded(sh, out, rank=r) + + merged = load_and_merge(out) + torch.testing.assert_close(merged.encoder.weight, dense.encoder.weight) + torch.testing.assert_close(merged.decoder.weight, dense.decoder.weight) + torch.testing.assert_close(merged.latent_bias, dense.latent_bias) + torch.testing.assert_close(merged.pre_bias, dense.pre_bias) + + # Merged dense model reproduces the original dense outputs. + torch.manual_seed(1) + x = torch.randn(B, D) + rd, _ = dense(x) + rm, _ = merged(x) + torch.testing.assert_close(rm, rd, atol=1e-6, rtol=1e-6) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_comms.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_comms.py new file mode 100644 index 0000000000..fc149cd773 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_comms.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tensor-parallel comms helpers (CPU / gloo).""" + +import pytest +import torch +from sae.parallel import all_gather_cat, all_reduce_sum, autograd_all_reduce_sum + +from ._dist_utils import run_distributed + + +def _w_all_gather(rank, world): + t = torch.tensor([rank * 10.0, rank * 10.0 + 1]) + out = all_gather_cat(t, dim=0) + expected = torch.cat([torch.tensor([r * 10.0, r * 10.0 + 1]) for r in range(world)]) + assert torch.equal(out, expected), (rank, out, expected) + return out + + +def _w_all_reduce(rank, world): + t = torch.full((3,), float(rank + 1)) + all_reduce_sum(t) + expected = float(sum(range(1, world + 1))) + assert torch.allclose(t, torch.full((3,), expected)), (rank, t) + return t + + +def _w_autograd_all_reduce(rank, world): + x = (torch.ones(2) * (rank + 1)).requires_grad_(True) + y = autograd_all_reduce_sum(x) + y.sum().backward() + total = float(sum(range(1, world + 1))) + assert torch.allclose(y.detach(), torch.full((2,), total)), (rank, y) + # d(sum_b sum_r x_r)/dx_r = 1 -> grad is ones on every rank. + assert torch.allclose(x.grad, torch.ones(2)), (rank, x.grad) + return x.grad + + +@pytest.mark.parametrize("world", [2, 4]) +def test_all_gather_cat(world): + run_distributed(_w_all_gather, world) + + +@pytest.mark.parametrize("world", [2, 4]) +def test_all_reduce_sum(world): + run_distributed(_w_all_reduce, world) + + +@pytest.mark.parametrize("world", [2, 4]) +def test_autograd_all_reduce_sum(world): + run_distributed(_w_autograd_all_reduce, world) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_global_topk.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_global_topk.py new file mode 100644 index 0000000000..58a525d5cb --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_global_topk.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sharded global top-k (CPU / gloo). + +The sharded global top-k must exactly match a single-process torch.topk over the +full (concatenated) latent dimension, and must partition the selections so each is +owned by exactly one rank. +""" + +import pytest +import torch +from sae.parallel import global_topk + +from ._dist_utils import run_distributed + + +B, N, K, SEED = 16, 256, 8, 0 + + +def _full_pre_act(): + # Distinct continuous values (randn) -> no ties, so indices match exactly. + torch.manual_seed(SEED) + return torch.randn(B, N) + + +def _w_global_topk(rank, world): + full = _full_pre_act() + latents_per_rank = N // world + local = full[:, rank * latents_per_rank : (rank + 1) * latents_per_rank].contiguous() + res = global_topk(local, K, rank, latents_per_rank) + return (res.global_values, res.global_indices, res.owned_mask, res.local_indices) + + +@pytest.mark.parametrize("world", [2, 4]) +def test_global_topk_matches_dense(world): + results = run_distributed(_w_global_topk, world) + full = _full_pre_act() + dense_vals, dense_idx = torch.topk(full, K, dim=-1) + latents_per_rank = N // world + + gv0, gidx0, _, _ = results[0] + torch.testing.assert_close(gv0, dense_vals) + assert torch.equal(gidx0, dense_idx), (gidx0, dense_idx) + + # Global selection is replicated identically on every rank. + for r in range(1, world): + assert torch.equal(results[r][1], gidx0) + + # Ownership partitions the selections: exactly one rank owns each (b, j). + owned_stack = torch.stack([results[r][2].int() for r in range(world)], dim=0) # [world, B, K] + assert torch.equal(owned_stack.sum(0), torch.ones(B, K, dtype=torch.int32)) + + # Owned local indices map back to the global indices. + for r in range(world): + _, gidx, owned, lidx = results[r] + recon_global = lidx + r * latents_per_rank + assert torch.equal(recon_global[owned], gidx[owned]) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_init.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_init.py new file mode 100644 index 0000000000..d846983eeb --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_init.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sharded init_pre_bias_from_data matches dense (single process; pre_bias is replicated).""" + +import pytest +import torch +from sae.architectures import ShardedTopKSAE, TopKSAE + + +@pytest.mark.parametrize("normalize", [True, False]) +def test_init_pre_bias_matches_dense(normalize): + torch.manual_seed(0) + data = torch.randn(500, 16) + + dense = TopKSAE(input_dim=16, hidden_dim=64, top_k=8, normalize_input=normalize) + dense.init_pre_bias_from_data(data) + + sh = ShardedTopKSAE(16, 64, 8, rank=0, world_size=2, normalize_input=normalize) + sh.init_pre_bias_from_data(data) + + torch.testing.assert_close(sh.pre_bias, dense.pre_bias) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_loss.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_loss.py new file mode 100644 index 0000000000..f1bacc7b58 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_loss.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity tests for ShardedTopKSAE.loss() incl. metrics, dead_pct, and auxk (CPU/gloo). + +dead_tokens_threshold=0 forces dead latents after the first step so dead_pct and the +auxk path are actually exercised. The sharded loss dict must match the dense one. +""" + +import pytest +import torch +from sae.architectures import TopKSAE +from sae.architectures.topk_tp import ShardedTopKSAE + +from ._dist_utils import run_distributed + + +D, N, K, B = 16, 64, 8, 32 + + +def _build_dense(normalize, auxk): + torch.manual_seed(0) + dense = TopKSAE( + input_dim=D, + hidden_dim=N, + top_k=K, + normalize_input=normalize, + auxk=auxk, + dead_tokens_threshold=0, + ) + torch.manual_seed(1) + x = torch.randn(B, D) + return dense, x + + +def _w_loss(rank, world, normalize, auxk): + dense, x = _build_dense(normalize, auxk) + sh = ShardedTopKSAE( + D, + N, + K, + rank, + world, + normalize_input=normalize, + auxk=auxk, + dead_tokens_threshold=0, + decoder_impl="dense", + ) + sh.load_shard_from_dense(dense) + sh.loss(x) # prime dead-latent stats + out = sh.loss(x) + return {k: v.detach() for k, v in out.items()} + + +@pytest.mark.parametrize("world", [2, 4]) +@pytest.mark.parametrize("normalize,auxk", [(True, None), (False, None), (True, 16)]) +def test_sharded_loss_matches_dense(world, normalize, auxk): + res = run_distributed(_w_loss, world, args=(normalize, auxk)) + + dense, x = _build_dense(normalize, auxk) + dense.loss(x) # prime + out_d = dense.loss(x) + tol = {"atol": 1e-5, "rtol": 1e-5} + + keys = ["total", "fvu", "sparsity", "mse", "variance_explained", "dead_pct"] + if auxk is not None: + keys.append("aux") + for key in keys: + torch.testing.assert_close(res[0][key], out_d[key], msg=lambda m, k=key: f"{k}: {m}", **tol) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_topk.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_topk.py new file mode 100644 index 0000000000..72bf543305 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_topk.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parity tests: ShardedTopKSAE == dense TopKSAE (CPU / gloo). + +Shard a dense TopKSAE's weights across ranks and assert the sharded forward +reconstruction and all parameter gradients match the dense model. +""" + +import pytest +import torch +import torch.nn.functional as F +from sae.architectures import TopKSAE +from sae.architectures.topk_tp import ShardedTopKSAE + +from ._dist_utils import run_distributed + + +D, N, K, B = 16, 64, 8, 32 + + +def _build_dense(normalize): + torch.manual_seed(0) + dense = TopKSAE(input_dim=D, hidden_dim=N, top_k=K, normalize_input=normalize) + torch.manual_seed(1) + x = torch.randn(B, D) + return dense, x + + +def _w_parity(rank, world, normalize): + dense, x = _build_dense(normalize) + sh = ShardedTopKSAE(D, N, K, rank, world, normalize_input=normalize, decoder_impl="dense") + sh.load_shard_from_dense(dense) + recon, _ = sh(x) + F.mse_loss(recon, x).backward() + return { + "recon": recon.detach(), + "W_enc": sh.W_enc_local.grad.detach(), + "W_dec": sh.W_dec_local.grad.detach(), + "lb": sh.latent_bias_local.grad.detach(), + "pre_bias": sh.pre_bias.grad.detach(), + } + + +@pytest.mark.parametrize("world", [2, 4]) +@pytest.mark.parametrize("normalize", [True, False]) +def test_sharded_matches_dense(world, normalize): + res = run_distributed(_w_parity, world, args=(normalize,)) + + dense, x = _build_dense(normalize) + recon_d, _ = dense(x) + F.mse_loss(recon_d, x).backward() + L = N // world + tol = {"atol": 1e-5, "rtol": 1e-5} + + # Forward reconstruction (replicated; check rank 0). + torch.testing.assert_close(res[0]["recon"], recon_d, **tol) + + # Sharded parameter grads == the corresponding dense slices. + for r in range(world): + torch.testing.assert_close(res[r]["W_enc"], dense.encoder.weight.grad[r * L : (r + 1) * L, :], **tol) + torch.testing.assert_close(res[r]["W_dec"], dense.decoder.weight.grad[:, r * L : (r + 1) * L], **tol) + torch.testing.assert_close(res[r]["lb"], dense.latent_bias.grad[r * L : (r + 1) * L], **tol) + + # Replicated pre_bias grad: sum across ranks == dense (the pre_bias/world_size trick). + pre_bias_sum = sum(res[r]["pre_bias"] for r in range(world)) + torch.testing.assert_close(pre_bias_sum, dense.pre_bias.grad, **tol) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_train_step.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_train_step.py new file mode 100644 index 0000000000..c89ad9d0f0 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/tests/test_tp_train_step.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""B1: one optimizer step on the sharded model matches the dense model (CPU / gloo). + +Validates that sharded grads + the replicated-pre_bias all-reduce produce the same +parameter update as training the dense TopKSAE. +""" + +import pytest +import torch +from sae.architectures import ShardedTopKSAE, TopKSAE + +from ._dist_utils import run_distributed + + +D, N, K, B = 16, 64, 8, 32 + + +def _build_dense(normalize): + torch.manual_seed(0) + dense = TopKSAE(input_dim=D, hidden_dim=N, top_k=K, normalize_input=normalize) + torch.manual_seed(1) + x = torch.randn(B, D) + return dense, x + + +def _w_step(rank, world, normalize): + dense, x = _build_dense(normalize) + sh = ShardedTopKSAE(D, N, K, rank, world, normalize_input=normalize, decoder_impl="dense") + sh.load_shard_from_dense(dense) + opt = torch.optim.Adam(sh.parameters(), lr=1e-3) + sh.loss(x)["total"].backward() + sh.reduce_replicated_grads() + opt.step() + return { + "W_enc": sh.W_enc_local.detach(), + "W_dec": sh.W_dec_local.detach(), + "lb": sh.latent_bias_local.detach(), + "pre_bias": sh.pre_bias.detach(), + } + + +@pytest.mark.parametrize("world", [2, 4]) +@pytest.mark.parametrize("normalize", [True, False]) +def test_one_step_matches_dense(world, normalize): + res = run_distributed(_w_step, world, args=(normalize,)) + + dense, x = _build_dense(normalize) + opt = torch.optim.Adam(dense.parameters(), lr=1e-3) + dense.loss(x)["total"].backward() + opt.step() + + L = N // world + tol = {"atol": 1e-5, "rtol": 1e-5} + for r in range(world): + torch.testing.assert_close(res[r]["W_enc"], dense.encoder.weight[r * L : (r + 1) * L, :], **tol) + torch.testing.assert_close(res[r]["W_dec"], dense.decoder.weight[:, r * L : (r + 1) * L], **tol) + torch.testing.assert_close(res[r]["lb"], dense.latent_bias[r * L : (r + 1) * L], **tol) + torch.testing.assert_close(res[r]["pre_bias"], dense.pre_bias, **tol) # replicated