Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-gpus training with accelerate #778

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

mshukor
Copy link
Collaborator

@mshukor mshukor commented Feb 26, 2025

What this does

This PR supports training on multiple gpus using the accelerate librarie

How it was tested

Launching training on aloha sim with multiple GPUs and obtaining similar scores.

Examples:
This requires installing accelerate:

pip install accelerate
POLICY=act

ENV=aloha
TASK=AlohaTransferCube-v0
REPO_ID=lerobot/aloha_sim_transfer_cube_human
DATASET_NAME=aloha_sim_transfer_cube_human

TASK_NAME=lerobot_${DATASET_NAME}_${POLICY}_gpus${GPUS}
TRAIN_DIR=$WORK/logs/lerobot/$TASK_NAME
echo $TRAIN_DIR

PORT=29502

GPUS=2
OFFLINE_STEPS=100000
EVAL_FREQ=1000
BATCH_SIZE=8
EVAL_BATCH_SIZE=10
SAVE_FREQ=10000

export MUJOCO_GL=egl

python -m accelerate.commands.launch --num_processes=$GPUS --mixed_precision=fp16 --main_process_port=$PORT lerobot/scripts/train.py \
     --policy.type=$POLICY  \
     --dataset.repo_id=$REPO_ID \
     --env.type=$ENV \
     --env.task=$TASK \
     --output_dir=$TRAIN_DIR \
     --batch_size=$BATCH_SIZE \
     --steps=$OFFLINE_STEPS \
     --eval_freq=$EVAL_FREQ --save_freq=$SAVE_FREQ --eval.batch_size=$EVAL_BATCH_SIZE --eval.n_episodes=$EVAL_BATCH_SIZE  

@mshukor mshukor marked this pull request as ready for review February 27, 2025 08:28
@mshukor mshukor requested review from aliberts and Cadene and removed request for aliberts February 27, 2025 08:28
@qgallouedec
Copy link
Member

@bot /style

Copy link

Style fixes have been applied. View the workflow run here.

@huggingface huggingface deleted a comment from qgallouedec Feb 27, 2025
Copy link
Collaborator

@aliberts aliberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round with comments, nice addition ;)

It'd be nice to add a quick tutorial in examples/ on how to do multi-gpu training.
Also, could you add accelerate as an extra to pyproject.toml?

# pyproject.toml

[project.optional-dependencies]
...
+ accelerate = [
+     "accelerate>=1.4.0",
+ ]

optimizer.step()
else:
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.

Comment on lines +109 to +115
if accelerator:
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
if accelerator:
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
if accelerator and has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
elif has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()

Comment on lines +129 to +131
if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this works as intended, are the metrics reported correct?
We should probably integrate accelerate's WandBTracker inside our WandBLogger instead.

Comment on lines +296 to +297
if accelerator:
accelerator.wait_for_everyone()
Copy link
Collaborator

@aliberts aliberts Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be added before checkpointing as well. Also it should probably go inside the if is_saving_step or if is_eval_step statements, otherwise this will be blocking at each step.

Comment on lines 281 to +291
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
save_checkpoint(
checkpoint_dir,
step,
cfg,
policy if not accelerator else accelerator.unwrap_model(policy),
optimizer,
lr_scheduler,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is fully equivalent to Accelerator.save_state. In particular, I don't think this still works with the training state (optimizer, scheduler, rng etc.)

Pointer: https://huggingface.co/docs/accelerate/v1.4.0/en/usage_guides/checkpoint

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants