From 61cd6f447c56a47331de0b5d1159034ad758ca4b Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Fri, 31 Jan 2025 00:31:56 +0000 Subject: [PATCH 01/10] Add Pathways Benchmarking Recipes for Scale Testing --- benchmarks/maxtext_trillium_model_configs.py | 1920 ++++++++--------- benchmarks/maxtext_xpk_runner.py | 108 +- benchmarks/recipes/pw_long_running_recipe.py | 132 ++ .../recipes/pw_mcjax_benchmark_recipe.py | 143 ++ benchmarks/recipes/pw_remote_python_recipe.py | 2 +- benchmarks/xla_flags_library.py | 13 - 6 files changed, 1290 insertions(+), 1028 deletions(-) create mode 100644 benchmarks/recipes/pw_long_running_recipe.py create mode 100644 benchmarks/recipes/pw_mcjax_benchmark_recipe.py diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 9b4e7ba87..ba7f44080 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -24,6 +24,8 @@ # TODO(vbarr@) Make slice dependent configurations to allow for a model's tuning # to adjust at scales. +REMOVE = "remove" + @dataclasses.dataclass class MaxTextModel: @@ -32,1150 +34,1050 @@ class MaxTextModel: tuning_params: dict[str, typing.Any] xla_flags: str + # Additional pathways tuning params as necessary. Adding + # enable_single_controller=True to pathways_tuning_params is not necessary. + pathways_tuning_params: dict[str, typing.Any] = None + + # XLA flags for pathways, if different from the default. Some flags may not + # be supported by pathways e.g. "--2a886c8_chip_config_name". + pathways_xla_flag_options: dict[str, typing.Any] = None + + trillium_model_dict = {} + # Run this for new definitions that should be part of the library. -def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_model: MaxTextModel)-> MaxTextModel: - model_dictionary[maxtext_model.model_name.replace('-', '_')] = maxtext_model +def _add_to_model_dictionary( + model_dictionary: dict[str, MaxTextModel], maxtext_model: MaxTextModel +) -> MaxTextModel: + model_dictionary[maxtext_model.model_name.replace("-", "_")] = maxtext_model return maxtext_model + default_basic_1 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-basic-1", - model_type="default", - tuning_params={ - "per_device_batch_size": 1, - "remat_policy": "full", - "global_parameter_scale": 1, - "attention": "flash", - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - }, - xla_flags="", - ) + trillium_model_dict, + MaxTextModel( + model_name="default-basic-1", + model_type="default", + tuning_params={ + "per_device_batch_size": 1, + "remat_policy": "full", + "global_parameter_scale": 1, + "attention": "flash", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + }, + pathways_tuning_params={ + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_checkpoint_cloud_logger": True, + }, + xla_flags="", + ), ) -default_basic_1_pw = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-basic-1-pw", - model_type="default", - tuning_params={ - "per_device_batch_size": 1, - "remat_policy": "full", - "global_parameter_scale": 1, - "attention": "flash", - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - # "profiler": "xplane", - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - # "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags="", - ) -) - default_32 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-32", - model_type="default", - tuning_params={ - "per_device_batch_size": 13, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 32, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 1024, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-32", + model_type="default", + tuning_params={ + "per_device_batch_size": 13, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 32, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 1024, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) default_64 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-64", - model_type="default", - tuning_params={ - "per_device_batch_size": 6, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 64, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-64", + model_type="default", + tuning_params={ + "per_device_batch_size": 6, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 64, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) default_128 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-128", - model_type="default", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 128, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-128", + model_type="default", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 128, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) # OOM, Not Optimized yet default_256 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-256", - model_type="default", - tuning_params={ - "per_device_batch_size": 1, - "ici_fsdp_parallelism": -1, - "dcn_fsdp_transpose_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 256, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-256", + model_type="default", + tuning_params={ + "per_device_batch_size": 1, + "ici_fsdp_parallelism": -1, + "dcn_fsdp_transpose_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 256, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) # OOM, Not Optimized yet default_512 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="default-512", - model_type="default", - tuning_params={ - "per_device_batch_size": 1, - "ici_fsdp_parallelism": -1, - # "dcn_fsdp_parallelism": 2, - "dcn_fsdp_parallelism": -1, - "remat_policy": "full", - "global_parameter_scale": 512, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="default-512", + model_type="default", + tuning_params={ + "per_device_batch_size": 1, + "ici_fsdp_parallelism": -1, + # "dcn_fsdp_parallelism": 2, + "dcn_fsdp_parallelism": -1, + "remat_policy": "full", + "global_parameter_scale": 512, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) gpt_3_175b = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="gpt-3-175b", - model_type="gpt3-175b", - tuning_params={ - "per_device_batch_size": 3, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "attention": "flash", - "quantization": "int8", - "gcs_metrics": True, - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.DISABLE_BUNDLE_AWARE_COST_MODEL + trillium_model_dict, + MaxTextModel( + model_name="gpt-3-175b", + model_type="gpt3-175b", + tuning_params={ + "per_device_batch_size": 3, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "attention": "flash", + "quantization": "int8", + "gcs_metrics": True, + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.DISABLE_BUNDLE_AWARE_COST_MODEL + ), ), - ) ) llama2_7b_4096 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-7b-4096", - model_type="llama2-7b", - tuning_params={ - "per_device_batch_size": 12, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - ), - ) -) - -llama2_7b_4096_pw = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-7b-4096-pw", - model_type="llama2-7b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "steps": 1000000, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - # "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2-7b-4096", + model_type="llama2-7b", + tuning_params={ + "per_device_batch_size": 12, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + pathways_tuning_params={ + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_checkpoint_cloud_logger": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) llama2_70b_4096 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) llama2_70b_4096_optimized = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2_70b_4096_synthetic", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama2_70b_4096_synthetic", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + pathways_tuning_params={ + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_checkpoint_cloud_logger": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) # Enable SparseCore Offloading of AR in an optimized model. llama2_70b_4096_sc = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096-sc", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096-sc", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + ), + pathways_xla_flag_options={ + REMOVE: ["--2a886c8_chip_config_name=megachip_tccontrol"], + }, ), - ) ) llama2_70b_4096_sc_real_data_tfds = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096-sc", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://trillium-storage-datasets-sr", - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096-sc", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + ), + pathways_xla_flag_options={ + REMOVE: ["--2a886c8_chip_config_name=megachip_tccontrol"], + }, ), - ) ) llama2_70b_4096_sc_real_data_grain = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://trillium-storage-datasets-sr", - "base_output_directory": ( - "gs://trillium-storage-tests-nov24-sr/long-run-dec11" + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "base_output_directory": ( + "gs://trillium-storage-tests-nov24-sr/long-run-dec11" + ), + "enable_checkpointing": False, + "dataset_type": "grain", + "grain_train_files": ( + "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*" + ), + "grain_worker_count": 24, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "profile_cleanly": False, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE ), - "enable_checkpointing": False, - "dataset_type": "grain", - "grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*", - "grain_worker_count": 24, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "profile_cleanly": False, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + pathways_xla_flag_options={ + REMOVE: ["--2a886c8_chip_config_name=megachip_tccontrol"], + }, ), - ) ) llama2_70b_4096_sc_real_data_grain_checkpoint = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://trillium-storage-datasets-sr", - "base_output_directory": ( - "gs://trillium-storage-tests-nov24-sr/long-run-dec11" + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": 1, + "ici_fsdp_transpose_parallelism": -1, + "ici_tensor_parallelism": 1, + "remat_policy": "qkv_proj_offloaded", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://trillium-storage-datasets-sr", + "base_output_directory": ( + "gs://trillium-storage-tests-nov24-sr/long-run-dec11" + ), + "checkpoint_period": 100, + "enable_checkpointing": True, + "async_checkpointing": True, + "dataset_type": "grain", + "grain_train_files": ( + "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*" + ), + "grain_worker_count": 24, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE ), - "checkpoint_period": 100, - "enable_checkpointing": True, - "async_checkpointing": True, - "dataset_type": "grain", - "grain_train_files": "/tmp/dataset/array-record/c4/en/3.0.1/c4-train.array_record*", - "grain_worker_count": 24, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE - ), - ) -) - - -llama2_70b_4096_real_data_pw_long_run = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096-rd-pw-lr", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "reuse_example_batch": 0, - "profiler": "xplane", - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "tfds", - "tokenizer_path": "assets/tokenizer.llama2", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "steps": 1000000, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + pathways_xla_flag_options={ + REMOVE: ["--2a886c8_chip_config_name=megachip_tccontrol"], + }, ), - ) ) - -llama2_70b_4096_synthetic_pw_lr = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2_70b_4096_synthetic_pw_lr", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - # "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "steps": 1000000, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - ), - ) -) - - -llama2_70b_4096_pw_long_run = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2-70b-4096-pw-lr", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "full", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "steps": 1000000, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - ), - ) -) - - -llama2_70b_4096_pw_rd_tfds = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama2_70b_4096_pw_rd_tfds", - model_type="llama2-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": 1, - "ici_fsdp_transpose_parallelism": -1, - "ici_tensor_parallelism": 1, - "remat_policy": "qkv_proj_offloaded", - "max_target_length": 4096, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://trillium-storage-datasets-sr", - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - - # Additional tuning params for pathways long running test. - "enable_checkpointing": True, - "async_checkpointing": True, - "checkpoint_period": 100, - "checkpoint_storage_use_ocdbt": False, - "checkpoint_storage_use_zarr3": False, - "metrics_file": "metrics.txt", - "goodput_upload_interval_seconds": 30, - "enable_pathways_goodput": True, - "enable_checkpoint_cloud_logger": True, - "enable_single_controller": True, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER +llama2_70b_4096_real_data_long_run = _add_to_model_dictionary( + trillium_model_dict, + MaxTextModel( + model_name="llama2-70b-4096-rd-lr", + model_type="llama2-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 4096, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "reuse_example_batch": 0, + "profiler": "xplane", + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "tfds", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + pathways_tuning_params={ + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_checkpoint_cloud_logger": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) llama3_8b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3-8b-8192", - model_type="llama3-8b", - tuning_params={ - "per_device_batch_size": 8, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER + trillium_model_dict, + MaxTextModel( + model_name="llama3-8b-8192", + model_type="llama3-8b", + tuning_params={ + "per_device_batch_size": 8, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + ), ), - ) ) llama3_70b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3-70b-8192", - model_type="llama3-70b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_parallelism": -1, - "remat_policy": "full", - "optimizer_memory_host_offload": True, - "gradient_clipping_threshold": 0, - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS - + " --xla_tpu_scheduler_percent_shared_memory_limit=90" + trillium_model_dict, + MaxTextModel( + model_name="llama3-70b-8192", + model_type="llama3-70b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_parallelism": -1, + "remat_policy": "full", + "optimizer_memory_host_offload": True, + "gradient_clipping_threshold": 0, + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + + " --xla_tpu_scheduler_percent_shared_memory_limit=90" + ), ), - ) ) llama3_1_405b_8192_fsdp_dcn = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3-1-405b-8192-fsdp-dcn", - model_type="llama3.1-405b", - tuning_params={ - "per_device_batch_size": 1, - "ici_fsdp_parallelism": 64, - "ici_tensor_parallelism": 4, - "dcn_fsdp_parallelism": 2, - "allow_split_physical_axes": True, - "custom_mesh": "hybrid_ring_64x4", - "remat_policy": "custom", - "decoder_layer_input": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "out_proj": "offload", - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 1024, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS + trillium_model_dict, + MaxTextModel( + model_name="llama3-1-405b-8192-fsdp-dcn", + model_type="llama3.1-405b", + tuning_params={ + "per_device_batch_size": 1, + "ici_fsdp_parallelism": 64, + "ici_tensor_parallelism": 4, + "dcn_fsdp_parallelism": 2, + "allow_split_physical_axes": True, + "custom_mesh": "hybrid_ring_64x4", + "remat_policy": "custom", + "decoder_layer_input": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "out_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 1024, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), ), - ) ) llama3_1_8b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3_1-8b-8192", - model_type="llama3.1-8b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": -1, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "out_proj": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "max_target_length": 8192, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "sa_block_q": 2048, - "sa_block_kv": 2048, - "sa_block_kv_compute": 2048, - "sa_block_q_dkv": 2048, - "sa_block_kv_dkv": 2048, - "sa_block_kv_dkv_compute": 2048, - "sa_block_q_dq": 2048, - "sa_block_kv_dq": 2048, - "sa_use_fused_bwd_kernel": True, - "profiler": "xplane", - "skip_first_n_steps_for_profiler": 10, - "profiler_steps": 5, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE - + xla_flags_library.HOST_OFFLOAD_FLAGS + trillium_model_dict, + MaxTextModel( + model_name="llama3_1-8b-8192", + model_type="llama3.1-8b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), + pathways_xla_flag_options={ + REMOVE: ["--2a886c8_chip_config_name=megachip_tccontrol"], + }, ), - ) ) llama3_1_70b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3_1-70b-8192", - model_type="llama3.1-70b", - tuning_params={ - "per_device_batch_size": 4, - "ici_fsdp_parallelism": -1, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "max_target_length": 8192, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "sa_block_q": 2048, - "sa_block_kv": 2048, - "sa_block_kv_compute": 2048, - "sa_block_q_dkv": 2048, - "sa_block_kv_dkv": 2048, - "sa_block_kv_dkv_compute": 2048, - "sa_block_q_dq": 2048, - "sa_block_kv_dq": 2048, - "sa_use_fused_bwd_kernel": True, - "profiler": "xplane", - "skip_first_n_steps_for_profiler": 10, - "profiler_steps": 5, - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS + trillium_model_dict, + MaxTextModel( + model_name="llama3_1-70b-8192", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), ), - ) ) +llama3_1_70b_8192_lr_real_data = _add_to_model_dictionary( + trillium_model_dict, + MaxTextModel( + model_name="llama3_1-70b-8192-pw-lr-rd", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 4, + "ici_fsdp_parallelism": -1, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 8192, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + }, + pathways_tuning_params={ + "enable_checkpointing": True, + "async_checkpointing": True, + "checkpoint_period": 100, + "metrics_file": "metrics.txt", + "goodput_upload_interval_seconds": 30, + "enable_checkpoint_cloud_logger": True, + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), + ), +) + llama3_1_70b_129024 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="llama3_1-70b-129024", - model_type="llama3.1-70b", - tuning_params={ - "per_device_batch_size": 0.125, - "ici_fsdp_parallelism": -1, - "ici_sequence_parallelism": 8, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "out_proj": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "max_target_length": 129024, - "attention": "flash", - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "enable_checkpointing": False, - "sa_block_q": 2048, - "sa_block_kv": 2048, - "sa_block_kv_compute": 2048, - "sa_block_q_dkv": 2048, - "sa_block_kv_dkv": 2048, - "sa_block_kv_dkv_compute": 2048, - "sa_block_q_dq": 2048, - "sa_block_kv_dq": 2048, - "sa_use_fused_bwd_kernel": True, - "profiler": "xplane", - "skip_first_n_steps_for_profiler": 10, - "profiler_steps": 5, - "allow_split_physical_axes": True, - "custom_mesh": "hybrid_ring_32x8", - }, - xla_flags=( - xla_flags_library.DENSE_VMEM_LIMIT_FLAG - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER - + xla_flags_library.DATA_PARALLEL_OVERLAP - + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER - + xla_flags_library.HOST_OFFLOAD_FLAGS + trillium_model_dict, + MaxTextModel( + model_name="llama3_1-70b-129024", + model_type="llama3.1-70b", + tuning_params={ + "per_device_batch_size": 0.125, + "ici_fsdp_parallelism": -1, + "ici_sequence_parallelism": 8, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "max_target_length": 129024, + "attention": "flash", + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "enable_checkpointing": False, + "sa_block_q": 2048, + "sa_block_kv": 2048, + "sa_block_kv_compute": 2048, + "sa_block_q_dkv": 2048, + "sa_block_kv_dkv": 2048, + "sa_block_kv_dkv_compute": 2048, + "sa_block_q_dq": 2048, + "sa_block_kv_dq": 2048, + "sa_use_fused_bwd_kernel": True, + "profiler": "xplane", + "skip_first_n_steps_for_profiler": 10, + "profiler_steps": 5, + "allow_split_physical_axes": True, + "custom_mesh": "hybrid_ring_32x8", + }, + xla_flags=( + xla_flags_library.DENSE_VMEM_LIMIT_FLAG + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + + xla_flags_library.DATA_PARALLEL_OVERLAP + + xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_GATHER + + xla_flags_library.HOST_OFFLOAD_FLAGS + ), + pathways_xla_flag_options={ + REMOVE: ["--2a886c8_chip_config_name=megachip_tccontrol"], + } ), - ) ) mixtral_8x7b_dropless = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="mixtral_8x7b_dropless", - model_type="mixtral-8x7b", - tuning_params={ - "per_device_batch_size": 12, - "ici_fsdp_parallelism": -1, - "max_target_length": 4096, - "remat_policy": "full", - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "megablox": True, - "sparse_matmul": True, - }, - xla_flags=( - xla_flags_library.MOE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP + trillium_model_dict, + MaxTextModel( + model_name="mixtral_8x7b_dropless", + model_type="mixtral-8x7b", + tuning_params={ + "per_device_batch_size": 12, + "ici_fsdp_parallelism": -1, + "max_target_length": 4096, + "remat_policy": "full", + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "megablox": True, + "sparse_matmul": True, + }, + xla_flags=( + xla_flags_library.MOE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + ), ), - ) ) mixtral_8x7b_dropped = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="mixtral_8x7b_dropped", - model_type="mixtral-8x7b", - tuning_params={ - "per_device_batch_size": 12, - "ici_fsdp_parallelism": -1, - "max_target_length": 4096, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "out_proj": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "megablox": False, - "sparse_matmul": False, - "capacity_factor": 1.25, - "tokenizer_path": "assets/tokenizer.mistral-v1", - }, - xla_flags=( - xla_flags_library.MOE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP + trillium_model_dict, + MaxTextModel( + model_name="mixtral_8x7b_dropped", + model_type="mixtral-8x7b", + tuning_params={ + "per_device_batch_size": 12, + "ici_fsdp_parallelism": -1, + "max_target_length": 4096, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "megablox": False, + "sparse_matmul": False, + "capacity_factor": 1.25, + "tokenizer_path": "assets/tokenizer.mistral-v1", + }, + xla_flags=( + xla_flags_library.MOE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + ), ), - ) ) mixtral_8x7b_dropped_int8 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="mixtral_8x7b_dropped_int8", - model_type="mixtral-8x7b", - tuning_params={ - "per_device_batch_size": 8, - "ici_fsdp_parallelism": -1, - "max_target_length": 4096, - "remat_policy": "full", - "attention": "flash", - "gcs_metrics": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "megablox": False, - "sparse_matmul": False, - "capacity_factor": 1.25, - "quantization": "int8", - "tokenizer_path": "assets/tokenizer.mistral-v1", - }, - xla_flags=( - xla_flags_library.MOE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP + trillium_model_dict, + MaxTextModel( + model_name="mixtral_8x7b_dropped_int8", + model_type="mixtral-8x7b", + tuning_params={ + "per_device_batch_size": 8, + "ici_fsdp_parallelism": -1, + "max_target_length": 4096, + "remat_policy": "full", + "attention": "flash", + "gcs_metrics": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "megablox": False, + "sparse_matmul": False, + "capacity_factor": 1.25, + "quantization": "int8", + "tokenizer_path": "assets/tokenizer.mistral-v1", + }, + xla_flags=( + xla_flags_library.MOE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + ), ), - ) ) mixtral_8x22b_dropped = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="mixtral_8x22b_dropped", - model_type="mixtral-8x22b", - tuning_params={ - "per_device_batch_size": 8, - "max_target_length": 4096, - "ici_fsdp_parallelism": 64, - "ici_expert_parallelism": 4, - "remat_policy": "custom", - "decoder_layer_input": "offload", - "out_proj": "offload", - "query_proj": "offload", - "key_proj": "offload", - "value_proj": "offload", - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - "megablox": False, - "sparse_matmul": False, - "capacity_factor": 1.25, - "tokenizer_path": "assets/tokenizer.mistral-v3", - "dtype": "bfloat16", - "weight_dtype": "bfloat16", - "allow_split_physical_axes": True, - "custom_mesh": "hybrid_ring_64x4", - }, - xla_flags=( - xla_flags_library.MOE_VMEM_LIMIT_FLAG - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.DATA_PARALLEL_OVERLAP + trillium_model_dict, + MaxTextModel( + model_name="mixtral_8x22b_dropped", + model_type="mixtral-8x22b", + tuning_params={ + "per_device_batch_size": 8, + "max_target_length": 4096, + "ici_fsdp_parallelism": 64, + "ici_expert_parallelism": 4, + "remat_policy": "custom", + "decoder_layer_input": "offload", + "out_proj": "offload", + "query_proj": "offload", + "key_proj": "offload", + "value_proj": "offload", + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + "megablox": False, + "sparse_matmul": False, + "capacity_factor": 1.25, + "tokenizer_path": "assets/tokenizer.mistral-v3", + "dtype": "bfloat16", + "weight_dtype": "bfloat16", + "allow_split_physical_axes": True, + "custom_mesh": "hybrid_ring_64x4", + }, + xla_flags=( + xla_flags_library.MOE_VMEM_LIMIT_FLAG + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.DATA_PARALLEL_OVERLAP + ), ), - ) ) gemma2_9b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="gemma2-9b-8192", - model_type="gemma2-9b", - tuning_params={ - "per_device_batch_size": 3, - "ici_fsdp_transpose_parallelism": 256, - "remat_policy": "full", - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "tokenizer_path": "assets/tokenizer.llama2", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(114688) - + xla_flags_library.REDUCE_SCATTER_FUSION - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + trillium_model_dict, + MaxTextModel( + model_name="gemma2-9b-8192", + model_type="gemma2-9b", + tuning_params={ + "per_device_batch_size": 3, + "ici_fsdp_transpose_parallelism": 256, + "remat_policy": "full", + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(114688) + + xla_flags_library.REDUCE_SCATTER_FUSION + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + ), ), - ) ) gemma2_27b_8192 = _add_to_model_dictionary( - trillium_model_dict, - MaxTextModel( - model_name="gemma2-27b-8192", - model_type="gemma2-27b", - tuning_params={ - "per_device_batch_size": 2, - "ici_fsdp_transpose_parallelism": 256, - "remat_policy": "full", - "max_target_length": 8192, - "attention": "flash", - "gcs_metrics": True, - "use_iota_embed": True, - "dataset_path": "gs://max-datasets-rogue", - "dataset_type": "synthetic", - "reuse_example_batch": 1, - "enable_checkpointing": False, - "profiler": "xplane", - "tokenizer_path": "assets/tokenizer.llama2", - "sa_block_q": 2048, - "sa_block_q_dkv": 2048, - "sa_block_q_dq": 2048, - }, - xla_flags=( - xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(122880) - + xla_flags_library.REDUCE_SCATTER_FUSION - + xla_flags_library.CF_FOR_ALL_GATHER - + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + trillium_model_dict, + MaxTextModel( + model_name="gemma2-27b-8192", + model_type="gemma2-27b", + tuning_params={ + "per_device_batch_size": 2, + "ici_fsdp_transpose_parallelism": 256, + "remat_policy": "full", + "max_target_length": 8192, + "attention": "flash", + "gcs_metrics": True, + "use_iota_embed": True, + "dataset_path": "gs://max-datasets-rogue", + "dataset_type": "synthetic", + "reuse_example_batch": 1, + "enable_checkpointing": False, + "profiler": "xplane", + "tokenizer_path": "assets/tokenizer.llama2", + "sa_block_q": 2048, + "sa_block_q_dkv": 2048, + "sa_block_q_dq": 2048, + }, + xla_flags=( + xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(122880) + + xla_flags_library.REDUCE_SCATTER_FUSION + + xla_flags_library.CF_FOR_ALL_GATHER + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER + ), ), - ) ) diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index c22de9157..48217b6be 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -40,6 +40,15 @@ # Assumes you have xpk installed in a git clone repo of ~/{wl_config.xpk_path}/xpk.py _DEFAULT_MAXTEXT_BASE_DOCKER_IMAGE_NAME = 'maxtext_base_image' +# The minimum set of tuning params required for pathways. +BASE_PATHWAYS_TUNING_PARAMS = { + 'checkpoint_storage_use_ocdbt': False, + 'checkpoint_storage_use_zarr3': False, + 'enable_pathways_goodput': True, + 'enable_single_controller': True, +} + + class LibTpuType(enum.Enum): NIGHTLY = 'nightly-libtpu' # In order to use a custom libtpu, put a libtpu.so file in your local @@ -275,15 +284,70 @@ def run_command_with_updates(command, task, verbose=True) -> int: return 0 +def _get_config_tuning_params(wl_config: WorkloadConfig): + """Get config tuning parameters for the workload. + + Args: + wl_config: Workload configuration. + + Returns: + A string of config tuning parameters. + """ + is_pw_enabled = wl_config.pathways_config is not None + + config_tuning_params = '' + unified_tuning_params = wl_config.model.tuning_params.copy() # Create a copy + + # Overwrite the tuning params with pathways specific tuning params if present. + # otherwise add them to the dictionary. If pathays tuning params are not + # present, add the default pathways tuning params. + if is_pw_enabled: + if wl_config.model.pathways_tuning_params is None: + print( + 'WARNING: Pathways tuning params are not present for model:' + f' {wl_config.model.model_name}, Adding the following base params to' + f' support pathways: {BASE_PATHWAYS_TUNING_PARAMS}' + ) + wl_config.model.pathways_tuning_params = BASE_PATHWAYS_TUNING_PARAMS + + # Automatically inject Base Pathways tuning params if not present. The user + # can override these values if they want, but if not present, we will add + # them to the dictionary. + for key, value in BASE_PATHWAYS_TUNING_PARAMS.items(): + if key not in wl_config.model.pathways_tuning_params: + wl_config.model.pathways_tuning_params[key] = value + + print( + f'WARNING: {key} is not present in pathways tuning' + f' params for model: {wl_config.model.model_name}, Adding the' + f' param {key}={value} to support pathways.' + ) + + print( + f'Pathways tuning params for model: {wl_config.model.model_name} are:' + f' {wl_config.model.pathways_tuning_params}' + ) + for key, value in wl_config.model.pathways_tuning_params.items(): + unified_tuning_params[key] = value + + print( + f'Unified tuning params for model are:' + f' {unified_tuning_params}' + ) + + for key, value in unified_tuning_params.items(): + config_tuning_params += f'{key}={value} ' + + return config_tuning_params + + def build_user_command( name: str, wl_config: WorkloadConfig, ): is_pw_enabled = wl_config.pathways_config is not None - config_tuning_params = '' - for key, value in wl_config.model.tuning_params.items(): - config_tuning_params += f'{key}={value} ' + config_tuning_params = _get_config_tuning_params(wl_config) install_libtpu_cmd = '' jax_platforms = None @@ -336,6 +400,41 @@ def build_user_command( return command +def _get_pathways_proxy_flags(wl_config: WorkloadConfig): + """Get the pathways proxy flags for the workload and removes any extras.""" + # Add in the xla flags alongside the proxy flags from the pathways config. + pw_config = wl_config.pathways_config + + # Get proxy and xla flag string from model config + proxy_flags_string = pw_config.proxy_flags + xla_flags_string = wl_config.model.xla_flags + + # Split both proxy_flags_string and xla_flags_string into lists of flags + proxy_flags_list = proxy_flags_string.strip().split() + xla_flags_list = xla_flags_string.strip().split() + + # Combine the two lists of flags into a single list + proxy_flags = proxy_flags_list + xla_flags_list + + # Remove the flags that are specified to be removed. + if ( + wl_config.model.pathways_xla_flag_options + and model_configs.REMOVE in wl_config.model.pathways_xla_flag_options + ): + flags_to_remove = wl_config.model.pathways_xla_flag_options[ + model_configs.REMOVE + ] + updated_proxy_flags = [] + for flag in proxy_flags: + if flag not in flags_to_remove: + updated_proxy_flags.append(flag) + proxy_flags = updated_proxy_flags + + # Join the list of flags back into a single string, space-separated + return " ".join(proxy_flags) + + + def _get_pathways_specific_flags(wl_config: WorkloadConfig): pw_config = wl_config.pathways_config if pw_config is None: @@ -357,7 +456,7 @@ def _get_pathways_specific_flags(wl_config: WorkloadConfig): else '' ) - proxy_flags = pw_config.proxy_flags + wl_config.model.xla_flags + proxy_flags = _get_pathways_proxy_flags(wl_config) pathways_specific_flags = ( f' {server_image_flag} ' @@ -365,7 +464,6 @@ def _get_pathways_specific_flags(wl_config: WorkloadConfig): f' {remote_python_sidecar_image_flag} ' f' --termination-grace-period-seconds=300 ' f' --pathways-gcs-location={wl_config.base_output_directory} ' - f' --restart-on-user-code-failure' f' --custom-pathways-server-args="{pw_config.server_flags}" ' f' --custom-pathways-proxy-server-args="{proxy_flags}" ' f' --custom-pathways-worker-args="{pw_config.worker_flags}" ' diff --git a/benchmarks/recipes/pw_long_running_recipe.py b/benchmarks/recipes/pw_long_running_recipe.py new file mode 100644 index 000000000..7a0f24fe2 --- /dev/null +++ b/benchmarks/recipes/pw_long_running_recipe.py @@ -0,0 +1,132 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import datetime +import sys +import os +import args_helper as helper + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +import maxtext_trillium_model_configs as model_configs +import maxtext_xpk_runner as mxr + +PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/sanitized_proxy_server:latest" +SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/sanitized_server:latest" +RUNNER = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest" + +# Cluster Params +CLUSTER = "v6e-256-cluster" +PROJECT = "tpu-prod-env-cluster" +ZONE = "us-east5-b" +COUNTRY = "us" +DEVICE_TYPE = "v6e-256" + +# Other parameters (MUST BE SET BY USER) +XPK_PATH = "../xpk" # We're running this script from the maxtext directory +USER = os.environ["USER"] +BASE_OUTPUT_DIRECTORY = ( + f"gs://{USER}-{PROJECT}-{COUNTRY}/pw_mcjax_benchmarking/" +) + +BENCHMARK_STEPS=10_000_000 + + +def main() -> int: + # V6e cluster config + cluster_config = mxr.XpkClusterConfig( + cluster_name=CLUSTER, + project=PROJECT, + zone=ZONE, + device_type=DEVICE_TYPE, + ) + + # Handle command line arguments using args_helper + should_continue = helper.handle_cmd_args( + cluster_config, helper.DELETE, xpk_path=XPK_PATH + ) + + if not should_continue: + return 0 + + model_list = [ + # model_configs.llama3_1_70b_8192_pw_lr_real_data, + model_configs.llama3_1_8b_8192_pw, + ] + pathways_config = mxr.PathwaysConfig( + server_image=SERVER_IMAGE, + proxy_server_image=PROXY_IMAGE, + runner_image=RUNNER, + server_flags=( + "--temporary_flags_for_debugging=" + "temporary_flag_for_debugging_worker_expected_tpu_chip_config=" + "megachip_tccontrol --xla_tpu_use_enhanced_launch_barrier" + ), + proxy_flags="--xla_tpu_use_enhanced_launch_barrier", + worker_flags="--xla_tpu_use_enhanced_launch_barrier", + ) + num_slices_list = [ + 2 + ] + + xpk_workload_cmds = [] + xpk_workload_names = [] + + for model in model_list: + # Run workloads on the below clusters + for cluster_config in [ + cluster_config, + ]: + # Run workloads in the following slice configurations + for num_slices in num_slices_list: + wl_config = mxr.WorkloadConfig( + model=model, + num_slices=num_slices, + device_type=cluster_config.device_type, + base_output_directory=BASE_OUTPUT_DIRECTORY, + max_restarts=10_000, + libtpu_type=None, + libtpu_nightly_version="", + base_docker_image=None, + pathways_config=pathways_config, + xpk_path=XPK_PATH, + num_steps=BENCHMARK_STEPS, + ) + command, name = mxr.generate_xpk_workload_cmd( + cluster_config=cluster_config, wl_config=wl_config + ) + + print(f"Name of the workload is: {name} \n") + xpk_workload_names.append(name) + + print(f"XPK command to be used is: {command} \n") + xpk_workload_cmds.append(command) + + for xpk_workload_name, xpk_workload_cmd in zip( + xpk_workload_names, xpk_workload_cmds + ): + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + print(f"[{timestamp}] Running workload: {xpk_workload_name} with command: {xpk_workload_cmd}") + return_code = mxr.run_command_with_updates( + xpk_workload_cmd, xpk_workload_name + ) + if return_code != 0: + print(f"Unable to run xpk workload: {xpk_workload_name}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py new file mode 100644 index 000000000..888e17fe0 --- /dev/null +++ b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py @@ -0,0 +1,143 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import datetime +import sys +import os +import args_helper as helper + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +import maxtext_trillium_model_configs as model_configs +import maxtext_xpk_runner as mxr + +PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/unsanitized_proxy_server:latest" +SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/unsanitized_server:latest" +RUNNER = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest" + +# Cluster Params +CLUSTER = "v6e-256-cluster" +PROJECT = "tpu-prod-env-cluster" +ZONE = "us-east5-b" +COUNTRY = "us" +DEVICE_TYPE = "v6e-256" + +# Other parameters (MUST BE SET BY USER) +XPK_PATH = "../xpk" # We're running this script from the maxtext directory +USER = os.environ["USER"] +BASE_OUTPUT_DIRECTORY = ( + f"gs://{USER}-{PROJECT}-{COUNTRY}/pw_mcjax_benchmarking/" +) + +BENCHMARK_STEPS = 20 + + +def main() -> int: + # V6e cluster config + cluster_config = mxr.XpkClusterConfig( + cluster_name=CLUSTER, + project=PROJECT, + zone=ZONE, + device_type=DEVICE_TYPE, + ) + + # Handle command line arguments using args_helper + should_continue = helper.handle_cmd_args( + cluster_config, helper.DELETE, xpk_path=XPK_PATH + ) + + if not should_continue: + return 0 + + models = { + "mcjax": [ + # model_configs.llama3_1_8b_8192, + # model_configs.llama3_1_70b_8192, + # model_configs.llama3_1_405b_8192_fsdp_dcn, + # model_configs.llama2_70b_4096_real_data_long_run, + ], + "pathways": [ + model_configs.llama3_1_8b_8192, + # model_configs.llama3_1_70b_8192, + # model_configs.llama3_1_405b_8192_fsdp_dcn, + # model_configs.llama2_70b_4096_real_data_long_run, + ] + } + pathways_config = mxr.PathwaysConfig( + server_image=SERVER_IMAGE, + proxy_server_image=PROXY_IMAGE, + runner_image=RUNNER, + server_flags=( + "--temporary_flags_for_debugging=" + "temporary_flag_for_debugging_worker_expected_tpu_chip_config=" + "megachip_tccontrol --xla_tpu_use_enhanced_launch_barrier" + ), + proxy_flags="--xla_tpu_use_enhanced_launch_barrier", + worker_flags="--xla_tpu_use_enhanced_launch_barrier", + ) + num_slices_list = [ + 2 + ] + + xpk_workload_cmds = [] + xpk_workload_names = [] + + for infra, model_list in models.items(): + for model in model_list: + # Run workloads on the below clusters + for cluster_config in [ + cluster_config, + ]: + # Run workloads in the following slice configurations + for num_slices in num_slices_list: + wl_config = mxr.WorkloadConfig( + model=model, + num_slices=num_slices, + device_type=cluster_config.device_type, + base_output_directory=BASE_OUTPUT_DIRECTORY, + max_restarts=0, + libtpu_type=None, + libtpu_nightly_version="", + base_docker_image=RUNNER if infra == "mcjax" else None, + pathways_config=pathways_config if infra == "pathways" else None, + xpk_path=XPK_PATH, + num_steps=BENCHMARK_STEPS, + ) + command, name = mxr.generate_xpk_workload_cmd( + cluster_config=cluster_config, wl_config=wl_config + ) + + print(f"Name of the workload is: {name} \n") + xpk_workload_names.append(name) + + print(f"XPK command to be used is: {command} \n") + xpk_workload_cmds.append(command) + + for xpk_workload_name, xpk_workload_cmd in zip( + xpk_workload_names, xpk_workload_cmds + ): + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + print(f"[{timestamp}] Running workload: {xpk_workload_name} with command: {xpk_workload_cmd}") + return_code = mxr.run_command_with_updates( + xpk_workload_cmd, xpk_workload_name + ) + if return_code != 0: + print(f"Unable to run xpk workload: {xpk_workload_name}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/recipes/pw_remote_python_recipe.py b/benchmarks/recipes/pw_remote_python_recipe.py index c0a98acd2..7c2e7f2d8 100644 --- a/benchmarks/recipes/pw_remote_python_recipe.py +++ b/benchmarks/recipes/pw_remote_python_recipe.py @@ -59,7 +59,7 @@ def main() -> int: base_output_directory = f"gs://{user}-{region}/{user}" list_of_models = [ - model_configs.default_basic_1_pw, + model_configs.default_basic_1, ] pathways_config = mxr.PathwaysConfig( server_image=server_image, diff --git a/benchmarks/xla_flags_library.py b/benchmarks/xla_flags_library.py index d3ca4d32e..78b8b736d 100644 --- a/benchmarks/xla_flags_library.py +++ b/benchmarks/xla_flags_library.py @@ -74,13 +74,6 @@ " --2a886c8_chip_config_name=megachip_tccontrol" ) -ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS_PW = ( - " --xla_tpu_use_tc_device_shape_on_sc=true" - " --xla_sc_enable_instruction_fusion=false" - " --xla_sc_disjoint_spmem=false" - " --xla_sc_disable_megacore_partitioning=true" - # " --2a886c8_chip_config_name=megachip_tccontrol" # Flag has issues in PW. -) # Enable SparseCore All Gather (1D), Reduce Scatter (1D) and All Reduce (ND) ENABLE_SPARSECORE_OFFLOADING_FOR_RS_AG_AR = ( @@ -123,12 +116,6 @@ " --xla_tpu_enable_all_reduce_offload_tracing=true" ) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS -ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE_PW = ( - " --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false" - " --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true" - " --xla_tpu_enable_all_reduce_offload_tracing=true" -) + ENABLE_SPARSECORE_OFFLOADING_BASE_FLAGS_PW - # Better memory layout for all-reduce (AR). LAYOUT_FOR_ALL_REDUCE_SCATTER = ( " --xla_tpu_use_minor_sharding_for_major_trivial_input=true" From 5c44764a86caf9837416067e7cf3691f1c9f4864 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Thu, 21 Nov 2024 21:09:09 +0000 Subject: [PATCH 02/10] Updated input pipeline to use values from elastic utils. Adding ElasticUtils to config Added elasticutils and gkeutils Added a watchdog/timebomb to each step Completely working test run. Further test runs and optimizations to follow Fixed a bug if DATA_LOSS occurs during a save added timeit to reshard_fn Updated the watchdog to repeatably stack trace every timeout intervals. send a fatal log if the failures > max failures in slice_down() instead of in the training loop in order to fail correctly if there is a reshard/failure loop within the reshard handler --- MaxText/elasticutils.py | 282 +++++++++++++++++++++++++++++++++++ MaxText/max_utils.py | 2 +- MaxText/pyconfig.py | 18 +++ MaxText/train.py | 320 ++++++++++++++++++++++++++++------------ 4 files changed, 530 insertions(+), 92 deletions(-) create mode 100644 MaxText/elasticutils.py diff --git a/MaxText/elasticutils.py b/MaxText/elasticutils.py new file mode 100644 index 000000000..2f5e32d5b --- /dev/null +++ b/MaxText/elasticutils.py @@ -0,0 +1,282 @@ +"""Utilities for elastic training.""" + +import collections +import contextlib +import itertools +import functools +import logging +import os +import sys +import time +import threading +import traceback +from typing import Sequence, Any, Optional, Callable +import jax +import numpy as np + +PyTree = Any + +logger = logging.getLogger(__name__) + +logging.basicConfig(level=logging.INFO) +logger.setLevel(logging.INFO) + + +@contextlib.contextmanager +def timer(name: str): + start = time.time() + try: + yield + finally: + end = time.time() + logger.info("%s elaspsed %.2fs.", name, end - start) + +def timeit(func: Callable): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with timer(func.__name__): + return func(*args, **kwargs) + return wrapper + + +class ElasticUtils: + """Utility class for elastic training.""" + + TEST_VALUE = 100 + + def __init__( + self, + devices: Sequence[jax.Device], + total_slice_count: int, + save_period: Optional[int] = None, + reshard_check_period: Optional[int] = None, + max_failures: Optional[int] = None, + ): + self.devices = devices + self.total_slice_count = total_slice_count + + if save_period is None: + save_period = 1 + self.save_period = save_period + + if reshard_check_period is None: + reshard_check_period = 1 + self.reshard_check_period = reshard_check_period + + if max_failures is None: + max_failures = float("inf") + self.max_failures = max_failures + + self.failure_count = 0 + self.good_slice_indices = self.get_slice_availability() + self.data = {} + + def slice_down(self): + """Slice down.""" + logger.info("Slice down") + self.good_slice_indices = self.get_slice_availability() + self.failure_count += 1 + + logger.info(f"Failure count: {self.failure_count} with max {self.max_failures}") + if self.failure_count >= self.max_failures: + logger.fatal(f"Max failures reached {self.max_failures}") + + @timeit + def save(self, save_step: int, **kwargs): + """Save step and state.""" + # In case DATA_LOSS occurs during jax.block_until_ready, overwrite self.data + # at the end + data = {k: jax.tree.map(lambda x: x.copy(), v) for k, v in kwargs.items()} + for v in data.values(): + jax.block_until_ready(v) + data["save_step"] = save_step + + self.data = data + + def is_ready_to_reshard(self, step: int): + """ + Indicates if it is time to reshard. + + May update `good_slice_indices`. + """ + if step % self.reshard_check_period: + return False + if self.good_slice_count >= self.total_slice_count: + return False + + good_slice_indices = self.get_slice_availability() + + if len(good_slice_indices) <= self.good_slice_count: + return False + + logger.info("New slice available.") + logger.info(f"Previous good slice indices: {self.good_slice_indices}") + logger.info(f"Current good slice indices: {good_slice_indices}") + + if not good_slice_indices & self.good_slice_indices: + raise ValueError("All copies of the data have been lost") + + self.good_slice_indices = good_slice_indices + + return True + + @property + def devices(self) -> Sequence[jax.Device]: + """Returns the devices.""" + return self._devices + + @devices.setter + def devices(self, devices: Sequence[jax.Device]) -> None: + """Sets the devices.""" + self._devices = devices + + self.slice_to_devices = collections.defaultdict(list) + for d in self._devices: + self.slice_to_devices[d.slice_index].append(d) + self.slice_to_devices = dict(self.slice_to_devices) + + @property + def good_slice_to_devices(self) -> dict[int, Sequence[jax.Device]]: + """Returns the good slice to devices map.""" + return { + slice_index: self.slice_to_devices[slice_index] + for slice_index in self.good_slice_indices + } + + @property + def good_devices(self) -> Sequence[jax.Device]: + """Returns the good data slice indices.""" + return list( + itertools.chain.from_iterable(self.good_slice_to_devices.values()) + ) + + @property + def default_device(self) -> jax.Device: + """Returns the device that should be set to the default device.""" + return self.slice_to_devices[next(iter(self.good_slice_indices))][0] + + @property + def good_slice_count(self) -> int: + """Returns the number of slices.""" + return len(self.good_slice_indices) + + def slice_device_count(self, slice_index: int) -> int: + """Returns the number of devices in a slice.""" + return len(self.slice_to_devices[slice_index]) + + def _simple_execution( + self, devices: Sequence[jax.Device], block: bool = True + ) -> jax.Array: + """Simple execution to test if a slice is available.""" + x = np.zeros(len(devices), dtype=float) + (self.TEST_VALUE - 1) + y = jax.pmap(lambda x: x + 1, devices=devices)(x) + if block: + y.block_until_ready() + return y + + @timeit + def get_slice_availability(self) -> set[int]: + """Returns the set of good and bad slices.""" + good_slice_indices = set() + + results = { + slice_index: self._simple_execution(devices, block=False) + for slice_index, devices in self.slice_to_devices.items() + } + + for slice_index, x in results.items(): + logger.info(f"checking {slice_index=}") # pylint: disable=logging-fstring-interpolation + expected = ( + np.zeros(self.slice_device_count(slice_index), dtype=float) + + self.TEST_VALUE + ) + try: + with timer(f"checking {slice_index=}"): + if np.allclose(x, expected): + good_slice_indices.add(slice_index) + logger.info(f"{slice_index=} good") + else: + logger.error( # pylint: disable=logging-fstring-interpolation + f"Error with _simple_execution for {slice_index=}. " + "This should not happen." + ) + except jax.errors.JaxRuntimeError as e: + if "DATA_LOSS" in str(e): + logger.info( # pylint: disable=logging-fstring-interpolation + f"Caught JaxRuntimeError DATA_LOSS exception for {slice_index=}" + ) + logger.info(f"{e}") + else: + logger.exception(f"Unknown JaxRuntimeError for {slice_index=}") # pylint: disable=logging-fstring-interpolation + logger.info(f"{slice_index=} bad") + + logger.info(f"{good_slice_indices=}") + + return good_slice_indices + + @staticmethod + @timeit + def reshard( + tree: PyTree, + mesh: jax.sharding.Mesh, + *, + donate: bool = True, + ) -> PyTree: + """Reshard a PyTree.""" + def func(leaf): + return jax.device_put( + leaf, + jax.sharding.NamedSharding(mesh, leaf.sharding.spec), + donate=donate, + ) + + return jax.tree.map(func, tree) + + def scale_by_good_slices(self, x: int | float) -> int | float: + """Scale x by the number of good slices.""" + if isinstance(x, int): + ret, remainder = divmod(x * self.good_slice_count, self.total_slice_count) + if remainder: + raise ValueError( + f"Cannot scale {x=} by good slices because it will result in a " + f"remainder of {remainder=}." + ) + return ret + elif isinstance(x, float): + return x * self.good_slice_count / self.total_slice_count + else: + raise ValueError(f"Unsupported type: {type(x)}") + + +@contextlib.contextmanager +def watchdog(timeout): + event = threading.Event() + + def handler(): + count = 0 + while not event.wait(timeout): + logger.info(f"Watchdog thread dump every {timeout=} seconds. {count=}") + try: + for thread in threading.enumerate(): + try: + logger.info(f"Thread: {thread.ident}") + logger.info("".join(traceback.format_stack(sys._current_frames().get(thread.ident, [])))) + except: + logger.info(f"Error print traceback for {thread.ident=}") + pass + finally: + # logger.fatal("Timeout from timebomb!") + # os.abort() + pass + + count += 1 + + logger.debug("Registering watchdog") + watchdog = threading.Thread(target=handler, name="watchdog") + watchdog.start() + try: + yield + finally: + event.set() + watchdog.join() + logger.debug("Degistering watchdog") diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index ae0e49563..879ae413f 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -634,7 +634,7 @@ def optimize_mesh_for_tpu_v6e(mesh, devices): def create_device_mesh(config, devices=None): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: - devices = jax.devices() + devices = config.eu.good_devices num_devices = len(devices) num_slices = 1 if config.inference_benchmark_test else config.num_slices num_devices_per_slice = num_devices // num_slices diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5c0d5c007..6c56116a6 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -22,6 +22,7 @@ import sys from typing import Any, Union +import elasticutils import jax from jax.experimental.compilation_cache import compilation_cache from layers.attentions import AttentionType @@ -473,6 +474,7 @@ def user_init(raw_keys): raw_keys["add_eos"] = False max_logging.log("Override add_bos and add_eos to False when dataset_type=c4_mlperf") + raw_keys["eu"] = elasticutils.ElasticUtils(jax.devices(), raw_keys["num_slices"], save_period=3, reshard_check_period=5, max_failures=10) # Write raw_keys to GCS before type conversions max_utils.write_config_raw_keys_for_gcs(raw_keys) @@ -869,6 +871,22 @@ class HyperParameters: def __init__(self, config): object.__setattr__(self, "_config", config) + @property + def global_batch_size_to_train_on(self): + return self.eu.scale_by_good_slices(_config.keys["global_batch_size_to_train_on"]) + + @property + def global_batch_size_to_load(self): + return self.eu.scale_by_good_slices(_config.keys["global_batch_size_to_load"]) + + @property + def micro_batch_size_to_train_on(self): + return self.eu.scale_by_good_slices(_config.keys["micro_batch_size_to_train_on"]) + + @property + def num_slices(self): + return self.eu.good_slice_count + def __getattr__(self, attr): try: # Attempt to perform the normal lookup diff --git a/MaxText/train.py b/MaxText/train.py index 76d061eef..43bb4c740 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -25,6 +25,7 @@ import sys import functools import time +import traceback import queue from typing import Sequence @@ -71,12 +72,26 @@ from ml_goodput_measurement import goodput from ml_goodput_measurement import monitoring +import elasticutils + # pylint: disable=too-many-positional-arguments Transformer = models.Transformer EPS = 1e-8 _DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3 +def find_leaf_devices(tree): + return jax.tree.map(lambda x: {d.slice_index for d in x.devices()}, tree) + +def find_leaf_bad(tree): + def func(x): + try: + jax.block_until_ready(x) + return True + except: + return False + return jax.tree.map(func, tree) + def validate_train_config(config): """Validates the configuration is set correctly for train.py""" @@ -767,6 +782,86 @@ def setup_train_loop(config): ) +@elasticutils.timeit +def reshard_fn(config: pyconfig.HyperParameters): + """Reshard function.""" + while True: + try: + clear_buffered_metrics() + + init_rng, _, checkpoint_manager, mesh, model, learning_rate_schedule, tx = ( + setup_mesh_and_model(config) + ) + + restore_step = config.eu.data["save_step"] + restore_state = config.eu.reshard(config.eu.data["state"], mesh, donate=False) + + config.eu.save(restore_step, state=restore_state) + + data_iterator, _ = create_data_iterator(config, mesh) + state, _, state_mesh_shardings, data_iterator = max_utils.setup_training_state( + model, + data_iterator, + tx, + config, + jax.random.fold_in(init_rng, restore_step), + mesh, + checkpoint_manager, + ) + + state = state.replace( + step=restore_state.step, + params=restore_state.params, + opt_state=restore_state.opt_state, + ) + + ( + functional_train, + in_shard_train, + out_shard_train, + static_argnums_train, + donate_argnums_train, + ) = maxtext_utils.get_functional_train_with_signature( + train_step, mesh, state_mesh_shardings, model, config + ) + + p_train_step = jax.jit( + functional_train, + in_shardings=in_shard_train, + out_shardings=out_shard_train, + static_argnums=static_argnums_train, + donate_argnums=donate_argnums_train, + ) + + example_batch = None + jax.block_until_ready(state) + break + except jax.errors.JaxRuntimeError as e: + if "DATA_LOSS" in str(e): + max_logging.log("Caught JaxRuntimeError DATA_LOSS exception during resharding!") + max_logging.log(traceback.format_exc()) + elif "INTERNAL" in str(e): + max_logging.log("Caught JaxRuntimeError INTERNAL exception during resharding!") + max_logging.log(traceback.format_exc()) + + else: + max_logging.log("Unknown JaxRuntimeError during resharding!") + raise + + config.eu.slice_down() + + return ( + restore_step, + state, + mesh, + checkpoint_manager, + data_iterator, + p_train_step, + example_batch, + learning_rate_schedule, + ) + + def train_loop(config, state=None): """Main Training loop. Args: @@ -877,100 +972,143 @@ def train_loop(config, state=None): performance_metric_queue = queue.Queue() gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue) - for step in np.arange(start_step, config.steps): - if step == first_profiling_step or prof.should_activate_periodic_profile(step): - optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" - prof.activate(blocking_object=state, optional_postfix=optional_postfix) - - with jax.profiler.StepTraceAnnotation("train", step_num=step): - record_goodput(recorder, config, recorder.record_data_loading_start_time if recorder else None) - example_batch = load_next_batch(data_iterator, example_batch, config) - record_goodput(recorder, config, recorder.record_data_loading_end_time if recorder else None) - check_example_batch(config, example_batch=example_batch) - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) - record_goodput(recorder, config, recorder.record_step_start_time if recorder else None, step) - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - state, metrics = p_train_step(state, example_batch, nextrng) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens) - if performance_metric_queue: - performance_metric_queue.put(step_time_delta.total_seconds()) - - if checkpoint_manager is not None: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config): - checkpointing.print_save_message(step, config.async_checkpointing) - - # Upon preemption, exit when and only when all ongoing saves are complete. - if checkpoint_manager.reached_preemption(step): - checkpoint_manager.wait_until_finished() - sys.exit() - - write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config) - - if config.dump_hlo and step == start_step: - jax.block_until_ready(state) # Ensure compilation has finished. - max_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) + step = start_step - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - cumulative_eval_metrics = { - "scalar": { - "eval/total_loss": 0.0, - "eval/total_weights": 0.0, - "eval/avg_loss": 0.0, - "eval/moe_lb_loss": 0.0, - } - } - eval_dpo_reward_accuracy = 0.0 - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"]) - cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"]) - cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"]) - eval_dpo_reward_accuracy += float(eval_metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) # for dpo only - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - eval_loss = cumulative_eval_metrics["scalar"]["eval/total_loss"] / ( - cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS - ) - cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss - cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = ( - cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count - ) - if config.use_dpo: - cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = eval_dpo_reward_accuracy / eval_step_count - write_metrics( - writer, local_metrics_file, running_gcs_metrics, cumulative_eval_metrics, step, config, is_training=False - ) - max_logging.log( - f"average loss after {step=}: {eval_step_count=}, {eval_loss=}," - f" total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}" - ) - if eval_loss <= config.target_eval_loss: - max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}") - prof.deactivate() - break + while True: + with elasticutils.watchdog(120): + if step == first_profiling_step or prof.should_activate_periodic_profile(step): + optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" + prof.activate(blocking_object=state, optional_postfix=optional_postfix) - if step == last_profiling_step or prof.should_deactivate_periodic_profile(step): - prof.deactivate(blocking_object=state) + if step >= config.steps: + break + max_logging.log(f"{step=} {config.eu.failure_count=} {config.eu.good_slice_count=}") + try: + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules), jax.default_device(config.eu.default_device): + with jax.profiler.StepTraceAnnotation("train", step_num=step): + record_goodput(recorder, config, recorder.record_data_loading_start_time if recorder else None) + example_batch = load_next_batch(data_iterator, example_batch, config) + record_goodput(recorder, config, recorder.record_data_loading_end_time if recorder else None) + check_example_batch(config, example_batch=example_batch) + # pylint: disable=not-callable + nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + record_goodput(recorder, config, recorder.record_step_start_time if recorder else None, step) + state, metrics = p_train_step(state, example_batch, nextrng) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens) + if performance_metric_queue: + performance_metric_queue.put(step_time_delta.total_seconds()) + + if checkpoint_manager is not None: + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config): + checkpointing.print_save_message(step, config.async_checkpointing) + + # Upon preemption, exit when and only when all ongoing saves are complete. + if checkpoint_manager.reached_preemption(step): + checkpoint_manager.wait_until_finished() + sys.exit() + + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config) + + if config.dump_hlo and step == start_step: + jax.block_until_ready(state) # Ensure compilation has finished. + max_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) - if step == start_step: - max_utils.print_mem_stats("After params initialized") + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + cumulative_eval_metrics = { + "scalar": { + "eval/total_loss": 0.0, + "eval/total_weights": 0.0, + "eval/avg_loss": 0.0, + "eval/moe_lb_loss": 0.0, + } + } + eval_dpo_reward_accuracy = 0.0 + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, nextrng) + cumulative_eval_metrics["scalar"]["eval/total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"]) + cumulative_eval_metrics["scalar"]["eval/total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"]) + cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] += float(eval_metrics["scalar"]["evaluation/moe_lb_loss"]) + eval_dpo_reward_accuracy += float(eval_metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0)) # for dpo only + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + eval_loss = cumulative_eval_metrics["scalar"]["eval/total_loss"] / ( + cumulative_eval_metrics["scalar"]["eval/total_weights"] + EPS + ) + cumulative_eval_metrics["scalar"]["eval/avg_loss"] = eval_loss + cumulative_eval_metrics["scalar"]["eval/avg_moe_lb_loss"] = ( + cumulative_eval_metrics["scalar"]["eval/moe_lb_loss"] / eval_step_count + ) + if config.use_dpo: + cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = eval_dpo_reward_accuracy / eval_step_count + write_metrics( + writer, local_metrics_file, running_gcs_metrics, cumulative_eval_metrics, step, config, is_training=False + ) + max_logging.log( + f"average loss after {step=}: {eval_step_count=}, {eval_loss=}," + f" total_weights={cumulative_eval_metrics['scalar']['eval/total_weights']}" + ) + if eval_loss <= config.target_eval_loss: + max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}") + prof.deactivate() + break + + if step == last_profiling_step or prof.should_deactivate_periodic_profile(step): + prof.deactivate(blocking_object=state) + + reshard_flag = config.eu.is_ready_to_reshard(step) + if reshard_flag or step % config.eu.save_period == 0: + config.eu.save(step, state=state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + step += 1 + + except jax.errors.JaxRuntimeError as e: + if "DATA_LOSS" in str(e): + max_logging.log("Caught JaxRuntimeError DATA_LOSS exception") + max_logging.log(traceback.format_exc()) + elif "INTERNAL" in str(e): + max_logging.log("Caught JaxRuntimeError INTERNAL exception") + max_logging.log(traceback.format_exc()) + + else: + max_logging.log("Unknown JaxRuntimeError") + raise + + config.eu.slice_down() + reshard_flag = True + + if reshard_flag: + (step, + state, + mesh, + checkpoint_manager, + data_iterator, + p_train_step, + example_batch, + learning_rate_schedule,) = reshard_fn(config) + max_logging.log("Resharding complete. Continuing") + reshard_flag = False + + if step == start_step: + max_utils.print_mem_stats("After params initialized") if checkpoint_manager is not None: checkpoint_manager.wait_until_finished() From 413484e814c01fedfc7cd21521cc188b3ad60270 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 5 Feb 2025 21:59:08 +0000 Subject: [PATCH 03/10] Updated elasticutils and added a fake elasticutils --- MaxText/elasticutils.py | 283 ++++++++++++++++++++++++++++++----- MaxText/elasticutils_fake.py | 128 ++++++++++++++++ MaxText/train.py | 4 + 3 files changed, 376 insertions(+), 39 deletions(-) create mode 100644 MaxText/elasticutils_fake.py diff --git a/MaxText/elasticutils.py b/MaxText/elasticutils.py index 2f5e32d5b..8d0b6acbf 100644 --- a/MaxText/elasticutils.py +++ b/MaxText/elasticutils.py @@ -2,39 +2,72 @@ import collections import contextlib -import itertools import functools +import itertools import logging -import os import sys -import time import threading +import time import traceback -from typing import Sequence, Any, Optional, Callable +from typing import Any, Callable, Optional, Sequence + import jax import numpy as np +jax._src.array.ArrayImpl._check_if_deleted = lambda _: False # pylint: disable=protected-access + PyTree = Any logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) logger.setLevel(logging.INFO) +# pylint: disable=logging-fstring-interpolation + + +class Profile: + """Profile context manager.""" + + def __init__(self, gcs_path: Optional[str] = None): + self.gcs_path = gcs_path + + def __enter__(self): + if self.gcs_path: + jax.profiler.start_trace(self.gcs_path) + + def __exit__(self, exc_type, exc_value, tb): + if self.gcs_path: + jax.profiler.stop_trace() -@contextlib.contextmanager -def timer(name: str): - start = time.time() - try: - yield - finally: - end = time.time() - logger.info("%s elaspsed %.2fs.", name, end - start) -def timeit(func: Callable): +class Timer: + """Timer context manager.""" + + def __init__(self, name): + self.name = name + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, exc_type, exc_value, tb): + self.stop = time.time() + self.time = self.stop - self.start + logger.info(str(self)) + + def __str__(self): + return f"{self.name} elaspsed {self.time}." + + +def timeit( + func: Callable[..., Any], name: Optional[str] = None +) -> Callable[..., Any]: + if name is None: + name = getattr(func, "__name__", "Unknown") + @functools.wraps(func) def wrapper(*args, **kwargs): - with timer(func.__name__): + with Timer(name): return func(*args, **kwargs) return wrapper @@ -77,7 +110,9 @@ def slice_down(self): self.good_slice_indices = self.get_slice_availability() self.failure_count += 1 - logger.info(f"Failure count: {self.failure_count} with max {self.max_failures}") + logger.info( + f"Failure count: {self.failure_count} with max {self.max_failures}" + ) if self.failure_count >= self.max_failures: logger.fatal(f"Max failures reached {self.max_failures}") @@ -94,10 +129,15 @@ def save(self, save_step: int, **kwargs): self.data = data def is_ready_to_reshard(self, step: int): - """ - Indicates if it is time to reshard. + """Indicates if it is time to reshard. May update `good_slice_indices`. + + Args: + step: The current step. + + Returns: + True if it is time to reshard, False otherwise. """ if step % self.reshard_check_period: return False @@ -191,7 +231,7 @@ def get_slice_availability(self) -> set[int]: + self.TEST_VALUE ) try: - with timer(f"checking {slice_index=}"): + with Timer(f"checking {slice_index=}"): if np.allclose(x, expected): good_slice_indices.add(slice_index) logger.info(f"{slice_index=} good") @@ -214,23 +254,170 @@ def get_slice_availability(self) -> set[int]: return good_slice_indices - @staticmethod + @classmethod @timeit def reshard( - tree: PyTree, - mesh: jax.sharding.Mesh, + cls, + x: Any, + sharding: jax.sharding.Sharding | Any, *, - donate: bool = True, - ) -> PyTree: - """Reshard a PyTree.""" - def func(leaf): - return jax.device_put( - leaf, - jax.sharding.NamedSharding(mesh, leaf.sharding.spec), - donate=donate, - ) + donate_input: bool = True, + put_array: Optional[ + Callable[ + [jax.Array, Sequence[jax.sharding.Sharding], bool], jax.Array + ] + ] = None, + ) -> Any: + """Reshards `x` to the specified `sharding`. + + Args: + x: An array, scalar, or a nested Python container thereof. + sharding: A `Sharding` or a nested `Sharding` in a Python container + (must match the structure of `x`), specifying the target sharding. + donate_input: If `True`, donates the input arrays to reduce memory + needed for resharding. Donated buffers should not be reused. + put_array: A function that takes an array, a sharding, and a boolean + indicating whether to donate the input, and returns a copy of the + array with the specified sharding. + + Returns: + A copy of `x` with the specified `sharding`. + """ + if put_array is None: + put_array = cls.default_put_array + + flat_x, tree_def = jax.tree_util.tree_flatten(x) + flat_sharding = jax.api_util.flatten_axes( + "reshard sharding", tree_def, sharding + ) + + if len(flat_x) != len(flat_sharding): + raise ValueError("Mismatched length between `x` and `sharding`.") + + arrays = [ + put_array(arr, dst_sharding, donate_input) + for arr, dst_sharding in zip(flat_x, flat_sharding) + ] + return jax.tree_util.tree_unflatten(tree_def, arrays) + + @staticmethod + def put_array_device_put0( + arr: jax.Array, + dst_sharding: jax.sharding.Sharding, + donate_input: bool, + ): + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `Sharding` instances.") + return jax.device_put(arr, dst_sharding, donate=donate_input) + + default_put_array = put_array_device_put0 + + def put_array_device_put1( + self, + arr: jax.Array, + dst_sharding: jax.sharding.Sharding, + donate_input: bool, # pylint: disable=unused-argument + ): + """Reshards `arr` to the specified `dst_sharding`. + + Args: + arr: An array, scalar, or a nested Python container thereof. + dst_sharding: A `Sharding` or a nested `Sharding` in a Python container + (must match the structure of `x`), specifying the target sharding. + donate_input: If `True`, donates the input arrays to reduce memory + needed for resharding. Donated buffers should not be reused. + + Returns: + A copy of `x` with the specified `sharding`. + """ + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `Sharding` instances.") + + if dst_sharding.num_devices <= arr.sharding.num_devices: + # Reshard down + arrays = [ + x.data + for x in arr.addressable_shards + if x.device.slice_index in self.good_slice_indices + ] + else: + # Reshard up + arrays = [x.data for x in arr.addressable_shards] + + good_reference_slice = arr.addressable_shards[0].device.slice_index + good_reference_arrays = [ + array + for array in arrays + if array.device.slice_index == good_reference_slice + ] + + new_slice_index = ( + self.good_slice_indices + - {d.slice_index for d in arr.sharding.device_set} + ).pop() + + for device, array in zip( + self.slice_to_devices[new_slice_index], good_reference_arrays + ): + arrays.append(jax.device_put(array, device)) + + return jax.make_array_from_single_device_arrays( + arr.shape, dst_sharding, arrays + ) - return jax.tree.map(func, tree) + def put_array_device_put2( + self, + arr: jax.Array, + dst_sharding: jax.sharding.Sharding, + donate_input: bool, # pylint: disable=unused-argument + ): + """Reshards `arr` to the specified `dst_sharding`. + + Args: + arr: An array, scalar, or a nested Python container thereof. + dst_sharding: A `Sharding` or a nested `Sharding` in a Python container + (must match the structure of `x`), specifying the target sharding. + donate_input: If `True`, donates the input arrays to reduce memory + needed for resharding. Donated buffers should not be reused. + + Returns: + A copy of `x` with the specified `sharding`. + """ + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `Sharding` instances.") + + if dst_sharding.num_devices <= arr.sharding.num_devices: + # Reshard down + arrays = [ + x.data + for x in arr.addressable_shards + if x.device.slice_index in self.good_slice_indices + ] + else: + # Reshard up + arrays = [x.data for x in arr.addressable_shards] + + good_reference_slice = arr.addressable_shards[0].device.slice_index + good_reference_arrays = [ + array + for array in arrays + if array.device.slice_index == good_reference_slice + ] + + new_slice_index = ( + self.good_slice_indices + - {d.slice_index for d in arr.sharding.device_set} + ).pop() + + new_arrays = jax.device_put( + good_reference_arrays, self.slice_to_devices[new_slice_index] + ) + + arrays += new_arrays + + return jax.make_array_from_single_device_arrays( + arr.shape, dst_sharding, arrays + ) def scale_by_good_slices(self, x: int | float) -> int | float: """Scale x by the number of good slices.""" @@ -249,7 +436,17 @@ def scale_by_good_slices(self, x: int | float) -> int | float: @contextlib.contextmanager -def watchdog(timeout): +def watchdog(timeout: float): + """Watchdog context manager. + + Prints the stack trace of all threads every `timeout` seconds. + + Args: + timeout: The timeout in seconds. + + Yields: + None + """ event = threading.Event() def handler(): @@ -260,8 +457,15 @@ def handler(): for thread in threading.enumerate(): try: logger.info(f"Thread: {thread.ident}") - logger.info("".join(traceback.format_stack(sys._current_frames().get(thread.ident, [])))) - except: + logger.info( + "".join( + traceback.format_stack( + sys._current_frames() # pylint: disable=protected-access + .get(thread.ident, []) + ) + ) + ) + except Exception: # pylint: disable=broad-exception-caught logger.info(f"Error print traceback for {thread.ident=}") pass finally: @@ -272,11 +476,12 @@ def handler(): count += 1 logger.debug("Registering watchdog") - watchdog = threading.Thread(target=handler, name="watchdog") - watchdog.start() + watchdog_thread = threading.Thread(target=handler, name="watchdog") + watchdog_thread.start() try: yield finally: event.set() - watchdog.join() - logger.debug("Degistering watchdog") + watchdog_thread.join() + logger.debug("Deregistering watchdog") + diff --git a/MaxText/elasticutils_fake.py b/MaxText/elasticutils_fake.py new file mode 100644 index 000000000..5e9586220 --- /dev/null +++ b/MaxText/elasticutils_fake.py @@ -0,0 +1,128 @@ +"""Utilities for elastic training.""" +import logging +from typing import Any, Optional, Sequence + +from elasticutils import ElasticUtils +from elasticutils import timeit +import jax + + +PyTree = Any + +logger = logging.getLogger(__name__) + +logger.setLevel(logging.INFO) + +# pylint: disable=logging-fstring-interpolation + + +class FakeElasticUtils(ElasticUtils): + """Utility class for elastic training. + + This class will simulate slices going down and coming back up. + """ + + def __init__( + self, + devices: Sequence[jax.Device], + total_slice_count: int, + save_period: Optional[int] = None, + reshard_check_period: Optional[int] = None, + max_failures: Optional[int] = None, + ): + self.fake_good_slice_indices = set(d.slice_index for d in devices) + + super().__init__( + devices, + total_slice_count, + save_period, + reshard_check_period, + max_failures, + ) + + def update_good_slice_indices(self, good_slice_indices: set[int]): + """Start step handler.""" + self.fake_good_slice_indices = good_slice_indices + self.good_slice_indices = self.get_slice_availability() + + @timeit + def get_slice_availability(self) -> set[int]: + """Returns the set of good and bad slices.""" + good_slice_indices = self.fake_good_slice_indices + + logger.info(f"{good_slice_indices=}") + + return good_slice_indices + + # Does not work + @staticmethod + def put_array_jit( + arr: jax.Array, + dst_sharding: jax.sharding.Sharding, + donate_input: bool, + ): + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `Sharding` instances.") + + return jax.jit( + lambda x: x, + out_shardings=dst_sharding, + donate_argnums=(0,) if donate_input else (), + )(arr) + + # Slower than actually resharding + @staticmethod + def put_array_fake0( + arr: jax.Array, + dst_sharding: jax.sharding.Sharding, + donate_input: bool, + ): + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `Sharding` instances.") + return jax.numpy.ones_like(arr, device=dst_sharding) + + def put_array_fake1( + self, + arr: jax.Array, + dst_sharding: jax.sharding.Sharding, + donate_input: bool, # pylint: disable=unused-argument + ): + """Reshards `arr` to the specified `dst_sharding`. + + Args: + arr: An array, scalar, or a nested Python container thereof. + dst_sharding: A `Sharding` or a nested `Sharding` in a Python container + (must match the structure of `x`), specifying the target sharding. + donate_input: If `True`, donates the input arrays to reduce memory + needed for resharding. Donated buffers should not be reused. + + Returns: + A copy of `x` with the specified `sharding`. + """ + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `Sharding` instances.") + + if dst_sharding.num_devices <= arr.sharding.num_devices: + # Reshard down + arrays = [ + x.data + for x in arr.addressable_shards + if x.device.slice_index in self.good_slice_indices + ] + else: + # Reshard up + arrays = [x.data for x in arr.addressable_shards] + + new_slice_index = ( + self.good_slice_indices + - {d.slice_index for d in arr.sharding.device_set} + ).pop() + + new_arrays = [jax.numpy.zeros_like(arr, device=device) + for device in self.slice_to_devices[new_slice_index]] + + arrays += new_arrays + + return jax.make_array_from_single_device_arrays( + arr.shape, dst_sharding, arrays + ) diff --git a/MaxText/train.py b/MaxText/train.py index 43bb4c740..2a6aa0502 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -794,6 +794,10 @@ def reshard_fn(config: pyconfig.HyperParameters): ) restore_step = config.eu.data["save_step"] + sharding = jax.sharding.NamedSharding( + mesh, + config.eu.data["state"].sharding.spec, + ) restore_state = config.eu.reshard(config.eu.data["state"], mesh, donate=False) config.eu.save(restore_step, state=restore_state) From b9982319cff8cfb27afcd1d79dd12e20633b29c7 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Thu, 6 Feb 2025 23:23:27 +0000 Subject: [PATCH 04/10] Working fake elastic utils --- MaxText/elasticutils_fake.py | 2 +- MaxText/pyconfig.py | 4 ++-- MaxText/train.py | 39 ++++++++++++++++++++++++++---------- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/MaxText/elasticutils_fake.py b/MaxText/elasticutils_fake.py index 5e9586220..72ebf78e5 100644 --- a/MaxText/elasticutils_fake.py +++ b/MaxText/elasticutils_fake.py @@ -43,7 +43,7 @@ def __init__( def update_good_slice_indices(self, good_slice_indices: set[int]): """Start step handler.""" self.fake_good_slice_indices = good_slice_indices - self.good_slice_indices = self.get_slice_availability() + logger.info(f"Updated: {self.fake_good_slice_indices=}") @timeit def get_slice_availability(self) -> set[int]: diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 6c56116a6..b652b07e4 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -22,7 +22,7 @@ import sys from typing import Any, Union -import elasticutils +from elasticutils_fake import FakeElasticUtils as ElasticUtils import jax from jax.experimental.compilation_cache import compilation_cache from layers.attentions import AttentionType @@ -474,7 +474,7 @@ def user_init(raw_keys): raw_keys["add_eos"] = False max_logging.log("Override add_bos and add_eos to False when dataset_type=c4_mlperf") - raw_keys["eu"] = elasticutils.ElasticUtils(jax.devices(), raw_keys["num_slices"], save_period=3, reshard_check_period=5, max_failures=10) + raw_keys["eu"] = ElasticUtils(jax.devices(), raw_keys["num_slices"], save_period=3, reshard_check_period=5, max_failures=10) # Write raw_keys to GCS before type conversions max_utils.write_config_raw_keys_for_gcs(raw_keys) diff --git a/MaxText/train.py b/MaxText/train.py index 2a6aa0502..8690bec8f 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -794,11 +794,14 @@ def reshard_fn(config: pyconfig.HyperParameters): ) restore_step = config.eu.data["save_step"] - sharding = jax.sharding.NamedSharding( - mesh, - config.eu.data["state"].sharding.spec, + sharding = jax.tree.map( + lambda x: jax.sharding.NamedSharding(mesh, x.sharding.spec), + config.eu.data["state"], + ) + restore_state = config.eu.reshard( + config.eu.data["state"], + sharding, ) - restore_state = config.eu.reshard(config.eu.data["state"], mesh, donate=False) config.eu.save(restore_step, state=restore_state) @@ -978,16 +981,30 @@ def train_loop(config, state=None): step = start_step + step_down = {10, 30, 44} + step_up = {14, 16, 40, 45} while True: with elasticutils.watchdog(120): - if step == first_profiling_step or prof.should_activate_periodic_profile(step): - optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" - prof.activate(blocking_object=state, optional_postfix=optional_postfix) - - if step >= config.steps: - break - max_logging.log(f"{step=} {config.eu.failure_count=} {config.eu.good_slice_count=}") try: + if step in step_down: + step_down.remove(step) + # Remove a slice + config.eu.update_good_slice_indices(set(range(config.eu.total_slice_count)) - {step % config.eu.total_slice_count}) + raise jax.errors.JaxRuntimeError("DATA_LOSS: Fake") + elif step in step_up: + step_up.remove(step) + + config.eu.update_good_slice_indices(set(range(config.eu.total_slice_count))) + + + if step == first_profiling_step or prof.should_activate_periodic_profile(step): + optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" + prof.activate(blocking_object=state, optional_postfix=optional_postfix) + + if step >= config.steps: + break + + max_logging.log(f"{step=} {config.eu.failure_count=} {config.eu.good_slice_count=}") with mesh, nn_partitioning.axis_rules(config.logical_axis_rules), jax.default_device(config.eu.default_device): with jax.profiler.StepTraceAnnotation("train", step_num=step): record_goodput(recorder, config, recorder.record_data_loading_start_time if recorder else None) From ebffd1ba205a7ab476790579f7198be6670bbfeb Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Fri, 7 Feb 2025 00:20:07 +0000 Subject: [PATCH 05/10] Working host memory offloading. Unverified --- MaxText/elasticutils.py | 8 +++++++- MaxText/train.py | 36 +++++++++++++++++++++--------------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/MaxText/elasticutils.py b/MaxText/elasticutils.py index 8d0b6acbf..a9d989961 100644 --- a/MaxText/elasticutils.py +++ b/MaxText/elasticutils.py @@ -121,7 +121,13 @@ def save(self, save_step: int, **kwargs): """Save step and state.""" # In case DATA_LOSS occurs during jax.block_until_ready, overwrite self.data # at the end - data = {k: jax.tree.map(lambda x: x.copy(), v) for k, v in kwargs.items()} + data = { + k: jax.device_put( + v, + jax.tree.map(lambda x: x.sharding.with_memory_kind(kind="pinned_host"), v), + ) + for k, v in kwargs.items() + } for v in data.values(): jax.block_until_ready(v) data["save_step"] = save_step diff --git a/MaxText/train.py b/MaxText/train.py index 8690bec8f..06a5721e7 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -794,16 +794,6 @@ def reshard_fn(config: pyconfig.HyperParameters): ) restore_step = config.eu.data["save_step"] - sharding = jax.tree.map( - lambda x: jax.sharding.NamedSharding(mesh, x.sharding.spec), - config.eu.data["state"], - ) - restore_state = config.eu.reshard( - config.eu.data["state"], - sharding, - ) - - config.eu.save(restore_step, state=restore_state) data_iterator, _ = create_data_iterator(config, mesh) state, _, state_mesh_shardings, data_iterator = max_utils.setup_training_state( @@ -816,10 +806,22 @@ def reshard_fn(config: pyconfig.HyperParameters): checkpoint_manager, ) - state = state.replace( - step=restore_state.step, - params=restore_state.params, - opt_state=restore_state.opt_state, + state = state.replace(step=restore_step) + params_sharding = jax.tree.map( + lambda x: jax.sharding.NamedSharding(mesh, x.sharding.spec), + config.eu.data["params"], + ) + state = state.replace(params=config.eu.reshard(config.eu.data["params"], params_sharding)) + opt_state_sharding = jax.tree.map( + lambda x: jax.sharding.NamedSharding(mesh, x.sharding.spec), + config.eu.data["opt_state"], + ) + state = state.replace(opt_state=config.eu.reshard(config.eu.data["opt_state"], opt_state_sharding)) + + config.eu.save( + restore_step, + params=state.params, + opt_state=state.opt_state, ) ( @@ -1094,7 +1096,11 @@ def train_loop(config, state=None): reshard_flag = config.eu.is_ready_to_reshard(step) if reshard_flag or step % config.eu.save_period == 0: - config.eu.save(step, state=state) + config.eu.save( + step, + params=state.params, + opt_state=state.opt_state, + ) if step == start_step: max_utils.print_mem_stats("After params initialized") From 21b7600a546815378ecbffec8701ffb79ded0d2a Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Fri, 7 Feb 2025 20:13:58 +0000 Subject: [PATCH 06/10] Host memory offloading done --- MaxText/train.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/MaxText/train.py b/MaxText/train.py index 06a5721e7..cf62440ad 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -793,9 +793,10 @@ def reshard_fn(config: pyconfig.HyperParameters): setup_mesh_and_model(config) ) + data_iterator, _ = create_data_iterator(config, mesh) + restore_step = config.eu.data["save_step"] - data_iterator, _ = create_data_iterator(config, mesh) state, _, state_mesh_shardings, data_iterator = max_utils.setup_training_state( model, data_iterator, @@ -806,17 +807,21 @@ def reshard_fn(config: pyconfig.HyperParameters): checkpoint_manager, ) - state = state.replace(step=restore_step) - params_sharding = jax.tree.map( - lambda x: jax.sharding.NamedSharding(mesh, x.sharding.spec), - config.eu.data["params"], - ) - state = state.replace(params=config.eu.reshard(config.eu.data["params"], params_sharding)) - opt_state_sharding = jax.tree.map( - lambda x: jax.sharding.NamedSharding(mesh, x.sharding.spec), - config.eu.data["opt_state"], + def reshard(arr): + return config.eu.reshard( + arr, + jax.tree.map( + lambda x: jax.sharding.NamedSharding(mesh, x.sharding.spec), + arr, + ), + ) + + state = state.replace(step=0, params=None, opt_state=None) + state = state.replace( + step=restore_step, + params=reshard(config.eu.data["params"]), + opt_state=reshard(config.eu.data["opt_state"]), ) - state = state.replace(opt_state=config.eu.reshard(config.eu.data["opt_state"], opt_state_sharding)) config.eu.save( restore_step, From 8b768a3fe118a742160b75c8fbd98c8d0cdfb4e2 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 12 Feb 2025 22:29:38 +0000 Subject: [PATCH 07/10] Updating elasticutils --- MaxText/elasticutils.py | 70 +++++++++++++++++++++++++++++++++--- MaxText/elasticutils_fake.py | 46 ------------------------ 2 files changed, 66 insertions(+), 50 deletions(-) diff --git a/MaxText/elasticutils.py b/MaxText/elasticutils.py index a9d989961..f18ff536e 100644 --- a/MaxText/elasticutils.py +++ b/MaxText/elasticutils.py @@ -117,19 +117,22 @@ def slice_down(self): logger.fatal(f"Max failures reached {self.max_failures}") @timeit - def save(self, save_step: int, **kwargs): + def save(self, save_step: int, blocking: bool = True, **kwargs): """Save step and state.""" # In case DATA_LOSS occurs during jax.block_until_ready, overwrite self.data # at the end data = { k: jax.device_put( v, - jax.tree.map(lambda x: x.sharding.with_memory_kind(kind="pinned_host"), v), + jax.tree.map( + lambda x: x.sharding.with_memory_kind(kind="pinned_host"), v + ), ) for k, v in kwargs.items() } - for v in data.values(): - jax.block_until_ready(v) + if blocking: + for v in data.values(): + jax.block_until_ready(v) data["save_step"] = save_step self.data = data @@ -425,6 +428,65 @@ def put_array_device_put2( arr.shape, dst_sharding, arrays ) + def put_array_device_put3( + self, + arr: jax.Array, + dst_sharding: jax.sharding.Sharding, + donate_input: bool, # pylint: disable=unused-argument + ): + """Reshards `arr` to the specified `dst_sharding`. + + Args: + arr: An array, scalar, or a nested Python container thereof. + dst_sharding: A `Sharding` or a nested `Sharding` in a Python container + (must match the structure of `x`), specifying the target sharding. + donate_input: If `True`, donates the input arrays to reduce memory + needed for resharding. Donated buffers should not be reused. + + Returns: + A copy of `x` with the specified `sharding`. + """ + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `Sharding` instances.") + + if dst_sharding.num_devices <= arr.sharding.num_devices: + # Reshard down + arrays = [ + x.data + for x in arr.addressable_shards + if x.device.slice_index in self.good_slice_indices + ] + else: + # Reshard up + slice_to_arrays = collections.defaultdict(list) + for x in arr.addressable_shards: + slice_to_arrays[x.data.device.slice_index].append(x.data) + slice_to_arrays = dict(slice_to_arrays) + + good_data_slice_indices = {d.slice_index for d in arr.sharding.device_set} + new_slice_indices = self.good_slice_indices - good_data_slice_indices + + new_arrays = [] + for i, slice_index in enumerate(good_data_slice_indices): + arrays = slice_to_arrays[slice_index] + start_index = len(arrays) * i // len(self.good_slice_indices) + end_index = len(arrays) * (i + 1) // len(self.good_slice_indices) + + arrays_to_put = arrays[start_index:end_index] + + for new_slice_index in new_slice_indices: + new_arrays += jax.device_put( + arrays_to_put, + self.slice_to_devices[new_slice_index][start_index:end_index], + ) + + arrays = sum(slice_to_arrays.values(), []) + new_arrays + + return jax.make_array_from_single_device_arrays( + arr.shape, dst_sharding, arrays + ) + + def scale_by_good_slices(self, x: int | float) -> int | float: """Scale x by the number of good slices.""" if isinstance(x, int): diff --git a/MaxText/elasticutils_fake.py b/MaxText/elasticutils_fake.py index 72ebf78e5..f758fe98d 100644 --- a/MaxText/elasticutils_fake.py +++ b/MaxText/elasticutils_fake.py @@ -80,49 +80,3 @@ def put_array_fake0( if not isinstance(dst_sharding, jax.sharding.Sharding): raise ValueError("`sharding` must contain only `Sharding` instances.") return jax.numpy.ones_like(arr, device=dst_sharding) - - def put_array_fake1( - self, - arr: jax.Array, - dst_sharding: jax.sharding.Sharding, - donate_input: bool, # pylint: disable=unused-argument - ): - """Reshards `arr` to the specified `dst_sharding`. - - Args: - arr: An array, scalar, or a nested Python container thereof. - dst_sharding: A `Sharding` or a nested `Sharding` in a Python container - (must match the structure of `x`), specifying the target sharding. - donate_input: If `True`, donates the input arrays to reduce memory - needed for resharding. Donated buffers should not be reused. - - Returns: - A copy of `x` with the specified `sharding`. - """ - if not isinstance(dst_sharding, jax.sharding.Sharding): - raise ValueError("`sharding` must contain only `Sharding` instances.") - - if dst_sharding.num_devices <= arr.sharding.num_devices: - # Reshard down - arrays = [ - x.data - for x in arr.addressable_shards - if x.device.slice_index in self.good_slice_indices - ] - else: - # Reshard up - arrays = [x.data for x in arr.addressable_shards] - - new_slice_index = ( - self.good_slice_indices - - {d.slice_index for d in arr.sharding.device_set} - ).pop() - - new_arrays = [jax.numpy.zeros_like(arr, device=device) - for device in self.slice_to_devices[new_slice_index]] - - arrays += new_arrays - - return jax.make_array_from_single_device_arrays( - arr.shape, dst_sharding, arrays - ) From 0cbd4b5b896d60bb67575f012ecf650e7903ae4e Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Fri, 21 Feb 2025 22:09:20 +0000 Subject: [PATCH 08/10] Checking in --- MaxText/elasticutils_fake.py | 1 + MaxText/train.py | 1 + 2 files changed, 2 insertions(+) diff --git a/MaxText/elasticutils_fake.py b/MaxText/elasticutils_fake.py index f758fe98d..93979c187 100644 --- a/MaxText/elasticutils_fake.py +++ b/MaxText/elasticutils_fake.py @@ -80,3 +80,4 @@ def put_array_fake0( if not isinstance(dst_sharding, jax.sharding.Sharding): raise ValueError("`sharding` must contain only `Sharding` instances.") return jax.numpy.ones_like(arr, device=dst_sharding) + diff --git a/MaxText/train.py b/MaxText/train.py index cf62440ad..344d7bcb9 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -814,6 +814,7 @@ def reshard(arr): lambda x: jax.sharding.NamedSharding(mesh, x.sharding.spec), arr, ), + put_array=config.eu.put_array_device_put2, ) state = state.replace(step=0, params=None, opt_state=None) From e1ca0d2c4691f135db037c2cb3ab18f788f03e55 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Fri, 21 Feb 2025 23:02:30 +0000 Subject: [PATCH 09/10] Added a max reshard retry count --- MaxText/elasticutils.py | 33 ++++++++++++++++++++++++--------- MaxText/elasticutils_fake.py | 4 +++- MaxText/train.py | 2 +- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/MaxText/elasticutils.py b/MaxText/elasticutils.py index f18ff536e..90c4fc303 100644 --- a/MaxText/elasticutils.py +++ b/MaxText/elasticutils.py @@ -83,7 +83,8 @@ def __init__( total_slice_count: int, save_period: Optional[int] = None, reshard_check_period: Optional[int] = None, - max_failures: Optional[int] = None, + max_failure_count: Optional[int] = None, + max_reshard_retry_count: Optional[int] = None, ): self.devices = devices self.total_slice_count = total_slice_count @@ -96,25 +97,40 @@ def __init__( reshard_check_period = 1 self.reshard_check_period = reshard_check_period - if max_failures is None: - max_failures = float("inf") - self.max_failures = max_failures + if max_failure_count is None: + max_failure_count = float("inf") + self.max_failure_count = max_failure_count + + if max_reshard_retry_count is None: + max_reshard_retry_count = float("inf") + self.max_reshard_retry_count = max_reshard_retry_count self.failure_count = 0 + self.reshard_retry_count = 0 self.good_slice_indices = self.get_slice_availability() self.data = {} - def slice_down(self): + def slice_down(self, reshard_retry: bool = False): """Slice down.""" logger.info("Slice down") self.good_slice_indices = self.get_slice_availability() self.failure_count += 1 + if reshard_retry: + self.reshard_retry_count += 1 + else: + self.reshard_retry_count = 0 + + logger.info(f"{self.failure_count=} {self.max_failure_count=}") + if self.failure_count >= self.max_failure_count: + logger.fatal(f"Max failure count reached {self.max_failure_count}") logger.info( - f"Failure count: {self.failure_count} with max {self.max_failures}" + f"{self.reshard_retry_count=} {self.max_reshard_retry_count=}" ) - if self.failure_count >= self.max_failures: - logger.fatal(f"Max failures reached {self.max_failures}") + if self.reshard_retry_count > self.max_reshard_retry_count: + logger.fatal( + f"Max reshard retry count reached {self.max_reshard_retry_count}" + ) @timeit def save(self, save_step: int, blocking: bool = True, **kwargs): @@ -486,7 +502,6 @@ def put_array_device_put3( arr.shape, dst_sharding, arrays ) - def scale_by_good_slices(self, x: int | float) -> int | float: """Scale x by the number of good slices.""" if isinstance(x, int): diff --git a/MaxText/elasticutils_fake.py b/MaxText/elasticutils_fake.py index 93979c187..05c717055 100644 --- a/MaxText/elasticutils_fake.py +++ b/MaxText/elasticutils_fake.py @@ -28,7 +28,8 @@ def __init__( total_slice_count: int, save_period: Optional[int] = None, reshard_check_period: Optional[int] = None, - max_failures: Optional[int] = None, + max_failure_count: Optional[int] = None, + max_reshard_retry_count: Optional[int] = None, ): self.fake_good_slice_indices = set(d.slice_index for d in devices) @@ -38,6 +39,7 @@ def __init__( save_period, reshard_check_period, max_failures, + max_reshard_retry_count, ) def update_good_slice_indices(self, good_slice_indices: set[int]): diff --git a/MaxText/train.py b/MaxText/train.py index 344d7bcb9..743adf006 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -863,7 +863,7 @@ def reshard(arr): max_logging.log("Unknown JaxRuntimeError during resharding!") raise - config.eu.slice_down() + config.eu.slice_down(reshard_retry=True) return ( restore_step, From 0c975ce3b6e993e0e07a8ea1ced3d9f2b019f705 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 26 Feb 2025 00:25:35 +0000 Subject: [PATCH 10/10] Updated elasticutils to how it will be structured in pathwaysutils --- MaxText/elastic/reshard.py | 71 +++++++++++ MaxText/elastic/simulator.py | 70 +++++++++++ MaxText/{elasticutils.py => elastic/utils.py} | 112 +++++++++--------- MaxText/elasticutils_fake.py | 85 ------------- MaxText/pyconfig.py | 2 +- 5 files changed, 198 insertions(+), 142 deletions(-) create mode 100644 MaxText/elastic/reshard.py create mode 100644 MaxText/elastic/simulator.py rename MaxText/{elasticutils.py => elastic/utils.py} (90%) delete mode 100644 MaxText/elasticutils_fake.py diff --git a/MaxText/elastic/reshard.py b/MaxText/elastic/reshard.py new file mode 100644 index 000000000..a5c37ef63 --- /dev/null +++ b/MaxText/elastic/reshard.py @@ -0,0 +1,71 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Resharding API for elastic training.""" + +from typing import Any +from typing import Callable, Optional, Sequence +import jax + + +def default_put_array( + arr: jax.Array, + dst_sharding: jax.sharding.Sharding, + donate_input: bool, +): + if not isinstance(dst_sharding, jax.sharding.Sharding): + raise ValueError("`sharding` must contain only `Sharding` instances.") + return jax.device_put(arr, dst_sharding, donate=donate_input) + + +def reshard( + x: Any, + sharding: jax.sharding.Sharding | Any, + *, + donate_input: bool = True, + put_array: Optional[ + Callable[[jax.Array, Sequence[jax.sharding.Sharding], bool], jax.Array] + ] = None, +) -> Any: + """Reshards `x` to the specified `sharding`. + + Args: + x: An array, scalar, or a nested Python container thereof. + sharding: A `Sharding` or a nested `Sharding` in a Python container (must + match the structure of `x`), specifying the target sharding. + donate_input: If `True`, donates the input arrays to reduce memory needed + for resharding. Donated buffers should not be reused. + put_array: A function that takes an array, a sharding, and a boolean + indicating whether to donate the input, and returns a copy of the array + with the specified sharding. + + Returns: + A copy of `x` with the specified `sharding`. + """ + if put_array is None: + put_array = default_put_array + + flat_x, tree_def = jax.tree_util.tree_flatten(x) + flat_sharding = jax.api_util.flatten_axes( + "reshard sharding", tree_def, sharding + ) + + if len(flat_x) != len(flat_sharding): + raise ValueError("Mismatched length between `x` and `sharding`.") + + arrays = [ + put_array(arr, dst_sharding, donate_input) + for arr, dst_sharding in zip(flat_x, flat_sharding) + ] + return jax.tree_util.tree_unflatten(tree_def, arrays) + diff --git a/MaxText/elastic/simulator.py b/MaxText/elastic/simulator.py new file mode 100644 index 000000000..f59d8021e --- /dev/null +++ b/MaxText/elastic/simulator.py @@ -0,0 +1,70 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for elastic training.""" + +import logging +from typing import Any, Optional, Sequence + +import jax +from pathwaysutils.google_internal.elastic import utils + + +PyTree = Any + +logger = logging.getLogger(__name__) + +logger.setLevel(logging.INFO) + +# pylint: disable=logging-fstring-interpolation + + +class ElasticUtilsSimulator(utils.ElasticUtils): + """Utility class for elastic training. + + This class will simulate slices going down and coming back up. + """ + simulated_good_slice_indices: set[int] + + def __init__( + self, + devices: Sequence[jax.Device], + total_slice_count: int, + save_period: Optional[int] = None, + reshard_check_period: Optional[int] = None, + max_failures: Optional[int] = None, + ): + self.simulated_good_slice_indices = set(d.slice_index for d in devices) + + super().__init__( + devices, + total_slice_count, + save_period, + reshard_check_period, + max_failures, + ) + + def update_good_slice_indices(self, good_slice_indices: set[int]): + """Start step handler.""" + self.simulated_good_slice_indices = good_slice_indices + logger.info(f"Updated: {self.simulated_good_slice_indices=}") + + @utils.timeit + def get_slice_availability(self) -> set[int]: + """Returns the set of good and bad slices.""" + good_slice_indices = self.simulated_good_slice_indices + + logger.info(f"{good_slice_indices=}") + + return good_slice_indices + diff --git a/MaxText/elasticutils.py b/MaxText/elastic/utils.py similarity index 90% rename from MaxText/elasticutils.py rename to MaxText/elastic/utils.py index 90c4fc303..4cfcfa8fa 100644 --- a/MaxText/elasticutils.py +++ b/MaxText/elastic/utils.py @@ -1,6 +1,20 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Utilities for elastic training.""" import collections +from collections.abc import Mapping import contextlib import functools import itertools @@ -13,6 +27,7 @@ import jax import numpy as np +from pathwaysutils.google_internal.elastic import reshard jax._src.array.ArrayImpl._check_if_deleted = lambda _: False # pylint: disable=protected-access @@ -74,6 +89,17 @@ def wrapper(*args, **kwargs): class ElasticUtils: """Utility class for elastic training.""" + _devices: Sequence[jax.Device] + slice_to_devices: Mapping[int, Sequence[jax.Device]] + total_slice_count: int + save_period: int + reshard_check_period: int + max_failure_count: Optional[int] + max_reshard_retry_count: Optional[int] + failure_count: int + reshard_retry_count: int + good_slice_indices: set[int] + data: Mapping[str, Any] TEST_VALUE = 100 @@ -81,28 +107,16 @@ def __init__( self, devices: Sequence[jax.Device], total_slice_count: int, - save_period: Optional[int] = None, - reshard_check_period: Optional[int] = None, + save_period: int = 1, + reshard_check_period: int = 1, max_failure_count: Optional[int] = None, max_reshard_retry_count: Optional[int] = None, ): self.devices = devices self.total_slice_count = total_slice_count - - if save_period is None: - save_period = 1 self.save_period = save_period - - if reshard_check_period is None: - reshard_check_period = 1 self.reshard_check_period = reshard_check_period - - if max_failure_count is None: - max_failure_count = float("inf") self.max_failure_count = max_failure_count - - if max_reshard_retry_count is None: - max_reshard_retry_count = float("inf") self.max_reshard_retry_count = max_reshard_retry_count self.failure_count = 0 @@ -121,14 +135,20 @@ def slice_down(self, reshard_retry: bool = False): self.reshard_retry_count = 0 logger.info(f"{self.failure_count=} {self.max_failure_count=}") - if self.failure_count >= self.max_failure_count: - logger.fatal(f"Max failure count reached {self.max_failure_count}") + if ( + self.max_failure_count is not None + and self.failure_count >= self.max_failure_count + ): + logger.critical(f"Max failure count reached {self.max_failure_count}") logger.info( f"{self.reshard_retry_count=} {self.max_reshard_retry_count=}" ) - if self.reshard_retry_count > self.max_reshard_retry_count: - logger.fatal( + if ( + self.max_reshard_retry_count is not None + and self.reshard_retry_count > self.max_reshard_retry_count + ): + logger.critical( f"Max reshard retry count reached {self.max_reshard_retry_count}" ) @@ -311,31 +331,24 @@ def reshard( if put_array is None: put_array = cls.default_put_array - flat_x, tree_def = jax.tree_util.tree_flatten(x) - flat_sharding = jax.api_util.flatten_axes( - "reshard sharding", tree_def, sharding + return reshard.reshard( + x, sharding, donate_input=donate_input, put_array=put_array ) - if len(flat_x) != len(flat_sharding): - raise ValueError("Mismatched length between `x` and `sharding`.") - - arrays = [ - put_array(arr, dst_sharding, donate_input) - for arr, dst_sharding in zip(flat_x, flat_sharding) - ] - return jax.tree_util.tree_unflatten(tree_def, arrays) - - @staticmethod - def put_array_device_put0( - arr: jax.Array, - dst_sharding: jax.sharding.Sharding, - donate_input: bool, - ): - if not isinstance(dst_sharding, jax.sharding.Sharding): - raise ValueError("`sharding` must contain only `Sharding` instances.") - return jax.device_put(arr, dst_sharding, donate=donate_input) - - default_put_array = put_array_device_put0 + def scale_by_good_slices(self, x: int | float) -> int | float: + """Scale x by the number of good slices.""" + if isinstance(x, int): + ret, remainder = divmod(x * self.good_slice_count, self.total_slice_count) + if remainder: + raise ValueError( + f"Cannot scale {x=} by good slices because it will result in a " + f"remainder of {remainder=}." + ) + return ret + elif isinstance(x, float): + return x * self.good_slice_count / self.total_slice_count + else: + raise ValueError(f"Unsupported type: {type(x)}") def put_array_device_put1( self, @@ -502,20 +515,7 @@ def put_array_device_put3( arr.shape, dst_sharding, arrays ) - def scale_by_good_slices(self, x: int | float) -> int | float: - """Scale x by the number of good slices.""" - if isinstance(x, int): - ret, remainder = divmod(x * self.good_slice_count, self.total_slice_count) - if remainder: - raise ValueError( - f"Cannot scale {x=} by good slices because it will result in a " - f"remainder of {remainder=}." - ) - return ret - elif isinstance(x, float): - return x * self.good_slice_count / self.total_slice_count - else: - raise ValueError(f"Unsupported type: {type(x)}") + default_put_array = put_array_device_put1 @contextlib.contextmanager @@ -552,7 +552,7 @@ def handler(): logger.info(f"Error print traceback for {thread.ident=}") pass finally: - # logger.fatal("Timeout from timebomb!") + # logger.critical("Timeout from timebomb!") # os.abort() pass diff --git a/MaxText/elasticutils_fake.py b/MaxText/elasticutils_fake.py deleted file mode 100644 index 05c717055..000000000 --- a/MaxText/elasticutils_fake.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Utilities for elastic training.""" -import logging -from typing import Any, Optional, Sequence - -from elasticutils import ElasticUtils -from elasticutils import timeit -import jax - - -PyTree = Any - -logger = logging.getLogger(__name__) - -logger.setLevel(logging.INFO) - -# pylint: disable=logging-fstring-interpolation - - -class FakeElasticUtils(ElasticUtils): - """Utility class for elastic training. - - This class will simulate slices going down and coming back up. - """ - - def __init__( - self, - devices: Sequence[jax.Device], - total_slice_count: int, - save_period: Optional[int] = None, - reshard_check_period: Optional[int] = None, - max_failure_count: Optional[int] = None, - max_reshard_retry_count: Optional[int] = None, - ): - self.fake_good_slice_indices = set(d.slice_index for d in devices) - - super().__init__( - devices, - total_slice_count, - save_period, - reshard_check_period, - max_failures, - max_reshard_retry_count, - ) - - def update_good_slice_indices(self, good_slice_indices: set[int]): - """Start step handler.""" - self.fake_good_slice_indices = good_slice_indices - logger.info(f"Updated: {self.fake_good_slice_indices=}") - - @timeit - def get_slice_availability(self) -> set[int]: - """Returns the set of good and bad slices.""" - good_slice_indices = self.fake_good_slice_indices - - logger.info(f"{good_slice_indices=}") - - return good_slice_indices - - # Does not work - @staticmethod - def put_array_jit( - arr: jax.Array, - dst_sharding: jax.sharding.Sharding, - donate_input: bool, - ): - if not isinstance(dst_sharding, jax.sharding.Sharding): - raise ValueError("`sharding` must contain only `Sharding` instances.") - - return jax.jit( - lambda x: x, - out_shardings=dst_sharding, - donate_argnums=(0,) if donate_input else (), - )(arr) - - # Slower than actually resharding - @staticmethod - def put_array_fake0( - arr: jax.Array, - dst_sharding: jax.sharding.Sharding, - donate_input: bool, - ): - if not isinstance(dst_sharding, jax.sharding.Sharding): - raise ValueError("`sharding` must contain only `Sharding` instances.") - return jax.numpy.ones_like(arr, device=dst_sharding) - diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index b652b07e4..2178980a6 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -22,7 +22,7 @@ import sys from typing import Any, Union -from elasticutils_fake import FakeElasticUtils as ElasticUtils +from elastic.simulator import ElasticUtilsSimulator as ElasticUtils import jax from jax.experimental.compilation_cache import compilation_cache from layers.attentions import AttentionType