From 6c045de3a93dcfef3617a007bdf1fd4dc2172543 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:38:40 +0800 Subject: [PATCH 01/13] relax nvcc version check for sm_89 --- setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 868d39c..ce8824b 100644 --- a/setup.py +++ b/setup.py @@ -114,10 +114,11 @@ def get_torch_arch_list() -> Set[str]: # Validate the NVCC CUDA version. if nvcc_cuda_version < Version("12.0"): raise RuntimeError("CUDA 12.0 or higher is required to build the package.") -if nvcc_cuda_version < Version("12.4"): +if nvcc_cuda_version < Version("12.3"): if any(cc.startswith("8.9") for cc in compute_capabilities): raise RuntimeError( - "CUDA 12.4 or higher is required for compute capability 8.9.") + "CUDA 12.3 or higher is required for compute capability 8.9.") +if nvcc_cuda_version < Version("12.4"): if any(cc.startswith("9.0") for cc in compute_capabilities): raise RuntimeError( "CUDA 12.4 or higher is required for compute capability 9.0.") @@ -172,4 +173,4 @@ def get_torch_arch_list() -> Set[str]: python_requires='>=3.9', ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, -) \ No newline at end of file +) From 525b7e91b12474c5442514f11554db51521ffd39 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 09:09:57 +0800 Subject: [PATCH 02/13] Update setup.py --- setup.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index ce8824b..1b6ce3a 100644 --- a/setup.py +++ b/setup.py @@ -114,11 +114,10 @@ def get_torch_arch_list() -> Set[str]: # Validate the NVCC CUDA version. if nvcc_cuda_version < Version("12.0"): raise RuntimeError("CUDA 12.0 or higher is required to build the package.") -if nvcc_cuda_version < Version("12.3"): +if nvcc_cuda_version < Version("12.4"): if any(cc.startswith("8.9") for cc in compute_capabilities): raise RuntimeError( - "CUDA 12.3 or higher is required for compute capability 8.9.") -if nvcc_cuda_version < Version("12.4"): + "CUDA 12.4 or higher is required for compute capability 8.9.") if any(cc.startswith("9.0") for cc in compute_capabilities): raise RuntimeError( "CUDA 12.4 or higher is required for compute capability 9.0.") From f5b06508e84c42ea245b688f65492e261ca0bb98 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 11:34:35 +0800 Subject: [PATCH 03/13] workaround for distributed inference --- sageattention/core.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/sageattention/core.py b/sageattention/core.py index 53873b2..11c2827 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -41,6 +41,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union import warnings + def get_cuda_arch_versions(): cuda_archs = [] for i in range(torch.cuda.device_count()): @@ -436,7 +437,14 @@ def sageattn_qk_int8_pv_fp16_cuda( assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." - + # FIXME(DefTruth): make sage atttention work compatible with distributed + # env, for eaxmple, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + if torch.distributed.get_world_size() > 1: + torch.cuda.set_device(v.device) _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 @@ -472,7 +480,7 @@ def sageattn_qk_int8_pv_fp16_cuda( smooth_v = False if pv_accum_dtype == 'fp32': - v = v.to(torch.float16) + v = v.to(dtype=torch.float16) lse = qk_int8_sv_f16_accum_f32_attn_per_warp(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: @@ -581,6 +589,14 @@ def sageattn_qk_int8_pv_fp8_cuda( assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + # FIXME(DefTruth): make sage atttention work compatible with distributed + # env, for eaxmple, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + if torch.distributed.get_world_size() > 1: + torch.cuda.set_device(v.device) _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 @@ -628,4 +644,4 @@ def sageattn_qk_int8_pv_fp8_cuda( if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: - return o \ No newline at end of file + return o From d5f5d2bab1b8a681046550ab632a1c85acc123e0 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:01:43 +0800 Subject: [PATCH 04/13] Create parallel_sageattn_cogvideo.py --- example/parallel_sageattn_cogvideo.py | 113 ++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 example/parallel_sageattn_cogvideo.py diff --git a/example/parallel_sageattn_cogvideo.py b/example/parallel_sageattn_cogvideo.py new file mode 100644 index 0000000..531a8b5 --- /dev/null +++ b/example/parallel_sageattn_cogvideo.py @@ -0,0 +1,113 @@ +""" +modified from: https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py +sh ./run_parallel.sh +""" + +import torch +import torch.distributed +from xfuser import xFuserCogVideoXPipeline, xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_runtime_state, + is_dp_last_group, +) +from diffusers.utils import export_to_video +import time +import torch.nn.functional as F +from functools import partial +import sageattention + +torch.set_grad_enabled(False) + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + parser.add_argument("--use_sage_attn_fp16", action="store_true", help="Use Sage Attention fp16 or not.") + parser.add_argument("--use_sage_attn_fp8", action="store_true", help="Use Sage Attention fp8 or not.") + args = xFuserArgs.add_cli_args(parser).parse_args() + + engine_args = xFuserArgs.from_cli_args(args) + # Check if ulysses_degree is valid + num_heads = 30 + if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0: + raise ValueError( + f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})" + ) + + # Init distributed env here + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + sage_tag = "sage+None" + if args.use_sage_attn_fp16: + F.scaled_dot_product_attention = partial( + sageattention.sageattn_qk_int8_pv_fp16_cuda, + pv_accum_dtype="fp32") + sage_tag = f"sage+fp16" + elif args.use_sage_attn_fp8: + F.scaled_dot_product_attention = partial( + sageattention.sageattn_qk_int8_pv_fp8_cuda, + pv_accum_dtype="fp32+fp32") + sage_tag = f"sage+fp8" # acc fp32 + + pipe = xFuserCogVideoXPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + engine_config=engine_config, + torch_dtype=torch.bfloat16, + ) + + if args.enable_sequential_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + else: + device = torch.device(f"cuda:{local_rank}") + pipe = pipe.to(device) + + # Always enable tiling and slicing to avoid VAE OOM while batch size > 1 + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + + torch.cuda.reset_peak_memory_stats() + + start_time = time.time() + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator().manual_seed(input_config.seed), + guidance_scale=6, + use_dynamic_cfg=True, + latents=None # Load local latents or let it None. + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}_" + f"compile{engine_config.runtime_config.use_torch_compile}" + ) + + if is_dp_last_group(): + world_size = get_world_group().world_size + prompt_tag: str = input_config.prompt[0] + prompt_tag = prompt_tag.replace(" ", "_").replace(".", "") + resolution = f"{input_config.width}x{input_config.height}x{input_config.num_frames}" + output_filename = (f"results/cogvideox_{parallel_info}_{sage_tag}_{world_size}gpu_" + f"{resolution}_{prompt_tag}.mp4") + export_to_video(output, output_filename, fps=8) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") + get_runtime_state().destory_distributed_env() + + +if __name__ == "__main__": + main() From 03d61929b65403922b047f1a329652156058562b Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:02:41 +0800 Subject: [PATCH 05/13] Create run_parallel.sh --- example/run_parallel.sh | 61 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 example/run_parallel.sh diff --git a/example/run_parallel.sh b/example/run_parallel.sh new file mode 100644 index 0000000..c660d16 --- /dev/null +++ b/example/run_parallel.sh @@ -0,0 +1,61 @@ +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# Select the model type +# The model is downloaded to a specified location on disk, +# or you can simply use the model's ID on Hugging Face, +# which will then be downloaded to the default cache path on Hugging Face. + +export MODEL_TYPE="CogVideoX" +# Configuration for different model types +# script, model_id, inference_step +declare -A MODEL_CONFIGS=( + ["CogVideoX"]="parallel_sageattn_cogvideo.py THUDM/CogVideoX-2b 50" +) + +if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then + IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}" + export SCRIPT MODEL_ID INFERENCE_STEP +else + echo "Invalid MODEL_TYPE: $MODEL_TYPE" + exit 1 +fi + +mkdir -p ./results + +# task args +NUM_FRAMES=$1 +if [ "$NUM_FRAMES" = "" ]; then + NUM_FRAMES=49 +fi + +if [ "$MODEL_TYPE" = "CogVideoX" ]; then + TASK_ARGS="--height 480 --width 720 --num_frames ${NUM_FRAMES} --max_sequence_length 226" +fi + +# CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree. +if [ "$MODEL_TYPE" = "CogVideoX" ]; then +N_GPUS=2 +# Only use CFG parallelism for 2 GPUs since it has minimal communication cost. +PARALLEL_ARGS="--ulysses_degree 1 --ring_degree 1" +CFG_ARGS="--use_cfg_parallel" +fi + +# COMPILE_FLAG=--use_torch_compile +# SAGE_ATTN_FLAG=--use_sage_attn_fp16 +SAGE_ATTN_FLAG=--use_sage_attn_fp8 +torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$COMPILE_FLAG \ +$SAGE_ATTN_FLAG \ +--prompt \ +"A small dog." From 0f4aa9af655e5d67b75c229793936710c319f552 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:07:12 +0800 Subject: [PATCH 06/13] Update README.md --- example/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/example/README.md b/example/README.md index a3ac725..1ffb244 100644 --- a/example/README.md +++ b/example/README.md @@ -46,3 +46,12 @@ with the following code: is_causal=not is_full ) ``` +## Parallel SageAttention Inference + +Install xDiT(xfuser >= 0.3.5) and diffusers(>=0.32.0.dev0) from sources and run: + +```bash +./run_parallel.sh +``` + + From 810a4496609b9d953fc5da41c9e3a80dd1a1807c Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:17:40 +0800 Subject: [PATCH 07/13] Update core.py --- sageattention/core.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sageattention/core.py b/sageattention/core.py index 11c2827..2042e4d 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -437,8 +437,8 @@ def sageattn_qk_int8_pv_fp16_cuda( assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." - # FIXME(DefTruth): make sage atttention work compatible with distributed - # env, for eaxmple, xDiT which launch by torchrun. Without this workaround, + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, # sage attention will run into illegal memory access error after first # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile @@ -589,12 +589,11 @@ def sageattn_qk_int8_pv_fp8_cuda( assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." - # FIXME(DefTruth): make sage atttention work compatible with distributed - # env, for eaxmple, xDiT which launch by torchrun. Without this workaround, + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, # sage attention will run into illegal memory access error after first # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile - # through non-fullgraph compile mode. if torch.distributed.get_world_size() > 1: torch.cuda.set_device(v.device) From 1edb8809baf83f3e1150d0526d1ec9e608f7a50b Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:19:27 +0800 Subject: [PATCH 08/13] Update run_parallel.sh --- example/run_parallel.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/run_parallel.sh b/example/run_parallel.sh index c660d16..f12ccef 100644 --- a/example/run_parallel.sh +++ b/example/run_parallel.sh @@ -43,8 +43,8 @@ CFG_ARGS="--use_cfg_parallel" fi # COMPILE_FLAG=--use_torch_compile -# SAGE_ATTN_FLAG=--use_sage_attn_fp16 -SAGE_ATTN_FLAG=--use_sage_attn_fp8 +SAGE_ATTN_FLAG=--use_sage_attn_fp16 +# SAGE_ATTN_FLAG=--use_sage_attn_fp8 torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \ --model $MODEL_ID \ $PARALLEL_ARGS \ From 0ff690c97c7b87bc5c54219ac7ef16596804e30f Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:23:53 +0800 Subject: [PATCH 09/13] Update run_parallel.sh --- example/run_parallel.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/run_parallel.sh b/example/run_parallel.sh index f12ccef..c660d16 100644 --- a/example/run_parallel.sh +++ b/example/run_parallel.sh @@ -43,8 +43,8 @@ CFG_ARGS="--use_cfg_parallel" fi # COMPILE_FLAG=--use_torch_compile -SAGE_ATTN_FLAG=--use_sage_attn_fp16 -# SAGE_ATTN_FLAG=--use_sage_attn_fp8 +# SAGE_ATTN_FLAG=--use_sage_attn_fp16 +SAGE_ATTN_FLAG=--use_sage_attn_fp8 torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \ --model $MODEL_ID \ $PARALLEL_ARGS \ From 2a5a2ace041e7e9e011e2e538da3a0d734f29fa5 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 14:12:52 +0800 Subject: [PATCH 10/13] Update core.py --- sageattention/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sageattention/core.py b/sageattention/core.py index 2042e4d..f24c6c6 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -480,7 +480,7 @@ def sageattn_qk_int8_pv_fp16_cuda( smooth_v = False if pv_accum_dtype == 'fp32': - v = v.to(dtype=torch.float16) + v = v.to(torch.float16) lse = qk_int8_sv_f16_accum_f32_attn_per_warp(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: From 898aa618b0b773413b66436473b29973991af78a Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 26 Nov 2024 15:05:19 +0800 Subject: [PATCH 11/13] Update README.md --- example/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/example/README.md b/example/README.md index 1ffb244..4562d16 100644 --- a/example/README.md +++ b/example/README.md @@ -51,6 +51,12 @@ with the following code: Install xDiT(xfuser >= 0.3.5) and diffusers(>=0.32.0.dev0) from sources and run: ```bash +# install latest xDiT(xfuser). +pip install "xfuser[flash_attn]" +# install latest diffusers (>=0.32.0.dev0), need by latest xDiT. +git clone https://github.com/huggingface/diffusers.git +cd diffusers && python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl +# then run parallel sage attention inference. ./run_parallel.sh ``` From 6972d5a4408f5e7fb4fca6d394fdb0778b336d01 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:57:25 +0800 Subject: [PATCH 12/13] Update core.py --- sageattention/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sageattention/core.py b/sageattention/core.py index f24c6c6..451ed02 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -17,6 +17,7 @@ import torch import triton import triton.language as tl +import torch.distributed as dist from .triton.quant_per_block import per_block_int8 as per_block_int8_triton from .triton.quant_per_block_varlen import per_block_int8 as per_block_int8_varlen_triton @@ -443,7 +444,7 @@ def sageattn_qk_int8_pv_fp16_cuda( # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile # through non-fullgraph compile mode. - if torch.distributed.get_world_size() > 1: + if dist.is_initialized() and dist.get_world_size() > 1: torch.cuda.set_device(v.device) _tensor_layout = 0 if tensor_layout == "NHD" else 1 @@ -480,7 +481,7 @@ def sageattn_qk_int8_pv_fp16_cuda( smooth_v = False if pv_accum_dtype == 'fp32': - v = v.to(torch.float16) + v = v.to(dtype=torch.float16) lse = qk_int8_sv_f16_accum_f32_attn_per_warp(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: @@ -594,7 +595,7 @@ def sageattn_qk_int8_pv_fp8_cuda( # sage attention will run into illegal memory access error after first # inference step in distributed env for multi gpus inference. This small # workaround also make sage attention work compatible with torch.compile - if torch.distributed.get_world_size() > 1: + if dist.is_initialized() and dist.get_world_size() > 1: torch.cuda.set_device(v.device) _tensor_layout = 0 if tensor_layout == "NHD" else 1 From 61c7eb2c65fb48b5a4c8cb01414d343b2f805046 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:58:46 +0800 Subject: [PATCH 13/13] Update core.py --- sageattention/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sageattention/core.py b/sageattention/core.py index 451ed02..f04191d 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -481,7 +481,7 @@ def sageattn_qk_int8_pv_fp16_cuda( smooth_v = False if pv_accum_dtype == 'fp32': - v = v.to(dtype=torch.float16) + v = v.to(torch.float16) lse = qk_int8_sv_f16_accum_f32_attn_per_warp(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: