Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLama 405B Tweaks #1298

Draft
wants to merge 2 commits into
base: mlperf/5.0
Choose a base branch
from
Draft

LLama 405B Tweaks #1298

wants to merge 2 commits into from

Conversation

anfals
Copy link
Collaborator

@anfals anfals commented Feb 22, 2025

Just a few changes I made while converting and experimenting with Lllama 3.1 405B checkpoint

@@ -320,12 +321,20 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params):
assert load_parameters_from_path, "load_parameters_from_path is not defined."
max_logging.log(f"restoring params from {load_parameters_from_path}")
ckpt = epath.Path(load_parameters_from_path)
ckptr = ocp.PyTreeCheckpointer()
# ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(restore_concurrent_gb=500, save_concurrent_gb=500))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need something like this if you see an error like ValueError: Requested more bytes than we reserved space for: 109924319232 > 96000000000. You need to increase the restore limit.

The sharded ckpt I linked to in our internal docs should resolve the need for this though

# This is a memory optimization. We don't want to restore the entire checkpoint - only the params.
# Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste
# memory, we instead specify here that we are just restoring the params field of the checkpoint
# (which itself may be a dictionary containing a key named 'params').
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params)
def update_restore_args(restore_args):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is Orbax's way of allowing truncation of the parameters when restoring a ckpt. This is needed as NV used only a 32K vocab size, as opposed to the 128K. The truncation that Orbax does when strict mode is disabled is like A[:32000, :]

I made the entire PyTree not strict but you could probably just do the tokenizer embedding layer

cosine_learning_rate_final_fraction: 0.1
warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
warmup_steps_fraction: 0.0067
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just jibberish. We should compute this dynamically in pyconfig

@@ -203,7 +203,7 @@ def save_checkpoint(
) -> bool:
"""Wrapper for saving checkpoint."""
if config and config.enable_checkpointing:
if (step % config.checkpoint_period == 0) or (
if (step % config.checkpoint_period == 0 and step != 0) or (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure NOT to save a checkpoint during our real runs. GPT3 had some hackery in its ckpt to start at a non-zero ckpt.

Also set the checkpoint_period to like 10000 or something huge as well

@anfals anfals changed the base branch from main to mlperf/5.0 February 22, 2025 00:41
@anfals anfals marked this pull request as draft February 22, 2025 00:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant