This is a PyTorch implementation of Variational Diffusion Models, where the focus is on optimizing likelihood rather than sample quality, in the spirit of probabilistic generative modeling.
This implementation should match the
official one in JAX.
However, the purpose is mainly educational and the focus is on simplicity.
So far, the repo only includes CIFAR10, and variance minimization
with the I.2
in the paper) is not
implemented (it's only used for CIFAR10 with augmentations and, according
to the paper, it does not have a significant impact).
The samples below are from a model trained on CIFAR10 for 2M steps with gradient clipping and with a fixed noise
schedule such that
Without gradient clipping (as in the paper), the test set variational lower bound (VLB) is 2.715 bpd after 2M steps (the paper reports 2.65 after 10M steps). However, training is a bit unstable and requires some care (tendency to overfit) and the train-test gap is rather large. With gradient clipping, the test set VLB is slightly worse, but training seems more well-behaved.
Let
with
In discrete time, the generative (denoising) process in
where
where
The loss function is given by the usual variational lower bound:
where the diffusion loss
Long story short, using the classic noise-prediction parameterization of the denoising model:
and considering the continuous-time limit (
One of the key components to reach SOTA likelihood is the
concatenation of Fourier features to
with
Assume that each scalar variable takes values:
E.g., in our case the
which means the features have period
Below we visualize the feature values for pixel values 0 to 25, varying the
frequency
Below are the sine features on the Mandrill image (and detail on the right) with smoothly increasing frequency
from
The environment can be set up with requirements.txt
. For example with conda:
conda create --name vdm python=3.9
conda activate vdm
pip install -r requirements.txt
To train with default parameters and options:
accelerate launch --config_file accelerate_config.yaml train.py --results-path results/my_experiment/
Append --resume
to the command above to resume training from the latest checkpoint.
See train.py
for more training options.
Here we provide a sensible configuration for training on 2 GPUs in the file
accelerate_config.yaml
. This can be modified directly, or overridden
on the command line by adding flags before "train.py
" (e.g., --num_processes N
to train on N GPUs).
See the Accelerate docs for more configuration options.
After initialization, we print an estimate of the required GPU memory for the given
batch size, so that the number of GPUs can be adjusted accordingly.
The training loop periodically logs train and validation metrics to a JSONL file,
and generates samples.
python eval.py --results-path results/my_experiment/ --n-sample-steps 1000
This implementation is based on the VDM paper and official code. The code structure for training diffusion models with Accelerate is inspired by this repo.