MLX support for odd sequence lengths with SAME-S#59
Merged
Conversation
… music) T_lat was computed as ceil(seconds*44100/4096) and then bumped to even whenever --decoder=same-s. Because most durations round to an odd length (30s -> 323, 3s -> 33), same-s silently ran at L+1 while same-l ran at L. Since the noise is drawn as normal((1,256,T_lat), seed), the two decoders seeded the DiT with differently shaped noise and therefore sampled entirely different latents -- so switching only the decoder produced completely different music for the same prompt and seed. Derive T_lat from --seconds alone (natural ceil, no even-bump), matching the TensorRT pipeline (sa3_trt.py::resolve_T_lat). The latent -- and thus the music -- is now identical across decoders for a given prompt/seed/seconds. SAME-S still needs an even internal length (T_lat*17 must align to 34). That is handled at decode time, not by reshaping the noise: odd T_lat is routed through decode_chunked, whose windows are always even kernels. The decode dispatch is hardened for the sub-0.5s edge (odd T_lat <= 6) via a one-latent pad-and-trim so any duration decodes. The --init-audio encode path rounds the encode grid up to the encoder's modulo and trims the latent back to T_lat, keeping the DiT/noise path decoder-independent there too. Verified end-to-end vs PyTorch-eager (dit_torch / same_*_decoder_torch) at odd (33) and even (34) T_lat, both DiTs and both decoders: MLX-fp16 vs torch-fp32 latent PSNR is equal at odd vs even (small 61.7/59.5, medium 76.4/75.0 dB), confirming the odd path is not degraded; and same-s vs same-l on one DiT/seed/seconds now yield a bit-identical latent.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
T_lat was computed as ceil(seconds*44100/4096) and then bumped to even whenever --decoder=same-s. Because most durations round to an odd length (30s -> 323, 3s -> 33), same-s silently ran at L+1 while same-l ran at L. Since the noise is drawn as normal((1,256,T_lat), seed), the two decoders seeded the DiT with differently shaped noise and therefore sampled entirely different latents -- so switching only the decoder produced completely different music for the same prompt and seed.
Derive T_lat from --seconds alone (natural ceil, no even-bump), matching the TensorRT pipeline (sa3_trt.py::resolve_T_lat). The latent -- and thus the music -- is now identical across decoders for a given prompt/seed/seconds.
SAME-S still needs an even internal length (T_lat*17 must align to 34). That is handled at decode time, not by reshaping the noise: odd T_lat is routed through decode_chunked, whose windows are always even kernels. The decode dispatch is hardened for the sub-0.5s edge (odd T_lat <= 6) via a one-latent pad-and-trim so any duration decodes. The --init-audio encode path rounds the encode grid up to the encoder's modulo and trims the latent back to T_lat, keeping the DiT/noise path decoder-independent there too.
Verified end-to-end vs PyTorch-eager (dit_torch / same_*_decoder_torch) at odd (33) and even (34) T_lat, both DiTs and both decoders: MLX-fp16 vs torch-fp32 latent PSNR is equal at odd vs even (small 61.7/59.5, medium 76.4/75.0 dB), confirming the odd path is not degraded; and same-s vs same-l on one DiT/seed/seconds now yield a bit-identical latent.