Skip to content

GRPO: Scalable training with one LLM/node #3186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
47 changes: 24 additions & 23 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,39 +163,39 @@ For more information, see [Speeding up training with vLLM](speeding_up_training#

When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:

- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
- **DeepSpeed ZeRO Stage 3** or **Fully sharded data parallel (FSDP)**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
- **vLLM**: See the previous section on how to use vLLM to speed up generation. For scalable generation, deploy **one vLLM process per node**

Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 8th node GPU in every node for vLLM-powered generation.

```sh
#!/bin/bash
#SBATCH --nodes=5
#SBATCH --nodes=4
#SBATCH --gres=gpu:8
#SBATCH --tasks-per-node=1

# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
export JOB_MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)

# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
# Assign the first 7 GPUs on every node to training, and the 8th GPU to inference
cmd=$(tr -d "\n" << EOF
accelerate launch
--mixed_precision=bf16
--num_machines=${SLURM_NNODES}
--num_processes=$((${SLURM_NNODES}*7))
--machine_rank=\${SLURM_NODEID}
--main_process_ip=${JOB_MASTER_ADDR}
--main_process_port=29500
--gpu_ids=0,1,2,3,4,5,6
train_grpo.py --vllm_server_host=127.0.0.1 <other script args> &

# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
--num_processes 32 \
--num_machines 4 \
--main_process_ip ${NODELIST[0]} \
--machine_rank $SLURM_PROCID \
--rdzv_backend c10d \
train_grpo.py \
--server_ip $VLLM_NODE &

# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
VLLM_PORT=29501 CUDA_VISIBLE_DEVICES=7 trl vllm-serve --model ${MODEL} --host=127.0.0.1 &

wait
EOF
)

srun bash -c "${cmd}"
```

```python
Expand All @@ -217,13 +217,14 @@ def main():
return [len(set(c)) for c in completions]

training_args = GRPOConfig(
# <add your fsdp or deepspeed config here>
output_dir="Qwen2.5-72B-GRPO",
per_device_train_batch_size=4,
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
use_vllm=True,
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
vllm_server_host=args.vllm_server_host,
)

trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
Expand Down
3 changes: 1 addition & 2 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def init_communicator(self):

# Initialize weight update group
url = f"http://{self.host}:{self.server_port}/init_communicator/"
# In the server side, the host is set to 0.0.0.0
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
response = self.session.post(url, json={"host": self.host, "port": self.group_port, "world_size": world_size})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

Expand Down
16 changes: 13 additions & 3 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from trl import TrlParser
from trl.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available, is_vllm_available


if is_fastapi_available():
from fastapi import BackgroundTasks, FastAPI

Expand Down Expand Up @@ -176,6 +175,8 @@ class ScriptArguments:
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
this feature.
timeout_keep_alive (`int`, defaults to 60s):
How long to keep the http connection open when no data exchange is happening
"""

model: str = field(metadata={"help": "Model name or path to load the model from."})
Expand Down Expand Up @@ -226,6 +227,12 @@ class ScriptArguments:
"hardware support this feature."
},
)
timeout_keep_alive: Optional[int] = field(
default=60,
metadata={
"help": "TCP connection timeout without data transfer."
},
)


def main(script_args: ScriptArguments):
Expand Down Expand Up @@ -342,7 +349,7 @@ async def generate(request: GenerateRequest):
max_tokens=request.max_tokens,
guided_decoding=guided_decoding,
)
all_outputs = llm.generate(request.prompts, sampling_params=sampling_params)
all_outputs = llm.generate(request.prompts, sampling_params=sampling_params, use_tqdm=False)
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
return {"completion_ids": completion_ids}

Expand Down Expand Up @@ -416,7 +423,10 @@ async def close_communicator():
return {"message": "Request received, closing communicator"}

# Start the server
uvicorn.run(app, host=script_args.host, port=script_args.port)
uvicorn.run(app,
host=script_args.host,
port=script_args.port,
timeout_keep_alive=script_args.timeout_keep_alive)

dist.destroy_process_group()

Expand Down
121 changes: 96 additions & 25 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers.utils import is_liger_kernel_available, is_peft_available

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
Expand All @@ -57,7 +59,6 @@
selective_log_softmax,
)


if is_deepspeed_available():
import deepspeed

Expand Down Expand Up @@ -519,7 +520,7 @@ def data_collator(features): # No data collation is needed in GRPO
"`pip install vllm` to use it."
)

if self.accelerator.is_main_process:
if self.accelerator.is_local_main_process:
self.vllm_client = VLLMClient(
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout
)
Expand Down Expand Up @@ -569,6 +570,17 @@ def data_collator(features): # No data collation is needed in GRPO
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)

# compute the local world size
if self.accelerator.is_local_main_process:
one = [1]
else:
one = []
ones = gather_object(one)
self.num_local_processes = self.accelerator.num_processes // sum(ones)

# create the intra-node commmunicator
self.intra_node_group, _ = torch.distributed.new_subgroups(self.num_local_processes)

def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
Expand Down Expand Up @@ -680,6 +692,41 @@ def _move_model_to_vllm(self):
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext

def post_order_fsdp_processing(module: nn.Module, prefix: str = "", visited=None):
""" memory-efficient module gather """
extra_names = ['_fsdp_wrapped_module.', '_checkpoint_wrapped_module.']
if visited is None:
visited = set()

for child_name, child_module in module.named_children():
if prefix == "":
child_prefix = child_name
else:
child_prefix = f"{prefix}.{child_name}"

# Recurse into the child
post_order_fsdp_processing(child_module, prefix=child_prefix, visited=visited)

if isinstance(module, FSDP):
with FSDP.summon_full_params(module, recurse=False, writeback=False):
for param_name, param in module.named_parameters():
if prefix == "":
full_name = param_name
else:
full_name = f"{prefix}.{param_name}"

for extra in extra_names:
full_name = full_name.replace(extra, '')

if full_name in visited:
# skip FSDP subtrees already traversed
continue

visited.add(full_name)

if self.accelerator.is_local_main_process:
self.vllm_client.update_named_param(full_name, param.data)

if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
Expand All @@ -697,21 +744,24 @@ def _move_model_to_vllm(self):
continue
name = name.replace("modules_to_save.default.", "")

if self.accelerator.is_main_process:
if self.accelerator.is_local_main_process:
self.vllm_client.update_named_param(name, param.data)

# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
if self.is_fsdp_enabled:
post_order_fsdp_processing(self.model_wrapped)
else:
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_local_main_process:
self.vllm_client.update_named_param(name, param.data)

# Reset cache on main process
if self.accelerator.is_main_process:
if self.accelerator.is_local_main_process:
self.vllm_client.reset_prefix_cache()

@profiling_decorator
Expand Down Expand Up @@ -748,24 +798,40 @@ def _generate_and_score_completions(
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]

def gather_object_on_main_process(obj: Any, group=torch.distributed.group.WORLD):
output_objects = [None for _ in range(torch.distributed.get_world_size(group))]
is_main_process = torch.distributed.get_rank(group) == 0
torch.distributed.gather_object(obj,
output_objects if is_main_process else None,
group_dst=0, group=group)
# gather_object returns a list of lists, so we need to flatten it
return [x for y in output_objects for x in y] if is_main_process else None

def scatter_objects_from_main_process(obj_list: list[Any], group=torch.distributed.group.WORLD):
output_objects = [None]
torch.distributed.scatter_object_list(output_objects, obj_list, group_src=0, group=group)
return output_objects[0]

# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
# Generate completions using vLLM: gather all prompts of the current node
# and use them in a single call
all_prompts_text = gather_object_on_main_process(prompts_text, self.intra_node_group)

if self.accelerator.is_local_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
# TODO: not currently implemented with local main proceses

with profiling_context(self, "vLLM.generate"):
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
n=self.num_generations,
all_prompts_text,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
Expand All @@ -774,16 +840,21 @@ def _generate_and_score_completions(
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
)

# Ensure each process receives its corresponding slice.
def chunk_list(my_list, chunk_size):
chunks = []
for i in range(0, len(my_list), chunk_size):
chunks.append(my_list[i:i + chunk_size])
return chunks

if self.accelerator.is_local_main_process:
process_slices = chunk_list(completion_ids, len(prompts))
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
process_slices = [None] * self.num_local_processes

# scatter back to local processes
completion_ids = scatter_objects_from_main_process(process_slices, self.intra_node_group)

# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
Expand Down Expand Up @@ -960,8 +1031,8 @@ def _generate_and_score_completions(
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())

# Log prompt and completion texts
self._textual_logs["prompt"].extend(gather_object(prompts_text))
self._textual_logs["completion"].extend(gather_object(completions_text))
self._textual_logs["prompt"].extend(gather_object_on_main_process(prompts_text))
self._textual_logs["completion"].extend(gather_object_on_main_process(completions_text))
for i, name in enumerate(reward_func_names):
self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())

Expand Down