Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 263 additions & 41 deletions acestep/engine/model_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import math
import os
import sys
import threading
import time
from contextlib import contextmanager
from typing import Optional, Tuple
Expand Down Expand Up @@ -79,6 +81,11 @@ def __init__(
self._compile_vae = compile_vae
self._offload_text_encoder = False
self._diffusion_engine = None
# Serializes model placement (vram_parked) against every
# eager-module consumer (_load_model_context). Reentrant so a
# consumer that nests _load_model_context calls doesn't deadlock
# on itself.
self._placement_lock = threading.RLock()

try:
self._load_models(
Expand Down Expand Up @@ -501,56 +508,271 @@ def _recursive_to_device(self, model, device, dtype=None) -> None:

@contextmanager
def _load_model_context(self, model_name: str):
"""Move *model_name* to GPU for the duration of the block."""
if (
model_name == "text_encoder"
and self._offload_text_encoder
):
pass # fall through to offload logic
elif not self.offload_to_cpu:
yield
return
"""Move *model_name* to GPU for the duration of the block.

Holds ``_placement_lock`` for the whole block: every eager-module
consumer routes through here, so a concurrent :meth:`vram_parked`
(which holds the same lock while the models sit on CPU) blocks
the consumer until the models are back on the device instead of
letting it run GPU inputs against CPU weights.

Self-healing for persistently-offloaded contexts (see
:meth:`offload_eager_to_cpu`): when this context is in resident
mode (``offload_to_cpu=False``) but the requested module was
parked on CPU, it is restored to the device here before the
consumer runs — so a parked context never serves CPU weights to
a GPU consumer, no matter which code path touches it first.
"""
with self._placement_lock:
if (
model_name == "text_encoder"
and self._offload_text_encoder
):
pass # fall through to offload logic
elif not self.offload_to_cpu:
self._restore_module_if_parked(model_name)
yield
return

if model_name == "model" and not self.offload_dit_to_cpu:
model = getattr(self, model_name, None)
if model is not None:
try:
param = next(model.parameters())
if param.device.type == "cpu":
logger.info(f"Moving {model_name} to {self.device} (persistent)")
self._recursive_to_device(model, self.device, self.dtype)
if self.silence_latent is not None:
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
except StopIteration:
pass
yield
return

if model_name == "model" and not self.offload_dit_to_cpu:
model = getattr(self, model_name, None)
if model is not None:
try:
param = next(model.parameters())
if param.device.type == "cpu":
logger.info(f"Moving {model_name} to {self.device} (persistent)")
self._recursive_to_device(model, self.device, self.dtype)
if self.silence_latent is not None:
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
except StopIteration:
pass
yield
return
if model is None:
yield
return

model = getattr(self, model_name, None)
if model is None:
yield
return
logger.info(f"Loading {model_name} to {self.device}")
t0 = time.time()
if model_name == "vae":
self._recursive_to_device(model, self.device, self._get_vae_dtype())
else:
self._recursive_to_device(model, self.device, self.dtype)

logger.info(f"Loading {model_name} to {self.device}")
t0 = time.time()
if model_name == "vae":
self._recursive_to_device(model, self.device, self._get_vae_dtype())
else:
self._recursive_to_device(model, self.device, self.dtype)
if model_name == "model" and self.silence_latent is not None:
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)

if model_name == "model" and self.silence_latent is not None:
self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
logger.info(f"Loaded {model_name} in {time.time() - t0:.3f}s")

logger.info(f"Loaded {model_name} in {time.time() - t0:.3f}s")
try:
yield
finally:
logger.info(f"Offloading {model_name} to CPU")
t0 = time.time()
self._recursive_to_device(model, "cpu")
torch.cuda.empty_cache()
logger.info(f"Offloaded {model_name} in {time.time() - t0:.3f}s")

# Eager torch modules vram_parked() may move between devices. TRT
# engines are deliberately NOT here: their device memory is owned by
# TensorRT execution contexts and cannot be offloaded without
# destroying (and expensively rebuilding) them.
_PARKABLE_MODULES = ("model", "vae", "text_encoder")

def _accel_device(self) -> Optional[torch.device]:
"""The context's device when it's an offloadable accelerator."""
device = torch.device(self.device)
return device if device.type in ("cuda", "xpu") else None

def _module_on_device(self, module, device: torch.device) -> bool:
try:
yield
finally:
logger.info(f"Offloading {model_name} to CPU")
t0 = time.time()
self._recursive_to_device(model, "cpu")
return any(
p.device.type == device.type for p in module.parameters()
)
except Exception:
return False

def _offload_eager_locked(self) -> list:
"""Move every GPU-resident parkable module + the silence latent
to CPU and hand the freed pages back to CUDA. Caller must hold
``_placement_lock``. Returns the moved attribute names."""
parked: list = []
device = self._accel_device()
if device is None:
return parked
for name in self._PARKABLE_MODULES:
module = getattr(self, name, None)
if module is None or not self._module_on_device(module, device):
continue
self._recursive_to_device(module, "cpu")
parked.append(name)
if (
self.silence_latent is not None
and self.silence_latent.device.type == device.type
):
self.silence_latent = self.silence_latent.cpu()
parked.append("silence_latent")
if parked and device.type == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"Offloaded {model_name} in {time.time() - t0:.3f}s")
return parked

def _restore_eager_locked(self, names) -> list:
"""Move the named parked attributes back to ``self.device`` with
their canonical dtypes. Caller must hold ``_placement_lock``.
Attempts every name even if one fails; returns the exceptions."""
errors: list = []
for name in names:
try:
if name == "silence_latent":
self._ensure_silence_latent_on_device()
continue
module = getattr(self, name, None)
if module is None:
continue
dtype = (
self._get_vae_dtype() if name == "vae" else self.dtype
)
self._recursive_to_device(module, self.device, dtype)
except Exception as exc:
logger.exception(
"vram_restore_failed module={} error={}", name, exc,
)
errors.append(exc)
return errors

def _restore_module_if_parked(self, model_name: str) -> None:
"""Restore one module (plus the silence latent for ``model``)
when a persistent offload left it on CPU. Caller must hold
``_placement_lock``. No-op when already resident."""
device = self._accel_device()
if device is None:
return
module = getattr(self, model_name, None)
if module is None or self._module_on_device(module, device):
return
logger.info(
"vram_restore_on_demand module={} device={}",
model_name, self.device,
)
names = [model_name]
if (
model_name == "model"
and self.silence_latent is not None
and self.silence_latent.device.type == "cpu"
):
names.append("silence_latent")
errors = self._restore_eager_locked(names)
if errors:
raise errors[0]

def offload_eager_to_cpu(self) -> list:
"""Persistently evict the eager modules from VRAM.

Unlike :meth:`vram_parked`, nothing restores the modules when
this returns — they stay in system RAM until a consumer touches
them through :meth:`_load_model_context` (per-module lazy
restore) or :meth:`ensure_eager_on_device` brings everything
back at once.

Used by the demo's shared upload-encoder session: the eager
weights are only needed while an upload is in flight, so parking
them between uploads keeps ~6 GB of VRAM free for the live
streaming session. Returns the parked attribute names.
"""
with self._placement_lock:
parked = self._offload_eager_locked()
if parked:
logger.info(
"vram_offloaded_persistent modules={} device={}",
parked, self.device,
)
return parked

def ensure_eager_on_device(self) -> list:
"""Inverse of :meth:`offload_eager_to_cpu`: restore every parked
eager module to the device at once. Returns the restored names.
No-op in ``offload_to_cpu`` mode (per-op placement owns moves
there) and on CPU contexts."""
if self.offload_to_cpu:
return []
with self._placement_lock:
device = self._accel_device()
if device is None:
return []
names: list = []
for name in self._PARKABLE_MODULES:
module = getattr(self, name, None)
if module is None or self._module_on_device(module, device):
continue
names.append(name)
if (
self.silence_latent is not None
and self.silence_latent.device.type == "cpu"
):
names.append("silence_latent")
if not names:
return []
errors = self._restore_eager_locked(names)
logger.info(
"vram_restored_bulk modules={} errors={}", names, len(errors),
)
if errors:
raise errors[0]
return names

@contextmanager
def vram_parked(self):
"""Temporarily evict the eager ACE-Step modules from VRAM.

Moves every GPU-resident parkable module (DiT, VAE, text
encoder) plus the silence latent to CPU, returns the freed pages
to CUDA via ``empty_cache()``, and yields the list of parked
attribute names. On exit — success or exception — everything
that was parked is moved back to ``self.device`` with its
canonical dtype.

Used by ``acestep.streaming.stems.extract_upload_stems`` to make
room for the Mel-Band RoFormer separator on VRAM-constrained
pods: park ACE-Step → run the separator → release it → restore
ACE-Step.

Holds ``_placement_lock`` for the entire block (including the
caller's body), so concurrent conditioning / VAE / semantic ops
— which all enter :meth:`_load_model_context` — wait for the
restore instead of crashing on CPU-placed weights. The lock is
reentrant, but callers must not run ACE-Step inference from
inside the block: the weights are on CPU.

``model.to()`` moves parameters in place, so long-lived
references into these modules (DiffusionEngine, node handles)
stay valid across a park/restore cycle. No-op (yields ``[]``)
when the context lives on CPU or nothing is GPU-resident —
nesting is therefore safe and idempotent.
"""
with self._placement_lock:
t0 = time.time()
parked = self._offload_eager_locked()
if parked:
logger.info(
"vram_parked modules={} duration_s={:.2f}",
parked, time.time() - t0,
)
try:
yield list(parked)
finally:
if parked:
t0 = time.time()
restore_errors = self._restore_eager_locked(parked)
logger.info(
"vram_unparked modules={} duration_s={:.2f} errors={}",
parked, time.time() - t0, len(restore_errors),
)
# Surface a restore failure, but never mask an
# exception already unwinding from the body.
if restore_errors and sys.exc_info()[0] is None:
raise restore_errors[0]

# ------------------------------------------------------------------
# Text embedding inference
Expand Down
39 changes: 39 additions & 0 deletions acestep/gpu_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,45 @@ def get_gpu_memory_gb() -> float:
return 0


def get_vram_telemetry(device=None) -> Optional[Dict[str, float]]:
"""Snapshot of CUDA VRAM occupancy in GiB for ``device``.

Returns ``None`` when CUDA isn't available or ``device`` isn't a
CUDA device — callers treat that as "no VRAM to manage".

Keys:
free_gb: driver-reported free memory (excludes torch's cache).
total_gb: device capacity.
allocated_gb: torch tensors currently live.
reserved_gb: torch caching-allocator pages held from the driver.
available_gb: what a new allocation can realistically claim —
driver-free plus pages torch has cached but not handed out.
"""
try:
import torch

if not torch.cuda.is_available():
return None
dev = torch.device(device if device is not None else "cuda")
if dev.type != "cuda":
return None
index = dev.index if dev.index is not None else torch.cuda.current_device()
free_b, total_b = torch.cuda.mem_get_info(index)
allocated_b = torch.cuda.memory_allocated(index)
reserved_b = torch.cuda.memory_reserved(index)
gib = float(1024 ** 3)
return {
"free_gb": free_b / gib,
"total_gb": total_b / gib,
"allocated_gb": allocated_b / gib,
"reserved_gb": reserved_b / gib,
"available_gb": (free_b + max(0, reserved_b - allocated_b)) / gib,
}
except Exception as e:
logger.warning(f"Failed to read VRAM telemetry: {e}")
return None


def get_gpu_tier(gpu_memory_gb: float) -> str:
"""
Determine GPU tier based on available memory.
Expand Down
Loading