- Overview
- Installation
- Data
- Pretrained weights
- Train
- Evaluation
- Implementation details
- Acknowledgments
- Contacts
Official PyTorch implementation of "DiMSUM: Diffusion Mamba - A Scalable and Unified Spatial-Frequency Method for Image Generation" (NeurIPS'24)
Hoang Phan4 · Dimitris N. Metaxas2 · Anh Tran1
1VinAI Research 2Rutgers University 3Cornell University 4New York University
[Page] [Paper]
*Equal contribution †Work done while at VinAI Research
We propose DiMSUM, a hybrid Mamba-Transformer diffusion model that synergistically leverages both spatial and frequency information for high-quality image synthesis. Through extensive experiments on standard benchmarks, our method achieves state-of-the-art results, with a FID of 4.62 on CelebHQ 256, 3.76 on LSUN Church, and 2.11 on ImageNet1k 256. Additionally, our approach attains faster training convergence compared to Zigma and other diffusion methods. In detail, our method outperforms both DiT and SiT while requiring less than a third of the training iterations, achieving the best FID score of 2.11.
Details of the model architecture and experimental results can be found in our following paper:
@inproceedings{phung2024dimsum,
title={DiMSUM: Diffusion Mamba - A Scalable and Unified Spatial-Frequency Method for Image Generation},
author={Hao Phung and Quan Dao and Trung Dao and Hoang Phan and Dimitris Metaxas and Anh Tran},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year= {2024},
}
Please CITE our paper and give us a ⭐ whenever this repository is used to help produce published results or incorporated into other software.
-
Python 3.10.13
conda create -n dimsum python=3.10.13
-
torch 2.1.1 + cu118
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
-
Requirements:
pip install -r requirements.txt
-
Install
causal_conv1d
andmamba
conda install conda-forge::cudatoolkit-dev
cd causal_conv1d && pip install -e . && cd ..
cd mamba && pip install -e . && cd ..
-
Add python path for DiMSUM:
export PYTHONPATH=$PYTHONPATH:$(pwd)
For CelebA HQ (256) and LSUN, please follow this repo for dataset preparation.
For evaluation, please resize and extract "jpeg" images from dataset first.
For LMDB data (like celeba_256
and lsun_church
), run this command:
python eval_toolbox/resize_lmdb.py --dataset celeba_256 --datadir ./data/celeba_256/celeba-lmdb/ --image_size 256 --save_dir real_samples/
For image folder of jpeg/png images, run this command instead:
python eval_toolbox/resize.py main input_data_dir real_samples/dataname
Exp | #Params | FID | Checkpoints |
---|---|---|---|
Celeba 256 | 460M | 4.62 | celeb256_225ep.pt |
Church 256 | 460M | 3.76 | church_395ep.pt |
ImageNet-1K 256 (CFG) | 460M | 2.11 | imnet256_510ep.pt |
Comment/Uncomment command lines for desired dataset, then run:
bash scripts/train.sh
To sampe images from pretrained checkpoints, run:
bash scripts/sample.sh
To evaluate, select a relevant command and run:
bash scripts/eval.sh
- DiMSUM architecture is put in dimsum/models_dim.py.
- Conditional Mamba can be found in mamba/mamba_ssm/ops/selective_scan_interface.py and causal-conv1d/csrc/causal_conv1d.cpp.
- Frequency transformations: dimsum/wavelet_layer.py and dimsum/dct_layer.py.
- Mamba Scanning strategies (e.g. sweep8, jpeg8): dimsum/scanning_orders.py.
This project is based on Vim, LFM, SiT, DiT, ZigMa. Thanks for publishing their wonderful works with codes.
If you have any problems, please open an issue in this repository or ping an email to [email protected].