Our paper has been accepted to IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI). 🎉
Note: This work extends EPD-Solver (ICCV 2025). You are currently in the default branch,
EPD-Solver++. For those interested in our previous work, the original ICCV 2025 implementation is available in theEPD-Solverbranch.
EPD-Solver (Ensemble Parallel Direction) mitigates truncation errors in diffusion sampling by leveraging parallel gradient evaluations within a single step. We introduce a novel two-stage optimization framework that aligns the solver with human preferences without fine-tuning the heavy diffusion backbone.
Instead of sequential evaluations, EPD-Solver computes gradients at multiple learned intermediate timesteps (
-
Stage 1: Distillation-Based Initialization We first distill a few-step student solver by minimizing the trajectory error against a high-fidelity teacher (e.g., DPM-Solver). This provides a robust initialization that captures the trajectory curvature.
-
Stage 2: Residual Dirichlet Policy Optimization (RDPO) We reformulate the solver as a stochastic Dirichlet policy. Using a lightweight PPO variant, we fine-tune the solver's low-dimensional parameters (time segments and weights) to maximize human-aligned rewards (e.g., HPSv2, ImageReward). This ensures high perceptual quality and semantic alignment even at low NFEs (e.g., 20 steps).
-
Create Environment
conda env create -f environment.yml -n epd conda activate epd
-
Install Dependencies
# Core dependencies pip install omegaconf gdown lightning fairscale piq accelerate timm einops kornia HPSv2 pip install --upgrade diffusers[torch] # CLIP & Transformers pip install git+https://github.com/openai/CLIP.git pip install transformers pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-
Setup Environment Variables Important: Run this before training or inference.
export PYTHONPATH="$PWD/training/ppo/reward_models/HPSv2:$PWD/src/taming-transformers:$PYTHONPATH"
FLUX.1-dev is a gated Hugging Face model. For sampling, set
FLUX_MODEL_PATHto a local snapshot when available. For training,exps/flux/flux-start.pklstores the defaultblack-forest-labs/FLUX.1-devreference; make sure it is available in your Hugging Face cache, or setFLUX_ALLOW_REMOTE=1after authentication.
We provide pre-trained predictors (Stage 1: Distilled) and RL-finetuned solvers (Stage 2: Best).
Stage 2 models are optimized using Residual Dirichlet Policy Optimization for better human preference alignment.
For FLUX.1-dev, flux-start.pkl is a manually initialized start table for PPO, not a distilled checkpoint.
| Model | Resolution | Type | Download |
|---|---|---|---|
| Stable Diffusion v1.5 | 512x512 | RL-Best (Stage 2) | sd15-best.pkl |
| Distilled (Stage 1) | sd15-distilled.pkl | ||
| SD3-Medium | 1024x1024 | RL-Best (Stage 2) | sd3-1024-best.pkl |
| Distilled (Stage 1) | sd3-1024-distilled.pkl | ||
| SD3-Medium | 512x512 | RL-Best (Stage 2) | sd3-512-best.pkl |
| Distilled (Stage 1) | sd3-512-distilled.pkl | ||
| FLUX.1-dev | 1024x1024 | RL-Best (Stage 2) | flux-best.pkl |
| Start (manual init) | flux-start.pkl |
We also provide a detailed guide for each part below.
To train EPD-Solver using RDPO:
# Available configs: sd15.yaml, sd3_512.yaml, sd3_1024.yaml, flux_dev.yaml
torchrun --master_port=12345 --nproc_per_node=1 -m training.ppo.launch \
--config training/ppo/cfgs/sd3_1024.yaml
# FLUX.1-dev. This uses the checked-in exps/flux/flux-start.pkl table.
./train_flux.sh
# Or launch manually from the checked-in initial table. The FLUX model
# reference is read from the predictor metadata in exps/flux/flux-start.pkl.
python -m training.ppo.launch \
--config training/ppo/cfgs/flux_dev.yamlNote: RDPO training was performed using a single NVIDIA H200 GPU. Refer to launch.sh for full scripts.
To generate images with an EPD-Solver, use the examples below (replace checkpoint paths with your own exports as needed):
## SD1.5
MASTER_PORT=12345 python sample.py \
--predictor_path exps/sd15/sd15-best.pkl \
--prompt-file src/prompts/test.txt \
--seeds "0-19" \
--batch 4 \
--outdir samples/sd15
## SD3-Medium
python sample_sd3.py --predictor exps/sd3-1024/sd3-1024-best.pkl \
--seeds "0" \
--outdir samples/sd3 \
--prompt "..."
## FLUX.1-dev EPD
python sample_flux.py --predictor exps/flux/flux-best.pkl \
--model-id /path/to/local/FLUX.1-dev \
--prompt-file src/prompts/test.txt \
--seeds "0" \
--outdir samples/fluxWe provide six metrics to evaluate generated images: HPSv2.1, PickScore, ImageReward, CLIP, Aesthetic, and MPS. Please refer to the evaluation script section in launch.sh.
Sampling (sample.py)
| Parameter | Default | What it controls |
|---|---|---|
predictor_path |
required | EPD predictor snapshot (.pkl); numeric IDs auto-resolve to the latest matching checkpoint in ./exps. |
model_path |
None | Optional backbone checkpoint override; for SD3/FLUX this maps to model_name_or_path. |
max_batch_size (--batch) |
64 |
Per-process batch size; seeds are split across ranks. |
seeds |
0-63 |
Seed list or range; determines how many images are generated. |
prompt |
None | Single text prompt for all seeds; if omitted, falls back to prompt-file or MS-COCO eval captions for dataset_name=ms_coco. |
prompt-file |
None | Text or CSV (column text) with prompts; used when prompt is empty. |
backend |
Predictor metadata | Override backbone (ldm/sd3/flux); defaults to what is stored in the predictor. |
backend-config |
None | JSON object overriding backend options (e.g., SD3/FLUX resolution, torch_dtype, offload, token). |
use_fp16 |
False |
Reserved flag for mixed precision (not currently wired). |
return_inters |
False |
Reserved flag for saving intermediates (not currently wired). |
outdir |
Auto (./samples/{dataset} or ./samples/grids/{dataset}) |
Output root; falls back to a derived path when unset. |
grid |
False |
Save a grid per batch instead of per-image files. |
subdirs |
True |
When saving per-image files, create 1k-chunked subfolders. |
Sampling (sample_sd3.py)
| Parameter | Default | What it controls |
|---|---|---|
predictor |
required | SD3 EPD predictor snapshot (.pkl). |
seeds |
0-3 |
Seed list or range; determines how many images are generated. |
prompt |
None | Single prompt for all seeds; if empty, uses prompt-file or falls back to empty prompts. |
prompt-file |
None | Text/CSV file with prompts; repeats to match seeds length. |
outdir |
./samples/sd3_epd |
Output directory. |
grid |
False |
Save a grid per batch. |
max-batch-size |
4 |
Per-batch sample count (--max-batch-size). |
resolution |
Predictor/back-end config (512 or 1024) | Optional override; must match predictor metadata if set. |
Sampling (sample_flux.py)
| Parameter | Default | What it controls |
|---|---|---|
predictor |
required | FLUX EPD predictor snapshot (.pkl). |
model-id |
Predictor metadata or FLUX_MODEL_PATH |
FLUX.1-dev repo id or local snapshot path. |
seeds |
0 |
Seed list or range; determines how many images are generated. |
prompt |
None | Single prompt for all seeds. |
prompt-file |
None | Text/CSV file with prompts; repeats to match seeds length. |
outdir |
./samples/flux_epd |
Output directory. |
max-batch-size |
1 |
Per-batch sample count. |
FLUX.1-dev Notes
- Supported FLUX variant:
black-forest-labs/FLUX.1-dev. - FLUX support is fixed to
1024x1024,schedule_type=flowmatch, and embedded guidance scale3.5. - The sampling scripts resolve FLUX locally first via
FLUX_MODEL_PATHor the Hugging Face cache, then fall back to the Hugging Face repo id. SetFLUX_ALLOW_REMOTE=1when intentionally loading the gated Hugging Face repo instead of a local snapshot. exps/flux/flux-best.pklis the released RL-best inference checkpoint. FLUX training starts fromexps/flux/flux-start.pkl, a manually initialized start table thattrain_flux.shuses directly before PPO launch.
Solver metadata (read from predictor checkpoints)
| Parameter | Default source | Notes |
|---|---|---|
dataset_name |
Predictor ckpt | Dataset tag (e.g., ms_coco); drives prompt fallback and output paths. |
backend / backend_config |
Predictor ckpt | Backbone type plus stored options (resolution, flow-match params, offload/token settings for SD3/FLUX, etc.). |
num_steps |
Predictor ckpt | Inference steps; base NFE 2*(num_steps-1) (minus one eval when afs=True, doubled again for CFG in ms_coco). |
num_points |
Predictor ckpt | Number of intermediate points per step; used for NFE reporting/outdir naming. |
guidance_type / guidance_rate |
Predictor ckpt | CFG sampling (e.g., 4.5 for SD3 PPO configs, 7.5 for SD1.5). |
schedule_type / schedule_rho |
Predictor ckpt | flowmatch for SD3/FLUX, discrete for SD1.5. |
sigma_min / sigma_max |
Predictor or backend | Noise range passed to scheduler (falls back to backend defaults when unset). |
flowmatch_mu / flowmatch_shift |
Predictor or backend | Flow-matching parameters used by SD3/FLUX schedules. |
afs, max_order, predict_x0, lower_order_final |
Predictor ckpt | EPD/DPM solver behavior flags. |
RDPO Training configs (training/ppo/cfgs/*.yaml)
| Key | sd3_512 | sd3_1024 | sd15 | flux_dev | Purpose |
|---|---|---|---|---|---|
data.predictor_snapshot |
exps/sd3-512/...-distilled.pkl |
exps/sd3-1024/...-distilled.pkl |
exps/sd15/...-distilled.pkl |
exps/flux/flux-start.pkl |
Starting EPD predictor. |
model.backend |
sd3 |
sd3 |
ldm |
flux |
Backbone family used during RL. |
model.resolution |
512 |
1024 |
n/a | 1024 |
Training resolution for flow-matching backbones. |
model.schedule_type |
flowmatch |
flowmatch |
discrete |
flowmatch |
Diffusion schedule during RL. |
model.guidance_rate |
4.5 |
4.5 |
7.5 |
3.5 |
Guidance scale used while training the solver. |
ppo.rollout_batch_size |
16 |
8 |
8 |
8 |
Samples per PPO rollout. |
ppo.dirichlet_concentration |
10 |
10 |
20 |
10 |
Dirichlet policy concentration. |
reward.batch_size |
4 |
4 |
4 |
1 |
Reward evaluation batch size. |
reward.multi.weights |
hps:1.0 (others 0) |
same | same | same | Per-head reward weights. |
Shared defaults across configs: model.dataset_name=ms_coco, model.guidance_type=cfg, model.schedule_rho=1.0, model.num_steps/num_points left null to inherit predictors, reward.type=multi, reward.enable_amp=true, reward.weights_path=weights/HPS_v2.1_compressed.pt, ppo.learning_rate=7e-5, ppo.minibatch_size=4, ppo.ppo_epochs=1, ppo.rloo_k=4, ppo.clip_range=0.2, ppo.kl_coef=0.0, ppo.entropy_coef=0.0, ppo.max_grad_norm=1.0, ppo.decode_rgb=true, logging.log_interval=1, run.output_root=exps, run.seed=0. The SD configs use ppo.steps=99999 and logging.save_interval=500; flux_dev sets sigma_min=0.001, sigma_max=1.0, ppo.steps=20000, and logging.save_interval=200.
@misc{wang2025paralleldiffusionsolverresidual,
title={Parallel Diffusion Solver via Residual Dirichlet Policy Optimization},
author={Ruoyu Wang and Ziyu Li and Beier Zhu and Liangyu Yuan and Hanwang Zhang and Xun Yang and Xiaojun Chang and Chi Zhang},
year={2025},
eprint={2512.22796},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2512.22796},
}
@inproceedings{zhu2025distilling,
title={Distilling Parallel Gradients for Fast ODE Solvers of Diffusion Models},
author={Zhu, Beier and Wang, Ruoyu and Zhao, Tong and Zhang, Hanwang and Zhang, Chi},
booktitle={International Conference on Computer Vision (ICCV)},
year={2025}
}

