Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLama 405B Tweaks #1298

Draft
wants to merge 2 commits into
base: mlperf/5.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create_orbax_checkpoint_manager(
p.mkdir(exist_ok=True, parents=True)
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
# omitting `iter` uses default handler for `iter`
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}
item_handlers = {"items": PyTreeCheckpointHandler(save_concurrent_gb=500, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3)}
mngr = CheckpointManager(
p,
item_names=item_names,
Expand Down Expand Up @@ -221,6 +221,7 @@ def map_to_pspec(data):
single_replica_sharding=single_replica_sharding,
global_shape=data.shape,
dtype=data.dtype,
strict=False
)

if enable_single_replica_ckpt_restoring:
Expand Down Expand Up @@ -320,12 +321,20 @@ def load_params_from_path(load_parameters_from_path, abstract_unboxed_params):
assert load_parameters_from_path, "load_parameters_from_path is not defined."
max_logging.log(f"restoring params from {load_parameters_from_path}")
ckpt = epath.Path(load_parameters_from_path)
ckptr = ocp.PyTreeCheckpointer()
# ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(restore_concurrent_gb=500, save_concurrent_gb=500))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need something like this if you see an error like ValueError: Requested more bytes than we reserved space for: 109924319232 > 96000000000. You need to increase the restore limit.

The sharded ckpt I linked to in our internal docs should resolve the need for this though

ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
# This is a memory optimization. We don't want to restore the entire checkpoint - only the params.
# Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste
# memory, we instead specify here that we are just restoring the params field of the checkpoint
# (which itself may be a dictionary containing a key named 'params').
restore_args = ocp.checkpoint_utils.construct_restore_args(abstract_unboxed_params)
def update_restore_args(restore_args):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is Orbax's way of allowing truncation of the parameters when restoring a ckpt. This is needed as NV used only a 32K vocab size, as opposed to the 128K. The truncation that Orbax does when strict mode is disabled is like A[:32000, :]

I made the entire PyTree not strict but you could probably just do the tokenizer embedding layer

for value in restore_args.values():
if type(value) == ocp._src.serialization.type_handlers.ArrayRestoreArgs:
value.strict = False
elif type(value) == dict:
update_restore_args(value)
update_restore_args(restore_args)
restored = ckptr.restore(
ckpt, item={"params": abstract_unboxed_params}, transforms={}, restore_args={"params": restore_args}
)
Expand Down
6 changes: 3 additions & 3 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ skip_jax_distributed_system: False # If True we will not initialize the jax dist
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
# The zero learning rate section can be used to more accurately measure the fully trained model's performance.
learning_rate: 3.e-5
learning_rate: 8.e-5
cosine_learning_rate_final_fraction: 0.1
warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
warmup_steps_fraction: 0.0067
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just jibberish. We should compute this dynamically in pyconfig

learning_rate_schedule_steps: 2400000 # By default the length of the schedule is set to the number of steps.
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/models/llama3.1-405b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ base_num_decoder_layers: 126
base_mlp_dim: 53248
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 128256
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
Expand Down
2 changes: 1 addition & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def save_checkpoint(
) -> bool:
"""Wrapper for saving checkpoint."""
if config and config.enable_checkpointing:
if (step % config.checkpoint_period == 0) or (
if (step % config.checkpoint_period == 0 and step != 0) or (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure NOT to save a checkpoint during our real runs. GPT3 had some hackery in its ckpt to start at a non-zero ckpt.

Also set the checkpoint_period to like 10000 or something huge as well

config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0
):
blocking_until_ready_start = time.time()
Expand Down
Loading