Skip to content

HiDream-ai/PS-SR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PS-SR logo

PS-SR: Pseudo-Single-Step Video Super-Resolution via Speculative Diffusion

Project Page Paper Video 中文解读

PS-SR: Pseudo-Single-Step Video Super-Resolution via Speculative Diffusion
Aiqiu Wu, Zhaofan Qiu, Ting Yao, Tao Mei
In CVPR, 2026.

PS-SR is a video super-resolution method that accelerates diffusion-based VSR through speculative diffusion. Instead of relying on either slow multi-step sampling or an aggressively distilled single-step model, PS-SR uses a powerful base model to perform a comprehensive denoising step, then uses a lightweight draft model for subsequent speculative refinement. The draft model reuses guidance from the base trajectory while reducing computation, and a frequency-domain update preserves low-frequency content consistency while injecting high-frequency details from the draft predictions.

This repository provides the official implementation scripts for speculative sampling, frequency-domain update, training, and evaluation.

Demo

Low-Resolution Input PS-SR Result
demo_1.mp4
demo_1_result.mp4
demo_2.mp4
demo_2_result.mp4
demo_3.mp4
demo_3_result.mp4

Installation

Requirements

  • Linux GPU environment with CUDA
  • Python 3.10
  • PyTorch 2.7.1
  • DiffSynth-Studio 1.1.8

Environment Setup

git clone https://github.com/your-repo/PS-SR.git
cd PS-SR

conda create -n pssr python=3.10 -y
conda activate pssr

pip install -r requirements.txt

The scripts are designed for GPU inference and training. Adjust the PyTorch/CUDA installation if your cluster requires a specific CUDA build.

Model Preparation

Create the dependency and checkpoint folders:

mkdir -p dependent_models/Wan2.1-T2V-1.3B
mkdir -p checkpoints/pretrained_models

Download the required auxiliary models and place them as follows:

dependent_models/
|-- DAPE.pth
|-- ram_swin_large_14m.pth
`-- Wan2.1-T2V-1.3B/
    |-- diffusion_pytorch_model.safetensors
    |-- models_t5_umt5-xxl-enc-bf16.pth
    `-- Wan2.1_VAE.pth

Model sources:

Download the PS-SR pretrained checkpoints and place them as follows:

checkpoints/pretrained_models/
|-- base.safetensors
`-- draft.safetensors

Quick Start

Put low-resolution input videos in videos_for_test/input/, then run speculative sampling followed by frequency-domain update:

python inference_step_1.py \
    --input_dir videos_for_test/input \
    --output_dir videos_for_test/output \
    --lora_base_path ./checkpoints/pretrained_models/base.safetensors \
    --lora_draft_path ./checkpoints/pretrained_models/draft.safetensors \
    --wan_model_dir ./dependent_models/Wan2.1-T2V-1.3B

python inference_step_2.py \
    --consistent_dir videos_for_test/output/base \
    --sharp_dir videos_for_test/output/base+draft_3 \
    --output_dir videos_for_test/output/final

Final super-resolved videos are written to:

videos_for_test/output/final/

Inference

Step 1: Speculative Sampling

inference_step_1.py loads the base and draft checkpoints, applies temporal and spatial sliding-window inference, and writes the base prediction together with draft predictions at the configured speculative timesteps. With the default --timestep_base 699 and --timestep_draft_list "[599,499,399]", the script writes base/ and three draft branches.

python inference_step_1.py \
    --input_dir /path/to/input_videos \
    --output_dir /path/to/output_root \
    --lora_base_path ./checkpoints/pretrained_models/base.safetensors \
    --lora_draft_path ./checkpoints/pretrained_models/draft.safetensors \
    --wan_model_dir ./dependent_models/Wan2.1-T2V-1.3B \
    --window_t 33 \
    --overlap_t 16 \
    --window_h 720 \
    --window_w 1280

Default outputs under --output_dir:

output_root/
|-- base/
|-- base+draft_1/
|-- base+draft_2/
`-- base+draft_3/

Useful options include --prompt, --negative_prompt, --seed, --torch_dtype, --cfg_scale, --timestep_base, --timestep_draft_list, --sort_files, --skip_existing, --temp_dir, and --keep_temp.

Step 2: Frequency-Domain Update

inference_step_2.py exposes the paper's frequency-domain update as a separate fusion script. The default setup keeps the low-frequency structure from base/ and blends high-frequency details from base+draft_3/.

python inference_step_2.py \
    --consistent_dir /path/to/output_root/base \
    --sharp_dir /path/to/output_root/base+draft_3 \
    --output_dir /path/to/output_root/final

Useful fusion options include --fc, --alpha, --border, --order, --window_t, --overlap_t, --window_h, --window_w, --sort_files, and --skip_existing. The default fusion strength is --alpha 0.6, matching the value used by the released script.

Evaluation

Use eval_metrics.py to evaluate restored videos or image sequences.

Full-reference evaluation:

python eval_metrics.py \
    --gt /path/to/ground_truth \
    --pred /path/to/predictions \
    --out /path/to/metrics_output \
    --metrics psnr,ssim,lpips \
    --crop 0

No-reference evaluation is supported by omitting --gt and selecting no-reference metrics available in pyiqa, for example:

python eval_metrics.py \
    --pred /path/to/predictions \
    --metrics clipiqa

The script writes a JSON summary named metrics_<metric_names>.json to --out. If --out is omitted, the JSON file is saved in the prediction folder.

Training

Training uses accelerate and the configs in config/. The training pipeline follows the paper's base/draft decomposition: train the base model first, then initialize and train the lightweight draft model from the trained base checkpoint. The provided shell scripts are examples for single-node and multi-node runs:

bash train_base.sh
bash train_draft.sh

Dataset Format

Training data is read by VideoDataset in Wan_SR/trainers/utils.py.

  • --dataset_base_path points to the folder containing videos or images.
  • --dataset_metadata_path accepts CSV or JSON metadata.
  • --data_file_keys defaults to image,video,LQ_video.
  • Metadata should include paths relative to --dataset_base_path.
  • Prompt text is expected in the metadata. If metadata is not provided, the dataset loader can generate metadata from media files with same-name .txt prompt files.

Example CSV:

video,prompt
dataset/0001.mp4,"4K Ultra-clear, Sharp, Fine Details Restored, Temporal Consistency, Natural Colors"

Note: "4K Ultra-clear, Sharp, Fine Details Restored, Temporal Consistency, Natural Colors" is the default prompt used during training. You may use a different prompt, but the same prompt should also be used during inference to ensure consistency between training and inference.

Base Model Training

train_base.py trains the base LoRA and regularization branch. The base model is responsible for the comprehensive denoising step that anchors content and temporal consistency. A minimal command is:

accelerate launch \
    --config_file ./config/accelerate_config_single.yaml \
    train_base.py \
    --dataset_base_path ./datasets/YouHQ \
    --dataset_metadata_path ./metadata_YouHQ.csv \
    --wan_model_dir ./dependent_models/Wan2.1-T2V-1.3B \
    --output_path ./experiments/train_base \
    --lora_model_base dit \
    --lora_model_reg dit_update \
    --lora_target_modules q,k,v,o,ffn.0,ffn.2 \
    --lora_rank 32 \
    --save_steps 200 \
    --save_latest True

Expected outputs include:

experiments/train_base/
|-- latest_base.safetensors
`-- latest_reg.safetensors

Draft Model Training

train_draft.py trains the lightweight draft model from a trained base checkpoint. The draft model prunes the DiT backbone according to --k_select and adds feature-fusion layers for speculative refinement after the base step:

accelerate launch \
    --config_file ./config/accelerate_config_single.yaml \
    train_draft.py \
    --dataset_base_path ./datasets/YouHQ \
    --dataset_metadata_path ./metadata_YouHQ.csv \
    --wan_model_dir ./dependent_models/Wan2.1-T2V-1.3B \
    --output_path ./experiments/train_draft \
    --lora_model_base dit \
    --lora_target_modules q,k,v,o,ffn.0,ffn.2 \
    --lora_rank 32 \
    --save_steps 200 \
    --save_latest True \
    --load_model_base_from ./experiments/train_base/latest_base.safetensors

Expected output:

experiments/train_draft/
`-- latest_draft.safetensors

For distributed training, use config/accelerate_config_multi.yaml and set the standard multi-node variables required by the example scripts, such as RANK, MASTER_ADDR, and MASTER_PORT.

Repository Structure

PS-SR/
|-- checkpoints/          # PS-SR pretrained or trained checkpoints
|-- config/               # Accelerate configuration files
|-- dependent_models/     # DAPE, RAM, and Wan2.1 dependency weights
|-- models/               # Local model assets and tokenizer files
|-- videos_for_test/      # Example input/output video folders
|-- Wan_SR/               # Core PS-SR implementation
|-- inference_step_1.py   # Base and draft video SR inference
|-- inference_step_2.py   # Frequency-domain output fusion
|-- eval_metrics.py       # Video/image quality evaluation
|-- train_base.py         # Base model training
|-- train_draft.py        # Draft model training
`-- requirements.txt

Citation

If you find this work useful for your research, please cite:

@inproceedings{wu2026pssr,
  title={PS-SR: Pseudo-Single-Step Video Super-Resolution via Speculative Diffusion},
  author={Wu, Aiqiu and Qiu, Zhaofan and Yao, Ting and Mei, Tao},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2026}
}

Acknowledgements

This codebase is built upon or inspired by excellent open-source projects in video restoration and diffusion modeling, including:

License

This project is released under the Apache-2.0 License. See LICENSE for details.

About

[CVPR 2026] Official Implementation of PS-SR: Pseudo-Single-Step Video Super-Resolution via Speculative Diffusion

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors