Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous RLHF: Faster and More Efficient Online DPO #2278

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

mnoukhov
Copy link
Contributor

@mnoukhov mnoukhov commented Oct 24, 2024

This implements a proposed faster and more efficient paradigm: asynchronous RLHF. See the paper: https://arxiv.org/abs/2410.18252

reasoning

vllm inference is generally faster than hf generate. We want to separate the generation and training so we can use fast generation libraries / utilities

this proposes a simple first solution: run training on n gpus and generation on 1 gpu. This can be extended to run generation on more GPUs but,in practice for >=8 GPU setups with 8B models and less, 1 gpu for generation tends to be fine.

setup

We create an asynchronous trainer for Online DPO that uses vllm for generation. The generation GPU has vllm started on a separate python thread and communication between training and generation is via Queues. The training looks something like this:

  1. the training thread gets the batch of prompts
  2. send data
    a. training thread sends batch of prompts and current model weights to generation
    b. generation thread sends previous prompts with generated completions to training
  3. parallel training and generation
    a. training thread calculates reward, then trains on previous prompts and completions, updates weights
    b. generation thread generates completions to each prompt
  4. back to step 2
    ...
  5. at the end of training the training thread sends None for both prompts and parameters so the generation thread closes itself

Example usage for a 4 GPU setup is

accelerate launch --num_processes 3 examples/scripts/dpo_online_async.py \
    --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft  \
    --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
    --dataset_name trl-lib/tldr \
    --learning_rate 5.0e-7 \
    --output_dir pythia-1b-tldr-online-dpo \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 16

the generation GPU is by default the accelerate.num_processes + 1 GPU, so GPUs [1,2,3] are for training and GPU 4 is generation with vllm.

notes

  • online_dpo_trainer currently extends the regular huggingface Trainer and is limited to generating one minibatch of samples and training one step on those samples. As argued in the paper, more training steps on data (num_ppo_epochs) or generating more minibatches (num_mini_batches) can be useful. For this reason, AsyncOnlineDPOTrainer follows the style of RLOOTrainer.

  • to test functionality, I've added SyncOnlineDPOTrainer that has the exact same structure as Async but is synchronous and uses hf generate. I can remove it for the final submission,

Before submitting / To Do

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
    - [ ] Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines.
  • Did you write any new necessary tests?

I am working on docs. Do we want tests similar to RLOOTrainer ?

Who can review?

@qgallouedec @lewtun anyone else! comments welcome

@seanexp
Copy link
Contributor

seanexp commented Oct 25, 2024

What is the primary difference between this PR and #1628 ?

@mnoukhov
Copy link
Contributor Author

This is an updated and multi-gpu extension of #1628. It is also work between @vwxyzjn and I!

Instead of keeping vllm models on the same GPU, we move them to another. It also uses the more flexible vllm_utils.py written by @vwxyzjn in allenai/open_instruct (https://github.com/allenai/open-instruct/blob/main/open_instruct/vllm_utils.py) which allows using any version of vllm as opposed to the fixed 0.4.2 from #1628.

Finally, this has been tested and verified to match regular Online DPO performance while being faster and more efficient, see our new preprint https://arxiv.org/abs/2410.18252

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +17 to +46
"""This file basically allows us to place vLLM's driver worker in a specified
GPU. For example. you can try the following.
```python
from transformers import AutoTokenizer
from vllm import SamplingParams
from open_instruct.vllm_utils import SingleGPULLM
tok = AutoTokenizer.from_pretrained("facebook/opt-125m")
tok.chat_template = (
"{% for message in messages %}"
"{{'\n\n' if not loop.first else ''}}"
"{{message['role']|capitalize + ': ' +message['content']}}"
"{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}"
"{% endfor %}"
)
prompts = [
{"role": "user", "content": "Compose a speech about the need for more affordable dental care."},
]
prompt_ids = tok.apply_chat_template(prompts, add_generation_prompt=True)
sampling_params = SamplingParams(temperature=0.001, top_p=1.0, max_tokens=1024, include_stop_str_in_output=True)
llm = SingleGPULLM(model="facebook/opt-125m", tensor_parallel_size=1, device="cuda:1")
llmp = llm.llm_engine.model_executor.driver_worker.model_runner.model
print(f"🔥🔥🔥 vllm lives in {llmp.lm_head.weight.device}")
print("prepare to generate")
outputs = llm.generate(prompt_token_ids=[prompt_ids], sampling_params=sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mnoukhov

Could you kindly update the docstring? I think SingleGPULLM should be replaced with vllm_single_gpu_patch and LLM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants