Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
save_cluster_labels,
save_feature_atlas,
)
from .architectures import MoESAE, ReLUSAE, SparseAutoencoder, TopKSAE
from .architectures import MoESAE, ReLUSAE, ShardedTopKSAE, SparseAutoencoder, TopKSAE
from .autointerp import (
DEFAULT_PROMPT_TEMPLATE,
TOKEN_PROMPT_TEMPLATE,
Expand Down Expand Up @@ -80,8 +80,10 @@
evaluate_sae,
evaluate_sparsity,
)
from .kernels import HAS_TRITON, TritonDecoderAutograd
from .perf_logger import PerfLogger
from .process_group_manager import ProcessGroupManager
from .streaming import StreamingActivationDataset, StreamingConfig, make_streaming_dataloader
from .training import ParallelConfig, Trainer, TrainingConfig, WandbConfig
from .utils import get_device, set_seed

Expand All @@ -90,6 +92,7 @@

__all__ = [
"DEFAULT_PROMPT_TEMPLATE",
"HAS_TRITON",
"TOKEN_PROMPT_TEMPLATE",
"ActivationStore",
"ActivationStoreConfig",
Expand Down Expand Up @@ -117,14 +120,18 @@
"PerfLogger",
"ProcessGroupManager",
"ReLUSAE",
"ShardedTopKSAE",
"SparseAutoencoder",
"SparsityMetrics",
"StreamingActivationDataset",
"StreamingConfig",
"TokenActivationCollector",
"TokenExample",
"TopExample",
"TopKSAE",
"Trainer",
"TrainingConfig",
"TritonDecoderAutograd",
"WandbConfig",
"build_cluster_label_prompt",
"compute_cluster_centroids",
Expand All @@ -140,6 +147,7 @@
"get_device",
"launch_dashboard",
"load_activations",
"make_streaming_dataloader",
"save_activations",
"save_cluster_labels",
"save_feature_atlas",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
from .moe import MoESAE
from .relu_l1 import ReLUSAE
from .topk import TopKSAE
from .topk_tp import ShardedTopKSAE


__all__ = [
"MoESAE",
"ReLUSAE",
"ShardedTopKSAE",
"SparseAutoencoder",
"TopKSAE",
]
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,26 @@ def __init__(
dead_tokens_threshold: int = 10_000_000,
init_encoder_from_decoder: bool = True,
init_pre_bias: bool = True,
decoder_impl: str = "dense",
):
"""Initialize the Top-K SAE with encoder, decoder, and optional auxiliary loss."""
"""Initialize the Top-K SAE with encoder, decoder, and optional auxiliary loss.

``decoder_impl`` selects the decode path: "dense" (default) builds the dense
[batch, hidden_dim] code tensor and runs a full decoder matmul; "triton"
decodes directly from the top-k (indices, values) via a sparse kernel
(O(batch*k*d), no dense code tensor), enabling much larger hidden_dim. Weights
are identical, so checkpoints are interchangeable between the two.
"""
super().__init__(input_dim, hidden_dim)
self.top_k = top_k
self.init_pre_bias = init_pre_bias
self.normalize_input = normalize_input
self.auxk = auxk
self.auxk_coef = auxk_coef
self.dead_tokens_threshold = dead_tokens_threshold
if decoder_impl not in ("dense", "triton"):
raise ValueError(f"decoder_impl must be 'dense' or 'triton', got {decoder_impl!r}")
self.decoder_impl = decoder_impl

# Pre-bias (subtracted from normalized input, added to output before denorm)
self.pre_bias = nn.Parameter(torch.zeros(input_dim))
Expand Down Expand Up @@ -208,9 +219,40 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
top_k_vals, top_k_indices = torch.topk(codes_relu, self.top_k, dim=-1)
codes = torch.zeros_like(codes_relu).scatter(-1, top_k_indices, top_k_vals)

recon = self.decode(codes, info)
if self.decoder_impl == "triton":
recon = self._decode_topk_triton(top_k_vals, top_k_indices, info)
else:
recon = self.decode(codes, info)
return recon, codes

def _decode_topk_triton(
self,
top_k_vals: torch.Tensor,
top_k_indices: torch.Tensor,
info: Optional[Dict[str, torch.Tensor]] = None,
denormalize: bool = True,
) -> torch.Tensor:
"""Decode from top-k (values, indices) via the sparse Triton kernel.

Returns reconstruction with pre_bias added; denormalized to input scale when
``denormalize`` (set False to get the normalized-space recon for aux loss).
"""
from ..kernels import TritonDecoderAutograd

recon = TritonDecoderAutograd.apply(top_k_indices.contiguous(), top_k_vals.contiguous(), self.decoder.weight)
recon = recon + self.pre_bias
if denormalize and self.normalize_input and info is not None:
recon = self._denormalize(recon, info)
return recon

def _update_dead_latent_stats_from_indices(self, top_k_indices: torch.Tensor, n_tokens: int) -> None:
"""Update stats_last_nonzero from top-k indices (no dense [batch, hidden] tensor)."""
active_mask = torch.zeros_like(self.stats_last_nonzero, dtype=torch.bool)
active_mask[top_k_indices.reshape(-1)] = True
self.stats_last_nonzero = torch.where(
active_mask, torch.zeros_like(self.stats_last_nonzero), self.stats_last_nonzero + n_tokens
)

def forward_with_aux(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Forward pass with auxiliary info for auxk loss computation.

Expand Down Expand Up @@ -257,8 +299,9 @@ def _compute_auxk_loss(
x: torch.Tensor,
recon: torch.Tensor,
pre_act: torch.Tensor,
codes: torch.Tensor,
codes: Optional[torch.Tensor],
norm_info: Optional[Dict[str, torch.Tensor]] = None,
recon_norm: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute auxiliary loss for dead latents.

Expand Down Expand Up @@ -293,8 +336,10 @@ def _compute_auxk_loss(
if self.normalize_input and norm_info is not None:
# Normalize x to match the space where encoding happened
x_norm = (x - norm_info["mu"]) / norm_info["std"]
# Reuse codes from forward pass instead of re-encoding
recon_norm = self.decoder(codes) + self.pre_bias
# Reuse codes from forward pass instead of re-encoding (or a precomputed
# normalized recon, e.g. from the sparse/triton decode path).
if recon_norm is None:
recon_norm = self.decoder(codes) + self.pre_bias
residual = x_norm - recon_norm.detach()
else:
residual = x - recon.detach() + self.pre_bias.detach()
Expand Down Expand Up @@ -375,6 +420,9 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
- aux (if auxk enabled): auxiliary loss value
- dead_pct (if auxk enabled): percentage of dead latents
"""
if self.decoder_impl == "triton":
return self._loss_triton(x)

# Forward pass with auxiliary info
info = self.forward_with_aux(x)
recon = info["recon"]
Expand Down Expand Up @@ -422,3 +470,53 @@ def loss(self, x: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
result["aux"] = aux_loss

return result

def _loss_triton(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""loss() using the sparse Triton decoder.

Numerically equivalent to the dense loss() but never materializes the dense
[batch, hidden_dim] code tensor or runs the full decoder matmul: it decodes
from the top-k (values, indices) and derives dead-latent stats / L0 from the
indices. This is what lets hidden_dim scale to ~1M+.
"""
pre_act, info = self.encode_pre_act(x)
codes_relu = torch.relu(pre_act)
top_k_vals, top_k_indices = torch.topk(codes_relu, self.top_k, dim=-1)

# Sparse decode in normalized space (pre_bias added); denormalize for the main loss.
recon_norm = self._decode_topk_triton(top_k_vals, top_k_indices, info, denormalize=False)
recon = self._denormalize(recon_norm, info) if (self.normalize_input and info) else recon_norm

# Dead-latent stats from indices (no dense codes tensor).
self._update_dead_latent_stats_from_indices(top_k_indices, x.shape[0])

# Primary reconstruction loss (FVU), centered by pre_bias -- matches dense loss().
mse = (recon - x).pow(2).mean(dim=-1)
x_var = (x - self.pre_bias).pow(2).mean(dim=-1)
recon_loss = (mse / (x_var + 1e-8)).mean()

# For TopK, L0 == count of nonzero top-k values.
l0 = (top_k_vals != 0).float().sum(dim=-1).mean()

with torch.no_grad():
raw_mse = (recon - x).pow(2).mean()
total_var = torch.var(x, dim=0).sum()
residual_var = torch.var(recon - x, dim=0).sum()
var_explained = 1.0 - (residual_var / (total_var + 1e-8))

result = {
"total": recon_loss,
"fvu": 1.0 - var_explained,
"sparsity": l0,
"mse": raw_mse,
"variance_explained": var_explained,
}
dead_pct = (self.stats_last_nonzero > self.dead_tokens_threshold).float().mean() * 100
result["dead_pct"] = dead_pct

if self.auxk is not None:
aux_loss = self._compute_auxk_loss(x, recon, pre_act, codes=None, norm_info=info, recon_norm=recon_norm)
result["total"] = recon_loss + self.auxk_coef * aux_loss
result["aux"] = aux_loss

return result
Loading
Loading