[feat] SID generation — RQ-VAE and RQ-KMeans models (FAISS-trained)#517
[feat] SID generation — RQ-VAE and RQ-KMeans models (FAISS-trained)#517WhiteSwan1 wants to merge 41 commits into
Conversation
The online Mini-Batch K-Means backend is removed. ResidualKMeans /
RQKMeans / SidRqkmeans now always train the codebook once via FAISS
at end-of-loop through `flush_offline_fit`; `forward` is read-only
(assignment + centroid lookup).
- MiniBatchKMeans: stripped to a centroid container (load_centroids_
+ predict). train_step, cluster_counts, init_buffer, offline-lock,
unlock_for_online_finetune_ are gone.
- ResidualKMeans / RQKMeans: drop the train_mode and init_buffer_size
parameters; train_offline is the only training entry point.
- SidRqkmeans: predict always buffers during training and runs
assignment during eval; legacy `train_mode` proto field is silently
accepted (default "online") but a warning fires if it is explicitly
set to a non-FAISS value.
- Tests trimmed to FAISS-only paths; smoke test
ft_scripts/maxcompute_smoke_test.py runs RQKMeans + RQVAE over
forge_contrastive_item_embedding.csv. FAISS codes and RQ-VAE
training trajectory match the pre-refactor code_style baseline
bit-for-bit; full suite is 25/25 green.
Diff stat: -577 / +135 lines.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- SidRqkmeans proto: remove the legacy ``init_buffer_size`` (tag 5)
and ``train_mode`` (tag 6) fields. Their numbers and names are
``reserved`` so they can never be silently reused by future fields.
Regenerate sid_model_pb2.{py,pyi} with scripts/gen_proto.sh (or
``python -m grpc_tools.protoc -I . tzrec/protos/models/sid_model.proto
--python_out=. --pyi_out=.``).
- SidRqkmeans model: drop the deprecation guard now that the field is
gone; drop the corresponding legacy test.
- Rename ``MiniBatchKMeans`` → ``KMeansLayer``. After dropping the
online path the class is a per-layer centroid container (centroids
buffer + load_centroids_ + predict + an is_initialized guard), and
the old name was misleading. PyTorch state-dict keys are scoped by
attribute path (``layers.<i>.centroids``), so this rename does not
affect checkpoint compatibility.
Verified: 24/24 unit tests pass; the FAISS smoke test on
forge_contrastive_item_embedding.csv produces bit-identical codes and
RQ-VAE losses vs the pre-rename baseline.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rgs to tag 5
The reservations were defensive against textproto / binary-proto mixups
on the dropped ``init_buffer_size`` (tag 5) and ``train_mode`` (tag 6).
tzrec configs are textproto-only, which key fields by name, not tag —
old configs already fail to parse as "unknown field" regardless of the
tag table — so the reservation block buys no real safety here. Drop it
and tighten the schema to a contiguous 1/3/4/5/40 tag set.
Regenerate sid_model_pb2.{py,pyi} via scripts/gen_proto.sh.
Verified: 24/24 unit tests pass; FAISS smoke test on
forge_contrastive_item_embedding.csv produces bit-identical codes and
RQ-VAE losses; both demo .config files parse cleanly.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move ``_kmeans`` and ``_residual_kmeans`` (Lloyd's K-Means + residual
variant used by ResidualQuantized.init_embed_) from residual_quantized.py
into kmeans.py, next to ``_kmeans_plus_plus`` / ``_squared_euclidean_distance``
and the ``KMeansLayer`` centroid container. After this move:
- kmeans.py is the single home for torch-native K-Means code in the
SID-generation stack — distance helper, KMeans++ seeding, Lloyd's
iterator, residual variant, and the per-layer centroid container.
- residual_quantized.py is a quantizer module; it now imports the
single function it needs (``_residual_kmeans``) instead of pulling
in two private helpers and re-implementing Lloyd's on top.
- No functional change. RQ-VAE warm-start remains pure torch / GPU /
dependency-free — FAISS was not introduced here on purpose: the
init runs once on ~2k × 64 encoder outputs, where FAISS would add
a CPU round-trip + a hard dep for no measurable win.
Verified: 24/24 unit tests pass; FAISS smoke test on
forge_contrastive_item_embedding.csv produces bit-identical codes and
RQ-VAE losses.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a no-op ``on_train_end`` method on ``BaseModel`` and override it in ``SidRqkmeans`` (renamed from ``flush_offline_fit``). Subclasses that need end-of-loop work override it; the rest get the no-op default. ``tzrec.main.train_and_evaluate`` now calls ``_model.on_train_end()`` unconditionally — no more ``hasattr`` duck-typing of a SID-specific method on every model. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
``reference_code`` was a teacher-codebook-injection hook ported from
upstream al_sid (al_sid/SID_generation/rqvae_embed/{rqvae,quantizations}.py).
It was a defensive pass-through in both projects: the parameter
threads ``RQVAE → ResidualQuantized → VectorQuantize`` and, when set,
overrides ~4 % of distance-assigned ids with the supplied teacher
codes (vector_quantize.py:407-411). Nothing in tzrec or al_sid ever
supplies a non-None value — no caller, no config knob, no test.
Drop:
- the ``reference_code`` kwarg from ``VectorQuantize.forward`` and
its Step 3 random-replacement block;
- the ``reference_code`` kwarg from ``ResidualQuantized.forward``
and the per-layer slicing that fed it into VectorQuantize;
- docstrings + step numbering updated accordingly.
Verified: 11/11 surviving SID model tests pass; FAISS smoke test on
forge_contrastive_item_embedding.csv produces bit-identical codes and
RQ-VAE losses.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two-part change driven by an end-to-end lifecycle audit (see new
ft_scripts/lifecycle_smoke.py).
Comment pass:
- sid_rqkmeans.py / sid_rqvae.py / rqvae.py / residual_quantized.py /
residual_kmeans.py: drop comments that only narrate the next line
("# Eval metrics", "# Reconstruction MSE", "# Unique SID ratio",
"# Parse n_embed list", etc.). Keep every comment whose value is
*why* (gradient / dtype / memory rationale, DDP invariants, paper
references, step numbering tied to docstrings).
- rqvae.py: drop the duplicate ``self.use_clip = use_clip`` set inside
the ``if use_clip:`` block — the same assignment runs unconditionally
a few lines earlier.
SidRqkmeans bug fix:
- During training, predict() returns dummy zero codes (the codebook
does not exist until on_train_end() fits FAISS). The previous
init_metric / update_train_metric registered + updated mse,
rel_loss, unique_sid_ratio for the train path, which produced
``MeanMetric.compute() == nan`` for mse / rel_loss (never updated)
and a trivially constant ``1/B`` for unique_sid_ratio.
- Fix: drop the train-metric registrations and turn
update_train_metric into a documented no-op. compute_train_metric
now returns an empty dict, which ``tzrec.main`` already tolerates
via ``if train_metrics:``.
Lifecycle driver:
- ft_scripts/lifecycle_smoke.py exercises construct → init_loss →
init_metric → train (predict/loss/backward/update_train_metric) →
on_train_end → eval (predict/loss/update_metric/compute_metric) →
set_is_inference(True) → infer for SidRqkmeans, SidRqvae(no-clip),
SidRqvae(clip). Asserts shapes, codes ranges, and that no metric
comes back non-finite. All three pass.
Verified:
- 11/11 existing unit tests pass.
- lifecycle_smoke.py: ALL LIFECYCLES OK. RQ-VAE (no-clip) loss
1.63 -> 1.49 over 8 Adam steps; RQ-VAE (clip) 15.17 -> 9.33.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The scripts under ft_scripts/ are personal smoke / one-off experiment drivers that have always been treated as local-only by everyone except the three files that happened to get committed along the way (lifecycle_smoke.py, maxcompute_smoke_test.py, sid_rqvae_recon.config). Stop tracking those three and add ft_scripts/ to .gitignore so future edits stay local. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Translate the Chinese comments in EmaConfig, SinkhornConfig, ClipConfig,
and SidRqvae (network structure / quantization strategy / optional
sub-module sections) to English. No semantic changes; defaults, field
numbers, and field types are untouched.
Regenerate sid_model_pb2.{py,pyi} via scripts/gen_proto.sh (gitignored).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ft_scripts/ configs use neither EMA (no ``ema_config`` block) nor the
``restart_unused_codes`` knob, and the fused ``ema_mask`` pass-through
chain has no consumer. Strip the whole feature surface end-to-end.
Proto (sid_model.proto, regen pb2 via grpc_tools.protoc):
- drop ``message EmaConfig`` and ``SidRqkmeans.ema_config = 14``.
Field 14 is left unallocated (no reserved block, consistent with
the previous proto tightening).
Code:
- SidRqvae: drop the EMA sub-config parse + the three kwargs
(``use_ema``, ``ema_decay``, ``restart_unused_codes``) it forwarded.
- RQVAE: drop the same three ctor args + docs + propagation.
- ResidualQuantized: drop the three ctor args, ``self.use_ema``,
the ``ema_mask`` parameter on ``forward``, and the
EMA-zeros-loss2 branch in ``_single_commitment_loss``. The
commitment loss now always sums both directions (mathematically
equivalent to running today with ``use_ema=False`` explicitly).
- VectorQuantize: drop ``use_ema`` / ``ema_decay`` /
``restart_unused_codes`` / ``ema_mask``, the
``cluster_size_ema`` / ``embed_ema`` buffers,
``_update_ema_buffers``, ``_update_embedding_from_ema``,
``_tile_with_noise``, and the dead-code-restart machinery.
Commitment loss collapses to the single
``commitment_weight * e_latent_loss + q_latent_loss`` formula.
Side effects worth flagging:
- ``ema_mask`` was the mechanism by which mixed-mode CLIP path-2
skipped EMA updates for recon rows. Without EMA those rows now
contribute to the commitment-loss gradient through both paths (the
recon-row contribution shows up in both quant1 and quant2 because
fea2 == fea1 there). The masked CLIP loss itself is unchanged.
- Old SidRqvae checkpoints carrying ``cluster_size_ema`` /
``embed_ema`` buffers will need ``strict=False`` to load.
Verified: 11/11 unit tests pass; lifecycle_smoke.py — all three
scenarios pass; RQ-VAE no-clip 1.44 -> 1.34, RQ-VAE clip 13.14 -> 11.20
over 8 Adam steps.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… loop index
- sid_rqvae_test.py: drop the ``losses =`` assignments in the
all-recon / all-clip CLIP edge-case tests (the predictions dict
carries the asserted values; loss() is kept as a bare call to
ensure it doesn't raise).
- clip_loss.py::GatherLayer.forward / backward: add one-line
docstrings.
- residual_quantized.py::ResidualQuantized.forward: drop the unused
``i`` enumerate index in the per-layer loop.
Verified: 11/11 SID unit tests pass; lifecycle_smoke.py green for all
three scenarios (RQKMeans, RQVAE no-clip, RQVAE CLIP).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
torch.unique(codes, dim=0).shape[0] forces a GPU->host sync every training step because reading .shape[0] of a variable-length unique output materializes it. unique_sid_ratio is a codebook-coverage diagnostic, not a training signal — keep it on the eval path only, where it runs at lower frequency and the cost is amortized at metric flush. Train path now logs only mse; eval path is unchanged. (SidRqkmeans's train-path update_train_metric is already a no-op, so this commit only touches SidRqvae.) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t is_clip_pair flag Forward used to route rows via ``torch.all(embedding == fea2, dim=-1)`` — a bit-exact equality check on float tensors that silently mislabels samples on any upstream float cast, normalization, or numerical noise. Introduce a required per-sample boolean column (``is_clip_pair``, value_dim=1) emitted by the FG layer. The model now reads the flag directly and routes via ``flag > 0.5``. No back-compat fallback: the column is required in clip mode; missing it raises cleanly. Proto: ``ClipConfig.is_clip_pair_feature_name`` (required string). Existing clip-mode tests are updated to populate the flag in the batch (no longer relying on bit-exact identity). One new test ``test_clip_mask_uses_flag_not_equality`` constructs a batch where ``image_emb == item_emb`` numerically but ``is_clip_pair=1`` — proves the row routes to CLIP under the new logic (it would have been silently relabeled recon under the old logic). The ft_scripts/ working-area configs are updated locally (untracked) to add the column and wire the field. Downstream pipelines must add a matching raw_feature mapping (e.g. ``expression: "item:is_contrastive"`` when the source dataset uses that column name). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Five new test docstrings tripped the D205 "blank line between summary and description" check; restructure each to have a one-line summary, blank line, then body. ruff-format also wanted two long register_buffer / _extract_feature calls collapsed to a single line (removed unnecessary wrapping). No behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| # later resume-then-infer would silently return dummy zero codes | ||
| # instead of raising. Each load must re-run ``load_centroids_`` | ||
| # (or the FAISS fit) to set the flag. | ||
| self.register_buffer("_is_initialized", torch.tensor(False), persistent=False) |
There was a problem hiding this comment.
Critical — resume-then-infer silently returns zero codes.
_is_initialized is persistent=False, so on checkpoint resume the buffer is reset to False even though centroids was restored. ResidualKMeans.forward then takes the else branch and returns torch.zeros(...) for every layer (residual_kmeans.py:135-138) — no exception, no warning. The comment above accurately predicts this scenario ("would silently return dummy zero codes") but the code does not raise either; it just emits zeros.
There is no set_extra_state/on_load_state_dict hook that re-flips this flag after a load, so users running evaluate/predict against a saved RQKMeans checkpoint get garbage SIDs.
Suggested fix: derive initialization from centroids.any() at use sites, or raise in forward/get_codes when not is_initialized rather than silently returning zeros. Otherwise resume-then-eval is broken.
| optional bool normalize_residuals = 7 [default = false]; | ||
| // Distance metric: "l2" or "cosine". | ||
| optional string distance_type = 9 [default = "l2"]; | ||
| // Commitment loss type: "l2" or "cos". |
There was a problem hiding this comment.
ResidualQuantized.__init__ asserts commitment_loss in ("l2", "l1", "cos") (residual_quantized.py:90) and has a dedicated l1 branch in _single_commitment_loss. The proto comment, the RQVAE.__init__ docstring (rqvae.py:52-53), and the in-code "flow step 3" comment (residual_quantized.py:297) all advertise only "l2"/"cos". l1 is a fully reachable, supported value — please document it everywhere or remove it from the assert.
| // Commitment loss type: "l2" or "cos". | |
| // Commitment loss type: "l2", "l1", or "cos". |
| # Only forward latent_weight when proto sets it; otherwise let | ||
| # RQVAE / ResidualQuantized apply their signature default (1.0, 0.5). | ||
| rqvae_extra: Dict[str, Any] = {} | ||
| if cfg.latent_weight: | ||
| rqvae_extra["latent_weight"] = parse_float_list(cfg.latent_weight) |
There was a problem hiding this comment.
The proto declares optional string latent_weight = 11 [default = "1.0,0.5"] (sid_model.proto:54). In proto2 Python, accessing an unset field with a default returns the default literal — so cfg.latent_weight is always the truthy string "1.0,0.5" and the if cfg.latent_weight: guard is always True. The "let RQVAE / ResidualQuantized apply their signature default" branch is unreachable.
Not a behavioral bug (both paths yield (1.0, 0.5)), but the comment misrepresents control flow. If the intent really is "only override when user set it explicitly", drop the proto default and use cfg.HasField("latent_weight"). Otherwise simplify the comment and just always forward the value.
| def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: | ||
| """Dispatch based on use_clip. | ||
|
|
||
| use_clip=False: forward(x) -> forward_rqvae(x) | ||
| use_clip=True: forward(fea1, fea2, clip_mask) -> forward_mixed(...) | ||
| """ | ||
| if self._is_inference or not self.use_clip: | ||
| assert len(args) >= 1, "Standard mode requires (x,)" | ||
| return self.forward_rqvae(args[0], **kwargs) | ||
| else: | ||
| assert len(args) == 3, "Mixed mode requires (fea1, fea2, clip_mask)" | ||
| return self.forward_mixed(args[0], args[1], args[2], **kwargs) |
There was a problem hiding this comment.
self._is_inference is assigned False in __init__ (line 106) and never written anywhere else — there is no setter and RQVAE subclasses nn.Module directly, not BaseModule. The dispatch reduces to if not self.use_clip:. SidRqvae already calls forward_rqvae/forward_mixed directly and doesn't reach into RQVAE._is_inference, so it works "by accident".
Either drop the dead _is_inference flag and the misleading dispatch branch, or actually plumb BaseModule.set_is_inference through to the inner RQVAE so the documented behavior is real.
| # DDP path: every rank ships its local buffer to rank 0 via | ||
| # gather_object (variable-length pickle — fine for this one- | ||
| # shot, CPU-resident gather). Only rank 0 holds the corpus, | ||
| # so peak memory is O(world_size) on rank 0 and O(1) elsewhere | ||
| # (vs O(world_size²) for all_gather_object). | ||
| local = torch.cat(self._offline_buffer, dim=0) | ||
| del self._offline_buffer | ||
| self._offline_buffer = [] | ||
|
|
||
| rank = dist.get_rank() | ||
| gathered: Optional[List[Optional[torch.Tensor]]] = ( | ||
| [None] * dist.get_world_size() if rank == 0 else None | ||
| ) | ||
| dist.gather_object(local, gathered, dst=0) |
There was a problem hiding this comment.
Two concerns with the gather:
-
NCCL incompatibility.
dist.gather_objectrequires a backend that supports object collectives. Production GPU training typically initialises a single NCCL group andgather_objectraisesRuntimeError: ProcessGroupNCCL does not support gather_object(PyTorch documents this). Consider initialising a side gloo group at module construction and passing it asgroup=, or replace with a tensor-baseddist.gatherafter padding to a uniform per-rank shape. -
Memory claim is off. The comment says peak rank-0 memory is
O(world_size).gather_objectpickles every rank's tensor into a Pythonbytesthen deserialises intogathered, so rank 0 holds allworld_sizeper-rank tensors plus the pickle bytestream simultaneously — closer to ~3× total corpus at peak. For large datasets this is a real OOM risk. Either chunk the gather (every rank ships fixed-size chunks in a loop) or write embeddings to per-rank shard files duringpredictand have rank 0 stream them.
| def _coerce_proto_numbers(d: Dict) -> Dict: | ||
| """Coerce float-typed integers back to int. | ||
|
|
||
| ``google.protobuf.Struct.number_value`` is always float, but most | ||
| ``faiss.Kmeans`` kwargs (``niter``, ``seed``, ``nredo``, ...) require | ||
| Python ``int``. This helper converts any float that is an exact | ||
| integer to ``int`` for downstream consumption. | ||
| """ | ||
| out: Dict = {} | ||
| for k, v in d.items(): | ||
| if isinstance(v, float) and v.is_integer(): | ||
| out[k] = int(v) | ||
| else: | ||
| out[k] = v | ||
| return out |
There was a problem hiding this comment.
The output is forwarded straight through to faiss.Kmeans(D, K, **self.faiss_kmeans_kwargs) (residual_kmeans.py:244) with no whitelist or schema validation. A misspelled or unsupported kwarg (e.g. niters instead of niter, or gpu=True on a faiss-cpu build) raises inside faiss.Kmeans.__init__ only on rank 0, after gather_object has already pulled the entire corpus to rank 0. The other ranks then block on the subsequent broadcast/barrier until the NCCL watchdog tears the job down — no actionable error.
Two-step fix: (a) whitelist known FAISS keys (niter, nredo, seed, verbose, spherical, gpu, min/max_points_per_centroid, ...) and reject unknowns at construction time so the failure is symmetric across ranks; (b) construct the faiss.Kmeans object early (in __init__ or right before the gather) so config errors fire before the expensive cross-rank gather.
| # Global batch size for distributed training | ||
| if is_distributed and dist.is_initialized(): | ||
| B = Q.size(1) * dist.get_world_size() | ||
| else: | ||
| B = Q.size(1) |
There was a problem hiding this comment.
B = Q.size(1) * dist.get_world_size() assumes uniform per-rank batch sizes. On the final training step (or any step with drop_last=False and uneven shards) this is wrong: each rank computes a different denominator at Q /= B (line 100) and a different scale at Q *= B (line 103). argmax is invariant to a positive per-row scalar so the final codes stay consistent — but the resulting Q values do not match across ranks, which biases any downstream code that relies on the assignment magnitudes.
Suggested fix: all_reduce the local Q.size(1) once at the top to get the true global B (one extra collective is negligible vs the ones already inside the loop). Also, dist.is_initialized() alone isn't sufficient — guard on dist.get_world_size() > 1 to avoid no-op collectives in single-process toy runs.
| # Training: buffer for the end-of-loop FAISS fit and return dummy | ||
| # codes — the codebook does not exist yet. | ||
| if self.is_train: | ||
| self._offline_buffer.append(embedding.detach().cpu()) |
There was a problem hiding this comment.
Two issues that compound at scale:
-
Synchronous D2H every training step.
embedding.detach().cpu()is a sync copy on a non-pinned host buffer — it blocks the CUDA stream until the copy completes. Useembedding.detach().to('cpu', non_blocking=True)against a pinned host buffer if you want overlap. Per-step impact is small but adds up. -
Unbounded host-memory growth. The buffer grows linearly with
steps × batch_size × D × 4 bytesper rank, and rank 0 later holdsworld_size ×that during the gather (see the comment on the gather block). For a 100M-row dataset with D=512 you're looking at ~200 GB before the gather and ~world_size ×that on rank 0 — a near-certain OOM at production scale.
Consider streaming per-step embeddings to a pre-allocated np.memmap of known max size (or per-rank chunked .npy files) and mmap-reading them at FAISS fit time. For very large datasets, reservoir-sampling at the buffer step is usually sufficient for K-Means quality (~256 × K samples is plenty).
| # --- Safe labels: recon rows fallback to first clip column --- | ||
| labels = self.labels | ||
| fallback = clip_mask.long().argmax() # first clip sample index | ||
| safe_labels = torch.where(clip_mask, labels, fallback.expand_as(labels)) |
There was a problem hiding this comment.
fallback = clip_mask.long().argmax() is computed from the local mask but safe_labels indexes into the global gathered logits (offset local_batch_size * self._rank + arange). When this rank has no clip rows, argmax returns 0 → safe_labels points at local column 0 of the global logits, which may itself be a recon column (all -inf). cross_entropy then yields NaN, which the nan_to_num at line 134-135 swallows before the mask multiply.
The current code is safe only because nan_to_num runs before * clip_mask.float() — 0 × NaN = NaN, so the ordering is load-bearing. Any later "optimization" that swaps the order, or that removes the seemingly-defensive nan_to_num, silently poisons the loss.
Suggested fix: compute fallback from the gathered mask (clip_mask_all.long().argmax()) so safe_labels always points at a real clip column globally; then the nan_to_num becomes unnecessary and the contract is explicit. Also consider reading dist.get_rank() lazily in forward instead of caching at __init__ (line 104) — modules can be constructed before init_process_group.
Review summary — SID generation (RQ-VAE / RQ-KMeans)Substantial, well-organized PR. The DDP Cross-cutting issues to consider
Doc/comment drift
Nice work overall. The critical items are the checkpoint-resume gap, the |
…-time guard Commit 85932d8 made _is_initialized non-persistent to protect against checkpoints taken mid-FAISS-fit. That fix had a regression: a legitimate POST-fit checkpoint also loses the flag (centroids restore, but the flag defaults to False on load), so predict() then silently returns zero codes. The two states became indistinguishable to the loader. Make _is_initialized persistent again so normal post-fit save/load round-trips work. Catch the mid-fit-checkpoint case at load time via _load_from_state_dict: if the flag loaded as True but centroids sum to zero, the framework appends a clear error message (raised under strict=True load_state_dict). Two regression tests: - test_post_fit_checkpoint_round_trips: fit -> save -> load fresh -> predict produces non-zero codes. - test_mid_fit_checkpoint_rejected_on_load: tampered state with flag=True and zero centroids raises with the "mid-FAISS-fit" hint. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…oto doc - sid_model.proto: commitment_loss comment now lists "l1" (Commit 2 added the branch but missed the proto-side doc). - sid_rqkmeans.py: add TODO at the .cpu() per-step site documenting the D2H stall + unbounded buffer concern (Commit 4 was deliberately skipped; this records why so future readers don't repeat the audit). - sid_rqkmeans.py: collapse the 10-line DDP collective-check rationale to 3 lines. - sid_rqvae.py: drop redundant comments on _is_clip_pair_feature_name, Sinkhorn parse, and the clip_mask extraction — the proto/field name + code already convey the intent. - residual_quantized.py: drop the "mirrors l2 branch" comment on the L1 case; the branch reads obviously next to its siblings. No behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
| if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: | ||
| weights = _gumbel_softmax_sample( | ||
| -distances, temperature=temperature, hard=True | ||
| ) | ||
| emb = weights @ self.embedding.weight |
There was a problem hiding this comment.
Correctness — Gumbel-Softmax does not propagate gradient into the encoder.
_compute_distances (and _find_nearest_embedding) are decorated with @torch.no_grad(), so distances returned here is fully detached. In the Gumbel branch:
weights = _gumbel_softmax_sample(-distances, temperature=temperature, hard=True)
emb = weights @ self.embedding.weightweights is a function only of the detached distances, so gradient flows into self.embedding.weight but not back through x into the encoder. The STE branch happens to work because x + (q-x).detach() explicitly routes the gradient — Gumbel has no such shortcut.
Effect: choosing forward_mode="gumbel_softmax" silently trains only the codebook, not the encoder.
Fix: remove @torch.no_grad() from _compute_distances (the argmin/argmax step is naturally non-differentiable), or split out a _compute_distances_no_grad/_compute_distances_with_grad pair and use the differentiable one for the Gumbel logits.
| # Average per-layer centroids across DDP ranks so every rank | ||
| # starts from the same codebook. | ||
| if dist.is_initialized() and dist.get_world_size() > 1: | ||
| for c in centers: | ||
| dist.all_reduce(c, op=dist.ReduceOp.SUM) | ||
| c /= dist.get_world_size() |
There was a problem hiding this comment.
Correctness — averaging KMeans centroids across DDP ranks produces meaningless init.
_kmeans_plus_plus (kmeans.py:116–128) uses torch.randint/torch.multinomial with the per-rank default RNG, so each rank picks different seeds from its own local batch and runs Lloyd to different local minima. Cluster indices are unordered, so "centroid 0" on rank A and "centroid 0" on rank B correspond to different concepts. all_reduce(SUM)/world_size then averages permutation-misaligned centroids — the result is closer to noise than to a meaningful warm start.
The comment "every rank starts from the same codebook" is true but disguises that the codebook is junk.
Fix: run KMeans on rank 0 only, then dist.broadcast(c, src=0) — mirrors what SidRqkmeans.on_train_end already does for FAISS. Bonus: world_size× less compute.
| for i in range(1, n_clusters): | ||
| dists = _squared_euclidean_distance(data, centroids[:i]) # (N, i) | ||
| min_dists = dists.min(dim=1)[0] # (N,) | ||
| if min_dists.sum() == 0: | ||
| centroids[i:] = data[ | ||
| torch.randint(0, N, (n_clusters - i,), device=data.device) | ||
| ] | ||
| break | ||
| next_idx = torch.multinomial(min_dists, num_samples=1) | ||
| centroids[i] = data[next_idx] |
There was a problem hiding this comment.
Perf — KMeans++ is O(K²·N), should be O(K·N).
Each loop iteration calls _squared_euclidean_distance(data, centroids[:i]), recomputing distances to all previously selected centroids. Standard KMeans++ tracks a running min_dists and only computes the distance to the newly selected centroid:
min_dists = _squared_euclidean_distance(data, centroids[:1]).squeeze(1)
for i in range(1, n_clusters):
if min_dists.sum() == 0:
centroids[i:] = data[torch.randint(0, N, (n_clusters - i,), device=data.device)]
break
next_idx = torch.multinomial(min_dists, num_samples=1)
centroids[i] = data[next_idx]
new_dists = _squared_euclidean_distance(data, centroids[i:i+1]).squeeze(1)
min_dists = torch.minimum(min_dists, new_dists)For K=256, N=batch_size this is ~256× more work than necessary on the first training step. Combined with _kmeans having no convergence early-stop (always 100 iters), the first-step startup stall under kmeans_init=True is substantial.
Also: torch.randint/torch.multinomial here use the global default generator with no seed — KMeans++ is non-deterministic across runs even with torch.manual_seed set elsewhere. Consider accepting/threading a torch.Generator.
| @property | ||
| def is_initialized(self) -> bool: | ||
| """Whether centroids have been injected via ``load_centroids_``.""" | ||
| return self._is_initialized.item() |
There was a problem hiding this comment.
Perf — .item() on the hot path forces a GPU→host sync every forward.
is_initialized is read inside ResidualKMeans.forward (residual_kmeans.py:129) and all_initialized (residual_kmeans.py:101) on every eval batch, so each call serializes the CUDA stream. The flag is flipped exactly once (in load_centroids_ from on_train_end) — there's no reason to re-read the device tensor every step.
Fix: cache as a Python bool attribute, set in load_centroids_ and re-derived from the buffer in _load_from_state_dict. Keep _is_initialized as a buffer for checkpoint round-trip, but stop reading it on the forward path. For N_layers=3 this turns eval into 3 syncs per batch.
| local = torch.cat(self._offline_buffer, dim=0) | ||
| del self._offline_buffer | ||
| self._offline_buffer = [] | ||
|
|
||
| rank = dist.get_rank() | ||
| gathered: Optional[List[Optional[torch.Tensor]]] = ( | ||
| [None] * dist.get_world_size() if rank == 0 else None | ||
| ) | ||
| dist.gather_object(local, gathered, dst=0) | ||
| del local | ||
| if rank == 0: | ||
| assert gathered is not None | ||
| full = torch.cat([g for g in gathered if g is not None], dim=0) | ||
| del gathered | ||
| logger.info( | ||
| "[SidRqkmeans.on_train_end] rank0 fitting FAISS " | ||
| "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) | ||
| ) | ||
| self._rqkmeans.train_offline(full, verbose=True) |
There was a problem hiding this comment.
Scalability — dist.gather_object of the full per-rank buffer will OOM rank 0 on realistic corpora.
Each rank concats its host buffer (line 282), then ships via gather_object (which pickle-serializes the tensor, ~2× transient bytes), then rank 0 concats again. Peak rank-0 memory is roughly 2 × world_size × per_rank_corpus × D × 4B. For a 10M-row × 64-D corpus across 8 ranks that's ~20 GB on rank 0 alone — taking down DDP training right at the end, after hours of compute.
Combined with the per-step .cpu() D2H (line 137, already TODO'd), the FAISS-only path is the largest operational risk in the PR.
Fix options:
- Write per-rank buffers to a shared filesystem (
torch.save(buffer, f"{model_dir}/sid_embed_rank{rank}.pt")),dist.barrier(), then have rank 0 stream them into a pre-allocatednp.empty((total_N, D))(size known via anall_gatherof row counts first). - Reservoir-sample the per-rank buffer down to a configured cap (FAISS K-Means rarely benefits from > ~256 × K samples).
At minimum, expose a sample cap and add a guard that errors loudly when the gather would exceed available host memory.
| # Only forward latent_weight when proto sets it; otherwise let | ||
| # RQVAE / ResidualQuantized apply their signature default (1.0, 0.5). | ||
| rqvae_extra: Dict[str, Any] = {} | ||
| if cfg.latent_weight: | ||
| rqvae_extra["latent_weight"] = parse_float_list(cfg.latent_weight) |
There was a problem hiding this comment.
Comment contradicts behavior — if cfg.latent_weight: is always True.
The proto field has optional string latent_weight = 11 [default = "1.0,0.5"];, so cfg.latent_weight returns "1.0,0.5" (truthy) when unset and the user-provided string otherwise. The "let signature default apply" branch is unreachable; latent_weight is always forwarded.
If the intent is to detect "user explicitly set vs proto default", use cfg.HasField("latent_weight") — but that only works if you drop the proto-level default. Otherwise just always forward and delete the conditional + comment.
The same shape applies to commitment_loss (cfg defaults to "l2" per sid_model.proto:52) — the RQVAE.commitment_loss=None defaulting in rqvae.py:111 is dead from the proto-driven path.
| optional string hidden_dims = 3; | ||
| // Per-layer codebook size, comma-separated, e.g. "256,256,256". | ||
| // List length is the number of residual quantization layers; | ||
| // non-uniform codebooks such as "512,256,128" are supported. | ||
| optional string codebook = 5; | ||
|
|
||
| // === Quantization strategy === | ||
| // VQ forward mode: "ste" or "gumbel_softmax". | ||
| optional string forward_mode = 6 [default = "ste"]; | ||
| // L2-normalize residuals before each quantization layer. | ||
| optional bool normalize_residuals = 7 [default = false]; | ||
| // Distance metric: "l2" or "cosine". | ||
| optional string distance_type = 9 [default = "l2"]; | ||
| // Commitment loss type: "l2", "l1" or "cos". | ||
| optional string commitment_loss = 10 [default = "l2"]; | ||
| // Commitment loss weights [w1, w2], comma-separated. | ||
| optional string latent_weight = 11 [default = "1.0,0.5"]; |
There was a problem hiding this comment.
Stringly-typed list fields are inconsistent with the rest of tzrec.
hidden_dims, codebook, and latent_weight are optional string parsed at runtime via _sid_helpers.parse_int_list / parse_float_list. Everywhere else in tzrec these are repeated uint32 / repeated float (see optimizer.proto:226-228, tower.proto:158, most of feature.proto).
Cost of the string encoding:
- Defers all validation to model construction (
int(x.strip())raisesValueErroron a typo instead of failing at proto load). - Disables proto-level type checking, IDE completion, and
text_formatvalidation. - Requires the ad-hoc
_sid_helpers.pyshim. - The empty-string-truthy bug noted on
sid_rqvae.py:67-71would not exist with arepeatedfield (callers would always pass the list as-is).
Suggest repeated uint32 hidden_dims;, repeated uint32 codebook;, repeated float latent_weight;, then delete _sid_helpers.py.
Also: the comment "non-uniform codebooks such as 512,256,128 are supported" disagrees with ResidualKMeans.__init__ (residual_kmeans.py:82-85) which asserts uniformity for SidRqkmeans. Clarify the per-model rule.
| mse = F.mse_loss(predictions["x_hat"], embedding, reduction="mean") | ||
| self._metric_modules["mse"].update(mse) | ||
|
|
||
| unique_sids = torch.unique(codes, dim=0).shape[0] |
There was a problem hiding this comment.
Perf — torch.unique(..., dim=0).shape[0] plus int / int triggers a GPU→host sync per eval batch.
torch.unique with dim=0 does a CUDA sort + dedupe; shape[0] then forces a host sync, and unique_sids / B (Python int / int) produces a Python float, syncing again. The same pattern is in sid_rqkmeans.py:229-230.
Fix:
# pack codes per-row into a single int64 — unique is much faster on 1-D
mults = torch.tensor([K ** i for i in range(codes.shape[1])], device=codes.device)
packed = (codes.long() * mults).sum(dim=1)
unique_sids = torch.unique(packed).numel() # tensor-level until .update
self._metric_modules["unique_sid_ratio"].update(
torch.tensor(unique_sids, device=codes.device, dtype=torch.float) / B
)Or buffer packed codes across the eval epoch and run unique once at epoch end.
| normalize_residuals (bool): L2-normalize residuals. Default: False. | ||
| distance_type (str|List[str]): distance metric ('l2'|'cosine'). | ||
| Default: 'l2'. | ||
| commitment_loss (str|None): commitment loss type ('l2'|'cos'). |
There was a problem hiding this comment.
Docstring drops 'l1'. ResidualQuantized asserts commitment_loss in ("l2", "l1", "cos") (residual_quantized.py:90), and sid_model.proto:51 also documents the three options.
| commitment_loss (str|None): commitment loss type ('l2'|'cos'). | |
| commitment_loss (str|None): commitment loss type ('l2'|'l1'|'cos'). |
| if lr.by_epoch: | ||
| lr.step() | ||
|
|
||
| _model.on_train_end() |
There was a problem hiding this comment.
Lifecycle ordering — on_train_end runs before the tail-save, but the in-loop checkpoint at last_ckpt_step was taken before the FAISS fit.
For SidRqkmeans, if the last in-loop save coincides with the final training step (last_ckpt_step == i_step), the tail-save block at lines 529-548 is skipped and the only persisted artifact has _is_initialized=False with zero centroids. Reloading that checkpoint silently yields a model that returns dummy-zero codes — KMeansLayer._load_from_state_dict's mid-fit guard does not fire because the flag is False (which reads as "uninitialized" rather than corrupted).
Either:
- Always force a final save after
on_train_end()(independent oflast_ckpt_step), or - Have
SidRqkmeans.on_train_end()overwrite the most recent checkpoint after the FAISS fit (only rank 0), or - Refuse
state_dict()serialization onKMeansLayeruntil_is_initialized=True.
Review summaryBig, well-scoped PR. Models and module stack are cleanly separated, tests cover the happy paths for both Correctness (should block merge)
Performance / scalability
Code quality / API consistency
Test gaps worth filling
Documentation nits
Nice touches worth calling out: |
Summary
Introduces two new SID-generation models to
tzrecplus the supportingtzrec.modules.sid_generationmodule stack:SidRqvae— end-to-end RQ-VAE (Encoder + Residual VQ + Decoder), trainable via STE or Gumbel-Softmax, with optional CLIP contrastive learning for paired multimodal inputs.SidRqkmeans— encoder-free residual K-Means, trained in one shot by FAISS at the end of the train_eval loop ( gradient-free).Both produce N-layer semantic IDs
(code_0, …, code_{n_layers-1})that downstream generative-retrieval models can consume.What's in scope
Models (
tzrec/models/):sid_rqvae.py+sid_rqvae_test.py—SidRqvae(BaseModel), no-clip + mixed clip+recon paths.sid_rqkmeans.py+sid_rqkmeans_test.py—SidRqkmeans(BaseModel), FAISS-only training._sid_helpers.py— sharedparse_int_list/parse_float_listfor SID proto fields.model.py— addsBaseModel.on_train_end()no-op lifecycle hook (overridden bySidRqkmeansfor the FAISS fit).Modules (
tzrec/modules/sid_generation/):rqvae.py—RQVAE(encoder/decoder MLP + quantizer), supportsuse_clipmixed-mode.residual_quantized.py— gradient-trained multi-layer residual VQ used by RQ-VAE.residual_kmeans.py—ResidualKMeans+RQKMeanswrapper (FAISS-trained centroid store).kmeans.py—KMeansLayercentroid container + KMeans / KMeans++ utilities.vector_quantize.py— single VQ layer with Sinkhorn uniform assignment, STE / Gumbel-Softmax forward.clip_loss.py—CLIPLossandMaskedCLIPLoss(gradient-preservingGatherLayerfor DDP all-gather, three-way contrastive loss).types.py—QuantizeForwardMode,QuantizeOutput,ResidualQuantizedOutput.Protos (
tzrec/protos/models/sid_model.proto+tzrec/protos/model.proto):SidRqvaewith sub-messagesSinkhornConfigandClipConfig.SidRqkmeans(codebook + normalize_residuals +faiss_kmeans_kwargsStruct).ModelConfig.Wiring (
tzrec/main.py):_model.on_train_end()is called once aftertrain_and_evaluateexits the loop. No-op for every model exceptSidRqkmeans, which uses the hook to fit FAISS over the buffered embeddings.Test plan
python -m unittest tzrec.models.sid_rqkmeans_test tzrec.models.sid_rqvae_test→ 11/11 pass (5 RQKMeans, 6 RQVAE; covers train / eval / inference modes, CLIP all-recon and all-clip edge cases, backward pass).python ft_scripts/lifecycle_smoke.pydrives construct → init_loss → init_metric → train (predict/loss/backward/update_train_metric) →on_train_end→ eval (predict/loss/update_metric/compute_metric) →set_is_inference(True)→ infer, forSidRqkmeans,SidRqvaeno-clip, andSidRqvaeCLIP. All three pass.ModelConfig'soneof.