Skip to content

MLX support for odd sequence lengths with SAME-S#59

Merged
Cortexelus merged 1 commit into
mainfrom
fix-tlat-decoder-independent
Jul 1, 2026
Merged

MLX support for odd sequence lengths with SAME-S#59
Cortexelus merged 1 commit into
mainfrom
fix-tlat-decoder-independent

Conversation

@Cortexelus

Copy link
Copy Markdown
Collaborator

T_lat was computed as ceil(seconds*44100/4096) and then bumped to even whenever --decoder=same-s. Because most durations round to an odd length (30s -> 323, 3s -> 33), same-s silently ran at L+1 while same-l ran at L. Since the noise is drawn as normal((1,256,T_lat), seed), the two decoders seeded the DiT with differently shaped noise and therefore sampled entirely different latents -- so switching only the decoder produced completely different music for the same prompt and seed.

Derive T_lat from --seconds alone (natural ceil, no even-bump), matching the TensorRT pipeline (sa3_trt.py::resolve_T_lat). The latent -- and thus the music -- is now identical across decoders for a given prompt/seed/seconds.

SAME-S still needs an even internal length (T_lat*17 must align to 34). That is handled at decode time, not by reshaping the noise: odd T_lat is routed through decode_chunked, whose windows are always even kernels. The decode dispatch is hardened for the sub-0.5s edge (odd T_lat <= 6) via a one-latent pad-and-trim so any duration decodes. The --init-audio encode path rounds the encode grid up to the encoder's modulo and trims the latent back to T_lat, keeping the DiT/noise path decoder-independent there too.

Verified end-to-end vs PyTorch-eager (dit_torch / same_*_decoder_torch) at odd (33) and even (34) T_lat, both DiTs and both decoders: MLX-fp16 vs torch-fp32 latent PSNR is equal at odd vs even (small 61.7/59.5, medium 76.4/75.0 dB), confirming the odd path is not degraded; and same-s vs same-l on one DiT/seed/seconds now yield a bit-identical latent.

… music)

T_lat was computed as ceil(seconds*44100/4096) and then bumped to even
whenever --decoder=same-s. Because most durations round to an odd length
(30s -> 323, 3s -> 33), same-s silently ran at L+1 while same-l ran at L.
Since the noise is drawn as normal((1,256,T_lat), seed), the two decoders
seeded the DiT with differently shaped noise and therefore sampled entirely
different latents -- so switching only the decoder produced completely
different music for the same prompt and seed.

Derive T_lat from --seconds alone (natural ceil, no even-bump), matching the
TensorRT pipeline (sa3_trt.py::resolve_T_lat). The latent -- and thus the
music -- is now identical across decoders for a given prompt/seed/seconds.

SAME-S still needs an even internal length (T_lat*17 must align to 34). That
is handled at decode time, not by reshaping the noise: odd T_lat is routed
through decode_chunked, whose windows are always even kernels. The decode
dispatch is hardened for the sub-0.5s edge (odd T_lat <= 6) via a one-latent
pad-and-trim so any duration decodes. The --init-audio encode path rounds the
encode grid up to the encoder's modulo and trims the latent back to T_lat,
keeping the DiT/noise path decoder-independent there too.

Verified end-to-end vs PyTorch-eager (dit_torch / same_*_decoder_torch) at
odd (33) and even (34) T_lat, both DiTs and both decoders: MLX-fp16 vs
torch-fp32 latent PSNR is equal at odd vs even (small 61.7/59.5, medium
76.4/75.0 dB), confirming the odd path is not degraded; and same-s vs same-l
on one DiT/seed/seconds now yield a bit-identical latent.
@Cortexelus Cortexelus merged commit dedace1 into main Jul 1, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant