Skip to content

Commit

Permalink
Add Checkpoint Restore Benchmarking to PW Recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
SujeethJinesh committed Feb 27, 2025
1 parent d2e9450 commit 0f7408a
Show file tree
Hide file tree
Showing 4 changed files with 638 additions and 19 deletions.
125 changes: 125 additions & 0 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,131 @@ def _add_to_model_dictionary(
),
)

llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://trillium-scale-datasets-q1-25-west",
"dataset_type": "tfds",
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 20,
"enable_checkpoint_cloud_logger": True,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"gcs_metrics": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"metrics_file": "metrics.txt",
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
+ xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE
+ " --xla_tpu_iova_dma_chunk_size_bytes=104857"
),
pathways_xla_flag_options={
xla_flags_library.REMOVE: [
"--2a886c8_chip_config_name=megachip_tccontrol"
],
xla_flags_library.ADD_SERVER: (
xla_flags_library.ENHANCED_LAUNCH_BARRIER
),
xla_flags_library.ADD_PROXY: (
xla_flags_library.ENHANCED_LAUNCH_BARRIER
),
xla_flags_library.ADD_WORKER: (
xla_flags_library.ENHANCED_LAUNCH_BARRIER
),
}
),
)

llama3_1_70b_8192_iter_synth_data_and_checkpointing = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-synth",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 2,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 20,
"enable_checkpoint_cloud_logger": True,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"gcs_metrics": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,
"metrics_file": "metrics.txt",
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
+ xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE
+ " --xla_tpu_iova_dma_chunk_size_bytes=104857"
),
pathways_xla_flag_options={
xla_flags_library.REMOVE: [
"--2a886c8_chip_config_name=megachip_tccontrol"
],
xla_flags_library.ADD_SERVER: (
xla_flags_library.ENHANCED_LAUNCH_BARRIER
),
xla_flags_library.ADD_PROXY: (
xla_flags_library.ENHANCED_LAUNCH_BARRIER
),
xla_flags_library.ADD_WORKER: (
xla_flags_library.ENHANCED_LAUNCH_BARRIER
),
}
),
)

llama3_1_70b_8192_lr_real_data = _add_to_model_dictionary(
trillium_model_dict,
Expand Down
154 changes: 135 additions & 19 deletions benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import datetime
import enum
import os
import queue
import random
import string
import subprocess
import sys
import tempfile
import threading
import time

import maxtext_trillium_model_configs as model_configs
Expand Down Expand Up @@ -277,6 +279,85 @@ def run_command_with_updates(command, task, verbose=True) -> int:
return 0


def wait_for_xpk_workload_completion(
cluster_config: XpkClusterConfig, workload_name, xpk_path
) -> int:
"""Waits for the given XPK workload to complete.
Args:
cluster_config: XPK cluster configuration.
workload_name: Name of the workload to wait for.
xpk_path: Path to the xpk.py script.
Returns:
return_code: 0 if successful and non-zero otherwise.
"""
wait_command = [
f'python3 {xpk_path}/xpk.py workload list',
f'--cluster={cluster_config.cluster_name}',
f'--project={cluster_config.project}',
f'--zone={cluster_config.zone}',
f'--wait-for-job-completion={workload_name}',
]
wait_command_str = ' '.join(wait_command)
print(f'Waiting for workload "{workload_name}" to complete...')
return_code = run_command_with_updates(
wait_command_str, f'Wait for {workload_name} completion'
)
if return_code != 0:
print(
f'Error waiting for workload {workload_name} to complete. Return code:'
f' {return_code}'
)
else:
print(f'Workload "{workload_name}" completed successfully.')
return return_code


def wait_for_xpk_workloads_completion_async(
cluster_config: XpkClusterConfig, workload_names, xpk_path
):
"""Waits for a list of XPK workloads to complete in parallel and yields names and exit codes as they complete.
Args:
cluster_config: XPK cluster configuration.
workload_names: List of workload names to wait for.
xpk_path: Path to the xpk.py script.
Yields:
Tuple[workload_name, return_code]: The name of the workload that has just
completed and its return code.
"""
threads = []
result_queue = queue.Queue()

def _wait_for_completion_threaded(name):
return_code = wait_for_xpk_workload_completion(
cluster_config, name, xpk_path
)
result_queue.put((name, return_code))

for name in workload_names:
thread = threading.Thread(
target=_wait_for_completion_threaded, args=(name,)
)
threads.append(thread)
thread.start()

completed_count = 0
while completed_count < len(workload_names):
try:
# Wait for a result with a timeout (adjust timeout as needed)
workload_name, return_code = result_queue.get(timeout=10)
completed_count += 1

# Yield the result as soon as it's available
yield workload_name, return_code
except queue.Empty:
# Queue is empty, no thread has finished yet, continue waiting
print('Waiting for workloads to complete...')
time.sleep(10)


def _get_config_tuning_params(wl_config: WorkloadConfig):
"""Get config tuning parameters for the workload.
Expand Down Expand Up @@ -528,6 +609,7 @@ def _get_pathways_specific_flags(wl_config: WorkloadConfig):
def generate_xpk_workload_cmd(
cluster_config: XpkClusterConfig,
wl_config: WorkloadConfig,
workload_name=None, # Added optional workload_name
):
"""Generates a command to run a maxtext model on XPK."""

Expand All @@ -545,15 +627,18 @@ def generate_xpk_workload_cmd(
common_prefix = os.environ['USER']
pw_prefix = "pw-"

if is_pathways_enabled:
name = (
f"{pw_prefix}{wl_config.model.model_name.replace('_', '-')[:truncate_model_name - len(pw_prefix)]}"
)
if workload_name is None: # Generate name if not provided
if is_pathways_enabled:
name = (
f"{pw_prefix}{wl_config.model.model_name.replace('_', '-')[:truncate_model_name - len(pw_prefix)]}"
)
else:
name = (
f"{wl_config.model.model_name.replace('_', '-')[:truncate_model_name]}"
)
name = f"{common_prefix[:truncate_prefix]}-{name}{common_post_fix}"
else:
name = (
f"{wl_config.model.model_name.replace('_', '-')[:truncate_model_name]}"
)
name = f"{common_prefix[:truncate_prefix]}-{name}{common_post_fix}"
name = workload_name # Use provided name

user_command = build_user_command(
name=name,
Expand Down Expand Up @@ -608,27 +693,41 @@ def generate_xpk_workload_cmd(
def run_xpk_workload(
cluster_config: XpkClusterConfig,
wl_config: WorkloadConfig,
):
"""Runs a maxtext model on XPK.
wait_for_completion: bool = False,
) -> int:
"""Runs a maxtext model on XPK and waits for completion.
Args:
model:
cluster_config:
cluster_config: XPK cluster configuration.
wl_config: Workload configuration.
wait_for_completion: Whether to wait for workload completion. Defaults to
False.
Returns:
return_code: Return code of the workload creation command, or workload
completion wait command if wait_for_completion is True.
"""
assert cluster_config.device_type == wl_config.device_type, f"The workload device size {wl_config.device_type}, and cluster device size {cluster_config.device_type} don't match."
command, _ = generate_xpk_workload_cmd(
command, workload_name = generate_xpk_workload_cmd(
cluster_config=cluster_config,
wl_config=wl_config
)
return run_command_with_updates(command, 'Run XPK workload')
return_code = run_command_with_updates(command, 'Run XPK workload')
if return_code == 0 and wait_for_completion:
return_code = wait_for_xpk_workload_completion(cluster_config, workload_name, wl_config.xpk_path) # Wait for completion after successful run
return return_code


def xpk_benchmark_runner(
cluster_config: XpkClusterConfig,
workload_configs: list[WorkloadConfig],
):
"""Runs a list of maxtext models on XPK in parallel and waits for all to complete.
Args:
cluster_config: XPK cluster configuration.
workload_configs: List of workload configurations.
"""
xpk_workload_names = []
xpk_workload_cmds = []
for wl_config in workload_configs:
Expand All @@ -639,15 +738,32 @@ def xpk_benchmark_runner(

print(f"Name of the workload is: {name} \n")
xpk_workload_names.append(name)

print(f"XPK command to be used is: {command} \n")
xpk_workload_cmds.append(command)

# TODO(@vbarr) Support batch workloads.
for xpk_workload_name, xpk_workload_cmd in zip(xpk_workload_names, xpk_workload_cmds):
return_code = run_command_with_updates(xpk_workload_cmd, xpk_workload_name)
# Launch all workloads
workload_creation_return_codes = {}
for workload_name, workload_cmd in zip(xpk_workload_names, xpk_workload_cmds):
return_code = run_command_with_updates(workload_cmd, workload_name)
workload_creation_return_codes[workload_name] = return_code
if return_code != 0:
print('Unable to run xpk workload: {xpk_workload_name}')
print(f'Warning: Unable to start xpk workload: {workload_name}, but continuing to launch others.')


# Wait for workloads to complete in parallel and process them as they finish
completed_workload_names = wait_for_xpk_workloads_completion_async(
cluster_config, xpk_workload_names, workload_configs[0].xpk_path # Assuming xpk_path is the same for all wl_configs
)

for completed_name, return_code in completed_workload_names:
if workload_creation_return_codes[completed_name] == 0:
if return_code == 0:
print(f"Workload '{completed_name}' finished successfully and was waited for.")
else:
print(f"Workload '{completed_name}' finished with errors, wait returned code: {return_code}")
else:
print(f"Workload '{completed_name}' had creation errors but wait completion was still checked.") # Or handle errors differently


def on_device_benchmark_runner(
workload_configs: list[WorkloadConfig],
Expand Down
23 changes: 23 additions & 0 deletions benchmarks/recipes/args_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ def _handle_delete(
os.system(delete_command)


def handle_delete_specific_workload(
cluster_config: mxr.XpkClusterConfig, workload_name: str, **kwargs
) -> int:
"""Handles the deletion of workloads with a specific name.
Args:
cluster_config: mxr.XpkClusterConfig object
workload_name: workload name
**kwargs: Optional keyword arguments, such as xpk_path
"""
xpk_path = kwargs.get("xpk_path", "xpk") # Default to "xpk" if not provided
delete_command = (
f"python3 {xpk_path}/xpk.py workload delete "
f"--project={cluster_config.project} --cluster={cluster_config.cluster_name}"
f" --filter-by-job={workload_name} --zone={cluster_config.zone}"
)
print(
f"Deleting workload: {workload_name} using command:"
f" {delete_command}"
)
os.system(f"yes | {delete_command}")


def handle_cmd_args(
cluster_config: mxr.XpkClusterConfig, *actions: str, **kwargs
) -> bool:
Expand Down
Loading

0 comments on commit 0f7408a

Please sign in to comment.