Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions onecomp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def main():
default="auto",
help='save directory (default: auto-generated, "none" to skip)',
)
parser.add_argument(
"--check-env",
action="store_true",
help=(
"Print an environment and memory report before quantization. "
"Exits with code 1 if OOM risk is 'danger'."
),
)
parser.add_argument(
"--version",
action="version",
Expand All @@ -76,6 +84,23 @@ def main():
# Lazy import to keep --help fast
from .runner import Runner # pylint: disable=import-outside-toplevel

if args.check_env:
import sys # pylint: disable=import-outside-toplevel
from .utils.vram_estimator import ( # pylint: disable=import-outside-toplevel
check_environment,
print_env_report,
)

env_result = check_environment(
args.model_id,
total_vram_gb=args.total_vram_gb,
group_size=args.groupsize,
save_dir=save_dir if isinstance(save_dir, str) and save_dir != "auto" else None,
)
print_env_report(env_result, total_vram_gb_override=args.total_vram_gb)
if env_result.risk == "danger":
sys.exit(1)

Runner.auto_run(
model_id=args.model_id,
wbits=args.wbits,
Expand Down
19 changes: 19 additions & 0 deletions onecomp/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def auto_run(
evaluate: bool = True,
eval_original_model: bool = False,
save_dir: str = "auto",
check_env: bool = False,
**kwargs,
):
"""One-liner quantization with sensible defaults.
Expand Down Expand Up @@ -487,6 +488,24 @@ def auto_run(
setup_logger()
logger = getLogger(__name__)

if check_env:
from .utils.vram_estimator import ( # pylint: disable=import-outside-toplevel
check_environment,
print_env_report,
)

env_result = check_environment(
model_id,
total_vram_gb=total_vram_gb,
group_size=groupsize,
save_dir=save_dir if isinstance(save_dir, str) and save_dir != "auto" else None,
)
print_env_report(env_result, total_vram_gb_override=total_vram_gb)
if env_result.risk == "danger":
raise RuntimeError(
f"Environment check failed (OOM risk=danger): {env_result.risk_detail}"
)

candidate_bits = (2, 3, 4, 8)

if wbits is None:
Expand Down
5 changes: 5 additions & 0 deletions onecomp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
effective_bits_for_quantizer,
weight_memory_gb,
VRAMBitwidthEstimation,
EnvironmentSnapshot,
ModelMemoryProfile,
EnvCheckResult,
check_environment,
print_env_report,
)

from .model_inputs import add_model_specific_inputs
Expand Down
274 changes: 274 additions & 0 deletions onecomp/utils/vram_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,42 @@ class VRAMBitwidthEstimation:
meta_bits_per_param: float


@dataclass
class EnvironmentSnapshot:
"""Physical hardware readings at check-env time."""

gpu_count: int
gpu_name: str | None
gpu_total_vram_gb: float | None
gpu_free_vram_gb: float | None
ram_total_gb: float | None
ram_available_gb: float | None
disk_available_gb: float | None
disk_path: str


@dataclass
class ModelMemoryProfile:
"""Derived memory footprint for the target model."""

total_params: int
fp16_gb: float
quantized_gb: dict
calibration_overhead_gb: float


@dataclass
class EnvCheckResult:
"""Composite result returned by check_environment()."""

model_id: str
env: EnvironmentSnapshot
model: ModelMemoryProfile
estimation: VRAMBitwidthEstimation | None
risk: str
risk_detail: str


def estimate_target_bitwidth(
model: torch.nn.Module,
vram_ratio: float = 0.70,
Expand Down Expand Up @@ -322,3 +358,241 @@ def estimate_wbits_from_vram(
wbits=wbits,
logger=logger,
)


def check_environment(
model_id: str,
*,
total_vram_gb: float | None = None,
group_size: int = 128,
save_dir: str | None = None,
vram_ratio: float = 0.80,
calibration_overhead_ratio: float = 0.15,
) -> EnvCheckResult:
"""Collect hardware info and estimate OOM risk before quantization.

Loads the model architecture on a ``meta`` device (no GPU/CPU memory)
to count parameters, then compares available VRAM against estimated
memory requirements at 2/4/8-bit quantization.

Args:
model_id: Hugging Face model ID or local path.
total_vram_gb: Override GPU VRAM in GB for estimation math only.
Physical GPU readings are always from the real device.
group_size: GPTQ group size for metadata calculation.
save_dir: Path used for disk-space check. Defaults to cwd.
vram_ratio: Fraction of VRAM allocated for the estimation budget.
calibration_overhead_ratio: Calibration activation buffer as a
fraction of the FP16 model footprint (default 15 %).

Returns:
:class:`EnvCheckResult` with hardware snapshot, memory profile,
VRAM estimation, and risk level (``"safe"``, ``"warning"``,
``"danger"``, or ``"unknown"``).
"""
import os
import pathlib
import shutil

from transformers import AutoConfig, AutoModelForCausalLM

# --- GPU snapshot --------------------------------------------------------
gpu_count = torch.cuda.device_count()
if gpu_count > 0:
dev = torch.cuda.current_device()
props = torch.cuda.get_device_properties(dev)
gpu_name = props.name
gpu_total_vram_gb = props.total_memory / _BYTES_PER_GB
try:
free_bytes, _ = torch.cuda.mem_get_info(dev)
gpu_free_vram_gb = free_bytes / _BYTES_PER_GB
except Exception:
gpu_free_vram_gb = None
else:
gpu_name = None
gpu_total_vram_gb = None
gpu_free_vram_gb = None

# --- CPU RAM (psutil optional) -------------------------------------------
try:
import psutil

vm = psutil.virtual_memory()
ram_total_gb = vm.total / _BYTES_PER_GB
ram_available_gb = vm.available / _BYTES_PER_GB
except ImportError:
ram_total_gb = None
ram_available_gb = None

# --- Disk space (stdlib) -------------------------------------------------
check_path = save_dir if save_dir else os.getcwd()
p = pathlib.Path(check_path)
while not p.exists():
p = p.parent
disk_available_gb = shutil.disk_usage(p).free / _BYTES_PER_GB

# --- Model memory profile ------------------------------------------------
config = AutoConfig.from_pretrained(model_id)
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)

total_params = sum(p.numel() for p in model.parameters())
fp16_gb = (total_params * 2) / _BYTES_PER_GB
quantized_gb = {b: weight_memory_gb(total_params, b, group_size) for b in (2, 4, 8)}
calibration_overhead_gb = fp16_gb * calibration_overhead_ratio

# --- VRAM bitwidth estimation (reuse existing) ---------------------------
try:
estimation = estimate_target_bitwidth(
model,
vram_ratio=vram_ratio,
total_vram_gb=total_vram_gb,
group_size=group_size,
)
except (RuntimeError, ValueError):
estimation = None

# --- OOM risk assessment -------------------------------------------------
# Use free VRAM (runtime reality) when available; fall back to override.
effective_vram = gpu_free_vram_gb if gpu_free_vram_gb is not None else total_vram_gb

if effective_vram is None:
risk = "unknown"
risk_detail = "No GPU detected and no --total-vram-gb provided."
else:
need_4bit = quantized_gb[4] + calibration_overhead_gb
if effective_vram >= fp16_gb * 1.2:
risk = "safe"
risk_detail = (
f"Free VRAM ({effective_vram:.1f} GB) comfortably fits "
f"even FP16 weights ({fp16_gb:.1f} GB × 1.2)."
)
elif effective_vram >= need_4bit:
risk = "warning"
risk_detail = (
f"Free VRAM ({effective_vram:.1f} GB) fits 4-bit quantized "
f"weights but is tight (calibration overhead included)."
)
else:
risk = "danger"
risk_detail = (
f"Free VRAM ({effective_vram:.1f} GB) is insufficient for "
f"4-bit + calibration ({need_4bit:.1f} GB needed)."
)

return EnvCheckResult(
model_id=model_id,
env=EnvironmentSnapshot(
gpu_count=gpu_count,
gpu_name=gpu_name,
gpu_total_vram_gb=gpu_total_vram_gb,
gpu_free_vram_gb=gpu_free_vram_gb,
ram_total_gb=ram_total_gb,
ram_available_gb=ram_available_gb,
disk_available_gb=disk_available_gb,
disk_path=str(p),
),
model=ModelMemoryProfile(
total_params=total_params,
fp16_gb=fp16_gb,
quantized_gb=quantized_gb,
calibration_overhead_gb=calibration_overhead_gb,
),
estimation=estimation,
risk=risk,
risk_detail=risk_detail,
)


def print_env_report(result: EnvCheckResult, *, total_vram_gb_override: float | None = None) -> None:
"""Print a human-readable environment and OOM risk report to stdout.

Args:
result: The :class:`EnvCheckResult` from :func:`check_environment`.
total_vram_gb_override: When not ``None``, annotates the VRAM budget
line with ``[--total-vram-gb override]``.
"""
_W = 60
_SEP = "=" * _W
_COL = 22

def _row(label: str, value: str) -> str:
return f" {label:<{_COL}}: {value}"

risk_labels = {
"safe": "SAFE",
"warning": "WARNING",
"danger": "DANGER !!",
"unknown": "UNKNOWN",
}
risk_label = risk_labels.get(result.risk, result.risk.upper())

e = result.env
m = result.model

print(_SEP)
print(" OneComp Environment Check")
print(_SEP)
print()

# Hardware
print("Hardware")
print(_row("GPU count", str(e.gpu_count)))
if e.gpu_name is not None:
print(_row("GPU name", e.gpu_name))
if e.gpu_total_vram_gb is not None:
label = "GPU VRAM (total)"
value = f"{e.gpu_total_vram_gb:.1f} GB"
if total_vram_gb_override is not None:
value += " [physical]"
print(_row(label, value))
if total_vram_gb_override is not None:
print(_row("VRAM budget used", f"{total_vram_gb_override:.1f} GB [--total-vram-gb override]"))
if e.gpu_free_vram_gb is not None:
print(_row("GPU VRAM (free)", f"{e.gpu_free_vram_gb:.1f} GB"))
if e.ram_total_gb is not None:
print(_row("CPU RAM (total)", f"{e.ram_total_gb:.1f} GB"))
print(_row("CPU RAM (avail)", f"{e.ram_available_gb:.1f} GB"))
else:
print(_row("CPU RAM", "n/a (install psutil for RAM info)"))
print(_row("Disk (avail)", f"{e.disk_available_gb:.1f} GB [{e.disk_path}]"))
print()

# Model
print(f"Model: {result.model_id}")
print(_row("Parameters", f"{m.total_params:,}"))
print(_row("FP16 footprint", f"{m.fp16_gb:.2f} GB"))
print()

# Memory estimates
gs = "(group_size varies)"
print(f"Memory Estimates")
for bits in (2, 4, 8):
print(_row(f"{bits}-bit quantized", f"{m.quantized_gb[bits]:.2f} GB"))
print(_row("Calib. overhead", f"{m.calibration_overhead_gb:.2f} GB (15% of FP16)"))
print(_row("4-bit + overhead", f"{m.quantized_gb[4] + m.calibration_overhead_gb:.2f} GB"))
print()

# OOM risk
print("OOM Risk Assessment")
print(_row("Risk level", risk_label))
detail_words = result.risk_detail.split()
detail_line = ""
detail_lines = []
for word in detail_words:
if len(detail_line) + len(word) + 1 > 34:
detail_lines.append(detail_line)
detail_line = word
else:
detail_line = (detail_line + " " + word).lstrip()
if detail_line:
detail_lines.append(detail_line)
for i, dl in enumerate(detail_lines):
if i == 0:
print(_row("Detail", dl))
else:
print(f" {'':<{_COL}} {dl}")
print()
if result.estimation is not None:
print(_row("Recommended wbits", f"{result.estimation.target_bitwidth:.2f} (VRAM-estimated)"))
print(_SEP)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ cu124 = ["torch", "torchvision"]
cu126 = ["torch", "torchvision"]
cu128 = ["torch", "torchvision"]
cu130 = ["torch", "torchvision"]
check-env = ["psutil>=5.9"]
dev = [
"black",
"hydra-core",
"pylint",
"psutil>=5.9",
"pytest",
]
visualize = [
Expand Down