diff --git a/acestep/engine/model_context.py b/acestep/engine/model_context.py index 1c74e330..c8ae069b 100644 --- a/acestep/engine/model_context.py +++ b/acestep/engine/model_context.py @@ -15,6 +15,8 @@ import math import os +import sys +import threading import time from contextlib import contextmanager from typing import Optional, Tuple @@ -79,6 +81,11 @@ def __init__( self._compile_vae = compile_vae self._offload_text_encoder = False self._diffusion_engine = None + # Serializes model placement (vram_parked) against every + # eager-module consumer (_load_model_context). Reentrant so a + # consumer that nests _load_model_context calls doesn't deadlock + # on itself. + self._placement_lock = threading.RLock() try: self._load_models( @@ -501,56 +508,271 @@ def _recursive_to_device(self, model, device, dtype=None) -> None: @contextmanager def _load_model_context(self, model_name: str): - """Move *model_name* to GPU for the duration of the block.""" - if ( - model_name == "text_encoder" - and self._offload_text_encoder - ): - pass # fall through to offload logic - elif not self.offload_to_cpu: - yield - return + """Move *model_name* to GPU for the duration of the block. + + Holds ``_placement_lock`` for the whole block: every eager-module + consumer routes through here, so a concurrent :meth:`vram_parked` + (which holds the same lock while the models sit on CPU) blocks + the consumer until the models are back on the device instead of + letting it run GPU inputs against CPU weights. + + Self-healing for persistently-offloaded contexts (see + :meth:`offload_eager_to_cpu`): when this context is in resident + mode (``offload_to_cpu=False``) but the requested module was + parked on CPU, it is restored to the device here before the + consumer runs — so a parked context never serves CPU weights to + a GPU consumer, no matter which code path touches it first. + """ + with self._placement_lock: + if ( + model_name == "text_encoder" + and self._offload_text_encoder + ): + pass # fall through to offload logic + elif not self.offload_to_cpu: + self._restore_module_if_parked(model_name) + yield + return + + if model_name == "model" and not self.offload_dit_to_cpu: + model = getattr(self, model_name, None) + if model is not None: + try: + param = next(model.parameters()) + if param.device.type == "cpu": + logger.info(f"Moving {model_name} to {self.device} (persistent)") + self._recursive_to_device(model, self.device, self.dtype) + if self.silence_latent is not None: + self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) + except StopIteration: + pass + yield + return - if model_name == "model" and not self.offload_dit_to_cpu: model = getattr(self, model_name, None) - if model is not None: - try: - param = next(model.parameters()) - if param.device.type == "cpu": - logger.info(f"Moving {model_name} to {self.device} (persistent)") - self._recursive_to_device(model, self.device, self.dtype) - if self.silence_latent is not None: - self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) - except StopIteration: - pass - yield - return + if model is None: + yield + return - model = getattr(self, model_name, None) - if model is None: - yield - return + logger.info(f"Loading {model_name} to {self.device}") + t0 = time.time() + if model_name == "vae": + self._recursive_to_device(model, self.device, self._get_vae_dtype()) + else: + self._recursive_to_device(model, self.device, self.dtype) - logger.info(f"Loading {model_name} to {self.device}") - t0 = time.time() - if model_name == "vae": - self._recursive_to_device(model, self.device, self._get_vae_dtype()) - else: - self._recursive_to_device(model, self.device, self.dtype) + if model_name == "model" and self.silence_latent is not None: + self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) - if model_name == "model" and self.silence_latent is not None: - self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) + logger.info(f"Loaded {model_name} in {time.time() - t0:.3f}s") - logger.info(f"Loaded {model_name} in {time.time() - t0:.3f}s") + try: + yield + finally: + logger.info(f"Offloading {model_name} to CPU") + t0 = time.time() + self._recursive_to_device(model, "cpu") + torch.cuda.empty_cache() + logger.info(f"Offloaded {model_name} in {time.time() - t0:.3f}s") + + # Eager torch modules vram_parked() may move between devices. TRT + # engines are deliberately NOT here: their device memory is owned by + # TensorRT execution contexts and cannot be offloaded without + # destroying (and expensively rebuilding) them. + _PARKABLE_MODULES = ("model", "vae", "text_encoder") + def _accel_device(self) -> Optional[torch.device]: + """The context's device when it's an offloadable accelerator.""" + device = torch.device(self.device) + return device if device.type in ("cuda", "xpu") else None + + def _module_on_device(self, module, device: torch.device) -> bool: try: - yield - finally: - logger.info(f"Offloading {model_name} to CPU") - t0 = time.time() - self._recursive_to_device(model, "cpu") + return any( + p.device.type == device.type for p in module.parameters() + ) + except Exception: + return False + + def _offload_eager_locked(self) -> list: + """Move every GPU-resident parkable module + the silence latent + to CPU and hand the freed pages back to CUDA. Caller must hold + ``_placement_lock``. Returns the moved attribute names.""" + parked: list = [] + device = self._accel_device() + if device is None: + return parked + for name in self._PARKABLE_MODULES: + module = getattr(self, name, None) + if module is None or not self._module_on_device(module, device): + continue + self._recursive_to_device(module, "cpu") + parked.append(name) + if ( + self.silence_latent is not None + and self.silence_latent.device.type == device.type + ): + self.silence_latent = self.silence_latent.cpu() + parked.append("silence_latent") + if parked and device.type == "cuda" and torch.cuda.is_available(): torch.cuda.empty_cache() - logger.info(f"Offloaded {model_name} in {time.time() - t0:.3f}s") + return parked + + def _restore_eager_locked(self, names) -> list: + """Move the named parked attributes back to ``self.device`` with + their canonical dtypes. Caller must hold ``_placement_lock``. + Attempts every name even if one fails; returns the exceptions.""" + errors: list = [] + for name in names: + try: + if name == "silence_latent": + self._ensure_silence_latent_on_device() + continue + module = getattr(self, name, None) + if module is None: + continue + dtype = ( + self._get_vae_dtype() if name == "vae" else self.dtype + ) + self._recursive_to_device(module, self.device, dtype) + except Exception as exc: + logger.exception( + "vram_restore_failed module={} error={}", name, exc, + ) + errors.append(exc) + return errors + + def _restore_module_if_parked(self, model_name: str) -> None: + """Restore one module (plus the silence latent for ``model``) + when a persistent offload left it on CPU. Caller must hold + ``_placement_lock``. No-op when already resident.""" + device = self._accel_device() + if device is None: + return + module = getattr(self, model_name, None) + if module is None or self._module_on_device(module, device): + return + logger.info( + "vram_restore_on_demand module={} device={}", + model_name, self.device, + ) + names = [model_name] + if ( + model_name == "model" + and self.silence_latent is not None + and self.silence_latent.device.type == "cpu" + ): + names.append("silence_latent") + errors = self._restore_eager_locked(names) + if errors: + raise errors[0] + + def offload_eager_to_cpu(self) -> list: + """Persistently evict the eager modules from VRAM. + + Unlike :meth:`vram_parked`, nothing restores the modules when + this returns — they stay in system RAM until a consumer touches + them through :meth:`_load_model_context` (per-module lazy + restore) or :meth:`ensure_eager_on_device` brings everything + back at once. + + Used by the demo's shared upload-encoder session: the eager + weights are only needed while an upload is in flight, so parking + them between uploads keeps ~6 GB of VRAM free for the live + streaming session. Returns the parked attribute names. + """ + with self._placement_lock: + parked = self._offload_eager_locked() + if parked: + logger.info( + "vram_offloaded_persistent modules={} device={}", + parked, self.device, + ) + return parked + + def ensure_eager_on_device(self) -> list: + """Inverse of :meth:`offload_eager_to_cpu`: restore every parked + eager module to the device at once. Returns the restored names. + No-op in ``offload_to_cpu`` mode (per-op placement owns moves + there) and on CPU contexts.""" + if self.offload_to_cpu: + return [] + with self._placement_lock: + device = self._accel_device() + if device is None: + return [] + names: list = [] + for name in self._PARKABLE_MODULES: + module = getattr(self, name, None) + if module is None or self._module_on_device(module, device): + continue + names.append(name) + if ( + self.silence_latent is not None + and self.silence_latent.device.type == "cpu" + ): + names.append("silence_latent") + if not names: + return [] + errors = self._restore_eager_locked(names) + logger.info( + "vram_restored_bulk modules={} errors={}", names, len(errors), + ) + if errors: + raise errors[0] + return names + + @contextmanager + def vram_parked(self): + """Temporarily evict the eager ACE-Step modules from VRAM. + + Moves every GPU-resident parkable module (DiT, VAE, text + encoder) plus the silence latent to CPU, returns the freed pages + to CUDA via ``empty_cache()``, and yields the list of parked + attribute names. On exit — success or exception — everything + that was parked is moved back to ``self.device`` with its + canonical dtype. + + Used by ``acestep.streaming.stems.extract_upload_stems`` to make + room for the Mel-Band RoFormer separator on VRAM-constrained + pods: park ACE-Step → run the separator → release it → restore + ACE-Step. + + Holds ``_placement_lock`` for the entire block (including the + caller's body), so concurrent conditioning / VAE / semantic ops + — which all enter :meth:`_load_model_context` — wait for the + restore instead of crashing on CPU-placed weights. The lock is + reentrant, but callers must not run ACE-Step inference from + inside the block: the weights are on CPU. + + ``model.to()`` moves parameters in place, so long-lived + references into these modules (DiffusionEngine, node handles) + stay valid across a park/restore cycle. No-op (yields ``[]``) + when the context lives on CPU or nothing is GPU-resident — + nesting is therefore safe and idempotent. + """ + with self._placement_lock: + t0 = time.time() + parked = self._offload_eager_locked() + if parked: + logger.info( + "vram_parked modules={} duration_s={:.2f}", + parked, time.time() - t0, + ) + try: + yield list(parked) + finally: + if parked: + t0 = time.time() + restore_errors = self._restore_eager_locked(parked) + logger.info( + "vram_unparked modules={} duration_s={:.2f} errors={}", + parked, time.time() - t0, len(restore_errors), + ) + # Surface a restore failure, but never mask an + # exception already unwinding from the body. + if restore_errors and sys.exc_info()[0] is None: + raise restore_errors[0] # ------------------------------------------------------------------ # Text embedding inference diff --git a/acestep/gpu_config.py b/acestep/gpu_config.py index 2c1d3cae..ab95ec65 100644 --- a/acestep/gpu_config.py +++ b/acestep/gpu_config.py @@ -149,6 +149,45 @@ def get_gpu_memory_gb() -> float: return 0 +def get_vram_telemetry(device=None) -> Optional[Dict[str, float]]: + """Snapshot of CUDA VRAM occupancy in GiB for ``device``. + + Returns ``None`` when CUDA isn't available or ``device`` isn't a + CUDA device — callers treat that as "no VRAM to manage". + + Keys: + free_gb: driver-reported free memory (excludes torch's cache). + total_gb: device capacity. + allocated_gb: torch tensors currently live. + reserved_gb: torch caching-allocator pages held from the driver. + available_gb: what a new allocation can realistically claim — + driver-free plus pages torch has cached but not handed out. + """ + try: + import torch + + if not torch.cuda.is_available(): + return None + dev = torch.device(device if device is not None else "cuda") + if dev.type != "cuda": + return None + index = dev.index if dev.index is not None else torch.cuda.current_device() + free_b, total_b = torch.cuda.mem_get_info(index) + allocated_b = torch.cuda.memory_allocated(index) + reserved_b = torch.cuda.memory_reserved(index) + gib = float(1024 ** 3) + return { + "free_gb": free_b / gib, + "total_gb": total_b / gib, + "allocated_gb": allocated_b / gib, + "reserved_gb": reserved_b / gib, + "available_gb": (free_b + max(0, reserved_b - allocated_b)) / gib, + } + except Exception as e: + logger.warning(f"Failed to read VRAM telemetry: {e}") + return None + + def get_gpu_tier(gpu_memory_gb: float) -> str: """ Determine GPU tier based on available memory. diff --git a/acestep/nodes/vae_nodes.py b/acestep/nodes/vae_nodes.py index 3e9e86fe..505108df 100644 --- a/acestep/nodes/vae_nodes.py +++ b/acestep/nodes/vae_nodes.py @@ -279,6 +279,41 @@ def _find_trt_engine(name: str) -> Optional[str]: return None +def _trt_vae_profile_fits( + engine_path: str, tensor_name: str, shape, +) -> Optional[bool]: + """Whether ``shape`` satisfies the cached engine's optimization profile. + + The module-level TRT VAE cache is shared process-wide, so a cached + engine may belong to a session whose profile doesn't cover another + caller's input — e.g. the live streaming session's 60 s + ``vae_encode`` engine vs a 120 s upload going through the eager + upload-encoder. Callers use this to decide between the cached TRT + engine and an eager fallback BEFORE committing to the TRT path. + + Returns ``True``/``False`` when the profile verdict is known, or + ``None`` when it can't be determined (engine not cached, API error) + — treat ``None`` as "behave as before" (use TRT). + """ + try: + entry = _trt_vae_cache.get(os.path.abspath(engine_path)) + if entry is None or entry.get("engine") is None: + return None + mn, _opt, mx = entry["engine"].get_tensor_profile_shape(tensor_name, 0) + if len(mn) != len(shape): + return False + return all( + int(mn[i]) <= int(shape[i]) <= int(mx[i]) + for i in range(len(shape)) + ) + except Exception as exc: + logger.warning( + "trt_vae_profile_check_failed engine={} tensor={} error={}", + engine_path, tensor_name, exc, + ) + return None + + def _find_best_vae_engine(component: str) -> Optional[str]: """Return a TRT VAE engine path for *component* if one was preloaded. @@ -347,6 +382,21 @@ def execute(self, **kwargs: Any) -> dict[str, Any]: waveform = waveform.unsqueeze(0) trt_path = _find_best_vae_engine("vae_encode") if _trt_available() else None + if ( + trt_path + and handler.vae is not None + and _trt_vae_profile_fits(trt_path, "audio", tuple(waveform.shape)) is False + ): + # The cached engine belongs to another session and its + # profile can't take this input (e.g. a >60 s upload vs the + # live session's 60 s engine). This handler carries an eager + # VAE — use it instead of letting TRT reject the shape. + logger.info( + "vae_encode_trt_profile_mismatch input_shape={} engine={} " + "fallback=eager", + tuple(waveform.shape), os.path.basename(trt_path), + ) + trt_path = None if trt_path: logger.info("VAE encode via TRT") latents_bdt = _trt_vae_encode(waveform, trt_path, device) @@ -398,8 +448,22 @@ def execute(self, **kwargs: Any) -> dict[str, Any]: lat_bdt = latent.tensor.transpose(1, 2) trt_path = _find_best_vae_engine("vae_decode") if _trt_available() else None + if ( + trt_path + and handler.vae is not None + and _trt_vae_profile_fits(trt_path, "latents", tuple(lat_bdt.shape)) is False + ): + # Same cross-session profile hazard as the encode node: an + # eager-VAE handler must not be forced through another + # session's cached engine when the shape can't fit it. + logger.info( + "vae_decode_trt_profile_mismatch input_shape={} engine={} " + "fallback=eager", + tuple(lat_bdt.shape), os.path.basename(trt_path), + ) + trt_path = None if trt_path: - logger.info("VAE decode via TRT") + # logger.info("VAE decode via TRT") waveform = _trt_vae_decode(lat_bdt, trt_path, device) else: logger.info("VAE decode via PyTorch (no TRT engine found)") diff --git a/acestep/streaming/session.py b/acestep/streaming/session.py index b3acaa06..a62a74f2 100644 --- a/acestep/streaming/session.py +++ b/acestep/streaming/session.py @@ -41,6 +41,7 @@ import functools import os +import threading import time from dataclasses import asdict from pathlib import Path @@ -253,6 +254,11 @@ def extract_and_select_upload_stem( waveform=waveform, device=session.handler.device, backend_sample_rate=SAMPLE_RATE, + # Park this session's eager modules while the RoFormer + # runs (restored before the prepare_source below needs + # them back). Safe here: at create the runner doesn't + # exist yet, and at swap WE ARE the runner thread. + model_context=session.handler, ) if source_mode == "full": return upload_stems, None, source, waveform @@ -452,6 +458,12 @@ def __init__( # methods publish; transport adapters subscribe and serialize. self.bus = EventBus() + # Set at the end of close(), after GPU state is released. A + # preempting connection (ws_adapter's single-active-session + # policy) waits on this before creating its own session so the + # two model stacks never need VRAM simultaneously. + self.closed = threading.Event() + # The session's GeneratorBackend, selected by SessionConfig.backend # via the family registry. Constructed here (not in run()) so the # contract surface — capabilities()/geometry()/knob_specs() — is @@ -704,6 +716,9 @@ def close(self) -> None: self.session.close() except Exception as exc: logger.warning("session_close_raised error={}", exc) + # Last: signal waiters (preempting connections) that this + # session's GPU state is gone. + self.closed.set() def _on_audio_ready(self, wav_np, win_start=None, win_end=None): """Runner callback. Mutates ``audio_eng`` for full-buffer diff --git a/acestep/streaming/stems.py b/acestep/streaming/stems.py index 507c1a10..e2dac4ab 100644 --- a/acestep/streaming/stems.py +++ b/acestep/streaming/stems.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import gc import importlib import math @@ -17,9 +18,11 @@ import torch.nn.functional as F import torchaudio.functional as TAF from einops import pack, rearrange, reduce, repeat, unpack +from loguru import logger from torch import nn from torch.nn import Module, ModuleList +from acestep.gpu_config import get_vram_telemetry from acestep.model_downloader import resolve_melband_roformer_model_path try: @@ -741,6 +744,74 @@ def separate_stems( _INFER_LOCK = threading.Lock() +# --------------------------------------------------------------------------- +# VRAM management for the separator +# --------------------------------------------------------------------------- +# +# The RoFormer loads on top of a resident ACE-Step session (the streaming +# session at create/swap time, or the shared eager upload-encoder session +# in the demo's upload path). On VRAM-constrained pods that stack is the +# memory-pressure spike: before the separator loads we therefore park the +# ACE-Step context's eager modules on CPU (ModelContext.vram_parked), +# run separation, release the RoFormer, and only then restore ACE-Step — +# the separator and the parked models never need VRAM at the same time. + +# Free VRAM (GiB) the separator needs before it will load WITHOUT parking +# the resident ACE-Step models first. Covers fp16 weights plus chunked +# STFT/transformer activations and the on-device track buffers. Override +# with the env var; 0 disables parking entirely, a large value forces it. +MELBAND_VRAM_RESERVE_ENV = "DEMON_MELBAND_VRAM_RESERVE_GB" +DEFAULT_MELBAND_VRAM_RESERVE_GB = 6.0 + + +def melband_vram_reserve_gb() -> float: + raw = os.environ.get(MELBAND_VRAM_RESERVE_ENV) + if raw is None or not raw.strip(): + return DEFAULT_MELBAND_VRAM_RESERVE_GB + try: + return max(0.0, float(raw)) + except ValueError: + logger.warning( + "melband_vram_reserve_invalid value={!r} fallback={}", + raw, DEFAULT_MELBAND_VRAM_RESERVE_GB, + ) + return DEFAULT_MELBAND_VRAM_RESERVE_GB + + +def log_vram_telemetry(phase: str, device: torch.device) -> dict | None: + """Log one structured ``stems_vram`` line; returns the snapshot.""" + telemetry = get_vram_telemetry(device) if device.type == "cuda" else None + if telemetry is not None: + logger.info( + "stems_vram phase={} free_gb={:.2f} available_gb={:.2f} " + "allocated_gb={:.2f} reserved_gb={:.2f} total_gb={:.2f}", + phase, + telemetry["free_gb"], + telemetry["available_gb"], + telemetry["allocated_gb"], + telemetry["reserved_gb"], + telemetry["total_gb"], + ) + return telemetry + + +def should_park_for_melband(device: torch.device) -> tuple[bool, float, float]: + """Decide whether resident ACE-Step models must vacate VRAM first. + + Returns ``(park, available_gb, reserve_gb)``. Parks only when the + device is CUDA, parking isn't disabled (reserve > 0), and the + realistically-claimable VRAM (driver-free + torch's cached slack) + is below the reserve the separator needs. + """ + reserve_gb = melband_vram_reserve_gb() + if device.type != "cuda" or reserve_gb <= 0.0: + return False, 0.0, reserve_gb + telemetry = get_vram_telemetry(device) + if telemetry is None: + return False, 0.0, reserve_gb + available_gb = float(telemetry["available_gb"]) + return available_gb < reserve_gb, available_gb, reserve_gb + def normalize_stem_source_mode(value: object) -> str | None: if not isinstance(value, str): @@ -768,6 +839,7 @@ def extract_upload_stems( waveform: torch.Tensor, device: torch.device | str, backend_sample_rate: int, + model_context=None, ) -> dict[str, torch.Tensor]: """Use Mel-Band RoFormer for vocal and instrumental separation. @@ -775,30 +847,62 @@ def extract_upload_stems( is trained for 44.1 kHz. The separator handles the downsample internally; we resample its returned stems back to the backend sample rate before sending overlays or preparing a selected stem as the inference source. + + ``model_context`` is the resident ACE-Step + :class:`~acestep.engine.model_context.ModelContext` sharing ``device`` + (``session.handler``). When VRAM is tight (see + :func:`should_park_for_melband`) its eager modules are parked on CPU + for the duration of separation and restored only after the RoFormer + has been released, so the two model stacks never need VRAM + simultaneously. ``None`` preserves the legacy load-on-top behavior. """ torch_device = _coerce_device(device) t0 = time.time() with _INFER_LOCK: - model: MelBandRoformer | None = None - try: - model_path = _resolve_model_path() - print(f"[Server] Loading Mel-Band RoFormer model on {torch_device}...") - load_t0 = time.time() - model = load_model(model_path, torch_device) - print(f"[Server] Mel-Band RoFormer loaded in {time.time() - load_t0:.1f}s") - vocals_44k, instruments_44k = separate_stems( - model, - waveform.detach().cpu().float().unsqueeze(0), - backend_sample_rate, - torch_device, - ) - if torch_device.type == "cuda": - torch.cuda.synchronize(torch_device) - finally: - if model is not None: - print(f"[Server] Releasing Mel-Band RoFormer model from {torch_device}...") - del model - _collect_device_cache(torch_device) + park, available_gb, reserve_gb = should_park_for_melband(torch_device) + if model_context is None: + park = False + logger.info( + "stems_vram_plan park={} available_gb={:.2f} reserve_gb={:.2f} " + "model_context={}", + park, available_gb, reserve_gb, + "present" if model_context is not None else "absent", + ) + log_vram_telemetry("before_separation", torch_device) + park_ctx = ( + model_context.vram_parked() if park else contextlib.nullcontext() + ) + with park_ctx: + if park: + log_vram_telemetry("acestep_parked", torch_device) + model: MelBandRoformer | None = None + try: + model_path = _resolve_model_path() + print(f"[Server] Loading Mel-Band RoFormer model on {torch_device}...") + load_t0 = time.time() + model = load_model(model_path, torch_device) + log_vram_telemetry("melband_loaded", torch_device) + print(f"[Server] Mel-Band RoFormer loaded in {time.time() - load_t0:.1f}s") + vocals_44k, instruments_44k = separate_stems( + model, + waveform.detach().cpu().float().unsqueeze(0), + backend_sample_rate, + torch_device, + ) + if torch_device.type == "cuda": + torch.cuda.synchronize(torch_device) + log_vram_telemetry("melband_separated", torch_device) + finally: + # Release the RoFormer BEFORE the park context restores + # ACE-Step (this finally runs first on block exit), so + # the restore lands in the VRAM the separator vacated. + if model is not None: + print(f"[Server] Releasing Mel-Band RoFormer model from {torch_device}...") + del model + _collect_device_cache(torch_device) + log_vram_telemetry("melband_released", torch_device) + if park: + log_vram_telemetry("acestep_restored", torch_device) print(f"[Server] Mel-Band RoFormer stems complete in {time.time() - t0:.1f}s") vocals = _fit_stem_waveform( diff --git a/demos/realtime_motion_graph_web/web/hooks/useStartSession.ts b/demos/realtime_motion_graph_web/web/hooks/useStartSession.ts index 56fef257..d4506b10 100644 --- a/demos/realtime_motion_graph_web/web/hooks/useStartSession.ts +++ b/demos/realtime_motion_graph_web/web/hooks/useStartSession.ts @@ -6,7 +6,7 @@ import { AudioPlayer } from "@demon/client"; import { listFixtures, loadFixtureAudio, pickDefaultFixture } from "@/engine/audio/loadFixture"; import { createNetworkMonitor } from "@/engine/networkMonitor"; import { defaultWsUrl } from "@/engine/podUrl"; -import { RemoteBackend, SAMPLE_RATE, SLICE_FLAG_DELTA } from "@demon/client"; +import { PREEMPTED_CLOSE_CODE, RemoteBackend, SAMPLE_RATE, SLICE_FLAG_DELTA } from "@demon/client"; import { getApiKey, getClientId } from "@/engine/rtmgConfig"; import { WsReconnector } from "@demon/client"; import { @@ -311,6 +311,19 @@ function wireRemoteListeners( // session start is tearing this one down. Don't reconnect; just // get out of the way. if (remote.closedByUser) return; + // Server preempted this session because a newer connection took + // the pod (one-session-per-pod). FINAL: reconnecting would just + // preempt the newer session back and ping-pong the pod through + // full session rebuilds. + if (detail?.code === PREEMPTED_CLOSE_CODE) { + useSessionStore + .getState() + .setStatus( + "closed", + "Session ended: another connection took over this pod.", + ); + return; + } onUnexpectedClose(detail ?? { code: undefined, reason: undefined }); }); diff --git a/demos/realtime_motion_graph_web/web/sdk/protocol.ts b/demos/realtime_motion_graph_web/web/sdk/protocol.ts index ea2eb25a..e8c5c5c9 100644 --- a/demos/realtime_motion_graph_web/web/sdk/protocol.ts +++ b/demos/realtime_motion_graph_web/web/sdk/protocol.ts @@ -15,6 +15,7 @@ import * as fzstd from "fzstd"; import { + PREEMPTED_CLOSE_CODE, SAMPLE_RATE, SLICE_FLAG_DELTA, SLICE_HDR_SIZE, @@ -742,7 +743,9 @@ export class RemoteBackend extends EventTarget { // most often, both recoverable by reloading. if (!this.ready) { let msg: string; - if (e.code === 1011) { + if (e.code === PREEMPTED_CLOSE_CODE) { + msg = "Another connection took over this session."; + } else if (e.code === 1011) { msg = "Session failed while starting — refresh the page to retry."; } else if (e.code === 1006) { msg = "Connection lost — refresh to retry."; diff --git a/demos/realtime_motion_graph_web/web/sdk/types/protocol.ts b/demos/realtime_motion_graph_web/web/sdk/types/protocol.ts index 38b5fcee..f313c365 100644 --- a/demos/realtime_motion_graph_web/web/sdk/types/protocol.ts +++ b/demos/realtime_motion_graph_web/web/sdk/types/protocol.ts @@ -139,3 +139,12 @@ export const CROSSFADE_SECONDS = 0.025; export const SLICE_HDR_SIZE = 23; // 1+4+4+2+4+4+4 export const SLICE_FLAG_RAW = 0; export const SLICE_FLAG_DELTA = 1; + +/** + * Server-initiated WebSocket close meaning "a newer connection took + * over this pod" (one-session-per-pod policy; see ws_adapter.py + * PREEMPTED_CLOSE_CODE). Treated as FINAL by the client: reconnecting + * would preempt the newer session back and ping-pong the pod through + * full session rebuilds. + */ +export const PREEMPTED_CLOSE_CODE = 4001; diff --git a/demos/realtime_motion_graph_web/ws_adapter.py b/demos/realtime_motion_graph_web/ws_adapter.py index f40f7169..0f2792ba 100644 --- a/demos/realtime_motion_graph_web/ws_adapter.py +++ b/demos/realtime_motion_graph_web/ws_adapter.py @@ -107,9 +107,15 @@ # mutable GPU object in the system (every StreamingSession otherwise owns # its own Session). The tradeoff is deliberate: encoding uploads needs the # VAE encoder, which the streaming TRT path doesn't expose, so we keep a -# second resident eager copy of the weights rather than rebuild one per -# upload. Two costs follow from the sharing: -# - VRAM: the first upload permanently adds a second model copy. +# second eager copy of the weights rather than rebuild one per upload. +# Two costs follow from the sharing: +# - VRAM: the eager weights occupy GPU memory WHILE AN UPLOAD IS IN +# FLIGHT. Between uploads they are parked in system RAM +# (ModelContext.offload_eager_to_cpu in _handle_upload_track's +# finally); ModelContext._load_model_context lazily restores exactly +# the modules the next upload touches. Without the parking, the +# first upload would permanently pin ~6 GB next to the live +# streaming session. # - Concurrency: prepare_source / stem extraction are NOT thread-safe on # a shared Session, so _UPLOAD_INFER_LOCK serializes all GPU work on it. _UPLOAD_ENCODERS: dict[str, Session] = {} @@ -117,6 +123,111 @@ _UPLOAD_INFER_LOCK = threading.Lock() +# --------------------------------------------------------------------------- +# Single-active-session policy +# --------------------------------------------------------------------------- +# +# The rtmg backend is one-session-per-pod: the TRT VAE cache, the LoRA +# library, and the GPU budget are all sized for exactly one streaming +# session. Two concurrent ``StreamingSession.create`` calls stack two +# full model stacks (OOM on a 24 GB card), and either session's +# teardown evicts shared TRT VAE cache entries out from under the +# other — the failure cascade is: dual create → one OOMs → its cleanup +# evicts the shared engines → the HEALTHY session crashes on its next +# decode. Doubled connections happen routinely (page reload while the +# old socket is still draining, dev StrictMode double-mount, a stale +# tab auto-reconnecting), so the policy is enforced here: +# +# - ``_SESSION_LIFECYCLE_LOCK`` serializes preempt+create: at most one +# session is ever being constructed, and construction never overlaps +# another session's teardown. +# - A new main-session connection PREEMPTS the active session: its +# runner is stopped, its WebSocket is closed with +# ``PREEMPTED_CLOSE_CODE`` (the client treats that close as final — +# no reconnect war between two tabs), and the new connection WAITS +# for ``StreamingSession.closed`` so the old stack's VRAM is +# actually free before the new stack loads. +# +# The ``upload_track`` side-channel WS never touches this policy. + +_SESSION_LIFECYCLE_LOCK = threading.Lock() +_ACTIVE_SLOT_LOCK = threading.Lock() +_ACTIVE_SESSION: list = [None] # [_ActiveSession | None] + +# 4000-range application close code: "this session was replaced by a +# newer connection". The web client (web/sdk/protocol.ts + +# web/hooks/useStartSession.ts) recognizes it and does NOT enter the +# reconnect loop — reconnecting would just preempt the newer session +# back and ping-pong the pod through full session rebuilds. +PREEMPTED_CLOSE_CODE = 4001 + +# How long a preempting connection waits for the old session's teardown +# to release VRAM. Generous: the old runner may be mid stem-extraction +# (it only observes running=False between pipeline iterations). +_PREEMPT_TEARDOWN_TIMEOUT_S = 45.0 + + +class _ActiveSession: + __slots__ = ("session_id", "streaming", "ws") + + def __init__(self, session_id: str, streaming, ws): + self.session_id = session_id + self.streaming = streaming + self.ws = ws + + +def _preempt_active_session(new_session_id: str) -> None: + """Stop and drain the currently-active session, if any. + + Caller must hold ``_SESSION_LIFECYCLE_LOCK``. Returns once the old + session has released its GPU state (or after a bounded wait with a + warning — create proceeds either way; the OOM-retry paths downstream + are the backstop).""" + with _ACTIVE_SLOT_LOCK: + prev = _ACTIVE_SESSION[0] + if prev is None: + return + logger.info( + "session_preempt prev={} new={} reason=single_session_policy", + prev.session_id, new_session_id, + ) + # Stop the runner; it observes this between pipeline iterations and + # exits run() into close(). + prev.streaming.state.running = False + # Close the old socket so its handler unblocks from any recv/send + # and the client sees a deliberate, final close (not a 1006 blip). + try: + prev.ws.close(PREEMPTED_CLOSE_CODE, "preempted by a newer session") + except Exception: + pass + if not prev.streaming.closed.wait(timeout=_PREEMPT_TEARDOWN_TIMEOUT_S): + logger.warning( + "session_preempt_teardown_timeout prev={} waited_s={}", + prev.session_id, _PREEMPT_TEARDOWN_TIMEOUT_S, + ) + else: + logger.info("session_preempt_complete prev={}", prev.session_id) + with _ACTIVE_SLOT_LOCK: + if _ACTIVE_SESSION[0] is prev: + _ACTIVE_SESSION[0] = None + + +def _log_session_vram(stage: str) -> None: + from acestep.gpu_config import get_vram_telemetry + + telemetry = get_vram_telemetry() + if telemetry is not None: + logger.info( + "session_vram stage={} free_gb={:.2f} available_gb={:.2f} " + "allocated_gb={:.2f} reserved_gb={:.2f}", + stage, + telemetry["free_gb"], + telemetry["available_gb"], + telemetry["allocated_gb"], + telemetry["reserved_gb"], + ) + + def _upload_encoder_session(checkpoint: str) -> Session: with _UPLOAD_ENCODERS_LOCK: session = _UPLOAD_ENCODERS.get(checkpoint) @@ -222,20 +333,42 @@ def _handle_upload_track(ws, header: dict, *, checkpoint: str) -> None: # from multiple connections would otherwise drive prepare_source / # stem extraction on one Session at once and corrupt its state. with _UPLOAD_INFER_LOCK: - sources = { - "full": encoder.prepare_source( - Audio(waveform=waveform, sample_rate=SAMPLE_RATE), - ), - } - stems = extract_upload_stems( - waveform=waveform, - device=encoder.handler.device, - backend_sample_rate=SAMPLE_RATE, - ) - for mode in ("vocals", "instruments"): - sources[mode] = encoder.prepare_source( - Audio(waveform=stems[mode], sample_rate=SAMPLE_RATE), + try: + sources = { + "full": encoder.prepare_source( + Audio(waveform=waveform, sample_rate=SAMPLE_RATE), + ), + } + stems = extract_upload_stems( + waveform=waveform, + device=encoder.handler.device, + backend_sample_rate=SAMPLE_RATE, + # Park the shared eager encoder (a full second copy + # of the ACE-Step weights) while the RoFormer runs. + # Any live StreamingSession owns its own + # ModelContext and keeps streaming untouched. + model_context=encoder.handler, ) + for mode in ("vocals", "instruments"): + sources[mode] = encoder.prepare_source( + Audio(waveform=stems[mode], sample_rate=SAMPLE_RATE), + ) + finally: + # The encoder's eager weights are only needed while an + # upload is in flight. Between uploads they would pin + # ~6 GB of VRAM next to the live streaming session, so + # park them in system RAM; _load_model_context restores + # exactly the modules the next upload touches, lazily. + try: + parked = encoder.handler.offload_eager_to_cpu() + if parked: + logger.info( + "upload_encoder_offloaded modules={}", parked, + ) + except Exception as exc: + logger.warning( + "upload_encoder_offload_failed error={}", exc, + ) packet = persist_user_upload_packet( name, waveform=waveform, @@ -410,15 +543,24 @@ def _ms(stage: str) -> None: _ms("resolve_source_start") try: - streaming = StreamingSession.create( - audio=audio_in, - config=cfg, - checkpoint=checkpoint, - decoder_backend=decoder_backend, - vae_backend=vae_backend, - offload_text_encoder=offload_text_encoder, - session_id=session_id, - ) + # Single-active-session policy: serialize construction and + # preempt whatever session currently owns the GPU. See the + # policy comment block at module top. + with _SESSION_LIFECYCLE_LOCK: + _preempt_active_session(session_id) + _log_session_vram("create_start") + streaming = StreamingSession.create( + audio=audio_in, + config=cfg, + checkpoint=checkpoint, + decoder_backend=decoder_backend, + vae_backend=vae_backend, + offload_text_encoder=offload_text_encoder, + session_id=session_id, + ) + with _ACTIVE_SLOT_LOCK: + _ACTIVE_SESSION[0] = _ActiveSession(session_id, streaming, ws) + _log_session_vram("create_done") except UnsupportedTrtCheckpointError as exc: try: ws.send(json.dumps({ @@ -466,6 +608,16 @@ def _ms(stage: str) -> None: streaming_entered_run = False session_registered = False + def _release_active_slot() -> None: + # Compare-and-swap: only clear the slot if it's still ours (a + # preempting connection may have already replaced it). + with _ACTIVE_SLOT_LOCK: + cur = _ACTIVE_SESSION[0] + if cur is not None and cur.streaming is streaming: + _ACTIVE_SESSION[0] = None + + ctx_stack.callback(_release_active_slot) + def _close_streaming_if_init_fails() -> None: if not streaming_entered_run: if session_registered: diff --git a/docs/VOCALSTEM.md b/docs/VOCALSTEM.md index aa0cf080..76d52b4e 100644 --- a/docs/VOCALSTEM.md +++ b/docs/VOCALSTEM.md @@ -59,6 +59,76 @@ MELBAND_ROFORMER_MODEL_PATH The instrumental bed is the RoFormer instrumental output, not ACE-guided spectral suppression. +## VRAM Management + +The RoFormer always loads on top of a resident ACE-Step model stack — the +streaming session at create/swap time, or the shared eager upload-encoder +session in the demo's `upload_track` path. On VRAM-constrained pods that +stack is a memory-pressure spike, so `extract_upload_stems()` accepts the +resident `ModelContext` as `model_context` and runs a park/restore cycle +around separation: + +1. **Decide** (`should_park_for_melband()`): park only when the + realistically-claimable VRAM (driver-free + torch's cached slack) is + below the reserve the separator needs. + `DEMON_MELBAND_VRAM_RESERVE_GB` overrides the default (6.0); `0` + disables parking, a large value forces it. +2. **Park** (`ModelContext.vram_parked()`): the eager modules (DiT, VAE, + text encoder) and the silence latent move to CPU and the freed pages + return to CUDA. TRT engines are untouched — their device memory + belongs to TensorRT execution contexts and cannot be offloaded. +3. **Separate**: the RoFormer loads into the vacated VRAM and runs. +4. **Release**: the RoFormer is dropped and its cache emptied BEFORE the + restore, so ACE-Step returns into the space the separator vacated. +5. **Restore**: parked modules move back to the device with their + canonical dtypes. + +Concurrency: `vram_parked()` holds the ModelContext's placement lock for +the whole cycle, and every eager-module consumer routes through +`ModelContext._load_model_context()`, which takes the same lock. A +concurrent operation (prompt re-encode, timbre/structure set) issued +while the models are parked therefore blocks until the restore instead +of running GPU inputs against CPU weights. The session create and swap +paths run the extraction on the runner thread (or before the runner +exists), so streaming ticks never overlap a park; in the upload path the +parked encoder is a separate `ModelContext` from any live session, which +keeps streaming untouched. + +Beyond the per-separation cycle, two standing policies keep steady-state +VRAM flat: + +- **The shared upload encoder lives on GPU only while an upload is in + flight.** `_handle_upload_track` calls + `ModelContext.offload_eager_to_cpu()` when each upload finishes + (persistent park — nothing auto-restores), and + `_load_model_context()` lazily restores exactly the modules the next + upload touches (`model` for semantic extract; `vae` only when no TRT + VAE engine is cached). Without this the first upload would + permanently pin ~6 GB next to the live streaming session. +- **Shape-aware TRT VAE engine selection** (`acestep/nodes/vae_nodes.py`): + the process-wide TRT VAE cache can hand the upload encoder an engine + belonging to the live streaming session, whose optimization profile + may not cover the upload's length (a 120 s upload vs the session's + 60 s `vae_encode` engine). `_trt_vae_profile_fits()` checks the input + shape against the cached engine's profile first; on a mismatch, a + handler that carries an eager VAE falls back to it instead of letting + TRT reject the shape and fail the upload. +- **One streaming session per pod, enforced** (`ws_adapter`): + `StreamingSession.create` calls are serialized, and a new main-session + connection preempts the active session — stops its runner, closes its + socket with close code 4001 (`PREEMPTED_CLOSE_CODE`, which the web + client treats as final rather than reconnecting), and waits on + `StreamingSession.closed` for its GPU teardown before building the new + stack. This prevents the dual-create OOM and the cascade where a dying + session's cleanup evicts shared TRT VAE cache entries out from under a + live one. + +Every phase emits a structured `stems_vram` log line +(free/available/allocated/reserved GiB), plus `vram_parked` / +`vram_unparked` from the context. Verify against real models with +`scripts/verify_melband_vram.py`; unit coverage lives in +`tests/unit/test_melband_vram_management.py`. + ## Returned Stem Assets `extract_upload_stems()` returns: diff --git a/scripts/calibration/precompute_fixture_sidecars.py b/scripts/calibration/precompute_fixture_sidecars.py index 4c120d9a..7ae8bd5e 100644 --- a/scripts/calibration/precompute_fixture_sidecars.py +++ b/scripts/calibration/precompute_fixture_sidecars.py @@ -201,6 +201,7 @@ def precompute_one( waveform=waveform, device=session.handler.device, backend_sample_rate=SAMPLE_RATE, + model_context=session.handler, ) write_stem_wavs(out_dir, name, stems=stems, sample_rate=SAMPLE_RATE) for mode in ("vocals", "instruments"):