Skip to content

HACLINE/DiffPPOGAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Enhancing Pre-trained Diffusion Models with Reinforcement Learning and Adversarial Reward Functions

This is a course project for Computer Vision, second year course of Yao Class at Tsinghua University. This work is far below paper-level, but to prersent an interesting idea.

Contributors

Final Results

Final report and poster are in the root directory. Same as the ones we submitted to TA and shown in class.

Setup

conda env create -f environment.yaml
conda activate diffppogan

Dataset will be automatically downloaded when running the code. You may refer to Generative Zoo for more details on the dataset.

Training

We didn't provide the training script for the standard diffusion model. You can use the codes in Generative Zoo to train the standard diffusion model, or you can use the pre-trained model provided in pretrained/base.pt.

Please refer to scripts/train.sh for the training script. You can run the following command to start training:

export DATA_DIR=$PWD/data

python -m src.train.train \
    cfg=adv_schedule_r3 \
    cfg.gpu_id=0 \
    cfg.fid.real_image_path=$DATA_DIR/real/cifar10/imgs \
    cfg.ref_model_path=pretrained/base.pt \
    cfg.wandb.name=WANDB_NAME \

Evaluation

You can sample images from the pre-trained model using scripts/sample.sh the following command:

export DATA_DIR=$PWD/data

python -m src.sample.sample \
    cfg=adv_schedule_r3 \
    cfg.gpu_id=0 \
    cfg.ref_model_path=pretrained/base.pt \
    cfg.fid.real_image_path=DATA_DIR/real/cifar10/imgs \
    cfg.checkpoint=pretrained/best.pt

Run scripts/sample_fid.sh to sample images and use scripts/eval_fid.sh to evaluate the FID score.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors