Add TCM (CVPR 2023) + CCA (NeurIPS 2024) with Family-1 codec refactor#355
Open
Yiozolm wants to merge 8 commits into
Open
Add TCM (CVPR 2023) + CCA (NeurIPS 2024) with Family-1 codec refactor#355Yiozolm wants to merge 8 commits into
Yiozolm wants to merge 8 commits into
Conversation
…decs
Lay the groundwork for refactoring channel-slice models (STF, WACNN, TCM,
CCA) toward an ELIC-style containerized layout where the model owns only
g_a/g_s and the latent_codec owns the entire entropy stack (h_a, h_s, z
bottleneck, per-slice channel context, per-slice leaves).
Codec primitives (compressai/latent_codecs):
- DualHyperSynthesis(h_mean_s, h_scale_s) adapter so HyperpriorLatentCodec
can wrap two parallel hyper-synthesis heads while keeping their
state_dict paths split.
- LRPGaussianLatentCodec(GaussianConditionalLatentCodec): subclass adding
the lrp_scale * tanh(lrp_transform(cat(ctx_params, y_hat))) refinement
used by Zhu2022 and follow-ups; lives in the same file as its base.
- ChannelGroupsLatentCodec gains optional max_support_slices and
support_filter parameters (defaults -1 / None preserve ELIC's
use-all-prior behaviour). Enables STF/TCM-style support clamping and
CCA-aux skip-most-recent selection without sibling codec classes.
- _slice_helpers hosts make_entropy_transform, slice_support_channels,
lrp_support_channels, infer_num_slices, infer_max_support_slices, with
default state_dict prefixes pointing at the post-refactor layout
(latent_codec.latent_codec.y.channel_context.yK.mean_cc.*).
Application-layer helpers (compressai/models/_helpers):
- build_channel_slice_codec: per-slice factory that hides the
y0..yK-1 dictionary boilerplate when assembling
ChannelGroupsLatentCodec.
- MeanScaleContextHead + build_mean_scale_head: independent mean_cc and
scale_cc Sequentials with optional per-path support transforms (for
SWAtten / NAFTransform).
No model wiring is changed in this commit; STF/WACNN continue to use the
existing _bases/slice_entropy + ChannelSliceLatentCodec path. Coverage:
28 new unit tests across tests/test_latent_codecs.py and
tests/test_models_helpers.py; tests/test_models.py::TestStf still passes;
importing compressai / compressai.zoo / compressai.latent_codecs introduces
no new dependencies.
…inerized codec Replace the SliceEntropyCompressionModel base + monolithic ChannelSliceLatentCodec wiring with an ELIC-style HyperpriorLatentCodec that owns h_a, h_s, the z entropy bottleneck, and the per-slice channel context. Both WACNN and SymmetricalTransFormer now inherit CompressionModel directly with 5-line forward / compress / decompress methods that delegate to self.latent_codec. Supporting infra extensions (built on the prior containerized scaffolding commit 51f536f): - ChannelGroupsLatentCodec gains side_in_context: bool. Off (default) is the existing ELIC behaviour. On, _get_ctx_params: (a) routes side_params through channel_context.y0 instead of returning it raw, (b) for k>=1 feeds cat(side_params, prev_y_hat) to channel_context.y_k, (c) skips the trailing cat with side_params at the leaf level. Family 1 models opt in. - MeanScaleContextHead gains side_split (split leading 2*side_split into latent_means / latent_scales and route to mean_cc / scale_cc separately) and emit_mean_support (append cat(latent_means, prev_y_hat) to the output so downstream LRP can recover the upstream input layout). - LRPGaussianLatentCodec gains mean_support_trail_channels: when > 0, the leaf splits ctx_params into [gaussian_params, mean_support] and feeds the trailing block to the LRP transform. This restores the upstream cat(latent_means, *prev_y_hat, y_hat) LRP input shape, so the per-slice lrp_transforms.{k} weights from upstream Zou et al. checkpoints transfer byte-for-byte. - build_channel_slice_codec accepts side_in_context + side_channels and builds the y0 channel_context entry on opt-in. - _slice_helpers.infer_num_slices auto-detects whether y0 is present in state_dict and adjusts the count accordingly. Upstream checkpoint converter (convert_upstream_stf_state_dict): - Strips DataParallel module. prefix. - Re-roots cc_mean_transforms / cc_scale_transforms / lrp_transforms / gaussian_conditional / entropy_bottleneck / h_a / h_mean_s / h_scale_s under their new latent_codec.* paths. - Replicates the single shared gaussian_conditional buffer set into per-slice leaves (driven by the discovered slice count). - Nests upstream conv_b.<i>.attn.{qkv, proj, relative_position_*} keys under the WMSA wrapper level (.attn.attn.<x>) via _nest_winmsa_keys, so WindowAttention parameters land on the right submodule. Verified end-to-end with the upstream Zou et al. checkpoints candidate/cnn_0018_best.pth.tar (WACNN) and candidate/stf_0018_best.pth.tar (SymmetricalTransFormer): both WACNN.from_state_dict and SymmetricalTransFormer.from_state_dict succeed under strict loading and forward pass. State-dict round-trip is exercised by an updated tests/test_models.py::TestStf with self-checks on the new key paths.
Migrate TCM (Liu et al., CVPR 2023) to the H+G containerized entropy stack used by STF / WACNN. The hyperprior backbone (h_a, h_mean_s, h_scale_s, EntropyBottleneck) plus the per-slice channel-conditional entropy heads now live under a single HyperpriorLatentCodec — matching ELIC-pattern wiring with three Family 1 specializations: - DualHyperSynthesis(h_mean_s, h_scale_s) cats the two parallel hyper-synthesis outputs into side_params of width 2*M. - ChannelGroupsLatentCodec(side_in_context=True) routes side_params into every channel_context head (incl. y0). - LRPGaussianLatentCodec(mean_support_trail_channels=...) plus the MeanScaleContextHead(emit_mean_support=True) recover the upstream cat(latent_means, *prev_y_hat, y_hat) LRP layout for byte-for-byte weight transfer from upstream LIC_TCM checkpoints. TCM-specific kwargs vs STF/WACNN: 3-conv widths=(224, 128) plus support_transform_factory=SWAtten (independent windowed-attention per mean / scale path), mirroring upstream atten_mean[k] / atten_scale[k]. use_cca / use_auxt are intentionally not implemented in this PR: - use_cca will be re-added once Phase 5 lands the containerized _CCAAuxEntropyModel. - use_auxt (AuxT, ICLR 2025) depends on layers (WLS / iWLS / OLP) not in this branch and is out of scope. Verified on the LIC_TCM 0.05.pth.tar (N=64) and mse_lambda_0.05.pth.tar (N=128) candidate checkpoints: strict load succeeds, sinusoidal smoke test reaches PSNR 39.15 dB / 39.41 dB respectively (vs 5.41 dB for a fresh-init model), confirming all weights — including LRP — transfer byte-for-byte through convert_upstream_tcm_state_dict. Also: append a Family 1 wiring comment block to compressai/latent_codecs/__init__.py so reviewers can see how the ELIC-style upstream codecs compose into the STF / WACNN / TCM / CCA / DCAE / MambaVC pattern without reading model source. Tests: - tests/test_models.py::TestTcm::test_tcm_forward_and_state_dict_round_trip — forward + new state_dict path self-check + round-trip allclose. - tests/test_models.py::TestTcm::test_tcm_upstream_state_dict_conversion — synthetic upstream LIC_TCM-style state_dict, asserts MSA buffer reshape + per-slice / SWAtten-wrapper / hyperprior re-rooting. make static-analysis clean. pytest tests/test_models.py tests/test_latent_codecs.py tests/test_models_helpers.py tests/test_layers.py tests/test_init.py: 71/71. import compressai / compressai.zoo / compressai.latent_codecs still trigger zero timm imports (TCM follows the STF lazy-load convention — not re-exported from compressai.models.__init__).
Add the Causal Context Adjustment (CCA) standalone autoencoder from Han et al., NeurIPS 2024 (https://arxiv.org/abs/2410.04847) using the H+G containerized entropy stack already adopted by STF / WACNN / TCM. The hyperprior backbone (h_a, h_mean_s, h_scale_s, EntropyBottleneck) plus the per-slice channel-conditional heads live under a single HyperpriorLatentCodec — same Family 1 pattern, with two CCA-specific specializations: - Variable-length channel slices (slice_proportions defaults to the upstream M=320 layout (8, 28, 56, 92, 136)) instead of the equal slices used by STF / WACNN / TCM. - Per-slice NAFTransform mean / scale support transforms (analogous to TCM's SWAtten), wired via build_mean_scale_head with support_transform_factory. Auxiliary CCA branch (cca_training=True) adds a _CCAAuxEntropyModel field that re-encodes y with skip-most-recent support selection (support_filter=lambda k, prior: prior[: max(k - 1, 0)]) and produces y_aux / y_cca likelihoods consumed by CCARateDistortionLoss. All slices use LRPGaussianLatentCodec — the upstream published checkpoints carry LRP weights for every slice; the unused last-two-slice LRPs are benign because support_filter excludes those slices' y_hat from any later slice's prior. Three small infra additions to support the wiring above: - MeanScaleContextHead.emit_mean_support extended to bool|"pre"|"post". CCA's upstream LRP heads consume the *post*-NAFTransform mean_support, while STF / TCM use the raw pre layout (Identity transform makes pre/post equivalent for them, so True still maps to "pre" for back-compat). - build_channel_slice_codec.support_count_fn lets the caller declare the prior-slice count seen by each channel_context head when a custom support_filter selects a non-default count (CCA-aux's skip-most-recent: lambda k: max(k - 1, 0)). - ChannelGroupsLatentCodec._get_ctx_params_side_in_context falls back to side_params alone when support_filter returns an empty list (CCA-aux at k=1), avoiding torch.cat() on an empty list. CCAModel.forward in cca_training=True replays the hyperprior path to recover latent_means / latent_scales for the aux branch instead of piercing the HyperpriorLatentCodec abstraction — small cost, zero interface churn for the main codec. Verified on candidate/CCA/checkpoint_lambda_0.3.pth.tar (M=320, slice_sizes=[8, 28, 56, 92, 136], em_hidden=224, em_layers=4, cca_training=True; 97M params): strict-loads after convert_upstream_cca_state_dict, sinusoidal smoke yields PSNR 50.07 dB / total bpp 0.072 (vs ~5 dB for a fresh-init model), confirming all weights — including LRP and aux — transfer byte-for-byte. Upstream → compressai key count delta of +56 is explained by replicating the single shared gaussian_conditional buffer set across each per-slice leaf (7 buffers × 4 extra copies per branch × 2 branches = 56), matching the channel_context.y{k} layout. Tests: - tests/test_models.py::TestCca::test_cca_forward_and_state_dict_round_trip - tests/test_models.py::TestCca::test_cca_training_branch_forward_and_round_trip - tests/test_models.py::TestCca::test_cca_upstream_state_dict_conversion Also drop a few stray internal phase-tracking labels from existing docstrings (stf.py / tcm.py / channel_context.py / test_latent_codecs.py / test_models.py) so the comments describe the wiring directly rather than referencing project-internal sequencing.
…tropyCompressionModel Family 1 models (STF, WACNN, TCM, CCA) all migrated to ChannelGroupsLatentCodec + _slice_helpers, leaving these two scaffolding modules with no callers in production code or tests. Delete: - compressai/latent_codecs/channel_slice.py - compressai/models/_bases/ (slice_entropy.py + __init__.py; whole dir is empty) - corresponding exports from latent_codecs/__init__.py
Register tcm/cca in image_models and model_architectures via _LazyImport so import compressai.zoo stays timm-free; add tcm()/cca() factory functions mirroring stf()/stf_wacnn() (pretrained=True raises until weights are hosted).
Contributor
Author
…ffer for 2D-shape leaves ChannelGroupsLatentCodec.decompress reconstructed the destination buffer shape with (sum(s[0] for s in shape), *shape[0][1:]), which assumes each leaf reports a 3D (C, H, W) shape (the CheckerboardLatentCodec convention). Family 1 leaves (LRPGaussianLatentCodec via GaussianConditionalLatentCodec) report a 2D (H, W) shape, which collapsed the buffer to 3D (N, sum_H, W) and triggered a broadcast RuntimeError when assigning the leaf's 4D y_hat back into the split slice (manifesting on STF/WACNN compress->decompress round-trip). Use self.groups for the channel total and take the trailing two dims of any per-group shape as spatial -- works for both leaf shape conventions. Adds regression coverage for both 2D and 3D leaf shapes.
The original LIC TCM (`tcm.py:434`), WACNN (`cnn.py:152`), and SymmetricalTransFormer (`stf.py:602`) implementations all run `quantize_ste(z - z_offset) + z_offset` after the entropy bottleneck so that downstream `h_s` consumes a STE-rounded `z_hat` (likelihoods are still computed on noisy z to train the parametric prior). The Family 1 containerization in `_build_family1_latent_codec` and `_build_tcm_latent_codec` was passing `quantizer="noise"` to the `z` leaf, which silently propagated noisy `z_hat` to `h_s` during training -- a real RD-relevant deviation from the published models. Switch both build sites to `quantizer="ste"`, matching CCA-main (`cca.py:360`) which already used STE. Eval-mode forward and state-dict layout are unchanged (same module tree, same parameters); only training-time `z_hat` propagated to `h_s` becomes deterministic. Also drop the misleading STF/TCM-noise vs CCA-STE distinction from `compressai/latent_codecs/__init__.py` and `tests/test_models.py`, replacing it with a Family 1 invariant note: all four models use STE on z.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds two new models from the per-model PR series in #353:
Per the discussion in #354, this PR also delivers the refactored latent-codec abstraction I committed to ship next ("I'll include the refactored abstraction layer in the next PR"). The new shared infrastructure unifies the channel-slice topology used by STF / WACNN (added in #354), TCM (this PR), CCA (this PR), and the upcoming DCAE / MambaVC follow-ups onto upstream
ChannelGroupsLatentCodecrather than the temporaryChannelSliceLatentCodecintroduced in #354.Pretrained weights are intentionally not bundled — calling
pretrained=Trueraises a clearRuntimeErroruntil weights are hosted on S3 (per the discussion in #353).Summary
"tcm"and"cca"(compressai.models.tcm.TCM,compressai.models.cca.CCAModel), wired via lazy-import_LazyImportproxy inmodel_architecturessoimport compressai.zoostaystimm-free.compressai.losses.cca.CCARateDistortionLoss— extendsRateDistortionLosswith the auxiliary "causal context adjustment" term (NeurIPS 2024 §3.2) wired to the optional_CCAAuxEntropyModelhead.compressai/latent_codecs/(see Refactor below).compressai/models/_helpers/{channel_slice,channel_context}.py— declarative factories that wire Family-1 models in ~3 calls.ChannelSliceLatentCodec+SliceEntropyCompressionModelscaffolding from Add WACNN and SymmetricalTransFormer (STF, CVPR 2022) #354 with no checkpoint-format break (LRP weights are byte-for-byte transferable).examples/convert_tcm_checkpoint.pyandexamples/convert_cca_checkpoint.pyfor the published upstream weights.timm.layers.LayerNorm2d(already pulled in by STF), so they live under the existing[attn]extras group set up in Add WACNN and SymmetricalTransFormer (STF, CVPR 2022) #354.Refactor: Family-1 latent-codec abstraction
The four models targeted by this PR series so far (STF / WACNN / TCM / CCA) all follow the same outer entropy-stack shape but differ in the four shaded boxes below. #354 absorbed the variation by introducing a dedicated
ChannelSliceLatentCodec; this PR shows that all four variants fit cleanly inside upstreamChannelGroupsLatentCodeconce it gains four optional kwargs, eliminating the duplicate codec class and giving Family-1 models the same wiring story as ELIC.Concretely the PR adds:
DualHyperSynthesislatent_codecs/_hyper_synthesis.pyh_mean_s(z)andh_scale_s(z)in parallel and concatenates the result, soHyperpriorLatentCodecsees a singleh_s.LRPGaussianLatentCodeclatent_codecs/gaussian_conditional.py(~30 lines appended)GaussianConditionalLatentCodecthat adds the LRP residual prediction (y_hat += lrp_scale * tanh(lrp_transform(cat(mean_support, y_hat)))). Withmean_support_trail_channelsset, the leaf reads its LRP input from a trailing block ofctx_paramsproduced by the head'semit_mean_supportmode — giving byte-for-byte weight transfer from the upstreamcat(latent_means, *prev_y_hat, y_hat)layout.ChannelGroupsLatentCodecextensionslatent_codecs/channel_groups.py(~50-line diff)max_support_slices(clamp the number of preceding slices used as prior),support_filter(callable to pick a custom subset of priors),support_count_fn(declare how many priorssupport_filteryields, so head input widths can be sized correctly), andside_in_context(routeside_paramsfromh_sthrough everychannel_contexthead instead of only handing it to the leaves). ELIC and other existing users default-through to the original behaviour.MeanScaleContextHead+build_mean_scale_headmodels/_helpers/channel_context.pyccstacks with optional independent support-transforms per branch, optional `emit_mean_support="pre"build_channel_slice_codecmodels/_helpers/channel_slice.pyChannelGroupsLatentCodecfromgroups+leaf_factory+channel_context_factoryin one call._slice_helperslatent_codecs/_slice_helpers.pyslice_support_channels,lrp_support_channels,make_entropy_transform,infer_num_slices,infer_max_support_slices) shared by all four models'from_state_dictmachinery.Per-model variation now lives entirely in the kwargs:
groupssupport_transform[M//10]*10ccheadswidths=(224, 176, 128, 64).[M//K]*KSWAtten(independent per mean/scale)ccheadswidths=(224, 128).slice_proportions=(8,28,56,92,136)(variable-length)NAFTransform(independent per mean/scale)EntropyBottleneckLatentCodec(quantizer="ste")forz.NAFTransformHyperpriorLatentCodectree; usessupport_filter=skip_most_recent+ matchingsupport_count_fn.The
__init__.pyofcompressai/latent_codecs/documents this wiring story in a top-level comment block so reviewers don't need to read each model file to understand the pattern.State-dict layout
Containerization shifts the saved keys to a single-layer
latent_codec.*prefix (theHyperpriorLatentCodec'sself.y/self.zare realnn.Moduleregistrations, not nested dicts). The published upstream checkpoints round-trip via the converters below — LRP weights transfer byte-for-byte thanks tomean_support_trail_channels, and TCM's per-slicegaussian_conditionalbuffer is materialized by copying the single shared upstream copy K times.Commits
Six commits, designed to be reviewed independently:
feat(latent_codecs): add containerized infrastructure for Family 1 codecslatent_codecs/{_hyper_synthesis, _slice_helpers, gaussian_conditional, channel_groups, __init__}.py+models/_helpers/{channel_slice, channel_context, __init__}.py+ testsrefactor(models/stf): migrate WACNN + SymmetricalTransFormer to containerized codecmodels/stf.py+examples/convert_stf_checkpoint.pyupdates +tests/test_models.py::TestStffeat(models): add TCM with containerized codecmodels/tcm.py+examples/convert_tcm_checkpoint.py+tests/test_models.py::TestTcmfeat(models): add CCA model and loss with containerized codecmodels/cca.py+losses/cca.py+examples/convert_cca_checkpoint.py+tests/test_models.py::TestCcachore(latent_codecs,models): drop ChannelSliceLatentCodec and SliceEntropyCompressionModellatent_codecs/channel_slice.py+ entiremodels/_bases/directory + remove exportschore(zoo): wire cca/tcm zoo entries with lazy importzoo/{__init__,image}.pyfactory functions +_LazyImportproxiesThe cleanup commit lands after all four models are migrated, so the branch never goes through a state where STF/WACNN are broken. The refactor and migrations preserve the existing public model classes — only the internal codec-tree shape and the corresponding state-dict paths change.
License & attribution
compressai/models/tcm.pycarries a dual-license header pointing at the upstreamjmliu206/LIC_TCM(Apache-2.0) alongside the standard InterDigital BSD 3-Clause Clear license for modifications.compressai/models/cca.pycarries a dual-license header pointing at the upstreamLabShuHangGU/CCA(MIT) alongside the standard InterDigital BSD 3-Clause Clear license for modifications. The internal_NAFBlock/_NAFTransformare derived from NAFNet (Chen et al. 2022, MIT) — happy to add per-class attribution headers if maintainers prefer.compressai/losses/cca.pysimilarly attributes the CCA paper for the auxiliary-loss formulation.Verified
pytest tests/ -q(excluding pretrained-dependent suites — the local S3 ckpt cache is corrupted withunexpected EOF, unrelated to this PR) → 213 passed, 4 skipped, 32 deselected.pytest tests/test_models.py tests/test_latent_codecs.py tests/test_models_helpers.py tests/test_layers.py tests/test_init.py -q→ 74 passed (3 newTestStf+ 2 newTestTcm+ 3 newTestCca+ existing).from_state_dict(strict=True)then forward + sinusoidal-image smoke):cnn_0018_best.pth.tar(585 keys) — strict load OK.stf_0018_best.pth.tar(779 keys) — strict load OK.0.05.pth.tar(N=64, M=320, 1397 keys after per-slice GC copy) — strict load OK, sinusoidal PSNR 39.15 dB / total bpp 0.317.mse_lambda_0.05.pth.tar(N=128, M=320, 1397 keys) — strict load OK, sinusoidal PSNR 39.41 dB / total bpp 0.236.checkpoint_lambda_0.3.pth.tar(M=320, slice_sizes=[8,28,56,92,136], 97M params, 2384 keys with main + aux) — strict load OK, sinusoidal PSNR 50.07 dB / total bpp 0.072. Fresh-init baseline at the same config gives ~5 dB, confirming weights are participating.import compressai+import compressai.zoo+import compressai.latent_codecstriggers 0 timm modules (verified viasys.modulessnapshot diff).make static-analysis(ruff format / imports / lint, fail-fast) → all 3 steps clean.uv lock --check→ consistent (nopyproject.tomlchanges in this PR).Test plan
TestStf,TestTcm,TestCca).latent_codec.*paths exist and the old top-level paths are gone (test_*_upstream_state_dict_conversion).ChannelGroupsLatentCodecextensions are backward-compatible with ELIC's existing usage (tests/test_models.py::TestElicstill green with default kwargs).Notes for follow-up PRs (per #353)
_CCAAuxEntropyModelis a privatenn.ModuleinsideCCAModel, but its forward signature(y, latent_means, latent_scales)only depends onlatent_channels+slice_proportions— not on the host backbone — so it should plug cleanly into WACNN / STF / TCM / MLIC++ / DCAE / SAAF / Mamba-family models via ause_cca=Trueopt-in. The plan is to extract it into a publiccompressai.entropy_models.CausalContextAdjustmentEntropyModel(or upgrade to aLatentCodecvariant), pair it with the existingCCARateDistortionLoss, and let host models add it in ~30 lines without touching their main entropy path. Whether this transfers the RD gains the CCA paper reports onLICAutoencoderto other backbones is an empirical question for the follow-up PR; this PR only commits to keeping the API minimal so the migration is straightforward.