-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLama 405B Tweaks #1298
base: mlperf/5.0
Are you sure you want to change the base?
LLama 405B Tweaks #1298
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,7 +67,7 @@ def create_orbax_checkpoint_manager( | |
p.mkdir(exist_ok=True, parents=True) | ||
# we need to use ocdbt and zarr3 to control max file size in the checkpoint | ||
# omitting `iter` uses default handler for `iter` | ||
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)} | ||
item_handlers = {"items": PyTreeCheckpointHandler(save_concurrent_gb=500, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)} | ||
mngr = CheckpointManager( | ||
p, | ||
item_names=item_names, | ||
|
@@ -221,6 +221,7 @@ def map_to_pspec(data): | |
single_replica_sharding=single_replica_sharding, | ||
global_shape=data.shape, | ||
dtype=data.dtype, | ||
strict=False | ||
) | ||
|
||
if enable_single_replica_ckpt_restoring: | ||
|
@@ -320,12 +321,20 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params): | |
assert load_parameters_from_path, "load_parameters_from_path is not defined." | ||
max_logging.log(f"restoring params from {load_parameters_from_path}") | ||
ckpt = epath.Path(load_parameters_from_path) | ||
ckptr = ocp.PyTreeCheckpointer() | ||
# ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(restore_concurrent_gb=500, save_concurrent_gb=500)) | ||
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) | ||
# This is a memory optimization. We don't want to restore the entire checkpoint - only the params. | ||
# Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste | ||
# memory, we instead specify here that we are just restoring the params field of the checkpoint | ||
# (which itself may be a dictionary containing a key named 'params'). | ||
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params) | ||
def update_restore_args(restore_args): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is Orbax's way of allowing truncation of the parameters when restoring a ckpt. This is needed as NV used only a 32K vocab size, as opposed to the 128K. The truncation that Orbax does when strict mode is disabled is like A[:32000, :] I made the entire PyTree not strict but you could probably just do the tokenizer embedding layer |
||
for value in restore_args.values(): | ||
if type(value) == ocp._src.serialization.type_handlers.ArrayRestoreArgs: | ||
value.strict = False | ||
elif type(value) == dict: | ||
update_restore_args(value) | ||
update_restore_args(restore_args) | ||
restored = ckptr.restore( | ||
ckpt, item={"params": abstract_unboxed_params}, transforms={}, restore_args={"params": restore_args} | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -388,10 +388,10 @@ skip_jax_distributed_system: False # If True we will not initialize the jax dist | |
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps | ||
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. | ||
# The zero learning rate section can be used to more accurately measure the fully trained model's performance. | ||
learning_rate: 3.e-5 | ||
learning_rate: 8.e-5 | ||
cosine_learning_rate_final_fraction: 0.1 | ||
warmup_steps_fraction: 0.1 | ||
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. | ||
warmup_steps_fraction: 0.0067 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just jibberish. We should compute this dynamically in pyconfig |
||
learning_rate_schedule_steps: 2400000 # By default the length of the schedule is set to the number of steps. | ||
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before | ||
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -203,7 +203,7 @@ def save_checkpoint( | |
) -> bool: | ||
"""Wrapper for saving checkpoint.""" | ||
if config and config.enable_checkpointing: | ||
if (step % config.checkpoint_period == 0) or ( | ||
if (step % config.checkpoint_period == 0 and step != 0) or ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make sure NOT to save a checkpoint during our real runs. GPT3 had some hackery in its ckpt to start at a non-zero ckpt. Also set the |
||
config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0 | ||
): | ||
blocking_until_ready_start = time.time() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might need something like this if you see an error like
ValueError: Requested more bytes than we reserved space for: 109924319232 > 96000000000
. You need to increase the restore limit.The sharded ckpt I linked to in our internal docs should resolve the need for this though