From 70ebbffaa8d2db539f4dbbb67ef57f62063e790e Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 9 Jun 2026 17:49:58 -0400 Subject: [PATCH 1/7] perf(streaming): time produce/decode with CUDA events, not device syncs The produce bracket ran two full-device torch.cuda.synchronize() per tick and render_window/render_full a third, all purely to measure last_tick_ms/last_dec_ms. Record CUDA events around the engine step and the decode instead: the produce bracket resolves lazily at the start of the next produce (one tick stale; both readers are diagnostics), and the decode bracket resolves right after the waveform's D2H copy, which already drained the stream. Removes every measurement-only host-device sync from the tick loop and lets CPU prep overlap GPU work. --- acestep/streaming/ace_backend.py | 49 ++++++++++++++++++++++---- acestep/streaming/diffusion_backend.py | 37 +++++++++++++++---- 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/acestep/streaming/ace_backend.py b/acestep/streaming/ace_backend.py index ff454297..df0a9ad1 100644 --- a/acestep/streaming/ace_backend.py +++ b/acestep/streaming/ace_backend.py @@ -903,6 +903,25 @@ def _after_produce(self, prep: dict, result_latent, is_fresh: bool) -> None: # Stash the values the params echo + sampled trace need. self._echo = prep["echo"] + def _dec_event_pair(self): + """Cached CUDA-event pair for timing decodes, or None on CPU. + + Events replace the old ``torch.cuda.synchronize()`` bracket so + the decode measurement never stalls the runner thread: the + bracket is resolved right after the waveform's D2H copy, which + already completed everything the events depend on. + """ + if not torch.cuda.is_available(): + return None + pair = getattr(self, "_dec_ev_pair", None) + if pair is None: + pair = ( + torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True), + ) + self._dec_ev_pair = pair + return pair + def render_window(self, t_start_s: float): decode_src = ( self._current_result if self._current_result is not None @@ -910,7 +929,11 @@ def render_window(self, t_start_s: float): ) if decode_src is None: return None - t1 = time.perf_counter() + dec_ev = self._dec_event_pair() + if dec_ev is not None: + dec_ev[0].record() + else: + t1 = time.perf_counter() if self._walk_active: # The DiT output spans [win_start_s, # win_start_s + walk_window_s] of the song. @@ -939,9 +962,15 @@ def render_window(self, t_start_s: float): else: audio_out = self.codec.decode(decode_src, t_start=t_start_s, cyclic=True) win_offset_samples = 0 - torch.cuda.synchronize() - self.last_dec_ms += (time.perf_counter() - t1) * 1000 + if dec_ev is not None: + dec_ev[1].record() + else: + self.last_dec_ms += (time.perf_counter() - t1) * 1000 win_wav = audio_out.waveform.detach().cpu().float().squeeze(0) + if dec_ev is not None: + # The D2H copy above drained the decode work past the end + # event, so this resolves without a device-wide sync. + self.last_dec_ms += dec_ev[0].elapsed_time(dec_ev[1]) win_np = win_wav.numpy().T win_start = audio_out.start_sample + win_offset_samples return AudioChunk(pcm=win_np, start_sample=win_start) @@ -962,11 +991,19 @@ def render_full(self): and (result - self._mse_prev).pow(2).mean().item() < self.skip_threshold ): return None - t1 = time.perf_counter() + dec_ev = self._dec_event_pair() + if dec_ev is not None: + dec_ev[0].record() + else: + t1 = time.perf_counter() audio_out = self.codec.decode(result_latent) - torch.cuda.synchronize() - self.last_dec_ms += (time.perf_counter() - t1) * 1000 + if dec_ev is not None: + dec_ev[1].record() + else: + self.last_dec_ms += (time.perf_counter() - t1) * 1000 wav = audio_out.waveform.detach().cpu().float().squeeze(0) + if dec_ev is not None: + self.last_dec_ms += dec_ev[0].elapsed_time(dec_ev[1]) wav_np = wav.numpy().T if self.crop_seconds > 0: wav_np = wav_np[:int(self.crop_seconds * SAMPLE_RATE)] diff --git a/acestep/streaming/diffusion_backend.py b/acestep/streaming/diffusion_backend.py index 86bbaf24..4e010d8b 100644 --- a/acestep/streaming/diffusion_backend.py +++ b/acestep/streaming/diffusion_backend.py @@ -65,6 +65,15 @@ def __init__(self, *, adapter=None, codec=None): # runner for its latency trace. self.last_tick_ms = 0.0 self.last_dec_ms = 0.0 + # CUDA-event bracket around the engine step. Recorded each + # produce, resolved lazily at the START of the next produce so + # the measurement never inserts a full-device synchronize into + # the tick (the old bracket cost two of them). ``last_tick_ms`` + # is therefore one tick stale on GPU — it only feeds the + # runner trace and the params echo, both diagnostics. + self._tick_ev_start = None + self._tick_ev_end = None + self._tick_ev_pending = False # ---- contract defaults -------------------------------------------------- @@ -99,9 +108,23 @@ def produce(self, knobs: dict, ctx: TickContext, mode: ProduceMode) -> bool: """ prep = self._prepare_tick(knobs, ctx) - if torch.cuda.is_available(): - torch.cuda.synchronize() - t0 = time.perf_counter() + use_events = torch.cuda.is_available() + if use_events: + if self._tick_ev_start is None: + self._tick_ev_start = torch.cuda.Event(enable_timing=True) + self._tick_ev_end = torch.cuda.Event(enable_timing=True) + if self._tick_ev_pending: + # Last tick's bracket. Its work completed long ago (the + # render's D2H copy synced the stream), so this resolves + # without stalling; if it somehow hasn't, blocking here + # is no worse than the old synchronize. + self._tick_ev_end.synchronize() + self.last_tick_ms = self._tick_ev_start.elapsed_time( + self._tick_ev_end + ) + self._tick_ev_start.record() + else: + t0 = time.perf_counter() if mode == "reuse": result_latent = self._last_result_latent @@ -115,9 +138,11 @@ def produce(self, knobs: dict, ctx: TickContext, mode: ProduceMode) -> bool: if result_latent is not None: self._last_result_latent = result_latent - if torch.cuda.is_available(): - torch.cuda.synchronize() - self.last_tick_ms = (time.perf_counter() - t0) * 1000 + if use_events: + self._tick_ev_end.record() + self._tick_ev_pending = True + else: + self.last_tick_ms = (time.perf_counter() - t0) * 1000 self.last_dec_ms = 0.0 self._current_result = result_latent From e13f03f540cc23547f48e9a20c3849ea40af3bfc Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 9 Jun 2026 17:52:14 -0400 Subject: [PATCH 2/7] perf(stream): cache normalized shared curves on-device set_shared_curve stored CPU fp32 tensors and _eff_shared re-ran normalize_curve on every read, so every shared override (the exact path the live denoise/guidance knobs ride) cost a fresh CPU alloc plus a host-to-device copy per slot per step. Canonicalize and device-cast once at the setter (same approach set_channel_gain_tensor already uses), return dict hits directly, and memoize normalized SlotRequest scalar fields once per slot in _Slot.curve_cache. Dtype casting stays at the consumer boundary, so the math is byte-identical. --- acestep/engine/stream.py | 48 +++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/acestep/engine/stream.py b/acestep/engine/stream.py index 839b45c2..dda0297a 100644 --- a/acestep/engine/stream.py +++ b/acestep/engine/stream.py @@ -219,6 +219,12 @@ class _Slot: # populated on step 0, reused on every subsequent step of this slot. initial_noise: Optional[torch.Tensor] = None vt_neg_cached: Optional[torch.Tensor] = None + # Memo for normalized per-slot curve fields (the ``_eff_shared`` + # fallback path). Request fields are fixed after submit — the + # hot-mutation contract is ``set_shared_curve`` — so a scalar field + # is normalized once per slot instead of allocating a fresh + # ``[1, 1, 1]`` tensor on every read (per slot, per step). + curve_cache: Dict[str, torch.Tensor] = field(default_factory=dict) class StreamPipeline: @@ -1402,28 +1408,48 @@ def set_shared_curve( ``value`` can be a scalar or a per-frame tensor; both flow through :func:`ode_steps.normalize_curve` so the storage form is always ``[B, T, 1]`` and downstream consumers do not need to - type-discriminate. Pass ``None`` to revert that name to per-slot - behavior. + type-discriminate. The canonical tensor is moved to the + pipeline's device here (dtype is left alone — consumers cast at + their own boundary exactly as before) so the per-step readers + never pay a host-to-device copy. Pass ``None`` to revert that + name to per-slot behavior. """ if value is None: self._shared_curves.pop(name, None) return - self._shared_curves[name] = ode_steps.normalize_curve(value) + v = ode_steps.normalize_curve(value) + if self._device is not None and v.device != self._device: + v = v.to(device=self._device) + self._shared_curves[name] = v def _eff_shared(self, slot: "_Slot", name: str): """Return shared override for ``name`` if set, else slot's field. Output is always either ``None`` or a normalized ``[B, T, 1]`` - tensor — the shared override is canonicalized at the setter, and - any ``SlotRequest`` field is normalized here so callers never - need to ``isinstance``-check. + tensor — the shared override is canonicalized (and device-cast) + at the setter, and any ``SlotRequest`` field is normalized once + per slot via ``slot.curve_cache``, so this hot-path read never + allocates or copies. Curves set before the pipeline learned its + device (first submit) are device-fixed here once and stored + back. """ v = self._shared_curves.get(name) - if v is None: - v = getattr(slot.request, name, None) - if v is None: - return None - return ode_steps.normalize_curve(v) + if v is not None: + if self._device is not None and v.device != self._device: + v = v.to(device=self._device) + self._shared_curves[name] = v + return v + v = slot.curve_cache.get(name) + if v is not None: + return v + raw = getattr(slot.request, name, None) + if raw is None: + return None + v = ode_steps.normalize_curve(raw) + if self._device is not None and v.device != self._device: + v = v.to(device=self._device) + slot.curve_cache[name] = v + return v def set_dcw( self, From 8b08147fbb3b6cf608ee754ac7434d82950a27ea Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 9 Jun 2026 17:54:29 -0400 Subject: [PATCH 3/7] perf(stream): answer the x0_target gate from set-time flags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The strength gate ran eff_strength.abs().any().item() per slot per step — a host-device fence in the middle of the integration loop (the old comment even called it out). Track an "any nonzero" flag alongside each shared curve at set_shared_curve (free for scalar sets, one readback per knob write for tensor sets) and memoize the slot-field flag once per slot, computed without tensor ops for Python-scalar fields. The gate decision is unchanged for every input; the strength tensor itself is only fetched on the path that actually blends. --- acestep/engine/stream.py | 66 ++++++++++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 10 deletions(-) diff --git a/acestep/engine/stream.py b/acestep/engine/stream.py index dda0297a..e581efad 100644 --- a/acestep/engine/stream.py +++ b/acestep/engine/stream.py @@ -225,6 +225,12 @@ class _Slot: # is normalized once per slot instead of allocating a fresh # ``[1, 1, 1]`` tensor on every read (per slot, per step). curve_cache: Dict[str, torch.Tensor] = field(default_factory=dict) + # Memo for "does this slot field have any nonzero value" — the + # x0_target gate used to answer that with a per-slot-per-step + # ``.abs().any().item()`` device readback; the flag is computed at + # most once per slot here (and without any tensor op for the + # common Python-scalar fields). + nonzero_cache: Dict[str, bool] = field(default_factory=dict) class StreamPipeline: @@ -332,6 +338,12 @@ def __init__( # at the setter so callers can pass scalars without thinking # about shape. self._shared_curves: dict[str, torch.Tensor] = {} + # Sidecar of ``_shared_curves``: "any nonzero value" per name, + # computed at the setter (no tensor op for scalar sets) so the + # x0_target gate in the step loop never has to read a tensor + # back from the device. Maintained strictly in lockstep with + # ``_shared_curves`` — see set_shared_curve / _curve_nonzero. + self._shared_curve_nonzero: dict[str, bool] = {} # Channel guidance: a ``[1, T, 64]`` per-channel gain applied to # ``xt`` before each forward pass. Lives in its own field rather @@ -1224,16 +1236,11 @@ def _forward_pairs( # ``x0_target_strength`` path: blend toward a target latent # at scalar (or per-frame curve) strength, gated to the - # refinement half. Preserving the historical "strength==0 - # falls through to the fast path" behavior — checks the - # effective (shared override or slot field) strength via a - # tensor.any() sync, which costs one host-device fence per - # slot per step but lets the gate stay tensor-safe. - eff_strength = self._eff_shared(slot, "x0_target_strength") - strength_active = ( - eff_strength is not None - and bool(eff_strength.abs().any().item()) - ) + # refinement half. Preserves the historical "strength==0 + # falls through to the fast path" behavior; the nonzero + # check reads cached flags (maintained at set time), not + # tensor data, so the gate costs no host-device fence. + strength_active = self._curve_nonzero(slot, "x0_target_strength") scalar_x0_target = ( req.x0_target is not None and strength_active @@ -1306,6 +1313,9 @@ def _forward_pairs( x0_pred, req.x0_target, curve * blend_gate, ) elif scalar_x0_target: + # Non-None whenever ``strength_active`` held: the gate + # and this read resolve from the same sources. + eff_strength = self._eff_shared(slot, "x0_target_strength") alpha = eff_strength.to(device=x0_pred.device, dtype=x0_pred.dtype) x0_pred = (1.0 - alpha) * x0_pred + alpha * req.x0_target @@ -1416,11 +1426,21 @@ def set_shared_curve( """ if value is None: self._shared_curves.pop(name, None) + self._shared_curve_nonzero.pop(name, None) return + # Nonzero flag for the x0_target gate, computed where it is + # free: scalar sets need no tensor op at all, and tensor sets + # pay one readback here (per knob write) instead of one per + # slot per step in the loop. + if isinstance(value, (int, float, bool)): + nonzero = float(value) != 0.0 + else: + nonzero = bool(value.abs().any().item()) v = ode_steps.normalize_curve(value) if self._device is not None and v.device != self._device: v = v.to(device=self._device) self._shared_curves[name] = v + self._shared_curve_nonzero[name] = nonzero def _eff_shared(self, slot: "_Slot", name: str): """Return shared override for ``name`` if set, else slot's field. @@ -1451,6 +1471,31 @@ def _eff_shared(self, slot: "_Slot", name: str): slot.curve_cache[name] = v return v + def _curve_nonzero(self, slot: "_Slot", name: str) -> bool: + """True when the effective curve for ``name`` has any nonzero. + + Same source-resolution order as :meth:`_eff_shared` (shared + override, then slot field), but answers the boolean gate + without touching tensor data in the hot loop: the shared flag + is maintained by ``set_shared_curve`` and the slot-field flag + is computed at most once per slot — for free when the field is + a Python scalar (every production caller), with a single + readback when it is a tensor. + """ + if name in self._shared_curves: + return self._shared_curve_nonzero.get(name, True) + flag = slot.nonzero_cache.get(name) + if flag is None: + raw = getattr(slot.request, name, None) + if raw is None: + flag = False + elif isinstance(raw, (int, float, bool)): + flag = float(raw) != 0.0 + else: + flag = bool(raw.abs().any().item()) + slot.nonzero_cache[name] = flag + return flag + def set_dcw( self, *, @@ -1651,6 +1696,7 @@ def close(self) -> None: self._schedule_cache.clear() self._compiled_cache.clear() self._shared_curves.clear() + self._shared_curve_nonzero.clear() # DCW corrector holds wavelet basis tensors on GPU; drop it. self._dcw_corrector = None # Detach references to the engine + decoder so DiffusionEngine.close From 0d225e654f3bff216bacfe366cad4b5c61f603f5 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 9 Jun 2026 17:56:50 -0400 Subject: [PATCH 4/7] perf(stream): batch the per-row small H2D traffic in both forwards TRT path: the timestep fill wrote one Python scalar per row straight into the device buffer (B tiny H2D writes per forward, doubled under CFG). Stage the rows in a pinned host buffer that lives in the shape-keyed bufs cache and issue a single async copy; ordering with the TRT exec matches the other input copies (legacy default stream vs the blocking polygraphy stream). Eager path: every forward allocated a fresh timestep tensor (pageable H2D) and a fresh all-ones attention mask kernel. Reuse both from tiny per-shape caches on the adapter; values and dtypes are unchanged. --- acestep/engine/model_adapter.py | 40 +++++++++++++++++++++++++++------ acestep/engine/stream.py | 14 ++++++++++-- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/acestep/engine/model_adapter.py b/acestep/engine/model_adapter.py index 5c50d648..ddbb7aa4 100644 --- a/acestep/engine/model_adapter.py +++ b/acestep/engine/model_adapter.py @@ -104,6 +104,14 @@ class ACEAdapter: def __init__(self, pipeline): self._pipeline = pipeline + # Per-shape reuse caches for the eager (non-TRT) forward: + # (pinned-host, device) timestep buffer pairs keyed by B, and + # the all-ones attention mask keyed by (B, T). Both were + # allocated fresh on every forward. Shapes only vary with the + # ring-buffer fill level and the source duration, so the dicts + # stay tiny; cleared wholesale if they ever grow past 16. + self._t_bufs: dict = {} + self._attn_ones: dict = {} def build_schedule(self, config, denoise: float, device, dtype) -> torch.Tensor: from .diffusion import DiffusionConfig @@ -163,14 +171,32 @@ def batched_forward( ctx_batch=ctx_b, ) - t_b = torch.tensor( - timestep_list, device=p._device, dtype=p._dtype, - ) + B = xt_batch.shape[0] + tb = self._t_bufs.get(B) + if tb is None: + if len(self._t_bufs) > 16: + self._t_bufs.clear() + pin = p._device is not None and p._device.type == "cuda" + tb = ( + torch.empty(B, dtype=p._dtype, pin_memory=pin), + torch.empty(B, dtype=p._dtype, device=p._device), + ) + self._t_bufs[B] = tb + t_host, t_b = tb + for i, t in enumerate(timestep_list): + t_host[i] = t + t_b.copy_(t_host, non_blocking=True) + mask_b = torch.cat(mask_list, dim=0) - attn_b = torch.ones( - xt_batch.shape[0], xt_batch.shape[1], - device=p._device, dtype=p._dtype, - ) + key = (B, xt_batch.shape[1]) + attn_b = self._attn_ones.get(key) + if attn_b is None: + if len(self._attn_ones) > 16: + self._attn_ones.clear() + attn_b = torch.ones( + key[0], key[1], device=p._device, dtype=p._dtype, + ) + self._attn_ones[key] = attn_b out = p.decoder( hidden_states=xt_batch, diff --git a/acestep/engine/stream.py b/acestep/engine/stream.py index e581efad..a3974e6d 100644 --- a/acestep/engine/stream.py +++ b/acestep/engine/stream.py @@ -736,9 +736,14 @@ def _trt_forward( else: bufs["hidden_states"].copy_(xt_io) - # timestep: one scalar per row. + # timestep: stage all rows in the pinned host buffer, then one + # async H2D copy — instead of one tiny device write per row. + # Ordering with the TRT exec below is the same as the other + # input copies (legacy default stream → blocking TRT stream). + t_host = bufs["_timestep_host"] for i, t in enumerate(timestep_list): - bufs["timestep"][i] = t + t_host[i] = t + bufs["timestep"].copy_(t_host, non_blocking=True) # encoder_hidden_states: already padded to max_L + catted by # the caller. The engine has no ``encoder_attention_mask`` @@ -906,6 +911,11 @@ def _ensure_trt_bufs(self, B: int, T: int, max_L: int): bufs["_eff_T"] = eff_T bufs["_T"] = T bufs["_out_buf"] = out_buf + # Pinned host staging for the per-row timestep scalars (see + # _trt_forward). Underscore-prefixed so the bind loops skip it. + bufs["_timestep_host"] = torch.empty( + B, dtype=bufs["timestep"].dtype, pin_memory=True, + ) self._trt_bufs_cache[key] = bufs while len(self._trt_bufs_cache) > self._trt_bufs_cache_max: self._trt_bufs_cache.popitem(last=False) From 0cddfae7944cfab3a5a804a719258fb0f82bfd29 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 9 Jun 2026 18:02:47 -0400 Subject: [PATCH 5/7] perf(session): run prompt/timbre encodes off the tick-loop lock MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit set_prompt, the timbre applies, and clear_timbre held state._lock for the full duration of their GPU encodes (text encoder x2 + VAE for timbre). The runner acquires the same lock at the top of every tick (read_knobs / has_pending_refit), so every prompt or timbre change stalled the entire tick loop — no ticks, no gap-fill — for the encoder's duration. Restructure those flows as snapshot-encode-commit: inputs are read under state._lock, the encodes run unlocked (serialized against each other by the new state._encode_lock, preserving the old encodes-do-not-overlap property), and the commit re-checks the new state.cond_epoch, re-snapshotting and re-encoding if another conditioning commit landed meanwhile (so concurrent prompt/timbre/swap effects all land, matching the old serialized end state). The swap commit keeps its atomic locked shape and bumps the epoch. Timbre apply no longer needs rollback state: nothing mutates before the commit. --- acestep/streaming/session.py | 288 +++++++++++++++++++++++------------ acestep/streaming/state.py | 13 ++ 2 files changed, 201 insertions(+), 100 deletions(-) diff --git a/acestep/streaming/session.py b/acestep/streaming/session.py index b3acaa06..00f70d3a 100644 --- a/acestep/streaming/session.py +++ b/acestep/streaming/session.py @@ -354,13 +354,18 @@ class StreamingSession: ``state.pending_*`` and drain inside ``apply_pending`` which the runner calls from ``before_tick``. Cheap from any thread. - Conditioning, timbre, structure, and prompt-blend mutations run - ``encode_cond_pair`` / ``encode_audio`` on the caller thread and - serialize against each other and against the source-swap commit - via ``state._lock``. Two transports calling these concurrently - will block one until the other returns. The swap's setup work - (TRT profile resolve, ``prepare_source``, stem extract) runs - unlocked; only the commit phase that writes the new track's - fields holds the lock. + ``encode_cond_pair`` / ``encode_audio`` on the caller thread. + Prompt and timbre encodes run OFF ``state._lock`` (the runner + acquires that lock every tick, so a locked encode used to stall + the whole tick loop): inputs are snapshotted under the lock, the + encode serializes against other encodes via + ``state._encode_lock``, and the commit re-checks + ``state.cond_epoch``, re-encoding if another conditioning commit + landed meanwhile. Structure mutations and the source-swap commit + still hold ``state._lock`` throughout (the swap's setup work — + TRT profile resolve, ``prepare_source``, stem extract — runs + unlocked; the commit bumps ``cond_epoch`` so racing encodes + retry against the new track). - ``set_knobs`` and ``set_prompt_blend`` echo without mutating when ``origin=CommandOrigin.EXTERNAL``. @@ -967,6 +972,13 @@ def _apply_swap_if_pending(self) -> None: # Use the active timbre reference if one is uploaded; # otherwise the new playback source's own latent. # Override persists across source swaps. + # + # This commit rewrites cond_pair AND its encode inputs + # (bpm/key/duration/source), so bump the conditioning + # epoch: any prompt/timbre encode in flight on another + # thread re-snapshots and re-encodes instead of + # committing pairs built from the old track. + state.cond_epoch += 1 self.stream.source = new_source state.source = new_source state.playback_samples = int(new_wf.shape[-1]) @@ -1062,6 +1074,42 @@ def _active_refer_latent(self): tl = self.state.timbre_latent return tl if tl is not None else self.state.source.latent + def _encode_cond_pairs( + self, + tags: str, + tags_b: str | None, + refer, + bpm, + duration, + key, + time_signature, + *, + encode_b: bool, + ): + """Run the text encodes for an A (and optionally B) cond pair + WITHOUT holding ``state._lock``. + + The runner takes ``state._lock`` at the top of every tick + (``read_knobs`` / ``has_pending_refit``), so encoding under it + stalled the whole tick loop — no ticks, no gap-fill — for the + encoder's duration on every prompt/timbre change. + ``state._encode_lock`` keeps the historical "encodes serialize + against each other" property. Returns ``(pair, pair_b)`` with + ``pair_b`` None when ``encode_b`` is False (caller mirrors A). + """ + with self.state._encode_lock: + pair = encode_cond_pair( + self.session, tags, refer, bpm, duration, key, + time_signature, + ) + pair_b = None + if encode_b: + pair_b = encode_cond_pair( + self.session, tags_b, refer, bpm, duration, key, + time_signature, + ) + return pair, pair_b + def _refresh_conditioning(self): """Recompose ``stream.conditioning`` from the cached A/B pairs, current timbre strength, and current prompt blend.""" @@ -1166,26 +1214,34 @@ def _clear_struct_override(self): r.mark_hint_dirty() def _apply_timbre_waveform(self, t_wf: torch.Tensor, name: str) -> float: - """Mutate timbre state for a new ref. Returns post-truncation - duration (seconds). Rolls back to prior state and re-raises on - any failure.""" + """Apply a new timbre ref. Returns post-truncation duration + (seconds). Compute-then-commit: the VAE encode and the cond + re-encodes run WITHOUT ``state._lock`` (the runner takes that + lock every tick), into locals; the commit under the lock is + atomic and re-checks ``cond_epoch``, retrying the encodes if + another conditioning commit landed meanwhile. Nothing is + mutated before the commit, so failure needs no rollback — + exceptions propagate to ``_apply_ref``'s handler as before.""" state = self.state - prev_timbre_latent = state.timbre_latent - prev_timbre_name = state.timbre_name - prev_cond_pair = state.cond_pair - prev_cond_pair_b = state.cond_pair_b - prev_stream_cond = self.stream.conditioning - try: - cap = int(state.duration * SAMPLE_RATE) - t_wf = t_wf[:, :cap] - rem = t_wf.shape[-1] % self.pool + for _ in range(5): + with state._lock: + duration = state.duration + prompt_text = state.prompt_text + prompt_text_b = state.prompt_text_b + bpm = state.bpm + key = state.key + ts = state.time_signature + epoch = state.cond_epoch + cap = int(duration * SAMPLE_RATE) + wf = t_wf[:, :cap] + rem = wf.shape[-1] % self.pool if rem: - t_wf = t_wf[:, :t_wf.shape[-1] - rem] - if t_wf.shape[-1] < self.pool: + wf = wf[:, :wf.shape[-1] - rem] + if wf.shape[-1] < self.pool: raise ValueError("timbre clip too short") - clip_s = t_wf.shape[-1] / SAMPLE_RATE + clip_s = wf.shape[-1] / SAMPLE_RATE sc = _try_load_sidecar( - name, samples=int(t_wf.shape[-1]), + name, samples=int(wf.shape[-1]), ) if sc is not None: device = self.session.handler.device @@ -1196,42 +1252,36 @@ def _apply_timbre_waveform(self, t_wf: torch.Tensor, name: str) -> float: logger.debug("timbre_sidecar_hit name={}", name) else: timbre_audio = Audio( - waveform=t_wf, sample_rate=SAMPLE_RATE, + waveform=wf, sample_rate=SAMPLE_RATE, ) logger.debug( "timbre_vae_encode_start clip_s={:.1f} channels={}", - clip_s, t_wf.shape[0], + clip_s, wf.shape[0], ) - timbre_latent = self.session.encode_audio(timbre_audio) + with state._encode_lock: + timbre_latent = self.session.encode_audio(timbre_audio) logger.debug( "timbre_vae_encode_done latent_shape={}", tuple(timbre_latent.tensor.shape), ) - state.timbre_latent = timbre_latent - state.timbre_name = name - state.cond_pair = encode_cond_pair( - self.session, state.prompt_text, timbre_latent, - state.bpm, state.duration, state.key, - state.time_signature, + pair, pair_b = self._encode_cond_pairs( + prompt_text, prompt_text_b, timbre_latent, + bpm, duration, key, ts, + encode_b=(prompt_text_b != prompt_text), ) - # Re-encode B against the new timbre too. - if state.prompt_text_b != state.prompt_text: - state.cond_pair_b = encode_cond_pair( - self.session, state.prompt_text_b, timbre_latent, - state.bpm, state.duration, state.key, - state.time_signature, - ) - else: - state.cond_pair_b = state.cond_pair - self._refresh_conditioning() + with state._lock: + if state.cond_epoch != epoch: + continue + state.cond_epoch += 1 + state.timbre_latent = timbre_latent + state.timbre_name = name + state.cond_pair = pair + state.cond_pair_b = pair_b if pair_b is not None else pair + self._refresh_conditioning() return clip_s - except Exception: - state.timbre_latent = prev_timbre_latent - state.timbre_name = prev_timbre_name - state.cond_pair = prev_cond_pair - state.cond_pair_b = prev_cond_pair_b - self.stream.conditioning = prev_stream_cond - raise + raise RuntimeError( + "timbre apply lost to concurrent conditioning changes 5 times", + ) def _apply_structure_waveform(self, s_wf: torch.Tensor, name: str) -> tuple[float, float]: """Stash a structure-ref waveform and re-derive the override's @@ -1390,34 +1440,56 @@ def set_prompt( ) -> None: """Re-encode A (and optionally B) against the active timbre reference and refresh the live conditioning. Publishes - :class:`PromptApplied`.""" + :class:`PromptApplied`. + + Snapshot-encode-commit: inputs are read under ``state._lock``, + the text encodes run unlocked (see ``_encode_cond_pairs``), and + the commit re-checks ``cond_epoch`` — if another conditioning + commit landed while we encoded, re-snapshot and re-encode so + both effects land (the old whole-method lock gave the same + end state by serializing callers).""" state = self.state state.last_activity_ts = time.monotonic() - with state._lock: - ts_override = _normalize_time_signature(time_signature) - if ts_override is not None: - state.time_signature = ts_override - refer = self._active_refer_latent() - key_used = key or state.key + for _ in range(5): + with state._lock: + ts_override = _normalize_time_signature(time_signature) + if ts_override is not None: + state.time_signature = ts_override + refer = self._active_refer_latent() + key_used = key or state.key + bpm = state.bpm + duration = state.duration + ts_used = state.time_signature + epoch = state.cond_epoch logger.info( "prompt_set origin={} tags={!r} tags_b={!r} key={} time_signature={}", - origin.value, tags, tags_b, key_used, state.time_signature, + origin.value, tags, tags_b, key_used, ts_used, ) - state.cond_pair = encode_cond_pair( - self.session, tags, refer, state.bpm, state.duration, - key_used, state.time_signature, + encode_b = bool(tags_b and tags_b != tags) + pair, pair_b = self._encode_cond_pairs( + tags, tags_b, refer, bpm, duration, key_used, ts_used, + encode_b=encode_b, ) - state.prompt_text = tags - if tags_b and tags_b != tags: - state.cond_pair_b = encode_cond_pair( - self.session, tags_b, refer, state.bpm, state.duration, - key_used, state.time_signature, - ) - state.prompt_text_b = tags_b - else: - state.cond_pair_b = state.cond_pair - state.prompt_text_b = tags - self._refresh_conditioning() + with state._lock: + if state.cond_epoch != epoch: + continue + state.cond_epoch += 1 + state.cond_pair = pair + state.prompt_text = tags + if pair_b is not None: + state.cond_pair_b = pair_b + state.prompt_text_b = tags_b + else: + state.cond_pair_b = pair + state.prompt_text_b = tags + self._refresh_conditioning() + break + else: + logger.warning( + "prompt_set_dropped tags={!r}: lost to concurrent " + "conditioning changes 5 times", tags, + ) + return self.bus.publish(PromptApplied(tags=tags)) def set_prompt_blend( @@ -1622,12 +1694,14 @@ def set_timbre_source( (or hits the fixture sidecar) and replaces cond_full.""" self.state.last_activity_ts = time.monotonic() logger.info("set_timbre_source_recv origin={} name={}", origin.value, name) - with self.state._lock: - self._apply_ref( - "timbre", name, - lambda: audio.waveform[:2], - "source", - ) + # No outer lock: _apply_timbre_waveform snapshots, encodes + # unlocked, and commits under state._lock itself, so the tick + # loop keeps running through the encode. + self._apply_ref( + "timbre", name, + lambda: audio.waveform[:2], + "source", + ) @requires_capability("timbre", "set_timbre_fixture") def set_timbre_fixture( @@ -1640,12 +1714,12 @@ def set_timbre_fixture( from the pod's local cache; same apply path as upload.""" self.state.last_activity_ts = time.monotonic() logger.info("set_timbre_fixture origin={} name={}", origin.value, name) - with self.state._lock: - self._apply_ref( - "timbre", name, - lambda: self._load_fixture_waveform(name), - "fixture", - ) + # No outer lock — same reasoning as set_timbre_source. + self._apply_ref( + "timbre", name, + lambda: self._load_fixture_waveform(name), + "fixture", + ) @requires_capability("timbre", "clear_timbre_source") def clear_timbre_source( @@ -1657,24 +1731,38 @@ def clear_timbre_source( (encode against the playback source's own latent).""" state = self.state state.last_activity_ts = time.monotonic() - with state._lock: - state.timbre_latent = None - state.timbre_name = None - refer = state.source.latent - state.cond_pair = encode_cond_pair( - self.session, state.prompt_text, refer, - state.bpm, state.duration, state.key, - state.time_signature, + # Snapshot-encode-commit, same shape as set_prompt: the + # re-encodes run without state._lock so ticks keep flowing. + for _ in range(5): + with state._lock: + refer = state.source.latent + prompt_text = state.prompt_text + prompt_text_b = state.prompt_text_b + bpm = state.bpm + duration = state.duration + key = state.key + ts = state.time_signature + epoch = state.cond_epoch + pair, pair_b = self._encode_cond_pairs( + prompt_text, prompt_text_b, refer, bpm, duration, key, ts, + encode_b=(prompt_text_b != prompt_text), ) - if state.prompt_text_b != state.prompt_text: - state.cond_pair_b = encode_cond_pair( - self.session, state.prompt_text_b, refer, - state.bpm, state.duration, state.key, - state.time_signature, - ) - else: - state.cond_pair_b = state.cond_pair - self._refresh_conditioning() + with state._lock: + if state.cond_epoch != epoch: + continue + state.cond_epoch += 1 + state.timbre_latent = None + state.timbre_name = None + state.cond_pair = pair + state.cond_pair_b = pair_b if pair_b is not None else pair + self._refresh_conditioning() + break + else: + logger.warning( + "timbre_clear_dropped: lost to concurrent conditioning " + "changes 5 times", + ) + return self.bus.publish(TimbreCleared()) logger.info("timbre_cleared origin={}", origin.value) diff --git a/acestep/streaming/state.py b/acestep/streaming/state.py index 778337bc..c87b3d15 100644 --- a/acestep/streaming/state.py +++ b/acestep/streaming/state.py @@ -134,3 +134,16 @@ class SessionState: # dict, or that read/write multiple fields atomically. Plain # single-field reads/writes don't take the lock (GIL atomicity). _lock: threading.RLock = field(default_factory=threading.RLock) + + # === Conditioning epoch + encode serialization === + # ``cond_epoch`` is bumped under ``_lock`` by every commit that + # writes ``cond_pair``/``cond_pair_b`` or their encode inputs + # (prompt, timbre latent, swap commit). The unlocked encode flows + # in the session snapshot it before encoding and retry when it + # moved, so a slow encode can never commit conditioning built from + # stale inputs. ``_encode_lock`` serializes the GPU encodes + # themselves against each other (the historical property that + # holding ``_lock`` provided) without blocking the runner's + # per-tick ``_lock`` acquisitions. + cond_epoch: int = 0 + _encode_lock: threading.Lock = field(default_factory=threading.Lock) From 3ba48f66118860f8538eaf9fb99bb642b1dccc09 Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 9 Jun 2026 18:06:48 -0400 Subject: [PATCH 6/7] feat(runner): near-playhead re-patch after fresh generations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The frontier write lands at the adaptive lead, which is floored (0.25s default) for stall safety, and the region between the playhead and that lead is never otherwise rewritten — so the lead floor was the hard audibility floor for every control change regardless of how fast the engine reacted (the measured knob-to-ear was dominated by it). After a real generation (mode=="generate", fresh result) and only within 2s of inbound activity, render a second window from the new latent at playhead + (interval_ema * gain + safety_margin) — the lead formula without the floor and stall bump — and patch it in with the same crossfade/clamp behavior as the frontier write. The frontier write keeps covering the buffer through stalls exactly as before; the close write is opportunistic, so a late landing just leaves the valid older audio in place. Loop-band aware (wraps the target inside an armed band, clamps at B, skips sub-window bands); walk mode passes no band. Does not feed _note_decode_gap (second write in the same tick, like the band-wrap render). Costs one extra windowed decode (~2.4ms) per active-generation tick; idle and untouched sessions pay nothing. --- acestep/streaming/pipeline_runner.py | 117 +++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/acestep/streaming/pipeline_runner.py b/acestep/streaming/pipeline_runner.py index 82dded30..f756860f 100644 --- a/acestep/streaming/pipeline_runner.py +++ b/acestep/streaming/pipeline_runner.py @@ -96,6 +96,7 @@ def __init__( lead_floor_s=None, lead_ceiling_s=None, lead_release_tau_s=None, + near_patch_enabled=True, ): self.backend = backend self.audio_eng = audio_eng @@ -288,6 +289,30 @@ def on_audio_ready(wav, win_start=None, win_end=None): # the multi-second session-startup build) can't push it to the ceiling. self._rebuild_prewarm_s = 1.1 self._rebuild_prewarm_cap_s = 1.3 + + # ----- Near-playhead re-patch (knob-to-ear) ----- + # The frontier write above always lands at the adaptive lead, + # which is FLOORED (0.25s default) for stall safety — and the + # region between the playhead and that lead is never otherwise + # rewritten. That makes the lead floor the hard audibility + # floor for every control change: however fast the engine + # reacts, the listener plays out the floor's worth of + # already-written audio first. After a real generation lands + # (and only while the operator is actively driving the + # session), a SECOND window is rendered from the new latent at + # ``playhead + (interval_ema * gain + margin)`` — the same lead + # formula WITHOUT the floor and stall bump — and patched in. + # The frontier write keeps the buffer covered through stalls + # exactly as before; this write is purely opportunistic (if it + # lands late the buffer simply keeps the valid, older audio), + # so it can run as close as transit allows. See + # ``_near_playhead_repatch``. + self._near_patch_enabled = bool(near_patch_enabled) + # Only re-patch within this window after the last inbound + # activity: while knobs are being ridden the close-in region + # refreshes every tick; an untouched session pays zero extra + # decode/wire traffic. + self._near_patch_active_window_s = 2.0 self._playhead_clock = _RemotePlayheadClock(self.audio_eng) # ---- delegates kept for the session's runner_holder contract ---------- @@ -368,6 +393,78 @@ def _note_decode_gap(self) -> float: def _playhead_seconds_now(self) -> float: return self._playhead_clock.seconds() + def _near_playhead_repatch(self, backend, eff_dur: float, band) -> None: + """Render + patch a window just ahead of the playhead from the + backend's newest latent. See the init-block comment for why + this exists (the lead floor is otherwise the audibility floor + for every control change). + + Mirrors the frontier write's crossfade/clamp behavior; loop-band + aware (wraps the target inside an armed band, clamps the write + at B, and skips bands narrower than one window — the frontier + render already rewrites those whole). Does NOT feed + ``_note_decode_gap``: like the band-wrap render, it is a second + write within the same tick, not its own production interval. + """ + close_s = ( + self._decode_interval_ema_s * self._lead_interval_gain + + self._lead_safety_margin_s + ) + if close_s >= self._decode_advance_s(): + return # frontier write already lands this close + playhead_now = self._playhead_seconds_now() + target = playhead_now + close_s + band_end_sample = None + if band is not None and eff_dur > 0: + a_s = max(0.0, min(float(band[0]), eff_dur)) + b_s = max(0.0, min(float(band[1]), eff_dur)) + span = b_s - a_s + if span > 1e-3 and a_s <= playhead_now <= b_s: + if span < self.vae_window: + return + target = a_s + ((playhead_now + close_s - a_s) % span) + band_end_sample = int(round(b_s * SAMPLE_RATE)) + if eff_dur > 0: + target = target % eff_dur + chunk = backend.render_window(target) + if chunk is None: + return + win_np = chunk.pcm + win_start = chunk.start_sample + win_end = win_start + win_np.shape[0] + current = self.audio_eng.current + xfade = min(1200, win_np.shape[0] // 4) + if win_start > 0 and xfade > 0: + t_in = np.linspace(0.0, 1.0, xfade).reshape(-1, 1) + win_np[:xfade] = ( + current[win_start:win_start + xfade] * (1 - t_in) + + win_np[:xfade] * t_in + ) + if win_end < current.shape[0] and xfade > 0: + t_out = np.linspace(1.0, 0.0, xfade).reshape(-1, 1) + tail = min(xfade, current.shape[0] - win_end + xfade) + s = win_np.shape[0] - tail + win_np[s:] = ( + win_np[s:] * t_out[:tail] + + current[win_start + s:win_start + s + tail] + * (1 - t_out[:tail]) + ) + clamp_end = min(win_end, current.shape[0]) + if band_end_sample is not None and band_end_sample > win_start: + clamp_end = min(clamp_end, band_end_sample) + if clamp_end <= win_start: + return + patched = win_np[:clamp_end - win_start] + self.audio_eng.patch_window(patched, win_start) + self.on_audio_ready(patched, win_start, win_end) + if _LAT_TRACE: + logger.info( + "lat_nearpatch playhead_s={:.3f} close_s={:.3f} " + "win_start_s={:.3f} win_end_s={:.3f}", + playhead_now, close_s, + win_start / SAMPLE_RATE, win_end / SAMPLE_RATE, + ) + # ---- the loop ------------------------------------------------------------- def run(self): @@ -754,6 +851,26 @@ def run(self): ) self.audio_eng.patch_window(wrap_np, wrap_start) self.on_audio_ready(wrap_np, wrap_start, wrap_end) + + # Knob-to-ear: after a real generation, also refresh + # the window just ahead of the playhead from the new + # latent (see _near_playhead_repatch). Gated to + # active operation so an idle session pays nothing; + # "reuse" (DiT-pause) and gap-fill ticks carry no new + # content, so only mode=="generate" qualifies. + if ( + self._near_patch_enabled + and is_fresh + and mode == "generate" + and ( + time.monotonic() - self.state.last_activity_ts + < self._near_patch_active_window_s + ) + ): + self._near_playhead_repatch( + backend, eff_dur, + None if position_chase_only else band, + ) else: # Legacy full-buffer mode. Gap-fill never reaches # here (it requires vae_window > 0), so only fresh From 0539d714bcdc6a88fdfbc877574bcd3dbe4f44fb Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Tue, 9 Jun 2026 18:17:54 -0400 Subject: [PATCH 7/7] fix(golden): audible_first = earliest heard, not first arrival MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _action_audible took the FIRST qualifying slice in arrival order, which equaled the earliest-audible one only while slice starts advanced monotonically. The near-playhead re-patch emits a second, closer-to-playhead slice a few ms after each frontier write, so the first-arriving qualifying slice is now the FARTHEST one — the metric reported the old lead-floor number (~218ms) while re-patched content sat ~120ms from the playhead. Take min over qualifying slices of max(arrival, playhead-reaches-start); identical for monotonic streams, and a late-arriving slice cannot game it (arrival is inside the max). --- tests/golden/runner.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/golden/runner.py b/tests/golden/runner.py index 08433192..f486eb41 100644 --- a/tests/golden/runner.py +++ b/tests/golden/runner.py @@ -157,13 +157,18 @@ def _action_audible(ready_at: float, slices: list, entry: dict) -> dict: later = slices[entry["slices_before"]:] out: dict = {"audible_first_ms": None, "audible_full_ms": None} - first = next((s for s in later - if s.recv_at >= sent - and s.start_sample / SAMPLE_RATE > pos_at_send), None) - if first is not None: - heard = max(first.recv_at, - ready_at + first.start_sample / SAMPLE_RATE) - out["audible_first_ms"] = round((heard - sent) * 1000.0, 1) + # Earliest HEARD over all qualifying slices, not the first arrival: + # the near-playhead re-patch emits a second, closer-to-playhead + # slice a few ms after each frontier write, so slice starts are no + # longer monotonic in arrival order. A late-arriving slice can't + # game this — arrival time is inside the max(). + first_heard = [ + max(s.recv_at, ready_at + s.start_sample / SAMPLE_RATE) + for s in later + if s.recv_at >= sent and s.start_sample / SAMPLE_RATE > pos_at_send + ] + if first_heard: + out["audible_first_ms"] = round((min(first_heard) - sent) * 1000.0, 1) frontier = entry.get("frontier_s") if frontier is not None: