Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Pin orbax-checkpoint to v0.10.3 to resolve dependency error (#1273) #1274

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

shota-inoue-lts
Copy link

Situation

I execute following content shell script to train model via TextMax with xpk.

# !/bin/bash
# GCP Settings
PROJECT=XXXXXXX
ZONE=XXXXXXX
CLUSTER=XXXXXXX
TPU_TYPE=v6e-8
NUM_SLICES=1

# Storage path
BASE_OUTPUT_DIR=XXXXXXX
DATASET_PATH=XXXXXXX
DATASET_TYPE=tfds

# HyperParameters
PER_DEVICE_BATCH_SIZE=3
MODEL_NAME=llama3.1-8b
MAX_TARGET_LENGTH=4096
STEPS=35
BLOCK_SIZE=2048
REMAT_POLICY=full
TOKENIZER_PATH=assets/tokenizer_llama3.tiktoken
VMEM_LIMIT=114688
ENABLE_CHECKPOINTING=true
CHECKPOINT_PERIOD=30

# Parallelism
ICI_DATA_PARALLELISM=1
ICI_PIPELINE_PARALLELISM=4
ICI_FSDP_PARALLELISM=1
ICI_FSDP_TRANSPOSE_PARALLELISM=1
ICI_SEQUENCE_PARALLELISM=1
ICI_TENSOR_PARALLELISM=2
ICI_TENSOR_SEQUENCE_PARALLELISM=1
ICI_EXPERT_PARALLELISM=1
ICI_AUTOREGRESSIVE_PARALLELISM=1

# image settings
CLOUD_IMAGE_NAME=${USER}_runner
DOCKER_IMAGE=gcr.io/${PROJECT}/${CLOUD_IMAGE_NAME}:latest

# EXP settings
EXP_NAME=$(echo $MODEL_NAME | tr '.' '-')-bs${PER_DEVICE_BATCH_SIZE}-$(date +'%m-%d-%H-%M-%S') # --workload: Workload name must be less than 40 characters and match the pattern `[a-z]([-a-z0-9]*[a-z0-9])?`

# download dataset
cd ~/maxtext
bash download_dataset.sh ${PROJECT} ${DATASET_PATH}

# create and push image
cd ~/maxtext
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}

# create workload for model training
cd ~/xpk
python3 xpk.py workload create \
    --cluster ${CLUSTER} \
    --docker-image ${DOCKER_IMAGE} \
    --workload ${EXP_NAME} \
    --tpu-type ${TPU_TYPE} \
    --num-slices ${NUM_SLICES}  \
    --use-vertex-tensorboard \
    --experiment-name ${EXP_NAME} \
    --zone ${ZONE} \
    --on-demand \
    --enable-debug-logs \
    --project ${PROJECT} \
    --command "export LIBTPU_INIT_ARGS='--xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_scoped_vmem_limit_kib=${VMEM_LIMIT} --xla_tpu_enable_async_collective_fusion=true --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true' && python3 MaxText/train.py MaxText/configs/base.yml model_name=${MODEL_NAME} base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} run_name=${EXP_NAME} tokenizer_path=${TOKENIZER_PATH} max_target_length=${MAX_TARGET_LENGTH} per_device_batch_size=${PER_DEVICE_BATCH_SIZE} remat_policy=${REMAT_POLICY} steps=${STEPS} enable_checkpointing=${ENABLE_CHECKPOINTING} checkpoint_period=${CHECKPOINT_PERIOD} use_iota_embed=true gcs_metrics=true dataset_type=${DATASET_TYPE} reuse_example_batch=1 profiler=xplane attention=flash sa_block_q=${BLOCK_SIZE} sa_block_q_dkv=${BLOCK_SIZE} sa_block_q_dq=${BLOCK_SIZE} ici_data_parallelism=${ICI_DATA_PARALLELISM} ici_pipeline_parallelism=${ICI_PIPELINE_PARALLELISM} ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} ici_fsdp_transpose_parallelism=${ICI_FSDP_TRANSPOSE_PARALLELISM} ici_sequence_parallelism=${ICI_SEQUENCE_PARALLELISM} ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} ici_tensor_sequence_parallelism=${ICI_TENSOR_SEQUENCE_PARALLELISM} ici_expert_parallelism=${ICI_EXPERT_PARALLELISM} ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM}"

Error Message

I got the following error during process of MaxText/train.py. Especially, the error occur if I activate a checkpoint setting (ENABLE_CHECKPOINTING=true).

"'Traceback (most recent call last):
File ""/deps/MaxText/train.py"", line 1031, in <module>
app.run(main)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 308, in run
_run_main(main, args)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 254, in _run_main
sys.exit(main(argv))
File ""/deps/MaxText/train.py"", line 1027, in main
train_loop(config)
File ""/deps/MaxText/train.py"", line 897, in train_loop
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
File ""/deps/MaxText/train.py"", line 241, in save_checkpoint
return checkpoint_manager.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py"", line 1278, in save
self._checkpointer.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 491, in save
asyncio_utils.run_sync(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/asyncio_utils.py"", line 50, in run_sync
return asyncio.run(coro)
File ""/usr/local/lib/python3.10/asyncio/runners.py"", line 44, in run
return loop.run_until_complete(main)
File ""/usr/local/lib/python3.10/asyncio/base_events.py"", line 649, in run_until_complete
return future.result()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 392, in _save
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py"", line 706, in async_save
jax.tree.flatten(await asyncio.gather(*save_ops))[0] or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py"", line 583, in async_save
return await self._handler_impl.async_save(directory, args=args)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py"", line 482, in async_save
commit_futures = await asyncio.gather(*serialize_ops)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py"", line 1127, in serialize
future.CommitFutureAwaitingContractedSignals(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 367, in init
receive_signals = get_awaitable_signals_from_contract()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 57, in get_awaitable_signals_from_contract
values_str = str(client.key_value_try_get(barrier_key))
AttributeError: 'DistributedRuntimeClient' object has no attribute 'key_value_try_get'. Did you mean: 'key_value_dir_get'?"

Solution

We should install specific package version orbax-checkpoint==0.10.3 (Now orbax-checkpoint==0.11.5 will be installed without version specification) when we create docker image. We solved the problem by rewriting these requirements file (requirements_with_jax_stable_stack.txt, requirements_with_jax_stable_stack.txt).

# maxtext/requirements_with_jax_stable_stack.txt
...
orbax-checkpoint==0.10.3
...
# maxtext/requirements.txt
...
orbax-checkpoint==0.10.3
...

Reference

I referred the following URLs when I create the shell script.

How to run MaxText with XPK?
https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Run_MaxText_via_xpk.md

…Hypercomputer#1273)

- Update requirements.txt and requirements_with_jax_stable_stack.txt to specify orbax-checkpoint==0.10.3.
- Prevent AttributeError in MaxText/train.py related to key_value_try_get.
Copy link

google-cla bot commented Feb 14, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@A9isha
Copy link
Collaborator

A9isha commented Feb 20, 2025

Hi @shota-inoue-lts ,

Thank you so much for looking into this!
We are working on fixing this issue by updating the dependencies without needing to make a change in MaxText. In the meantime, you could also try updating your Jax version to the following and it should resolve the error:

jax==0.5.0
jaxlib==0.5.0
jaxtyping==0.2.38

@shota-inoue-lts
Copy link
Author

Hi @A9isha ,
Thank you for your suggestion to fix the problem more easily!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants