Skip to content

[feat] SID generation — RQ-VAE and RQ-KMeans models (FAISS-trained)#517

Open
WhiteSwan1 wants to merge 41 commits into
alibaba:masterfrom
WhiteSwan1:remove_ema_2
Open

[feat] SID generation — RQ-VAE and RQ-KMeans models (FAISS-trained)#517
WhiteSwan1 wants to merge 41 commits into
alibaba:masterfrom
WhiteSwan1:remove_ema_2

Conversation

@WhiteSwan1
Copy link
Copy Markdown
Collaborator

Summary

Introduces two new SID-generation models to tzrec plus the supporting tzrec.modules.sid_generation module stack:

  • SidRqvae — end-to-end RQ-VAE (Encoder + Residual VQ + Decoder), trainable via STE or Gumbel-Softmax, with optional CLIP contrastive learning for paired multimodal inputs.
  • SidRqkmeans — encoder-free residual K-Means, trained in one shot by FAISS at the end of the train_eval loop ( gradient-free).

Both produce N-layer semantic IDs (code_0, …, code_{n_layers-1}) that downstream generative-retrieval models can consume.

What's in scope

  • Models (tzrec/models/):

    • sid_rqvae.py + sid_rqvae_test.pySidRqvae(BaseModel), no-clip + mixed clip+recon paths.
    • sid_rqkmeans.py + sid_rqkmeans_test.pySidRqkmeans(BaseModel), FAISS-only training.
    • _sid_helpers.py — shared parse_int_list / parse_float_list for SID proto fields.
    • model.py — adds BaseModel.on_train_end() no-op lifecycle hook (overridden by SidRqkmeans for the FAISS fit).
  • Modules (tzrec/modules/sid_generation/):

    • rqvae.pyRQVAE (encoder/decoder MLP + quantizer), supports use_clip mixed-mode.
    • residual_quantized.py — gradient-trained multi-layer residual VQ used by RQ-VAE.
    • residual_kmeans.pyResidualKMeans + RQKMeans wrapper (FAISS-trained centroid store).
    • kmeans.pyKMeansLayer centroid container + KMeans / KMeans++ utilities.
    • vector_quantize.py — single VQ layer with Sinkhorn uniform assignment, STE / Gumbel-Softmax forward.
    • clip_loss.pyCLIPLoss and MaskedCLIPLoss (gradient-preserving GatherLayer for DDP all-gather, three-way contrastive loss).
    • types.pyQuantizeForwardMode, QuantizeOutput, ResidualQuantizedOutput.
  • Protos (tzrec/protos/models/sid_model.proto + tzrec/protos/model.proto):

    • SidRqvae with sub-messages SinkhornConfig and ClipConfig.
    • SidRqkmeans (codebook + normalize_residuals + faiss_kmeans_kwargs Struct).
    • Both registered under ModelConfig.
  • Wiring (tzrec/main.py):

    • Two-line addition: _model.on_train_end() is called once after train_and_evaluate exits the loop. No-op for every model except SidRqkmeans, which uses the hook to fit FAISS over the buffered embeddings.

Test plan

  • Unit testspython -m unittest tzrec.models.sid_rqkmeans_test tzrec.models.sid_rqvae_test11/11 pass (5 RQKMeans, 6 RQVAE; covers train / eval / inference modes, CLIP all-recon and all-clip edge cases, backward pass).
  • Lifecycle smokepython ft_scripts/lifecycle_smoke.py drives construct → init_loss → init_metric → train (predict/loss/backward/update_train_metric) → on_train_end → eval (predict/loss/update_metric/compute_metric) → set_is_inference(True) → infer, for SidRqkmeans, SidRqvae no-clip, and SidRqvae CLIP. All three pass.
  • Backward-compat — existing model configs in master are unaffected; the two new entries are additions to ModelConfig's oneof.

WhiteSwan1 and others added 25 commits May 20, 2026 02:52
The online Mini-Batch K-Means backend is removed. ResidualKMeans /
RQKMeans / SidRqkmeans now always train the codebook once via FAISS
at end-of-loop through `flush_offline_fit`; `forward` is read-only
(assignment + centroid lookup).

  - MiniBatchKMeans: stripped to a centroid container (load_centroids_
    + predict). train_step, cluster_counts, init_buffer, offline-lock,
    unlock_for_online_finetune_ are gone.
  - ResidualKMeans / RQKMeans: drop the train_mode and init_buffer_size
    parameters; train_offline is the only training entry point.
  - SidRqkmeans: predict always buffers during training and runs
    assignment during eval; legacy `train_mode` proto field is silently
    accepted (default "online") but a warning fires if it is explicitly
    set to a non-FAISS value.
  - Tests trimmed to FAISS-only paths; smoke test
    ft_scripts/maxcompute_smoke_test.py runs RQKMeans + RQVAE over
    forge_contrastive_item_embedding.csv. FAISS codes and RQ-VAE
    training trajectory match the pre-refactor code_style baseline
    bit-for-bit; full suite is 25/25 green.

Diff stat: -577 / +135 lines.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- SidRqkmeans proto: remove the legacy ``init_buffer_size`` (tag 5)
  and ``train_mode`` (tag 6) fields. Their numbers and names are
  ``reserved`` so they can never be silently reused by future fields.
  Regenerate sid_model_pb2.{py,pyi} with scripts/gen_proto.sh (or
  ``python -m grpc_tools.protoc -I . tzrec/protos/models/sid_model.proto
  --python_out=. --pyi_out=.``).
- SidRqkmeans model: drop the deprecation guard now that the field is
  gone; drop the corresponding legacy test.
- Rename ``MiniBatchKMeans`` → ``KMeansLayer``. After dropping the
  online path the class is a per-layer centroid container (centroids
  buffer + load_centroids_ + predict + an is_initialized guard), and
  the old name was misleading. PyTorch state-dict keys are scoped by
  attribute path (``layers.<i>.centroids``), so this rename does not
  affect checkpoint compatibility.

Verified: 24/24 unit tests pass; the FAISS smoke test on
forge_contrastive_item_embedding.csv produces bit-identical codes and
RQ-VAE losses vs the pre-rename baseline.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rgs to tag 5

The reservations were defensive against textproto / binary-proto mixups
on the dropped ``init_buffer_size`` (tag 5) and ``train_mode`` (tag 6).
tzrec configs are textproto-only, which key fields by name, not tag —
old configs already fail to parse as "unknown field" regardless of the
tag table — so the reservation block buys no real safety here. Drop it
and tighten the schema to a contiguous 1/3/4/5/40 tag set.

Regenerate sid_model_pb2.{py,pyi} via scripts/gen_proto.sh.

Verified: 24/24 unit tests pass; FAISS smoke test on
forge_contrastive_item_embedding.csv produces bit-identical codes and
RQ-VAE losses; both demo .config files parse cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move ``_kmeans`` and ``_residual_kmeans`` (Lloyd's K-Means + residual
variant used by ResidualQuantized.init_embed_) from residual_quantized.py
into kmeans.py, next to ``_kmeans_plus_plus`` / ``_squared_euclidean_distance``
and the ``KMeansLayer`` centroid container. After this move:

  - kmeans.py is the single home for torch-native K-Means code in the
    SID-generation stack — distance helper, KMeans++ seeding, Lloyd's
    iterator, residual variant, and the per-layer centroid container.
  - residual_quantized.py is a quantizer module; it now imports the
    single function it needs (``_residual_kmeans``) instead of pulling
    in two private helpers and re-implementing Lloyd's on top.
  - No functional change. RQ-VAE warm-start remains pure torch / GPU /
    dependency-free — FAISS was not introduced here on purpose: the
    init runs once on ~2k × 64 encoder outputs, where FAISS would add
    a CPU round-trip + a hard dep for no measurable win.

Verified: 24/24 unit tests pass; FAISS smoke test on
forge_contrastive_item_embedding.csv produces bit-identical codes and
RQ-VAE losses.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a no-op ``on_train_end`` method on ``BaseModel`` and override it
in ``SidRqkmeans`` (renamed from ``flush_offline_fit``). Subclasses that
need end-of-loop work override it; the rest get the no-op default.

``tzrec.main.train_and_evaluate`` now calls ``_model.on_train_end()``
unconditionally — no more ``hasattr`` duck-typing of a SID-specific
method on every model.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
``reference_code`` was a teacher-codebook-injection hook ported from
upstream al_sid (al_sid/SID_generation/rqvae_embed/{rqvae,quantizations}.py).
It was a defensive pass-through in both projects: the parameter
threads ``RQVAE → ResidualQuantized → VectorQuantize`` and, when set,
overrides ~4 % of distance-assigned ids with the supplied teacher
codes (vector_quantize.py:407-411). Nothing in tzrec or al_sid ever
supplies a non-None value — no caller, no config knob, no test.

Drop:
  - the ``reference_code`` kwarg from ``VectorQuantize.forward`` and
    its Step 3 random-replacement block;
  - the ``reference_code`` kwarg from ``ResidualQuantized.forward``
    and the per-layer slicing that fed it into VectorQuantize;
  - docstrings + step numbering updated accordingly.

Verified: 11/11 surviving SID model tests pass; FAISS smoke test on
forge_contrastive_item_embedding.csv produces bit-identical codes and
RQ-VAE losses.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two-part change driven by an end-to-end lifecycle audit (see new
ft_scripts/lifecycle_smoke.py).

Comment pass:
  - sid_rqkmeans.py / sid_rqvae.py / rqvae.py / residual_quantized.py /
    residual_kmeans.py: drop comments that only narrate the next line
    ("# Eval metrics", "# Reconstruction MSE", "# Unique SID ratio",
    "# Parse n_embed list", etc.). Keep every comment whose value is
    *why* (gradient / dtype / memory rationale, DDP invariants, paper
    references, step numbering tied to docstrings).
  - rqvae.py: drop the duplicate ``self.use_clip = use_clip`` set inside
    the ``if use_clip:`` block — the same assignment runs unconditionally
    a few lines earlier.

SidRqkmeans bug fix:
  - During training, predict() returns dummy zero codes (the codebook
    does not exist until on_train_end() fits FAISS). The previous
    init_metric / update_train_metric registered + updated mse,
    rel_loss, unique_sid_ratio for the train path, which produced
    ``MeanMetric.compute() == nan`` for mse / rel_loss (never updated)
    and a trivially constant ``1/B`` for unique_sid_ratio.
  - Fix: drop the train-metric registrations and turn
    update_train_metric into a documented no-op. compute_train_metric
    now returns an empty dict, which ``tzrec.main`` already tolerates
    via ``if train_metrics:``.

Lifecycle driver:
  - ft_scripts/lifecycle_smoke.py exercises construct → init_loss →
    init_metric → train (predict/loss/backward/update_train_metric) →
    on_train_end → eval (predict/loss/update_metric/compute_metric) →
    set_is_inference(True) → infer for SidRqkmeans, SidRqvae(no-clip),
    SidRqvae(clip). Asserts shapes, codes ranges, and that no metric
    comes back non-finite. All three pass.

Verified:
  - 11/11 existing unit tests pass.
  - lifecycle_smoke.py: ALL LIFECYCLES OK. RQ-VAE (no-clip) loss
    1.63 -> 1.49 over 8 Adam steps; RQ-VAE (clip) 15.17 -> 9.33.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The scripts under ft_scripts/ are personal smoke / one-off experiment
drivers that have always been treated as local-only by everyone except
the three files that happened to get committed along the way
(lifecycle_smoke.py, maxcompute_smoke_test.py, sid_rqvae_recon.config).
Stop tracking those three and add ft_scripts/ to .gitignore so future
edits stay local.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Translate the Chinese comments in EmaConfig, SinkhornConfig, ClipConfig,
and SidRqvae (network structure / quantization strategy / optional
sub-module sections) to English. No semantic changes; defaults, field
numbers, and field types are untouched.

Regenerate sid_model_pb2.{py,pyi} via scripts/gen_proto.sh (gitignored).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ft_scripts/ configs use neither EMA (no ``ema_config`` block) nor the
``restart_unused_codes`` knob, and the fused ``ema_mask`` pass-through
chain has no consumer. Strip the whole feature surface end-to-end.

Proto (sid_model.proto, regen pb2 via grpc_tools.protoc):
  - drop ``message EmaConfig`` and ``SidRqkmeans.ema_config = 14``.
    Field 14 is left unallocated (no reserved block, consistent with
    the previous proto tightening).

Code:
  - SidRqvae: drop the EMA sub-config parse + the three kwargs
    (``use_ema``, ``ema_decay``, ``restart_unused_codes``) it forwarded.
  - RQVAE: drop the same three ctor args + docs + propagation.
  - ResidualQuantized: drop the three ctor args, ``self.use_ema``,
    the ``ema_mask`` parameter on ``forward``, and the
    EMA-zeros-loss2 branch in ``_single_commitment_loss``. The
    commitment loss now always sums both directions (mathematically
    equivalent to running today with ``use_ema=False`` explicitly).
  - VectorQuantize: drop ``use_ema`` / ``ema_decay`` /
    ``restart_unused_codes`` / ``ema_mask``, the
    ``cluster_size_ema`` / ``embed_ema`` buffers,
    ``_update_ema_buffers``, ``_update_embedding_from_ema``,
    ``_tile_with_noise``, and the dead-code-restart machinery.
    Commitment loss collapses to the single
    ``commitment_weight * e_latent_loss + q_latent_loss`` formula.

Side effects worth flagging:
  - ``ema_mask`` was the mechanism by which mixed-mode CLIP path-2
    skipped EMA updates for recon rows. Without EMA those rows now
    contribute to the commitment-loss gradient through both paths (the
    recon-row contribution shows up in both quant1 and quant2 because
    fea2 == fea1 there). The masked CLIP loss itself is unchanged.
  - Old SidRqvae checkpoints carrying ``cluster_size_ema`` /
    ``embed_ema`` buffers will need ``strict=False`` to load.

Verified: 11/11 unit tests pass; lifecycle_smoke.py — all three
scenarios pass; RQ-VAE no-clip 1.44 -> 1.34, RQ-VAE clip 13.14 -> 11.20
over 8 Adam steps.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 20, 2026

CLA assistant check
All committers have signed the CLA.

… loop index

  - sid_rqvae_test.py: drop the ``losses =`` assignments in the
    all-recon / all-clip CLIP edge-case tests (the predictions dict
    carries the asserted values; loss() is kept as a bare call to
    ensure it doesn't raise).
  - clip_loss.py::GatherLayer.forward / backward: add one-line
    docstrings.
  - residual_quantized.py::ResidualQuantized.forward: drop the unused
    ``i`` enumerate index in the per-layer loop.

Verified: 11/11 SID unit tests pass; lifecycle_smoke.py green for all
three scenarios (RQKMeans, RQVAE no-clip, RQVAE CLIP).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
WhiteSwan1 and others added 3 commits May 22, 2026 07:34
torch.unique(codes, dim=0).shape[0] forces a GPU->host sync every
training step because reading .shape[0] of a variable-length unique
output materializes it. unique_sid_ratio is a codebook-coverage
diagnostic, not a training signal — keep it on the eval path only,
where it runs at lower frequency and the cost is amortized at metric
flush.

Train path now logs only mse; eval path is unchanged.

(SidRqkmeans's train-path update_train_metric is already a no-op, so
this commit only touches SidRqvae.)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t is_clip_pair flag

Forward used to route rows via ``torch.all(embedding == fea2, dim=-1)``
— a bit-exact equality check on float tensors that silently mislabels
samples on any upstream float cast, normalization, or numerical noise.

Introduce a required per-sample boolean column (``is_clip_pair``,
value_dim=1) emitted by the FG layer. The model now reads the flag
directly and routes via ``flag > 0.5``. No back-compat fallback: the
column is required in clip mode; missing it raises cleanly.

Proto: ``ClipConfig.is_clip_pair_feature_name`` (required string).
Existing clip-mode tests are updated to populate the flag in the
batch (no longer relying on bit-exact identity). One new test
``test_clip_mask_uses_flag_not_equality`` constructs a batch where
``image_emb == item_emb`` numerically but ``is_clip_pair=1`` — proves
the row routes to CLIP under the new logic (it would have been
silently relabeled recon under the old logic).

The ft_scripts/ working-area configs are updated locally (untracked)
to add the column and wire the field. Downstream pipelines must add a
matching raw_feature mapping (e.g. ``expression: "item:is_contrastive"``
when the source dataset uses that column name).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Five new test docstrings tripped the D205 "blank line between summary
and description" check; restructure each to have a one-line summary,
blank line, then body. ruff-format also wanted two long
register_buffer / _extract_feature calls collapsed to a single line
(removed unnecessary wrapping).

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@WhiteSwan1 WhiteSwan1 added the claude-review Let Claude Review label May 22, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label May 22, 2026
Comment thread tzrec/modules/sid_generation/kmeans.py Outdated
# later resume-then-infer would silently return dummy zero codes
# instead of raising. Each load must re-run ``load_centroids_``
# (or the FAISS fit) to set the flag.
self.register_buffer("_is_initialized", torch.tensor(False), persistent=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical — resume-then-infer silently returns zero codes.

_is_initialized is persistent=False, so on checkpoint resume the buffer is reset to False even though centroids was restored. ResidualKMeans.forward then takes the else branch and returns torch.zeros(...) for every layer (residual_kmeans.py:135-138) — no exception, no warning. The comment above accurately predicts this scenario ("would silently return dummy zero codes") but the code does not raise either; it just emits zeros.

There is no set_extra_state/on_load_state_dict hook that re-flips this flag after a load, so users running evaluate/predict against a saved RQKMeans checkpoint get garbage SIDs.

Suggested fix: derive initialization from centroids.any() at use sites, or raise in forward/get_codes when not is_initialized rather than silently returning zeros. Otherwise resume-then-eval is broken.

Comment thread tzrec/protos/models/sid_model.proto Outdated
optional bool normalize_residuals = 7 [default = false];
// Distance metric: "l2" or "cosine".
optional string distance_type = 9 [default = "l2"];
// Commitment loss type: "l2" or "cos".
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ResidualQuantized.__init__ asserts commitment_loss in ("l2", "l1", "cos") (residual_quantized.py:90) and has a dedicated l1 branch in _single_commitment_loss. The proto comment, the RQVAE.__init__ docstring (rqvae.py:52-53), and the in-code "flow step 3" comment (residual_quantized.py:297) all advertise only "l2"/"cos". l1 is a fully reachable, supported value — please document it everywhere or remove it from the assert.

Suggested change
// Commitment loss type: "l2" or "cos".
// Commitment loss type: "l2", "l1", or "cos".

Comment thread tzrec/models/sid_rqvae.py
Comment on lines +71 to +75
# Only forward latent_weight when proto sets it; otherwise let
# RQVAE / ResidualQuantized apply their signature default (1.0, 0.5).
rqvae_extra: Dict[str, Any] = {}
if cfg.latent_weight:
rqvae_extra["latent_weight"] = parse_float_list(cfg.latent_weight)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proto declares optional string latent_weight = 11 [default = "1.0,0.5"] (sid_model.proto:54). In proto2 Python, accessing an unset field with a default returns the default literal — so cfg.latent_weight is always the truthy string "1.0,0.5" and the if cfg.latent_weight: guard is always True. The "let RQVAE / ResidualQuantized apply their signature default" branch is unreachable.

Not a behavioral bug (both paths yield (1.0, 0.5)), but the comment misrepresents control flow. If the intent really is "only override when user set it explicitly", drop the proto default and use cfg.HasField("latent_weight"). Otherwise simplify the comment and just always forward the value.

Comment on lines +222 to +233
def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
"""Dispatch based on use_clip.

use_clip=False: forward(x) -> forward_rqvae(x)
use_clip=True: forward(fea1, fea2, clip_mask) -> forward_mixed(...)
"""
if self._is_inference or not self.use_clip:
assert len(args) >= 1, "Standard mode requires (x,)"
return self.forward_rqvae(args[0], **kwargs)
else:
assert len(args) == 3, "Mixed mode requires (fea1, fea2, clip_mask)"
return self.forward_mixed(args[0], args[1], args[2], **kwargs)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._is_inference is assigned False in __init__ (line 106) and never written anywhere else — there is no setter and RQVAE subclasses nn.Module directly, not BaseModule. The dispatch reduces to if not self.use_clip:. SidRqvae already calls forward_rqvae/forward_mixed directly and doesn't reach into RQVAE._is_inference, so it works "by accident".

Either drop the dead _is_inference flag and the misleading dispatch branch, or actually plumb BaseModule.set_is_inference through to the inner RQVAE so the documented behavior is real.

Comment on lines +280 to +293
# DDP path: every rank ships its local buffer to rank 0 via
# gather_object (variable-length pickle — fine for this one-
# shot, CPU-resident gather). Only rank 0 holds the corpus,
# so peak memory is O(world_size) on rank 0 and O(1) elsewhere
# (vs O(world_size²) for all_gather_object).
local = torch.cat(self._offline_buffer, dim=0)
del self._offline_buffer
self._offline_buffer = []

rank = dist.get_rank()
gathered: Optional[List[Optional[torch.Tensor]]] = (
[None] * dist.get_world_size() if rank == 0 else None
)
dist.gather_object(local, gathered, dst=0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two concerns with the gather:

  1. NCCL incompatibility. dist.gather_object requires a backend that supports object collectives. Production GPU training typically initialises a single NCCL group and gather_object raises RuntimeError: ProcessGroupNCCL does not support gather_object (PyTorch documents this). Consider initialising a side gloo group at module construction and passing it as group=, or replace with a tensor-based dist.gather after padding to a uniform per-rank shape.

  2. Memory claim is off. The comment says peak rank-0 memory is O(world_size). gather_object pickles every rank's tensor into a Python bytes then deserialises into gathered, so rank 0 holds all world_size per-rank tensors plus the pickle bytestream simultaneously — closer to ~3× total corpus at peak. For large datasets this is a real OOM risk. Either chunk the gather (every rank ships fixed-size chunks in a loop) or write embeddings to per-rank shard files during predict and have rank 0 stream them.

Comment on lines +39 to +53
def _coerce_proto_numbers(d: Dict) -> Dict:
"""Coerce float-typed integers back to int.

``google.protobuf.Struct.number_value`` is always float, but most
``faiss.Kmeans`` kwargs (``niter``, ``seed``, ``nredo``, ...) require
Python ``int``. This helper converts any float that is an exact
integer to ``int`` for downstream consumption.
"""
out: Dict = {}
for k, v in d.items():
if isinstance(v, float) and v.is_integer():
out[k] = int(v)
else:
out[k] = v
return out
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output is forwarded straight through to faiss.Kmeans(D, K, **self.faiss_kmeans_kwargs) (residual_kmeans.py:244) with no whitelist or schema validation. A misspelled or unsupported kwarg (e.g. niters instead of niter, or gpu=True on a faiss-cpu build) raises inside faiss.Kmeans.__init__ only on rank 0, after gather_object has already pulled the entire corpus to rank 0. The other ranks then block on the subsequent broadcast/barrier until the NCCL watchdog tears the job down — no actionable error.

Two-step fix: (a) whitelist known FAISS keys (niter, nredo, seed, verbose, spherical, gpu, min/max_points_per_centroid, ...) and reject unknowns at construction time so the failure is symmetric across ranks; (b) construct the faiss.Kmeans object early (in __init__ or right before the gather) so config errors fire before the expensive cross-rank gather.

Comment on lines +76 to +80
# Global batch size for distributed training
if is_distributed and dist.is_initialized():
B = Q.size(1) * dist.get_world_size()
else:
B = Q.size(1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B = Q.size(1) * dist.get_world_size() assumes uniform per-rank batch sizes. On the final training step (or any step with drop_last=False and uneven shards) this is wrong: each rank computes a different denominator at Q /= B (line 100) and a different scale at Q *= B (line 103). argmax is invariant to a positive per-row scalar so the final codes stay consistent — but the resulting Q values do not match across ranks, which biases any downstream code that relies on the assignment magnitudes.

Suggested fix: all_reduce the local Q.size(1) once at the top to get the true global B (one extra collective is negligible vs the ones already inside the loop). Also, dist.is_initialized() alone isn't sufficient — guard on dist.get_world_size() > 1 to avoid no-op collectives in single-process toy runs.

# Training: buffer for the end-of-loop FAISS fit and return dummy
# codes — the codebook does not exist yet.
if self.is_train:
self._offline_buffer.append(embedding.detach().cpu())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two issues that compound at scale:

  1. Synchronous D2H every training step. embedding.detach().cpu() is a sync copy on a non-pinned host buffer — it blocks the CUDA stream until the copy completes. Use embedding.detach().to('cpu', non_blocking=True) against a pinned host buffer if you want overlap. Per-step impact is small but adds up.

  2. Unbounded host-memory growth. The buffer grows linearly with steps × batch_size × D × 4 bytes per rank, and rank 0 later holds world_size × that during the gather (see the comment on the gather block). For a 100M-row dataset with D=512 you're looking at ~200 GB before the gather and ~world_size × that on rank 0 — a near-certain OOM at production scale.

Consider streaming per-step embeddings to a pre-allocated np.memmap of known max size (or per-rank chunked .npy files) and mmap-reading them at FAISS fit time. For very large datasets, reservoir-sampling at the buffer step is usually sufficient for K-Means quality (~256 × K samples is plenty).

Comment on lines +201 to +204
# --- Safe labels: recon rows fallback to first clip column ---
labels = self.labels
fallback = clip_mask.long().argmax() # first clip sample index
safe_labels = torch.where(clip_mask, labels, fallback.expand_as(labels))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fallback = clip_mask.long().argmax() is computed from the local mask but safe_labels indexes into the global gathered logits (offset local_batch_size * self._rank + arange). When this rank has no clip rows, argmax returns 0 → safe_labels points at local column 0 of the global logits, which may itself be a recon column (all -inf). cross_entropy then yields NaN, which the nan_to_num at line 134-135 swallows before the mask multiply.

The current code is safe only because nan_to_num runs before * clip_mask.float()0 × NaN = NaN, so the ordering is load-bearing. Any later "optimization" that swaps the order, or that removes the seemingly-defensive nan_to_num, silently poisons the loss.

Suggested fix: compute fallback from the gathered mask (clip_mask_all.long().argmax()) so safe_labels always points at a real clip column globally; then the nan_to_num becomes unnecessary and the contract is explicit. Also consider reading dist.get_rank() lazily in forward instead of caching at __init__ (line 104) — modules can be constructed before init_process_group.

@github-actions
Copy link
Copy Markdown

Review summary — SID generation (RQ-VAE / RQ-KMeans)

Substantial, well-organized PR. The DDP on_train_end vote-then-bail logic, the persistent=False decision on _is_initialized (intent is correct even if the resume path is incomplete — see inline), the FAISS in-place residual updates, and the GatherLayer.backward slice-then-reduce trick are all good calls. Inline comments cover individual concerns; below are themes worth tracking at the PR level.

Cross-cutting issues to consider

  • Checkpoint resume of SidRqkmeans is broken. Even with _is_initialized correctly non-persistent, there is no on_load hook that re-flips the flag after centroids are restored, so evaluate/predict against a saved checkpoint silently emits zero codes. See inline on kmeans.py:232.
  • on_train_end lifecycle. Called only on the normal-exit path in _train_and_evaluate — if training raises (OOM, KeyboardInterrupt, dataloader error), the FAISS fit never runs but the docstring at sid_rqkmeans.py:17 says "called unconditionally." Either wrap in try/finally or update the docstring. Also worth deciding whether a failed/empty fit should block the terminal checkpoint write so users don't silently ship unusable checkpoints.
  • Distributed code paths are untested. GatherLayer.backward, _sinkhorn's distributed reductions, and SidRqkmeans.on_train_end's DDP branch (gather_object/broadcast/empty-vote) all have zero test coverage — only single-process paths are exercised. A torch.multiprocessing.spawn smoke test on gloo with world_size=2 comparing MaskedCLIPLoss.forward → backward against a single-process reference would catch the most likely class of bug.
  • Other untested branches: forward_mode="gumbel_softmax", rotation_trick=True, kmeans_init=True, distance_type="cosine", normalize_residuals=True, loss_type in {"l1","cosine"}, RQVAE.decode_codes/get_codes, and the ResidualKMeans.train_offline tensor-input path. Several test assertions are also "exists / requires_grad" rather than numerical (e.g. test_commitment_loss_l1_branch doesn't verify the l1 value actually differs from l2; test_rqvae_backward accepts any one non-zero grad anywhere). Adding setUp: torch.manual_seed(0) to both test classes would also remove a latent flakiness source.
  • test_on_train_end_runs_faiss is silently skipped without faiss. If CI lacks FAISS the entire ResidualKMeans.train_offline (and KMeansLayer.load_centroids_) codepath becomes unverified while CI looks green. Either make FAISS a required test dep or add a fallback test using the pure-torch _residual_kmeans to validate the layered residual math.
  • Sinkhorn perf in DDP. Each iteration of _sinkhorn issues one all_reduce(sum_of_rows), so per training step you pay n_layers × (1 + sinkhorn_iters) small NCCL collectives (18 with defaults). Mostly fixed launch overhead, but worth either batching across layers or documenting the cost.
  • forward_mixed doubles encoder/decoder/quantizer cost for all rows in CLIP mode, including pure-recon rows whose fea2 pass is wasted (siamese semantics only require it for clip rows). Acceptable if CLIP batches are mostly clip pairs; worth documenting otherwise.

Doc/comment drift

  • Proto comment, RQVAE.__init__ docstring, and ResidualQuantized.forward step-3 comment all advertise commitment_loss as "l2"|"cos", but the code accepts "l1" too (see inline on the proto).
  • sid_rqvae.py:71-75 comment describes a fallback path that the proto default makes unreachable (see inline).
  • gather_object block memory comment claims O(world_size) rank-0 memory — actual is closer to ~3× corpus due to pickle (see inline).

Nice work overall. The critical items are the checkpoint-resume gap, the gather_object NCCL/memory situation, and the missing distributed tests; everything else is cleanup.

WhiteSwan1 and others added 2 commits May 22, 2026 10:06
…-time guard

Commit 85932d8 made _is_initialized non-persistent to protect against
checkpoints taken mid-FAISS-fit. That fix had a regression: a legitimate
POST-fit checkpoint also loses the flag (centroids restore, but the flag
defaults to False on load), so predict() then silently returns zero
codes. The two states became indistinguishable to the loader.

Make _is_initialized persistent again so normal post-fit save/load
round-trips work. Catch the mid-fit-checkpoint case at load time via
_load_from_state_dict: if the flag loaded as True but centroids sum to
zero, the framework appends a clear error message (raised under
strict=True load_state_dict).

Two regression tests:
- test_post_fit_checkpoint_round_trips: fit -> save -> load fresh ->
  predict produces non-zero codes.
- test_mid_fit_checkpoint_rejected_on_load: tampered state with flag=True
  and zero centroids raises with the "mid-FAISS-fit" hint.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…oto doc

- sid_model.proto: commitment_loss comment now lists "l1" (Commit 2
  added the branch but missed the proto-side doc).
- sid_rqkmeans.py: add TODO at the .cpu() per-step site documenting the
  D2H stall + unbounded buffer concern (Commit 4 was deliberately
  skipped; this records why so future readers don't repeat the audit).
- sid_rqkmeans.py: collapse the 10-line DDP collective-check rationale
  to 3 lines.
- sid_rqvae.py: drop redundant comments on _is_clip_pair_feature_name,
  Sinkhorn parse, and the clip_mask extraction — the proto/field name
  + code already convey the intent.
- residual_quantized.py: drop the "mirrors l2 branch" comment on the L1
  case; the branch reads obviously next to its siblings.

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@WhiteSwan1 WhiteSwan1 added the claude-review Let Claude Review label May 22, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label May 22, 2026
Comment on lines +250 to +254
if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX:
weights = _gumbel_softmax_sample(
-distances, temperature=temperature, hard=True
)
emb = weights @ self.embedding.weight
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness — Gumbel-Softmax does not propagate gradient into the encoder.

_compute_distances (and _find_nearest_embedding) are decorated with @torch.no_grad(), so distances returned here is fully detached. In the Gumbel branch:

weights = _gumbel_softmax_sample(-distances, temperature=temperature, hard=True)
emb = weights @ self.embedding.weight

weights is a function only of the detached distances, so gradient flows into self.embedding.weight but not back through x into the encoder. The STE branch happens to work because x + (q-x).detach() explicitly routes the gradient — Gumbel has no such shortcut.

Effect: choosing forward_mode="gumbel_softmax" silently trains only the codebook, not the encoder.

Fix: remove @torch.no_grad() from _compute_distances (the argmin/argmax step is naturally non-differentiable), or split out a _compute_distances_no_grad/_compute_distances_with_grad pair and use the differentiable one for the Gumbel logits.

Comment on lines +183 to +188
# Average per-layer centroids across DDP ranks so every rank
# starts from the same codebook.
if dist.is_initialized() and dist.get_world_size() > 1:
for c in centers:
dist.all_reduce(c, op=dist.ReduceOp.SUM)
c /= dist.get_world_size()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness — averaging KMeans centroids across DDP ranks produces meaningless init.

_kmeans_plus_plus (kmeans.py:116–128) uses torch.randint/torch.multinomial with the per-rank default RNG, so each rank picks different seeds from its own local batch and runs Lloyd to different local minima. Cluster indices are unordered, so "centroid 0" on rank A and "centroid 0" on rank B correspond to different concepts. all_reduce(SUM)/world_size then averages permutation-misaligned centroids — the result is closer to noise than to a meaningful warm start.

The comment "every rank starts from the same codebook" is true but disguises that the codebook is junk.

Fix: run KMeans on rank 0 only, then dist.broadcast(c, src=0) — mirrors what SidRqkmeans.on_train_end already does for FAISS. Bonus: world_size× less compute.

Comment on lines +119 to +128
for i in range(1, n_clusters):
dists = _squared_euclidean_distance(data, centroids[:i]) # (N, i)
min_dists = dists.min(dim=1)[0] # (N,)
if min_dists.sum() == 0:
centroids[i:] = data[
torch.randint(0, N, (n_clusters - i,), device=data.device)
]
break
next_idx = torch.multinomial(min_dists, num_samples=1)
centroids[i] = data[next_idx]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perf — KMeans++ is O(K²·N), should be O(K·N).

Each loop iteration calls _squared_euclidean_distance(data, centroids[:i]), recomputing distances to all previously selected centroids. Standard KMeans++ tracks a running min_dists and only computes the distance to the newly selected centroid:

min_dists = _squared_euclidean_distance(data, centroids[:1]).squeeze(1)
for i in range(1, n_clusters):
    if min_dists.sum() == 0:
        centroids[i:] = data[torch.randint(0, N, (n_clusters - i,), device=data.device)]
        break
    next_idx = torch.multinomial(min_dists, num_samples=1)
    centroids[i] = data[next_idx]
    new_dists = _squared_euclidean_distance(data, centroids[i:i+1]).squeeze(1)
    min_dists = torch.minimum(min_dists, new_dists)

For K=256, N=batch_size this is ~256× more work than necessary on the first training step. Combined with _kmeans having no convergence early-stop (always 100 iters), the first-step startup stall under kmeans_init=True is substantial.

Also: torch.randint/torch.multinomial here use the global default generator with no seed — KMeans++ is non-deterministic across runs even with torch.manual_seed set elsewhere. Consider accepting/threading a torch.Generator.

Comment on lines +230 to +233
@property
def is_initialized(self) -> bool:
"""Whether centroids have been injected via ``load_centroids_``."""
return self._is_initialized.item()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perf — .item() on the hot path forces a GPU→host sync every forward.

is_initialized is read inside ResidualKMeans.forward (residual_kmeans.py:129) and all_initialized (residual_kmeans.py:101) on every eval batch, so each call serializes the CUDA stream. The flag is flipped exactly once (in load_centroids_ from on_train_end) — there's no reason to re-read the device tensor every step.

Fix: cache as a Python bool attribute, set in load_centroids_ and re-derived from the buffer in _load_from_state_dict. Keep _is_initialized as a buffer for checkpoint round-trip, but stop reading it on the forward path. For N_layers=3 this turns eval into 3 syncs per batch.

Comment on lines +282 to +300
local = torch.cat(self._offline_buffer, dim=0)
del self._offline_buffer
self._offline_buffer = []

rank = dist.get_rank()
gathered: Optional[List[Optional[torch.Tensor]]] = (
[None] * dist.get_world_size() if rank == 0 else None
)
dist.gather_object(local, gathered, dst=0)
del local
if rank == 0:
assert gathered is not None
full = torch.cat([g for g in gathered if g is not None], dim=0)
del gathered
logger.info(
"[SidRqkmeans.on_train_end] rank0 fitting FAISS "
"on %d samples (D=%d)." % (full.shape[0], full.shape[1])
)
self._rqkmeans.train_offline(full, verbose=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scalability — dist.gather_object of the full per-rank buffer will OOM rank 0 on realistic corpora.

Each rank concats its host buffer (line 282), then ships via gather_object (which pickle-serializes the tensor, ~2× transient bytes), then rank 0 concats again. Peak rank-0 memory is roughly 2 × world_size × per_rank_corpus × D × 4B. For a 10M-row × 64-D corpus across 8 ranks that's ~20 GB on rank 0 alone — taking down DDP training right at the end, after hours of compute.

Combined with the per-step .cpu() D2H (line 137, already TODO'd), the FAISS-only path is the largest operational risk in the PR.

Fix options:

  1. Write per-rank buffers to a shared filesystem (torch.save(buffer, f"{model_dir}/sid_embed_rank{rank}.pt")), dist.barrier(), then have rank 0 stream them into a pre-allocated np.empty((total_N, D)) (size known via an all_gather of row counts first).
  2. Reservoir-sample the per-rank buffer down to a configured cap (FAISS K-Means rarely benefits from > ~256 × K samples).

At minimum, expose a sample cap and add a guard that errors loudly when the gather would exceed available host memory.

Comment thread tzrec/models/sid_rqvae.py
Comment on lines +67 to +71
# Only forward latent_weight when proto sets it; otherwise let
# RQVAE / ResidualQuantized apply their signature default (1.0, 0.5).
rqvae_extra: Dict[str, Any] = {}
if cfg.latent_weight:
rqvae_extra["latent_weight"] = parse_float_list(cfg.latent_weight)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment contradicts behavior — if cfg.latent_weight: is always True.

The proto field has optional string latent_weight = 11 [default = "1.0,0.5"];, so cfg.latent_weight returns "1.0,0.5" (truthy) when unset and the user-provided string otherwise. The "let signature default apply" branch is unreachable; latent_weight is always forwarded.

If the intent is to detect "user explicitly set vs proto default", use cfg.HasField("latent_weight") — but that only works if you drop the proto-level default. Otherwise just always forward and delete the conditional + comment.

The same shape applies to commitment_loss (cfg defaults to "l2" per sid_model.proto:52) — the RQVAE.commitment_loss=None defaulting in rqvae.py:111 is dead from the proto-driven path.

Comment on lines +38 to +54
optional string hidden_dims = 3;
// Per-layer codebook size, comma-separated, e.g. "256,256,256".
// List length is the number of residual quantization layers;
// non-uniform codebooks such as "512,256,128" are supported.
optional string codebook = 5;

// === Quantization strategy ===
// VQ forward mode: "ste" or "gumbel_softmax".
optional string forward_mode = 6 [default = "ste"];
// L2-normalize residuals before each quantization layer.
optional bool normalize_residuals = 7 [default = false];
// Distance metric: "l2" or "cosine".
optional string distance_type = 9 [default = "l2"];
// Commitment loss type: "l2", "l1" or "cos".
optional string commitment_loss = 10 [default = "l2"];
// Commitment loss weights [w1, w2], comma-separated.
optional string latent_weight = 11 [default = "1.0,0.5"];
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stringly-typed list fields are inconsistent with the rest of tzrec.

hidden_dims, codebook, and latent_weight are optional string parsed at runtime via _sid_helpers.parse_int_list / parse_float_list. Everywhere else in tzrec these are repeated uint32 / repeated float (see optimizer.proto:226-228, tower.proto:158, most of feature.proto).

Cost of the string encoding:

  • Defers all validation to model construction (int(x.strip()) raises ValueError on a typo instead of failing at proto load).
  • Disables proto-level type checking, IDE completion, and text_format validation.
  • Requires the ad-hoc _sid_helpers.py shim.
  • The empty-string-truthy bug noted on sid_rqvae.py:67-71 would not exist with a repeated field (callers would always pass the list as-is).

Suggest repeated uint32 hidden_dims;, repeated uint32 codebook;, repeated float latent_weight;, then delete _sid_helpers.py.

Also: the comment "non-uniform codebooks such as 512,256,128 are supported" disagrees with ResidualKMeans.__init__ (residual_kmeans.py:82-85) which asserts uniformity for SidRqkmeans. Clarify the per-model rule.

Comment thread tzrec/models/sid_rqvae.py
mse = F.mse_loss(predictions["x_hat"], embedding, reduction="mean")
self._metric_modules["mse"].update(mse)

unique_sids = torch.unique(codes, dim=0).shape[0]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perf — torch.unique(..., dim=0).shape[0] plus int / int triggers a GPU→host sync per eval batch.

torch.unique with dim=0 does a CUDA sort + dedupe; shape[0] then forces a host sync, and unique_sids / B (Python int / int) produces a Python float, syncing again. The same pattern is in sid_rqkmeans.py:229-230.

Fix:

# pack codes per-row into a single int64 — unique is much faster on 1-D
mults = torch.tensor([K ** i for i in range(codes.shape[1])], device=codes.device)
packed = (codes.long() * mults).sum(dim=1)
unique_sids = torch.unique(packed).numel()  # tensor-level until .update
self._metric_modules["unique_sid_ratio"].update(
    torch.tensor(unique_sids, device=codes.device, dtype=torch.float) / B
)

Or buffer packed codes across the eval epoch and run unique once at epoch end.

normalize_residuals (bool): L2-normalize residuals. Default: False.
distance_type (str|List[str]): distance metric ('l2'|'cosine').
Default: 'l2'.
commitment_loss (str|None): commitment loss type ('l2'|'cos').
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring drops 'l1'. ResidualQuantized asserts commitment_loss in ("l2", "l1", "cos") (residual_quantized.py:90), and sid_model.proto:51 also documents the three options.

Suggested change
commitment_loss (str|None): commitment loss type ('l2'|'cos').
commitment_loss (str|None): commitment loss type ('l2'|'l1'|'cos').

Comment thread tzrec/main.py
if lr.by_epoch:
lr.step()

_model.on_train_end()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lifecycle ordering — on_train_end runs before the tail-save, but the in-loop checkpoint at last_ckpt_step was taken before the FAISS fit.

For SidRqkmeans, if the last in-loop save coincides with the final training step (last_ckpt_step == i_step), the tail-save block at lines 529-548 is skipped and the only persisted artifact has _is_initialized=False with zero centroids. Reloading that checkpoint silently yields a model that returns dummy-zero codes — KMeansLayer._load_from_state_dict's mid-fit guard does not fire because the flag is False (which reads as "uninitialized" rather than corrupted).

Either:

  • Always force a final save after on_train_end() (independent of last_ckpt_step), or
  • Have SidRqkmeans.on_train_end() overwrite the most recent checkpoint after the FAISS fit (only rank 0), or
  • Refuse state_dict() serialization on KMeansLayer until _is_initialized=True.

@github-actions
Copy link
Copy Markdown

Review summary

Big, well-scoped PR. Models and module stack are cleanly separated, tests cover the happy paths for both SidRqvae and SidRqkmeans, and the on_train_end() hook is a reasonable way to land an offline-fit model in the existing train loop. Below is a consolidated punch list — inline comments have the full detail on each item.

Correctness (should block merge)

  • Gumbel-Softmax does not propagate gradient into the encoder. _compute_distances is @torch.no_grad(), so the Gumbel branch only trains the codebook. STE works only by accident of its x + (q-x).detach() shortcut. See vector_quantize.py:250-254.
  • DDP kmeans_init averages permutation-misaligned centroids. Each rank seeds KMeans++ independently and the all_reduce(SUM)/world_size blends unrelated clusters. Should be rank-0 fit + broadcast, like SidRqkmeans.on_train_end. See residual_quantized.py:183-188.
  • on_train_end lifecycle vs. final-step checkpoint. If last_ckpt_step == i_step, the only persisted state is pre-FAISS-fit and the mid-fit corruption guard does not catch it. See main.py:514.

Performance / scalability

  • KMeans++ is O(K²·N). Should incrementally update min_dists. Combined with _kmeans having no convergence early-stop (always 100 iters), this is a visible first-step stall. kmeans.py:119-128, 156-169.
  • KMeansLayer.is_initialized .item() per forward. Per-batch GPU→host sync on the eval path; cache as a Python bool. kmeans.py:230-233.
  • dist.gather_object of the full per-rank buffer. Rank-0 OOM risk at realistic corpus sizes; pickle adds a 2× transient. sid_rqkmeans.py:282-300. Combine with the already-TODO'd per-step .cpu() D2H (sid_rqkmeans.py:137) — the FAISS path is the largest operational risk.
  • torch.unique(codes, dim=0).shape[0] host sync per eval batch. Pack codes to int64 first and keep arithmetic on-device. sid_rqvae.py:258, sid_rqkmeans.py:229-230.

Code quality / API consistency

  • Protos use comma-separated string for hidden_dims / codebook / latent_weight while the rest of tzrec uses repeated fields. Drop _sid_helpers.py and use repeated uint32 / repeated float. sid_model.proto:38-54.
  • latent_weight "default detection" comment contradicts behavior. Proto default is "1.0,0.5" (non-empty), so if cfg.latent_weight: is always True. Same shape applies to commitment_loss. sid_rqvae.py:67-71, rqvae.py:111-114.
  • RQVAE.forward(*args, **kwargs) dispatch. Hides the signature; both call sites already invoke forward_rqvae / forward_mixed directly — consider removing the shim.

Test gaps worth filling

  • No test exercises forward_mode="gumbel_softmax" — the buggy path above would never be caught by the current suite.
  • No test exercises rotation_trick=True (25-line Householder math), distance_type="cosine", or loss_type ∈ {"l1","cosine"}.
  • No multi-process test covers GatherLayer / _all_gather_with_grad / rank-offset labels in MaskedCLIPLoss — gloo @ world_size=2 on CPU would suffice.
  • _kmeans_plus_plus non-determinism: no seed plumbed, so two identical-seed runs of the same model don't produce identical codebooks.

Documentation nits

  • RQVAE docstring drops 'l1' from commitment_loss options.
  • forward_mixed's "recon rows == fea1" annotation is not enforced and not used.
  • _predict_mixed docstring omits the _is_inference fast-path.

Nice touches worth calling out: KMeansLayer._load_from_state_dict mid-fit poisoning guard, int32 (not bool) NCCL collectives in on_train_end, empty-rank deadlock avoidance via all_reduce(MAX) on the empty-flag, and the rank-0 fit + broadcast pattern in SidRqkmeans — the same pattern that would fix the DDP kmeans-init issue above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants