Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@

## Setup

First, download and set up the repo:
First, clone the repo **with submodules** — the SC kernels live in the
[`scmp_kernels`](https://github.com/CrucibleComputingGroup/scmp_kernels)
submodule:

```bash
git clone https://github.com/Juanerx/Q-DiT.git
cd Q-DiT
git clone --recurse-submodules https://github.com/CrucibleComputingGroup/scmp_diffusion.git
cd scmp_diffusion
# if you forgot --recurse-submodules:
git submodule update --init --recursive
```
Then create the environment and install required packages:
```bash
conda create -n qdit python=3.8
conda activate qdit
pip install -r requirements.txt
pip install .
pip install -e ./scmp_kernels # SC matmul kernels (Triton, needs a CUDA GPU)
```

> SC mode imports `scmp_kernels` (`sc_matmul`, `scmp_kernels.mp`); standard
> GPTQ/static quantization works without it.


## Usage

Expand All @@ -43,7 +51,7 @@ SC mode replaces floating-point matrix multiplications with stochastic computing
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 1.0 --avlayerwise 0.2 \
--sc_prec 8 --sc_enable \
--sc_prec 8 \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.pt
```
Expand All @@ -59,7 +67,6 @@ python scripts/quant_sc_main.py \
| `--mlplayerwise` | float | 0.0 | Fraction of blocks to use SC for MLP fc1 and fc2 (0-1). |
| `--inputprojlayerwise` | float | 0.0 | Fraction of blocks to use SC for QKV input projection (0-1). |
| `--sc_prec` | int | 8 | SC precision in bits. Sets `stoc_len = 2^sc_prec` (e.g., 8 → stoc_len=256). |
| `--sc_enable` | flag | false | Use enable-signal SC multiplication (compact kernel) instead of standard XNOR/AND. |
| `--sc_noise_model` | flag | false | Replace real SC kernels with a fast analytical noise surrogate. See [SC Noise Model](#sc-noise-model-fast-surrogate) below. |
| `--sc_noise_local_correction` | float | 0.15 | Variance correction for per-row/per-batch scaled matmuls (MLP, AV, QK, input proj). Recommended: 0.10 - 0.15. |
| `--sc_noise_global_correction` | float | 0.60 | Variance correction for per-tensor scaled matmuls (output proj). Recommended: 0.10 - 0.20. |
Expand All @@ -71,6 +78,14 @@ python scripts/quant_sc_main.py \
| `--mp` | flag | false | Enable per-token-row mixed precision for QK/AV. Assigns different stoc_len levels to different token rows based on runtime importance. |
| `--mp_levels` | str | `256,128,64,32` | Comma-separated stoc_len levels for MP, sorted descending. |
| `--mp_fractions` | str | None | Comma-separated fractions of rows per MP level (must sum to 1). Default: equal fractions (1/N each). |
| `--adaptive_mp` | flag | false | Enable adaptive mixed precision with timestep-aware thresholds (rows ranked by a runtime metric; thresholds slide with diffusion progress). Tune with `--mp_alpha` (sensitivity to timestep progress) and `--mp_beta` (base threshold offset). Per-operator overrides exist (`--mp_alpha_qk`, `--mp_beta_av`, …). |
| `--adaptive_mp_table` | str | None | Path to a calibrated adaptive-MP threshold JSON table (precomputed thresholds instead of `--mp_alpha`/`--mp_beta`). |
| `--range_mp` | flag | false | Enable range-based mixed precision: assign `stoc_len` per group by its weight (max−min) range. Configure with `--range_mp_levels` and `--range_mp_threshold` (normalized 0–1; higher → more groups at lower precision; per-operator `--range_mp_threshold_qk` etc.). |
| `--sc_fixed_level_prec` | flag | false | Keep kernel `sc_prec` fixed to `--sc_prec` for all MP `stoc_len` levels (instead of deriving `sc_prec` per level). |

> **MP modes are mutually exclusive.** Pick one of `--mp` (static row fractions),
> `--adaptive_mp` (metric- and timestep-aware), or `--range_mp` (weight-range based).
> All three only activate for blocks/timesteps where QK/AV/MLP SC is already enabled.

#### Examples

Expand All @@ -79,7 +94,7 @@ python scripts/quant_sc_main.py \
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 1.0 \
--sc_prec 8 --sc_enable \
--sc_prec 8 \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.pt
```
Expand All @@ -89,7 +104,7 @@ python scripts/quant_sc_main.py \
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 1.0 --qklayerwise 0.8 --avlayerwise 0.8 --mlplayerwise 0.8 \
--sc_prec 8 --sc_enable \
--sc_prec 8 \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.pt
```
Expand All @@ -99,7 +114,7 @@ python scripts/quant_sc_main.py \
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 1.0 --avlayerwise 0.5 \
--sc_prec 8 --sc_enable \
--sc_prec 8 \
--save_sc_config my_config.json \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.pt
Expand All @@ -116,15 +131,15 @@ Rows/heads are bucketed into N discrete `stoc_len` levels using quantile boundar

**Quick test (4 levels, equal fractions):**
```bash
python scripts/quant_sc_main.py --wbits 8 --abits 8 --w_sym --a_sym --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 --sc_prec 8 --sc_enable --mp --mp_levels 256,128,64,32
python scripts/quant_sc_main.py --wbits 8 --abits 8 --w_sym --a_sym --timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 --sc_prec 8 --mp --mp_levels 256,128,64,32
```

**Custom fractions (aggressive — only 10% at full precision):**
```bash
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 \
--sc_prec 8 --sc_enable \
--sc_prec 8 \
--mp --mp_levels 256,128,64,32 --mp_fractions 0.1,0.2,0.3,0.4
```

Expand All @@ -133,7 +148,7 @@ python scripts/quant_sc_main.py \
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 0.5 --qklayerwise 0.5 --avlayerwise 0.5 \
--sc_prec 8 --sc_enable \
--sc_prec 8 \
--mp --mp_levels 256,64 --mp_fractions 0.3,0.7
```

Expand Down Expand Up @@ -172,7 +187,7 @@ python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 1 --qklayerwise 1.0 --avlayerwise 1.0 \
--projlayerwise 1.0 --mlplayerwise 1.0 --inputprojlayerwise 1.0 \
--sc_prec 8 --sc_enable --sc_noise_model \
--sc_prec 8 --sc_noise_model \
--image-size 256 --num-sampling-steps 100 --batch-size 16
```

Expand All @@ -182,7 +197,7 @@ python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 1 --qklayerwise 1.0 --avlayerwise 1.0 \
--projlayerwise 1.0 --mlplayerwise 1.0 --inputprojlayerwise 1.0 \
--sc_prec 8 --sc_enable --sc_noise_model \
--sc_prec 8 --sc_noise_model \
--adaptive_mp --mp_levels 256,128,64,32,16 \
--mp_alpha 0.3 --mp_beta 0.1 \
--image-size 256 --num-sampling-steps 100 --batch-size 16
Expand Down Expand Up @@ -210,7 +225,7 @@ python scripts/quant_sc_main.py \

#### Behavior

- Orthogonal to `--sc_enable` (both can be on simultaneously)
- Orthogonal to the real SC kernels — the surrogate replaces them, leaving the rest of the quant pipeline unchanged
- All SC dispatch paths are supported: uniform, adaptive MP, range MP, per-head mixed
- The real SC path is completely untouched when `--sc_noise_model` is not set
- Uses `torch.compile` for kernel fusion; first ~5 iterations are slow (JIT warmup), then steady state ~12 it/s at batch=8 on RTX PRO 6000
Expand Down Expand Up @@ -333,7 +348,7 @@ Non-power-of-2 values (e.g., 181 for int7.5) are supported. The quantization gri
python scripts/quant_sc_main.py \
--wbits 8 --abits 8 --w_sym --a_sym \
--timewise 1.0 --qklayerwise 1.0 --avlayerwise 1.0 \
--sc_prec 8 --sc_enable \
--sc_prec 8 \
--sc_config path/to/config.json \
--image-size 256 --num-sampling-steps 100 --batch-size 16 \
--ckpt pretrained_models/DiT-XL-2-256x256.pt
Expand Down Expand Up @@ -380,7 +395,7 @@ Results are saved to `../results/<NNN>-qdit_sc_<params>/`:

- **Do NOT combine `--static` with `--a_sym`** — the static quantizer only supports asymmetric activation quantization. Use `--a_sym` without `--static`, or `--static` without `--a_sym`.
- DiT-XL/2 has **28 blocks** and **16 attention heads** with head_dim=72.
- The `--sc_enable` flag selects the compact enable-signal kernel path, which is required for mixed-precision early termination.
- SC matmuls now always run through the compact **enable-signal** kernel path (via `scmp_kernels.sc_matmul`). The old `--sc_enable` toggle and the legacy XNOR/AND path have been removed — enable-signal is canonical and is what makes mixed-precision early termination possible.


## BibTeX
Expand Down