Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC for GPU e2e feature testing #1243

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions end_to_end/gpu/test_collective_matmul_llama2_7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/bin/bash

AR_THRESHOLD=134217728
AG_THRESHOLD=134217728
RS_THRESHOLD=67108864

MODEL="llama2-7b"
RUN_NAME=$MODEL-$(date +%Y-%m-%d-%H-%M)

export XLA_FLAGS="--xla_dump_hlo_as_text
--xla_dump_to=$BASE_OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_all_reduce_combine_threshold_bytes=${AR_THRESHOLD}
--xla_gpu_all_gather_combine_threshold_bytes=${AG_THRESHOLD}
--xla_gpu_reduce_scatter_combine_threshold_bytes=${RS_THRESHOLD}
--xla_gpu_enable_pipelined_all_gather=false
--xla_gpu_enable_pipelined_reduce_scatter=false
--xla_gpu_enable_pipelined_all_reduce=false
--xla_gpu_enable_while_loop_double_buffering=false
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization
--xla_gpu_threshold_for_windowed_einsum_mib=0
--xla_gpu_multi_streamed_windowed_einsum=true"

python3 MaxText/train.py \
MaxText/configs/base.yml \
model_name=${MODEL} \
per_device_batch_size=0.125 \
steps=1 \
scan_layers=true \
monitor_goodput=false \
enable_goodput_recording=false \
remat_policy=minimal_flash \
attention=cudnn_flash_te \
max_target_length=4096 \
use_iota_embed=true \
logits_dot_in_fp32=false\
enable_checkpointing=false \
ici_fsdp_parallelism=1 \
ici_tensor_parallelism=8 \
base_output_directory=local_train \
dataset_path=local \
dataset_type=synthetic \
hardware=gpu \
run_name=$RUN_NAME

FILE_PATTERN="module_[0-9]+\.jit_train_step\.sm_[0-9]+\.[0-9]+_gpu_after_optimizations\.txt"

search_file() {
local dir="$1"
local pattern="$2"
find "$dir" -type f | grep -E ".*/${pattern}"
}

HLO_FILE=$(search_file $BASE_OUTPUT_PATH/$RUN_NAME/HLO_dumps/ "$FILE_PATTERN")

if [ ! -f "$HLO_FILE" ]; then
echo "Error: $HLO_FILE file does not exist."
exit 1
fi

EXPECTED_UNROLLED_AG=17
EXPECTED_UNROLLED_RS=9

python3 end_to_end/gpu/test_feature.py collective_matmul $HLO_FILE $EXPECTED_UNROLLED_AG $EXPECTED_UNROLLED_RS
37 changes: 37 additions & 0 deletions end_to_end/gpu/test_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# pylint: skip-file
from absl import app
from typing import Sequence

from jax.test_util import XlaGpuFeatureTestCase

def test_collective_matmul(test_case, hlo_file, expected_unrolled_ag, expected_unrolled_rs):
"""
Test collective matmul correctness in HLO file.

Args:
test_case: The JAX test case object.
hlo_file: Path to the HLO file.
expected_unrolled_ag: Expected number of unrolled all-gather operations.
expected_unrolled_rs: Expected number of unrolled reduce-scatter operations.
"""
with open(hlo_file, 'r') as hlo_file:
hlo_content = hlo_file.read()
test_case.check_collective_matmul(hlo_content, expected_unrolled_ag, expected_unrolled_rs)
print('collective matmul test passed.')

def test_fp8_gemm(hlo_file):
pass

def main(argv: Sequence[str]) -> None:
test_case = XlaGpuFeatureTestCase()
_, test_scenario, *test_vars = argv

if test_scenario == 'collective_matmul':
test_collective_matmul(test_case, *test_vars)
elif test_scenario == 'fp8_gemm':
test_fp8_gemm(test_case, *test_vars)
else:
raise ValueError(f"Unrecognized test_scenario {test_scenario}")

if __name__ == "__main__":
app.run(main)