Skip to content
/ ERW Public

[Preprint] Efficient Generative Model Training via Embedded Representation Warmup

Notifications You must be signed in to change notification settings

LINs-lab/ERW

Repository files navigation

Efficient Generative Model Training via Embedded Representation Warmup

Deyuan Liu1*·Peng Sun1,2*·Xufeng Li1,3·Tao Lin1†
1 Westlake University   2 Zhejiang University   3 Nanjing University
* These authors contributed equally.   † Corresponding author.
teaser_page1 teaser_page1 teaser_page1

Summary: Diffusion models have made impressive progress in generating high-fidelity images. However, training them from scratch requires learning both robust semantic representations and the generative process simultaneously. Our work introduces Embedded Representation Warmup (ERW) – a plug-and-play two-phase training framework that:

  • Phase 1 – Warmup: Initializes the early layers of the diffusion model with high-quality, pretrained visual representations (e.g., from DINOv2 or other self-supervised encoders).
  • Phase 2 – Full Training: Continues with standard diffusion training while gradually reducing the alignment loss, so the model can focus on refining generation.

🔥 News

  • (🔥 New) [2025/4/15] 🔥ERW code & weights are released! 🎉 Include: Training & Inference code and Weights in HF are all released.

1. Environment setup

conda create -n erw python=3.9 -y
conda activate erw
pip install -r requirements.txt

2. Dataset

Dataset download

Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via --data-dir arguments in training scripts. Please refer to REPA preprocessing guide.

Process 1n1k to Latent

python extract_latent.py \
  --data-path=[YOUR_DATA_PATH] \
  --output-dir==[YOUR_DATA_OUTPUT_PATH]

Warmuped Checkpoint

Weights for 100K Warmuped

Weights for 100K Warmuped + 200K Full Training (With QK-Norm)

Weights for 100K Warmuped + 200K Full Training (Without QK-Norm)

3. Training

Training from Scartch

accelerate launch train.py \
  --report-to="wandb" \
  --allow-tf32 \
  --mixed-precision="fp16" \
  --seed=0 \
  --path-type="linear" \
  --prediction="v" \
  --weighting="uniform" \
  --model="SiT-XL/2" \
  --enc-type="dinov2-vit-b" \
  --encoder-depth=14 \
  --output-dir="exps" \
  --exp-name="erw-linear-dinov2-b-enc14" \
  --max-train-steps 300000 \
  --checkpointing-steps 50000 \
  --warmup-steps 100000 \
  --use_rope \
  --data-dir=[YOUR_DATA_PATH]

Resume Our warmed up checkpoint

accelerate launch train.py \
  --report-to="wandb" \
  --allow-tf32 \
  --mixed-precision="fp16" \
  --seed=0 \
  --path-type="linear" \
  --prediction="v" \
  --weighting="uniform" \
  --model="SiT-XL/2" \
  --enc-type="dinov2-vit-b" \
  --encoder-depth=14 \
  --output-dir="exps" \
  --exp-name="erw-linear-dinov2-b-enc14" \
  --max-train-steps 300000 \
  --checkpointing-steps 50000 \
  --warmup-steps 100000 \
  --resume-step 100000 \
  --use_rope \
  --data-dir=[YOUR_DATA_PATH]

Then this script will automatically create the folder in exps to save logs and checkpoints. You can adjust the following options:

  • --models: [SiT-B/2, SiT-L/2, SiT-XL/2]
  • --enc-type: [dinov2-vit-b, dinov2-vit-l, dinov2-vit-g, dinov1-vit-b, mocov3-vit-b, , mocov3-vit-l, clip-vit-L, jepa-vit-h, mae-vit-l]
  • --encoder-depth: Any values between 1 to the depth of the model
  • --output-dir: Any directory that you want to save checkpoints and logs
  • --exp-name: Any string name (the folder will be created under output-dir)

For DINOv2 models, it will be automatically downloaded from torch.hub. For CLIP models, it will be also automatically downloaded from the CLIP repository. For other pretrained visual encoders, please download the model weights from the below links and place into the following directories with these names:

  • mocov3: Download the ViT-B/16 or ViT-L/16 model from the RCG repository and place them as ./ckpts/mocov3_vitb.pth or ./ckpts/mocov3_vitl.pth

4. Evaluation

You can generate images (and the .npz file can be used for ADM evaluation suite) through the following script:

torchrun --nnodes=1 --nproc_per_node=8 inference.py \
  --model SiT-XL/2 \
  --num-fid-samples 50000 \
  --ckpt YOUR_CHECKPOINT_PATH \
  --path-type=linear \
  --encoder-depth=14 \
  --projector-embed-dims=768 \
  --per-proc-batch-size=64 \
  --mode=sde \
  --num-steps=250 \
  --cfg-scale=1.0 \
  --use_rope \
  --fid-reference YOUR_VIRTUAL_imagenet256_labeled.npz_PATH

w/ CFG

torchrun --nnodes=1 --nproc_per_node=8 inference.py \
  --model SiT-XL/2 \
  --num-fid-samples 50000 \
  --ckpt YOUR_CHECKPOINT_PATH \
  --path-type=linear \
  --encoder-depth=14 \
  --projector-embed-dims=768 \
  --per-proc-batch-size=64 \
  --mode=sde \
  --num-steps=250 \
  --cfg-scale=2.2 \
  --guidance-high=0.95 \
  --use_rope \
  --fid-reference YOUR_VIRTUAL_imagenet256_labeled.npz_PATH

We also provide the SiT-XL/2 checkpoint (trained for 4M iterations) used in the final evaluation. It will be automatically downloaded if you do not specify --ckpt.

4. CKNNA

tools/extract_fecture.py

Note

It's possible that this code may not accurately replicate the results outlined in the paper due to potential human errors during the preparation and cleaning of the code for release. If you encounter any difficulties in reproducing our findings, please don't hesitate to inform us. Additionally, we'll make an effort to carry out sanity-check experiments in the near future.

Acknowledgement

This code is mainly built upon REPA, LightningDiT, DiT, SiT, edm2, and RCG repositories.

BibTeX

@misc {liu2025efficientgenerativemodeltraining,
      title={Efficient Generative Model Training via Embedded Representation Warmup}, 
      author={Deyuan Liu and Peng Sun and Xufeng Li and Tao Lin},
      year={2025},
      eprint={2504.10188},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2504.10188}, 
}

About

[Preprint] Efficient Generative Model Training via Embedded Representation Warmup

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages