zai-org/RelayDiffusion
Python
Captured source
source ↗zai-org/RelayDiffusion
Description: The official implementation of "Relay Diffusion: Unifying diffusion process across resolutions for image synthesis" [ICLR 2024 Spotlight]
Language: Python
License: Apache-2.0
Stars: 316
Forks: 19
Open issues: 2
Created: 2023-09-04T14:28:18Z
Pushed: 2024-04-29T09:29:51Z
Default branch: main
Fork: no
Archived: no
README:
Relay Diffusion: Unifying diffusion process across resolutions for image synthesis
Official Pytorch Implementation 🌐[[WiseModel]](https://www.wisemodel.cn/models/ZhipuAI/RelayDiffsuon/intro) 🌐[[Model Scope]](https://www.modelscope.cn/models/ZhipuAI/RelayDiffusion/summary)
🎉News! The paper of RelayDiffusion has been accepted by ICLR 2024 (Spotlight)!

We propose *Relay Diffusion Model (RDM)* as a better framework for diffusion generation. *RDM* transfers a low-resolution image or noise into an equivalent high-resolution one via blurring diffusion and block noise. Therefore, the diffusion process can continue seamlessly in any new resolution or model without restarting from pure noise or low-resolution conditioning.
RDM achieved state-of-the-art FID on CelebA-HQ and sFID ImageNet-256 (FID=1.87)!
For a formal introduction, Read our paper: Relay Diffusion: Unifying diffusion process across resolutions for image synthesis.
Setup
Environment
Download the repo and setup the environment with:
git clone https://github.com/THUDM/RelayDiffusion.git cd RelayDiffusion conda env create -f environment.yml conda activate rdm
We enable xformers.ops.memory_efficient_attention to reduce about 15% training cost. If there is no need you can also remove xformers from environment.yml.
Linux servers with Nvidia A100s are recommended. However, by setting smaller --batch-gpu (batch size on a single gpu), you can still run the inference and training scripts on less powerful GPUs.
Dataset
We preprocess and implement datasets with the same format as EDM. For CelebA-HQ, follow *Progressive Growing of GANs for Improved Quality, Stability, and Variation* to construct the high-quality subset of CelebA. For ImageNet, download data from the official site.
To convert the original data to organized data ready for training at $64\times 64$ or $256\times 256$ resolution, run command:
python dataset_tool.py \ --source=/path/to/original/data \ --dest=/path/to/output/data.zip \ --transform=center-crop \ --resolution=64x64 # or --resolution=256x256
Inference & Evaluation
Sample Generation
To generate samples from RDM models, run command:
torchrun --standalone --nproc_per_node=1 generate.py --sampler_stages=both --outdir=/path/to/output/dir/ \ --network_first=/path/to/1st/ckpt --network_second=/path/to/2nd/ckpt
To generate $N$ images, set --seed=[K]-[K+N-1] with a randomly-picked $K$. You can assign --nproc_per_node=N to enable parallel generation of multiple GPUs.
If you want to generate final samples from first-stage results (only use the second stage model), set --sampler_stages=second and assign input directory of first-stage results by --indir.
Besides, arguments for configurations of the first stage are:
num_steps_first: number of sampling steps.sigma_min_first&sigma_max_first: lowest & highest noise level.rho_first: time step exponent.cfg_scale_first: scale of classifier-free guidance.S_churn: stochasticity strength.S_min&S_max: min & max noise level.S_noise: noise inflation.
Arguments for configurations of the second stage are:
num_steps_second: number of sampling steps.sigma_min_second&sigma_max_second: lowest & highest noise level.blur_sigma_max_second: maximum sigma of blurring schedule.rho_second: time step exponent.cfg_scale_second: scale of classifier-free guidance.up_scale_second: scale of upsampling.truncation_sigma_second&truncation_t_second: truncation point of noise & time schedule.s_block_second: strength of block noise addition.s_noise_second: strength of stochasticity.
Evaluation Metrics
We quantitatively measure the sample quality by metrics including Fréchet inception distance (FID), spatial FID (sFID), Inception Score (IS), Precision and Recall. For sFID, IS, Precision and Recall, we reformat the calculation pipeline based on the formulation in tensorflow from ADM.
First, run the following command to generate activation data file from samples and dataset:
torchrun --standalone --nproc_per_node=1 evaluate.py activations --data=/sample/dir/ --dest=eval-refs/activations_sample.npz --batch=64 # build sample activations torchrun --standalone --nproc_per_node=1 evaluate.py activations --data=/path/to/dataset.zip --dest=eval-refs/activations_ref.npz --batch=64 # build reference activations
Then calculate metrics based on pre-built activations, run command:
torchrun --standalone --nproc_per_node=1 evaluate.py calc --batch=64 \ --activations_sample=eval-refs/activations_sample.npz \ --activations_ref=eval-refs/activations_ref.npz \ [-m fid] [-m sfid] [-m is] [-m pr] \ # assign metrics to be calculated
Performance Reproduction
RDM achieves competitive results in comparison with previous SoTA models:
| Dataset | Resolution | Training Samples | FID | sFID | IS | Precision | Recall | | --------- | ---------- | ---------------- | :--: | :--: | :----: | :-------: | :----: | | CelebA-HQ | 256x256 | 47M | 3.15 | - | - | 0.77 | 0.55 | | ImageNet | 256x256 | 1250M | 1.87 | 3.97 | 278.75 | 0.81 | 0.59 |
We provide best pre-trained checkpoints of RDM and their sampler settings for reproducing performance:
- CelebA-HQ $256\times 256$:
Download checkpoints of first stage and second stage, place them in ckpts/, generate samples and their activations by commands:
torchrun --standalone --nproc_per_node=8 generate_celebahq.py --outdir=generations/celebahq_samples/ \ --network_first=ckpts/celebahq_first_stage.pt \ --network_second=ckpts/celebahq_second_stage.pt torchrun…
Excerpt shown — open the source for the full document.