-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLama 405B Tweaks #1298
base: mlperf/5.0
Are you sure you want to change the base?
LLama 405B Tweaks #1298
Conversation
@@ -320,12 +321,20 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params): | |||
assert load_parameters_from_path, "load_parameters_from_path is not defined." | |||
max_logging.log(f"restoring params from {load_parameters_from_path}") | |||
ckpt = epath.Path(load_parameters_from_path) | |||
ckptr = ocp.PyTreeCheckpointer() | |||
# ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(restore_concurrent_gb=500, save_concurrent_gb=500)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might need something like this if you see an error like ValueError: Requested more bytes than we reserved space for: 109924319232 > 96000000000
. You need to increase the restore limit.
The sharded ckpt I linked to in our internal docs should resolve the need for this though
# This is a memory optimization. We don't want to restore the entire checkpoint - only the params. | ||
# Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste | ||
# memory, we instead specify here that we are just restoring the params field of the checkpoint | ||
# (which itself may be a dictionary containing a key named 'params'). | ||
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params) | ||
def update_restore_args(restore_args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is Orbax's way of allowing truncation of the parameters when restoring a ckpt. This is needed as NV used only a 32K vocab size, as opposed to the 128K. The truncation that Orbax does when strict mode is disabled is like A[:32000, :]
I made the entire PyTree not strict but you could probably just do the tokenizer embedding layer
cosine_learning_rate_final_fraction: 0.1 | ||
warmup_steps_fraction: 0.1 | ||
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. | ||
warmup_steps_fraction: 0.0067 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just jibberish. We should compute this dynamically in pyconfig
@@ -203,7 +203,7 @@ def save_checkpoint( | |||
) -> bool: | |||
"""Wrapper for saving checkpoint.""" | |||
if config and config.enable_checkpointing: | |||
if (step % config.checkpoint_period == 0) or ( | |||
if (step % config.checkpoint_period == 0 and step != 0) or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make sure NOT to save a checkpoint during our real runs. GPT3 had some hackery in its ckpt to start at a non-zero ckpt.
Also set the checkpoint_period
to like 10000 or something huge as well
Just a few changes I made while converting and experimenting with Lllama 3.1 405B checkpoint