-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multi-gpus training with accelerate #778
base: main
Are you sure you want to change the base?
Support multi-gpus training with accelerate #778
Conversation
@bot /style |
Style fixes have been applied. View the workflow run here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round with comments, nice addition ;)
It'd be nice to add a quick tutorial in examples/
on how to do multi-gpu training.
Also, could you add accelerate as an extra to pyproject.toml
?
# pyproject.toml
[project.optional-dependencies]
...
+ accelerate = [
+ "accelerate>=1.4.0",
+ ]
optimizer.step() | ||
else: | ||
grad_scaler.scale(loss).backward() | ||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. | |
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. |
if accelerator: | ||
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): | ||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() | ||
else: | ||
if has_method(policy, "update"): | ||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). | ||
policy.update() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
if accelerator: | |
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): | |
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() | |
else: | |
if has_method(policy, "update"): | |
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). | |
policy.update() | |
if accelerator and has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): | |
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() | |
elif has_method(policy, "update"): | |
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). | |
policy.update() |
if accelerator and not accelerator.is_main_process: | ||
# Disable logging on non-main processes. | ||
cfg.wandb.enable = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this works as intended, are the metrics reported correct?
We should probably integrate accelerate's WandBTracker
inside our WandBLogger
instead.
if accelerator: | ||
accelerator.wait_for_everyone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be added before checkpointing as well. Also it should probably go inside the if is_saving_step
or if is_eval_step
statements, otherwise this will be blocking at each step.
if cfg.save_checkpoint and is_saving_step: | ||
logging.info(f"Checkpoint policy after step {step}") | ||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) | ||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) | ||
save_checkpoint( | ||
checkpoint_dir, | ||
step, | ||
cfg, | ||
policy if not accelerator else accelerator.unwrap_model(policy), | ||
optimizer, | ||
lr_scheduler, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this is fully equivalent to Accelerator.save_state
. In particular, I don't think this still works with the training state (optimizer, scheduler, rng etc.)
Pointer: https://huggingface.co/docs/accelerate/v1.4.0/en/usage_guides/checkpoint
What this does
This PR supports training on multiple gpus using the
accelerate
librarieHow it was tested
Launching training on aloha sim with multiple GPUs and obtaining similar scores.
Examples:
This requires installing accelerate: