diff --git a/README.md b/README.md index 96341f9..dec0976 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,15 @@ ## Setup -First, download and set up the repo: +First, clone the repo **with submodules** — the SC kernels live in the +[`scmp_kernels`](https://github.com/CrucibleComputingGroup/scmp_kernels) +submodule: ```bash -git clone https://github.com/Juanerx/Q-DiT.git -cd Q-DiT +git clone --recurse-submodules https://github.com/CrucibleComputingGroup/scmp_diffusion.git +cd scmp_diffusion +# if you forgot --recurse-submodules: +git submodule update --init --recursive ``` Then create the environment and install required packages: ```bash @@ -15,8 +19,12 @@ conda create -n qdit python=3.8 conda activate qdit pip install -r requirements.txt pip install . +pip install -e ./scmp_kernels # SC matmul kernels (Triton, needs a CUDA GPU) ``` +> SC mode imports `scmp_kernels` (`sc_matmul`, `scmp_kernels.mp`); standard +> GPTQ/static quantization works without it. + ## Usage @@ -43,7 +51,7 @@ SC mode replaces floating-point matrix multiplications with stochastic computing python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 0.5 --qklayerwise 1.0 --avlayerwise 0.2 \ - --sc_prec 8 --sc_enable \ + --sc_prec 8 \ --image-size 256 --num-sampling-steps 100 --batch-size 16 \ --ckpt pretrained_models/DiT-XL-2-256x256.pt ``` @@ -59,7 +67,6 @@ python scripts/quant_sc_main.py \ | `--mlplayerwise` | float | 0.0 | Fraction of blocks to use SC for MLP fc1 and fc2 (0-1). | | `--inputprojlayerwise` | float | 0.0 | Fraction of blocks to use SC for QKV input projection (0-1). | | `--sc_prec` | int | 8 | SC precision in bits. Sets `stoc_len = 2^sc_prec` (e.g., 8 → stoc_len=256). | -| `--sc_enable` | flag | false | Use enable-signal SC multiplication (compact kernel) instead of standard XNOR/AND. | | `--sc_noise_model` | flag | false | Replace real SC kernels with a fast analytical noise surrogate. See [SC Noise Model](#sc-noise-model-fast-surrogate) below. | | `--sc_noise_local_correction` | float | 0.15 | Variance correction for per-row/per-batch scaled matmuls (MLP, AV, QK, input proj). Recommended: 0.10 - 0.15. | | `--sc_noise_global_correction` | float | 0.60 | Variance correction for per-tensor scaled matmuls (output proj). Recommended: 0.10 - 0.20. | @@ -71,6 +78,14 @@ python scripts/quant_sc_main.py \ | `--mp` | flag | false | Enable per-token-row mixed precision for QK/AV. Assigns different stoc_len levels to different token rows based on runtime importance. | | `--mp_levels` | str | `256,128,64,32` | Comma-separated stoc_len levels for MP, sorted descending. | | `--mp_fractions` | str | None | Comma-separated fractions of rows per MP level (must sum to 1). Default: equal fractions (1/N each). | +| `--adaptive_mp` | flag | false | Enable adaptive mixed precision with timestep-aware thresholds (rows ranked by a runtime metric; thresholds slide with diffusion progress). Tune with `--mp_alpha` (sensitivity to timestep progress) and `--mp_beta` (base threshold offset). Per-operator overrides exist (`--mp_alpha_qk`, `--mp_beta_av`, …). | +| `--adaptive_mp_table` | str | None | Path to a calibrated adaptive-MP threshold JSON table (precomputed thresholds instead of `--mp_alpha`/`--mp_beta`). | +| `--range_mp` | flag | false | Enable range-based mixed precision: assign `stoc_len` per group by its weight (max−min) range. Configure with `--range_mp_levels` and `--range_mp_threshold` (normalized 0–1; higher → more groups at lower precision; per-operator `--range_mp_threshold_qk` etc.). | +| `--sc_fixed_level_prec` | flag | false | Keep kernel `sc_prec` fixed to `--sc_prec` for all MP `stoc_len` levels (instead of deriving `sc_prec` per level). | + +> **MP modes are mutually exclusive.** Pick one of `--mp` (static row fractions), +> `--adaptive_mp` (metric- and timestep-aware), or `--range_mp` (weight-range based). +> All three only activate for blocks/timesteps where QK/AV/MLP SC is already enabled. #### Examples @@ -79,7 +94,7 @@ python scripts/quant_sc_main.py \ python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 0.5 --qklayerwise 1.0 \ - --sc_prec 8 --sc_enable \ + --sc_prec 8 \ --image-size 256 --num-sampling-steps 100 --batch-size 16 \ --ckpt pretrained_models/DiT-XL-2-256x256.pt ``` @@ -89,7 +104,7 @@ python scripts/quant_sc_main.py \ python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 1.0 --qklayerwise 0.8 --avlayerwise 0.8 --mlplayerwise 0.8 \ - --sc_prec 8 --sc_enable \ + --sc_prec 8 \ --image-size 256 --num-sampling-steps 100 --batch-size 16 \ --ckpt pretrained_models/DiT-XL-2-256x256.pt ``` @@ -99,7 +114,7 @@ python scripts/quant_sc_main.py \ python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 0.5 --qklayerwise 1.0 --avlayerwise 0.5 \ - --sc_prec 8 --sc_enable \ + --sc_prec 8 \ --save_sc_config my_config.json \ --image-size 256 --num-sampling-steps 100 --batch-size 16 \ --ckpt pretrained_models/DiT-XL-2-256x256.pt @@ -116,7 +131,7 @@ Rows/heads are bucketed into N discrete `stoc_len` levels using quantile boundar **Quick test (4 levels, equal fractions):** ```bash -python scripts/quant_sc_main.py --wbits 8 --abits 8 --w_sym --a_sym --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 --sc_prec 8 --sc_enable --mp --mp_levels 256,128,64,32 +python scripts/quant_sc_main.py --wbits 8 --abits 8 --w_sym --a_sym --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 --sc_prec 8 --mp --mp_levels 256,128,64,32 ``` **Custom fractions (aggressive — only 10% at full precision):** @@ -124,7 +139,7 @@ python scripts/quant_sc_main.py --wbits 8 --abits 8 --w_sym --a_sym --timewise 0 python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 \ - --sc_prec 8 --sc_enable \ + --sc_prec 8 \ --mp --mp_levels 256,128,64,32 --mp_fractions 0.1,0.2,0.3,0.4 ``` @@ -133,7 +148,7 @@ python scripts/quant_sc_main.py \ python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 \ - --sc_prec 8 --sc_enable \ + --sc_prec 8 \ --mp --mp_levels 256,64 --mp_fractions 0.3,0.7 ``` @@ -172,7 +187,7 @@ python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 1 --qklayerwise 1.0 --avlayerwise 1.0 \ --projlayerwise 1.0 --mlplayerwise 1.0 --inputprojlayerwise 1.0 \ - --sc_prec 8 --sc_enable --sc_noise_model \ + --sc_prec 8 --sc_noise_model \ --image-size 256 --num-sampling-steps 100 --batch-size 16 ``` @@ -182,7 +197,7 @@ python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 1 --qklayerwise 1.0 --avlayerwise 1.0 \ --projlayerwise 1.0 --mlplayerwise 1.0 --inputprojlayerwise 1.0 \ - --sc_prec 8 --sc_enable --sc_noise_model \ + --sc_prec 8 --sc_noise_model \ --adaptive_mp --mp_levels 256,128,64,32,16 \ --mp_alpha 0.3 --mp_beta 0.1 \ --image-size 256 --num-sampling-steps 100 --batch-size 16 @@ -210,7 +225,7 @@ python scripts/quant_sc_main.py \ #### Behavior -- Orthogonal to `--sc_enable` (both can be on simultaneously) +- Orthogonal to the real SC kernels — the surrogate replaces them, leaving the rest of the quant pipeline unchanged - All SC dispatch paths are supported: uniform, adaptive MP, range MP, per-head mixed - The real SC path is completely untouched when `--sc_noise_model` is not set - Uses `torch.compile` for kernel fusion; first ~5 iterations are slow (JIT warmup), then steady state ~12 it/s at batch=8 on RTX PRO 6000 @@ -333,7 +348,7 @@ Non-power-of-2 values (e.g., 181 for int7.5) are supported. The quantization gri python scripts/quant_sc_main.py \ --wbits 8 --abits 8 --w_sym --a_sym \ --timewise 1.0 --qklayerwise 1.0 --avlayerwise 1.0 \ - --sc_prec 8 --sc_enable \ + --sc_prec 8 \ --sc_config path/to/config.json \ --image-size 256 --num-sampling-steps 100 --batch-size 16 \ --ckpt pretrained_models/DiT-XL-2-256x256.pt @@ -380,7 +395,7 @@ Results are saved to `../results/-qdit_sc_/`: - **Do NOT combine `--static` with `--a_sym`** — the static quantizer only supports asymmetric activation quantization. Use `--a_sym` without `--static`, or `--static` without `--a_sym`. - DiT-XL/2 has **28 blocks** and **16 attention heads** with head_dim=72. -- The `--sc_enable` flag selects the compact enable-signal kernel path, which is required for mixed-precision early termination. +- SC matmuls now always run through the compact **enable-signal** kernel path (via `scmp_kernels.sc_matmul`). The old `--sc_enable` toggle and the legacy XNOR/AND path have been removed — enable-signal is canonical and is what makes mixed-precision early termination possible. ## BibTeX