Skip to content
40 changes: 33 additions & 7 deletions acestep/engine/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ class ACEAdapter:

def __init__(self, pipeline):
self._pipeline = pipeline
# Per-shape reuse caches for the eager (non-TRT) forward:
# (pinned-host, device) timestep buffer pairs keyed by B, and
# the all-ones attention mask keyed by (B, T). Both were
# allocated fresh on every forward. Shapes only vary with the
# ring-buffer fill level and the source duration, so the dicts
# stay tiny; cleared wholesale if they ever grow past 16.
self._t_bufs: dict = {}
self._attn_ones: dict = {}

def build_schedule(self, config, denoise: float, device, dtype) -> torch.Tensor:
from .diffusion import DiffusionConfig
Expand Down Expand Up @@ -163,14 +171,32 @@ def batched_forward(
ctx_batch=ctx_b,
)

t_b = torch.tensor(
timestep_list, device=p._device, dtype=p._dtype,
)
B = xt_batch.shape[0]
tb = self._t_bufs.get(B)
if tb is None:
if len(self._t_bufs) > 16:
self._t_bufs.clear()
pin = p._device is not None and p._device.type == "cuda"
tb = (
torch.empty(B, dtype=p._dtype, pin_memory=pin),
torch.empty(B, dtype=p._dtype, device=p._device),
)
self._t_bufs[B] = tb
t_host, t_b = tb
for i, t in enumerate(timestep_list):
t_host[i] = t
t_b.copy_(t_host, non_blocking=True)

mask_b = torch.cat(mask_list, dim=0)
attn_b = torch.ones(
xt_batch.shape[0], xt_batch.shape[1],
device=p._device, dtype=p._dtype,
)
key = (B, xt_batch.shape[1])
attn_b = self._attn_ones.get(key)
if attn_b is None:
if len(self._attn_ones) > 16:
self._attn_ones.clear()
attn_b = torch.ones(
key[0], key[1], device=p._device, dtype=p._dtype,
)
self._attn_ones[key] = attn_b

out = p.decoder(
hidden_states=xt_batch,
Expand Down
128 changes: 105 additions & 23 deletions acestep/engine/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ class _Slot:
# populated on step 0, reused on every subsequent step of this slot.
initial_noise: Optional[torch.Tensor] = None
vt_neg_cached: Optional[torch.Tensor] = None
# Memo for normalized per-slot curve fields (the ``_eff_shared``
# fallback path). Request fields are fixed after submit — the
# hot-mutation contract is ``set_shared_curve`` — so a scalar field
# is normalized once per slot instead of allocating a fresh
# ``[1, 1, 1]`` tensor on every read (per slot, per step).
curve_cache: Dict[str, torch.Tensor] = field(default_factory=dict)
# Memo for "does this slot field have any nonzero value" — the
# x0_target gate used to answer that with a per-slot-per-step
# ``.abs().any().item()`` device readback; the flag is computed at
# most once per slot here (and without any tensor op for the
# common Python-scalar fields).
nonzero_cache: Dict[str, bool] = field(default_factory=dict)


class StreamPipeline:
Expand Down Expand Up @@ -326,6 +338,12 @@ def __init__(
# at the setter so callers can pass scalars without thinking
# about shape.
self._shared_curves: dict[str, torch.Tensor] = {}
# Sidecar of ``_shared_curves``: "any nonzero value" per name,
# computed at the setter (no tensor op for scalar sets) so the
# x0_target gate in the step loop never has to read a tensor
# back from the device. Maintained strictly in lockstep with
# ``_shared_curves`` — see set_shared_curve / _curve_nonzero.
self._shared_curve_nonzero: dict[str, bool] = {}

# Channel guidance: a ``[1, T, 64]`` per-channel gain applied to
# ``xt`` before each forward pass. Lives in its own field rather
Expand Down Expand Up @@ -718,9 +736,14 @@ def _trt_forward(
else:
bufs["hidden_states"].copy_(xt_io)

# timestep: one scalar per row.
# timestep: stage all rows in the pinned host buffer, then one
# async H2D copy — instead of one tiny device write per row.
# Ordering with the TRT exec below is the same as the other
# input copies (legacy default stream → blocking TRT stream).
t_host = bufs["_timestep_host"]
for i, t in enumerate(timestep_list):
bufs["timestep"][i] = t
t_host[i] = t
bufs["timestep"].copy_(t_host, non_blocking=True)

# encoder_hidden_states: already padded to max_L + catted by
# the caller. The engine has no ``encoder_attention_mask``
Expand Down Expand Up @@ -888,6 +911,11 @@ def _ensure_trt_bufs(self, B: int, T: int, max_L: int):
bufs["_eff_T"] = eff_T
bufs["_T"] = T
bufs["_out_buf"] = out_buf
# Pinned host staging for the per-row timestep scalars (see
# _trt_forward). Underscore-prefixed so the bind loops skip it.
bufs["_timestep_host"] = torch.empty(
B, dtype=bufs["timestep"].dtype, pin_memory=True,
)
self._trt_bufs_cache[key] = bufs
while len(self._trt_bufs_cache) > self._trt_bufs_cache_max:
self._trt_bufs_cache.popitem(last=False)
Expand Down Expand Up @@ -1218,16 +1246,11 @@ def _forward_pairs(

# ``x0_target_strength`` path: blend toward a target latent
# at scalar (or per-frame curve) strength, gated to the
# refinement half. Preserving the historical "strength==0
# falls through to the fast path" behavior — checks the
# effective (shared override or slot field) strength via a
# tensor.any() sync, which costs one host-device fence per
# slot per step but lets the gate stay tensor-safe.
eff_strength = self._eff_shared(slot, "x0_target_strength")
strength_active = (
eff_strength is not None
and bool(eff_strength.abs().any().item())
)
# refinement half. Preserves the historical "strength==0
# falls through to the fast path" behavior; the nonzero
# check reads cached flags (maintained at set time), not
# tensor data, so the gate costs no host-device fence.
strength_active = self._curve_nonzero(slot, "x0_target_strength")
scalar_x0_target = (
req.x0_target is not None
and strength_active
Expand Down Expand Up @@ -1300,6 +1323,9 @@ def _forward_pairs(
x0_pred, req.x0_target, curve * blend_gate,
)
elif scalar_x0_target:
# Non-None whenever ``strength_active`` held: the gate
# and this read resolve from the same sources.
eff_strength = self._eff_shared(slot, "x0_target_strength")
alpha = eff_strength.to(device=x0_pred.device, dtype=x0_pred.dtype)
x0_pred = (1.0 - alpha) * x0_pred + alpha * req.x0_target

Expand Down Expand Up @@ -1402,28 +1428,83 @@ def set_shared_curve(
``value`` can be a scalar or a per-frame tensor; both flow
through :func:`ode_steps.normalize_curve` so the storage form is
always ``[B, T, 1]`` and downstream consumers do not need to
type-discriminate. Pass ``None`` to revert that name to per-slot
behavior.
type-discriminate. The canonical tensor is moved to the
pipeline's device here (dtype is left alone — consumers cast at
their own boundary exactly as before) so the per-step readers
never pay a host-to-device copy. Pass ``None`` to revert that
name to per-slot behavior.
"""
if value is None:
self._shared_curves.pop(name, None)
self._shared_curve_nonzero.pop(name, None)
return
self._shared_curves[name] = ode_steps.normalize_curve(value)
# Nonzero flag for the x0_target gate, computed where it is
# free: scalar sets need no tensor op at all, and tensor sets
# pay one readback here (per knob write) instead of one per
# slot per step in the loop.
if isinstance(value, (int, float, bool)):
nonzero = float(value) != 0.0
else:
nonzero = bool(value.abs().any().item())
v = ode_steps.normalize_curve(value)
if self._device is not None and v.device != self._device:
v = v.to(device=self._device)
self._shared_curves[name] = v
self._shared_curve_nonzero[name] = nonzero

def _eff_shared(self, slot: "_Slot", name: str):
"""Return shared override for ``name`` if set, else slot's field.

Output is always either ``None`` or a normalized ``[B, T, 1]``
tensor — the shared override is canonicalized at the setter, and
any ``SlotRequest`` field is normalized here so callers never
need to ``isinstance``-check.
tensor — the shared override is canonicalized (and device-cast)
at the setter, and any ``SlotRequest`` field is normalized once
per slot via ``slot.curve_cache``, so this hot-path read never
allocates or copies. Curves set before the pipeline learned its
device (first submit) are device-fixed here once and stored
back.
"""
v = self._shared_curves.get(name)
if v is None:
v = getattr(slot.request, name, None)
if v is None:
return None
return ode_steps.normalize_curve(v)
if v is not None:
if self._device is not None and v.device != self._device:
v = v.to(device=self._device)
self._shared_curves[name] = v
return v
v = slot.curve_cache.get(name)
if v is not None:
return v
raw = getattr(slot.request, name, None)
if raw is None:
return None
v = ode_steps.normalize_curve(raw)
if self._device is not None and v.device != self._device:
v = v.to(device=self._device)
slot.curve_cache[name] = v
return v

def _curve_nonzero(self, slot: "_Slot", name: str) -> bool:
"""True when the effective curve for ``name`` has any nonzero.

Same source-resolution order as :meth:`_eff_shared` (shared
override, then slot field), but answers the boolean gate
without touching tensor data in the hot loop: the shared flag
is maintained by ``set_shared_curve`` and the slot-field flag
is computed at most once per slot — for free when the field is
a Python scalar (every production caller), with a single
readback when it is a tensor.
"""
if name in self._shared_curves:
return self._shared_curve_nonzero.get(name, True)
flag = slot.nonzero_cache.get(name)
if flag is None:
raw = getattr(slot.request, name, None)
if raw is None:
flag = False
elif isinstance(raw, (int, float, bool)):
flag = float(raw) != 0.0
else:
flag = bool(raw.abs().any().item())
slot.nonzero_cache[name] = flag
return flag

def set_dcw(
self,
Expand Down Expand Up @@ -1625,6 +1706,7 @@ def close(self) -> None:
self._schedule_cache.clear()
self._compiled_cache.clear()
self._shared_curves.clear()
self._shared_curve_nonzero.clear()
# DCW corrector holds wavelet basis tensors on GPU; drop it.
self._dcw_corrector = None
# Detach references to the engine + decoder so DiffusionEngine.close
Expand Down
49 changes: 43 additions & 6 deletions acestep/streaming/ace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,14 +903,37 @@ def _after_produce(self, prep: dict, result_latent, is_fresh: bool) -> None:
# Stash the values the params echo + sampled trace need.
self._echo = prep["echo"]

def _dec_event_pair(self):
"""Cached CUDA-event pair for timing decodes, or None on CPU.

Events replace the old ``torch.cuda.synchronize()`` bracket so
the decode measurement never stalls the runner thread: the
bracket is resolved right after the waveform's D2H copy, which
already completed everything the events depend on.
"""
if not torch.cuda.is_available():
return None
pair = getattr(self, "_dec_ev_pair", None)
if pair is None:
pair = (
torch.cuda.Event(enable_timing=True),
torch.cuda.Event(enable_timing=True),
)
self._dec_ev_pair = pair
return pair

def render_window(self, t_start_s: float):
decode_src = (
self._current_result if self._current_result is not None
else self._last_result_latent
)
if decode_src is None:
return None
t1 = time.perf_counter()
dec_ev = self._dec_event_pair()
if dec_ev is not None:
dec_ev[0].record()
else:
t1 = time.perf_counter()
if self._walk_active:
# The DiT output spans [win_start_s,
# win_start_s + walk_window_s] of the song.
Expand Down Expand Up @@ -939,9 +962,15 @@ def render_window(self, t_start_s: float):
else:
audio_out = self.codec.decode(decode_src, t_start=t_start_s, cyclic=True)
win_offset_samples = 0
torch.cuda.synchronize()
self.last_dec_ms += (time.perf_counter() - t1) * 1000
if dec_ev is not None:
dec_ev[1].record()
else:
self.last_dec_ms += (time.perf_counter() - t1) * 1000
win_wav = audio_out.waveform.detach().cpu().float().squeeze(0)
if dec_ev is not None:
# The D2H copy above drained the decode work past the end
# event, so this resolves without a device-wide sync.
self.last_dec_ms += dec_ev[0].elapsed_time(dec_ev[1])
win_np = win_wav.numpy().T
win_start = audio_out.start_sample + win_offset_samples
return AudioChunk(pcm=win_np, start_sample=win_start)
Expand All @@ -962,11 +991,19 @@ def render_full(self):
and (result - self._mse_prev).pow(2).mean().item() < self.skip_threshold
):
return None
t1 = time.perf_counter()
dec_ev = self._dec_event_pair()
if dec_ev is not None:
dec_ev[0].record()
else:
t1 = time.perf_counter()
audio_out = self.codec.decode(result_latent)
torch.cuda.synchronize()
self.last_dec_ms += (time.perf_counter() - t1) * 1000
if dec_ev is not None:
dec_ev[1].record()
else:
self.last_dec_ms += (time.perf_counter() - t1) * 1000
wav = audio_out.waveform.detach().cpu().float().squeeze(0)
if dec_ev is not None:
self.last_dec_ms += dec_ev[0].elapsed_time(dec_ev[1])
wav_np = wav.numpy().T
if self.crop_seconds > 0:
wav_np = wav_np[:int(self.crop_seconds * SAMPLE_RATE)]
Expand Down
37 changes: 31 additions & 6 deletions acestep/streaming/diffusion_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def __init__(self, *, adapter=None, codec=None):
# runner for its latency trace.
self.last_tick_ms = 0.0
self.last_dec_ms = 0.0
# CUDA-event bracket around the engine step. Recorded each
# produce, resolved lazily at the START of the next produce so
# the measurement never inserts a full-device synchronize into
# the tick (the old bracket cost two of them). ``last_tick_ms``
# is therefore one tick stale on GPU — it only feeds the
# runner trace and the params echo, both diagnostics.
self._tick_ev_start = None
self._tick_ev_end = None
self._tick_ev_pending = False

# ---- contract defaults --------------------------------------------------

Expand Down Expand Up @@ -99,9 +108,23 @@ def produce(self, knobs: dict, ctx: TickContext, mode: ProduceMode) -> bool:
"""
prep = self._prepare_tick(knobs, ctx)

if torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
use_events = torch.cuda.is_available()
if use_events:
if self._tick_ev_start is None:
self._tick_ev_start = torch.cuda.Event(enable_timing=True)
self._tick_ev_end = torch.cuda.Event(enable_timing=True)
if self._tick_ev_pending:
# Last tick's bracket. Its work completed long ago (the
# render's D2H copy synced the stream), so this resolves
# without stalling; if it somehow hasn't, blocking here
# is no worse than the old synchronize.
self._tick_ev_end.synchronize()
self.last_tick_ms = self._tick_ev_start.elapsed_time(
self._tick_ev_end
)
self._tick_ev_start.record()
else:
t0 = time.perf_counter()

if mode == "reuse":
result_latent = self._last_result_latent
Expand All @@ -115,9 +138,11 @@ def produce(self, knobs: dict, ctx: TickContext, mode: ProduceMode) -> bool:
if result_latent is not None:
self._last_result_latent = result_latent

if torch.cuda.is_available():
torch.cuda.synchronize()
self.last_tick_ms = (time.perf_counter() - t0) * 1000
if use_events:
self._tick_ev_end.record()
self._tick_ev_pending = True
else:
self.last_tick_ms = (time.perf_counter() - t0) * 1000
self.last_dec_ms = 0.0

self._current_result = result_latent
Expand Down
Loading