Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel SageAttention Inference #50

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
15 changes: 15 additions & 0 deletions example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,18 @@ with the following code:
is_causal=not is_full
)
```
## Parallel SageAttention Inference

Install xDiT(xfuser >= 0.3.5) and diffusers(>=0.32.0.dev0) from sources and run:

```bash
# install latest xDiT(xfuser).
pip install "xfuser[flash_attn]"
# install latest diffusers (>=0.32.0.dev0), need by latest xDiT.
git clone https://github.com/huggingface/diffusers.git
cd diffusers && python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl
# then run parallel sage attention inference.
./run_parallel.sh
```


113 changes: 113 additions & 0 deletions example/parallel_sageattn_cogvideo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
modified from: https://github.com/xdit-project/xDiT/blob/main/examples/cogvideox_example.py
sh ./run_parallel.sh
"""

import torch
import torch.distributed
from xfuser import xFuserCogVideoXPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_runtime_state,
is_dp_last_group,
)
from diffusers.utils import export_to_video
import time
import torch.nn.functional as F
from functools import partial
import sageattention

torch.set_grad_enabled(False)


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
parser.add_argument("--use_sage_attn_fp16", action="store_true", help="Use Sage Attention fp16 or not.")
parser.add_argument("--use_sage_attn_fp8", action="store_true", help="Use Sage Attention fp8 or not.")
args = xFuserArgs.add_cli_args(parser).parse_args()

engine_args = xFuserArgs.from_cli_args(args)
# Check if ulysses_degree is valid
num_heads = 30
if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0:
raise ValueError(
f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})"
)

# Init distributed env here
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

sage_tag = "sage+None"
if args.use_sage_attn_fp16:
F.scaled_dot_product_attention = partial(
sageattention.sageattn_qk_int8_pv_fp16_cuda,
pv_accum_dtype="fp32")
sage_tag = f"sage+fp16"
elif args.use_sage_attn_fp8:
F.scaled_dot_product_attention = partial(
sageattention.sageattn_qk_int8_pv_fp8_cuda,
pv_accum_dtype="fp32+fp32")
sage_tag = f"sage+fp8" # acc fp32

pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)

if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)

# Always enable tiling and slicing to avoid VAE OOM while batch size > 1
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

torch.cuda.reset_peak_memory_stats()

start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator().manual_seed(input_config.seed),
guidance_scale=6,
use_dynamic_cfg=True,
latents=None # Load local latents or let it None.
).frames[0]

end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}_"
f"compile{engine_config.runtime_config.use_torch_compile}"
)

if is_dp_last_group():
world_size = get_world_group().world_size
prompt_tag: str = input_config.prompt[0]
prompt_tag = prompt_tag.replace(" ", "_").replace(".", "")
resolution = f"{input_config.width}x{input_config.height}x{input_config.num_frames}"
output_filename = (f"results/cogvideox_{parallel_info}_{sage_tag}_{world_size}gpu_"
f"{resolution}_{prompt_tag}.mp4")
export_to_video(output, output_filename, fps=8)
print(f"output saved to {output_filename}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
main()
61 changes: 61 additions & 0 deletions example/run_parallel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
set -x

export PYTHONPATH=$PWD:$PYTHONPATH

# Select the model type
# The model is downloaded to a specified location on disk,
# or you can simply use the model's ID on Hugging Face,
# which will then be downloaded to the default cache path on Hugging Face.

export MODEL_TYPE="CogVideoX"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
["CogVideoX"]="parallel_sageattn_cogvideo.py THUDM/CogVideoX-2b 50"
)

if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}"
export SCRIPT MODEL_ID INFERENCE_STEP
else
echo "Invalid MODEL_TYPE: $MODEL_TYPE"
exit 1
fi

mkdir -p ./results

# task args
NUM_FRAMES=$1
if [ "$NUM_FRAMES" = "" ]; then
NUM_FRAMES=49
fi

if [ "$MODEL_TYPE" = "CogVideoX" ]; then
TASK_ARGS="--height 480 --width 720 --num_frames ${NUM_FRAMES} --max_sequence_length 226"
fi

# CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree.
if [ "$MODEL_TYPE" = "CogVideoX" ]; then
N_GPUS=2
# Only use CFG parallelism for 2 GPUs since it has minimal communication cost.
PARALLEL_ARGS="--ulysses_degree 1 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
fi

# COMPILE_FLAG=--use_torch_compile
# SAGE_ATTN_FLAG=--use_sage_attn_fp16
SAGE_ATTN_FLAG=--use_sage_attn_fp8
torchrun --nproc_per_node=$N_GPUS ./$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$SAGE_ATTN_FLAG \
--prompt \
"A small dog."
20 changes: 18 additions & 2 deletions sageattention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import triton
import triton.language as tl
import torch.distributed as dist

from .triton.quant_per_block import per_block_int8 as per_block_int8_triton
from .triton.quant_per_block_varlen import per_block_int8 as per_block_int8_varlen_triton
Expand All @@ -41,6 +42,7 @@
from typing import Any, List, Literal, Optional, Tuple, Union
import warnings


def get_cuda_arch_versions():
cuda_archs = []
for i in range(torch.cuda.device_count()):
Expand Down Expand Up @@ -436,7 +438,14 @@ def sageattn_qk_int8_pv_fp16_cuda(
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."

# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
# inference step in distributed env for multi gpus inference. This small
# workaround also make sage attention work compatible with torch.compile
# through non-fullgraph compile mode.
if dist.is_initialized() and dist.get_world_size() > 1:
torch.cuda.set_device(v.device)

_tensor_layout = 0 if tensor_layout == "NHD" else 1
_is_caual = 1 if is_causal else 0
Expand Down Expand Up @@ -581,6 +590,13 @@ def sageattn_qk_int8_pv_fp8_cuda(
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
# inference step in distributed env for multi gpus inference. This small
# workaround also make sage attention work compatible with torch.compile
if dist.is_initialized() and dist.get_world_size() > 1:
torch.cuda.set_device(v.device)

_tensor_layout = 0 if tensor_layout == "NHD" else 1
_is_caual = 1 if is_causal else 0
Expand Down Expand Up @@ -628,4 +644,4 @@ def sageattn_qk_int8_pv_fp8_cuda(
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
return o
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def get_torch_arch_list() -> Set[str]:
python_requires='>=3.9',
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)
)