Skip to content

📝 vLLM integration doc #3358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
140 changes: 132 additions & 8 deletions docs/source/vllm_integration.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,115 @@
# vLLM Integration
If you are here to learn about how to use vLLM with TRL, you are in the right place. This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summerize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥

<Tip warning={true}>
# 🫠🚀 How can I use vLLM with TRL to make things go faster? TL;DR
First run the server by; (this example allocate 4 GPUs for vLLM generation)
```sh
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
```
Then, run the training script by passing `use_vllm=True` in the training arguments (this example allocate 4 GPUs for training) by;

```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```

Sample of a simple `train.py` script:

```python
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig

dataset = load_dataset("trl-lib/tldr", split="train")


# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]


training_args = GRPOConfig(
output_dir="my_test",
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
)

trainer = GRPOTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```

# 🎬 Flashback: why do we need to use vLLM in online methods?
Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large models. In the default setup (without vLLM), completions are generated using the [(unwrapped)model's `generate` method here](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods.

# 🤔 How does vLLM solve the slow generation issue?
If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate next tokens (Q) based on them. These cached key and value tensors are often referred to as KV cache. However, this storing is really a pain as it occupies a lot of memory. So here is the secret sauce of vLLM, it uses a technique called PagedAttention to solve this problem. PagedAttention , which is inspired by the OS’s virtual memory concept stores continuous keys and values in **non-contiguous memory space** which is way more efficient. The detail of this is beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details.


# ⚙️ How to use vLLM in practice for generation in online methods in TRL?

1. To use [vLLM](https://github.com/vllm-project/vllm), first install it using:

```bash
pip install vllm
```

Section under construction. Feel free to contribute!
or

</Tip>
```bash
pip install "trl[vllm]"
```

<hfoptions id="vllm examples">
<hfoption id="Online DPO">

</hfoption>
<hfoption id="GRPO">

2. Then, **start a vLLM server** by running:

```bash
trl vllm-serve --model <model_name>
```
# 🔎 What exactly happens when you run `trl vllm-serve --model <model_name>`?
when you run for example `trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4`, the following happens:
![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png)
1. When you run a command like trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4, vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the --tensor-parallel-size and --data-parallel-size values. In this example, it spawns 4 workers (1 × 4).
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.

2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.

3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself.
Separately, the number of completions to generate per prompt is controlled by the num_generations setting in the GRPO config. For instance, if you set num_generations=2 (like the picture above), each prompt will have 2 completions. So, with 8 prompts and num_generations=2, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.

4. **🔬 How it works in practice in trl/grpo?**
- The vLLM server starts by running the command: trl vllm-serve --model Qwen/Qwen2.5-7B.
- Once the server is running, it generates completions based on requests from the client (trainer) using vllm_client.generate here.
- The client (trainer) then requests these completions from the server.
- These completions are used to compute the reward signal.
- Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights.
- **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.

## TRL vLLM server
## 📝 Important vLLM Notes:
When using vLLM, ensure the gpus assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation for training using `CUDA_VISIBLE_DEVICES`. See the example below;

TRL provides a way to speedup generation using a dedicated vLLM server.
- **Set GPUs **0-3** for vLLM generation:** Assume `CUDA_VISIBLE_DEVICES=0,1,2,3` are allocated for vLLM generation.
```sh
trl vllm-serve --model <model_name> --tensor-parallel-size 1 --data-parallel-size 4
```

- **And GPUs **4-7** for training:** If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default which may lead to resource conflicts. To avoid this, you can set the `CUDA_VISIBLE_DEVICES` environment variable to specify which GPUs to use for training. For example, if you want to use GPUs 4-7 for training, you can set the `CUDA_VISIBLE_DEVICES` environment variable as follows:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```

## 🍷 More customization options with vLLM?
You can customize the server configuration by passing additional arguments.

```sh
$ trl vllm-serve --help
Expand Down Expand Up @@ -40,8 +141,31 @@ options:
feature. (default: None)
```

### Find the best distributed setup
## 🥸 Okay, now that we have the server running, how can we use it to generate completions?

Then, run the training script and pass `use_vllm=True` in the training arguments.

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_vllm=True)
```

# 💆🏻‍♀️ Find the best distributed setup?
First and foremost is that you should always remember that the optimal setup depends on;
- the model size
- the number of GPUs you have
- the GPU memory size
- the batch size you are using
- the number of requests you are sending to the server (prompts)
- the max_model_len you are using (this is the max length of the input sequence that the model can process aka the context_window size)
- the number of completions you are generating for each request (num_generations)
![8gpu](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png)
![4gpu](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png)
Now given our experiments on Qwen model family (3B, 7B, 14B, 32B), on 8 H100 GPUs, showes that:
- For reasonable (3-14B) at the same time a reasonable context window (max_len<8k) if we use the full capacity for DP, we get better result in terms of throughput, (tp=1, dp=8) is the best setup.
- For larger models (32B) and larger context window (max_len>8k), we need to use a smaller DP size, along with some form of parallalism on the model side. For example, (tp=2, dp=4) is a good setup for 32B models with larger context window.

![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png)

![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png)
</hfoption>
</hfoptions>
2 changes: 0 additions & 2 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def test_sft(self):
packing=True,
max_length=256,
dataset_num_proc=4,
dataset_batch_size=512,
neftune_noise_alpha=0.1,
model_init_kwargs={"trust_remote_code": True},
dataset_kwargs={"append_concat_token": True, "skip_prepare_dataset": True},
Expand All @@ -381,7 +380,6 @@ def test_sft(self):
self.assertEqual(trainer.args.packing, True)
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.dataset_batch_size, 512)
self.assertEqual(trainer.args.neftune_noise_alpha, 0.1)
self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
self.assertIn("append_concat_token", trainer.args.dataset_kwargs)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,38 @@ def test_pad_2_dim_right_multidim(self):
)
self.assertTrue(torch.equal(output, expected))

def test_pad_to_multiple_of_1(self):
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5])
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])
self.assertTrue(torch.equal(output, expected))

def test_pad_to_multiple_of_2(self):
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([6, 7, 8])
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]])
self.assertTrue(torch.equal(output, expected))

def test_pad_to_multiple_of_side_left(self):
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([6, 7, 8])
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]])
self.assertTrue(torch.equal(output, expected))

def test_pad_to_multiple_of_no_extra_padding(self):
x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([5, 6, 7, 8])
# Already multiple of 4
output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
self.assertTrue(torch.equal(output, expected))


@require_peft
class TestGetPEFTConfig(unittest.TestCase):
Expand Down
18 changes: 12 additions & 6 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional, Union

from accelerate.utils import is_deepspeed_available
from packaging import version
from transformers import PreTrainedModel, PreTrainedTokenizer

Expand All @@ -30,12 +29,10 @@
AutoModelForSeq2SeqLMWithValueHead,
)

if is_deepspeed_available():
import deepspeed

if TYPE_CHECKING:
from accelerate import Accelerator
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel


Expand Down Expand Up @@ -167,6 +164,8 @@ def iter_params(module, recurse=False):

def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
import deepspeed

if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
Expand Down Expand Up @@ -214,6 +213,8 @@ def unwrap_model_for_generation(
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
else:
import deepspeed

with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
Expand All @@ -222,8 +223,13 @@ def unwrap_model_for_generation(
yield unwrapped_model


def prepare_deepspeed(model, accelerator):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
"""Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.

Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
"""
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252

deepspeed_plugin = accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
stage = config_kwargs["zero_optimization"]["stage"]
Expand Down
41 changes: 3 additions & 38 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union

Expand All @@ -32,7 +31,7 @@
import transformers
from accelerate import PartialState
from accelerate.logging import get_logger
from accelerate.utils import is_deepspeed_available, tqdm
from accelerate.utils import tqdm
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader, SequentialSampler
Expand All @@ -56,7 +55,7 @@

from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_joblib_available
from ..models import PreTrainedModelWrapper, create_reference_model
from ..models import create_reference_model, prepare_deepspeed
from .bco_config import BCOConfig
from .utils import (
DPODataCollatorWithPadding,
Expand All @@ -83,9 +82,6 @@
if is_joblib_available():
import joblib

if is_deepspeed_available():
import deepspeed

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer

Expand Down Expand Up @@ -712,7 +708,7 @@ def make_inputs_require_grad(module, input, output):
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

Expand Down Expand Up @@ -846,37 +842,6 @@ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512

return all_embeddings

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)

if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)

# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model

def _save_optimizer_and_scheduler(self, output_dir):
output_dir = output_dir if output_dir is not None else self.args.output_dir
super()._save_optimizer_and_scheduler(output_dir)
Expand Down
Loading
Loading