Skip to content

HKUNLP/Dream

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Dream 7B

Static Badge Static Badge Static Badge Static Badge

Dream is a 7B diffusion large language model that achieves competitive performance comparable to leading autoregressive models with a similar size.

News

  • [2025-04-05]: Dream checkpoints and inference code are released.
  • [2025-04-02]: Dream blog is released.

Installation

Our implementation of Dream is based on the Huggingface transformers library. You should first install transformers by pip install transformers==4.46.2 and torch==2.5.1 as Dream uses the SdpaAttention built in torch. Other versions of transformers and torch are not been fully tested.

Run the model requires a GPU with at least 20GB memory.

Thanks Apolinário for providing the online demo at https://huggingface.co/spaces/multimodalart/Dream.

Usage

We provide several demos to show the inference code of Dream. A simple implementation is:

import torch
from transformers import AutoModel, AutoTokenizer

model_path = "Dream-org/Dream-v0-Instruct-7B"
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to("cuda").eval()

messages = [
    {"role": "user", "content": "Please write a Python class that implements a PyTorch trainer capable of training a model on a toy dataset."}
]
inputs = tokenizer.apply_chat_template(
    messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
)
input_ids = inputs.input_ids.to(device="cuda")
attention_mask = inputs.attention_mask.to(device="cuda")

output = model.diffusion_generate(
    input_ids,
    attention_mask=attention_mask,
    max_new_tokens=512,
    output_history=True,
    return_dict_in_generate=True,
    steps=512,
    temperature=0.2,
    top_p=0.95,
    alg="entropy",
    alg_temp=0.,
)
generations = [
    tokenizer.decode(g[len(p) :].tolist())
    for p, g in zip(input_ids, output.sequences)
]

print(generations[0].split(tokenizer.eos_token)[0])

Parameters of diffusion_generate()

model.diffusion_generate() supports a subset of arguments in model.generate() and some diffusion-specific arguments:

  • input_ids: The input token ids.
  • attention_mask: The attention mask when performing batch inference.
  • max_new_tokens: The maximum tokens to generate. Note that the context length (input+output) of Dream currently is 2048.
  • output_history: Whether to return the output at each intermediate step.
  • return_dict_in_generate: The output format, mostly set to True.
  • steps: The diffusion timesteps. max_new_tokens/steps tokens will be generated at each step. Fewer steps yield faster but coarser results.
  • temperature: The value used to module the next token probabilities. By default 0.0. The smaller the value, the more accurate the results (e.g., in math or coding). The larger the value, the more diverse the results (e.g., in general conversation). If you notice repeated results, you might consider increasing the temperature.
  • top_p: If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. By default None. Control the diversity of generation.
  • top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. By default None. Control the diversity of generation.
  • alg: The remasking strategy in diffusion sampling, controlling the token generation order. Support one random strategy and three confidence-based strategies:
    • origin: Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006. The default strategy. Note this may degrade performance in some tasks.
    • maskgit_plus: Token will be generated based on the top1 confidence from https://arxiv.org/abs/2202.04200.
    • topk_margin: Token will be generated based on the margin confidence by taking top1 - top2 from https://arxiv.org/abs/2502.06768.
    • entropy: Token will be generated based on the entropy of each token distribution.
  • alg_temp: Add some randomness to alg when using confidence-based strategies. By default None.
  • generation_logits_hook_func: a hook that can be user-defined to control the logits at each intermediate step, e.g., do some guidance.
  • generation_tokens_hook_func: a hook that can be user-defined to control the tokens at each intermediate step, e.g., print, infill, or other token control strategies. See demo_token_control.py for reference.

Citation

@misc{dream2025,
    title = {Dream 7B},
    url = {https://hkunlp.github.io/blog/2025/dream},
    author = {Ye, Jiacheng and Xie, Zhihui and Zheng, Lin and Gao, Jiahui and Wu, Zirui and Jiang, Xin and Li, Zhenguo and Kong, Lingpeng},
    year = {2025}
}

Releases

No releases published

Packages

No packages published

Languages