From 6b1d059eda21c1bd421f3d352786fca2cab61954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Sat, 18 Jan 2025 05:18:37 +0100 Subject: [PATCH 001/102] Support ROCM builds from source distribution, and improve error handling (#1446) * Always update both submodules to include them in sdist Always update both submodules, irrespectively of whether a CUDA or a ROCM build is being done, to ensure that the necessary files from both are present in sdist. Otherwise, attempt to perform a ROCM build from sdist fails because of missing `composable_kernel` srouces. * Include `*.py` files from composable_kernel in sdist Include the `*.py` files from `csrc` in sdist, to ensure that the `generate.py` script is present. * Replace the `os.system()` calls in `setup.py` with `subprocess.run()` * Add error checking to `subprocess.run()` calls in `setup.py` Add error checking to ensure that `setup.py` fails immediately if one of the commands fail. Otherwise, the failures result only in messages to stderr that could be missed, and could lead to more confusing errors later in the build process. * Call git in `setup.py` only when working in a git repository Call git commands in `setup.py` only when the `.git` directory is present, indicating that we are working in a git checkout. Otherwise, just assert that the needed files are there. With this, building from a source distribution no longer attempts to call git in an incorrect directory. --- MANIFEST.in | 1 + setup.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 021b4d0f7..d3c4b4eda 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,7 @@ recursive-include csrc *.h recursive-include csrc *.cuh recursive-include csrc *.cpp recursive-include csrc *.hpp +recursive-include csrc *.py recursive-include flash_attn *.cu recursive-include flash_attn *.h diff --git a/setup.py b/setup.py index a802a7e65..264b0eed5 100644 --- a/setup.py +++ b/setup.py @@ -145,11 +145,19 @@ def validate_and_update_archs(archs): # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. -if IS_ROCM: - if not USE_TRITON_ROCM: - subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"]) +if os.path.isdir(".git"): + subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) else: - subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) + if IS_ROCM: + if not USE_TRITON_ROCM: + assert ( + os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py") + ), "csrc/composable_kernel is missing, please use source distribution or git clone" + else: + assert ( + os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") + ), "csrc/cutlass is missing, please use source distribution or git clone" if not SKIP_CUDA_BUILD and not IS_ROCM: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) @@ -324,10 +332,10 @@ def validate_and_update_archs(archs): if not os.path.exists("./build"): os.makedirs("build") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2"], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2"], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2"], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2"], check=True) # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 From cd393e0ace51f8b0812b6e4f071ef2094082056a Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Wed, 29 Jan 2025 13:27:59 -0800 Subject: [PATCH 002/102] [Build] Update version of setuptools used to generate core package (#1460) --- .github/workflows/publish.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4746c7149..5dffc0d14 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -203,7 +203,9 @@ jobs: - name: Install dependencies run: | - pip install ninja packaging setuptools wheel twine + pip install ninja packaging wheel twine + # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) + pip install setuptools==75.8.0 # We don't want to download anything CUDA-related here pip install torch --index-url https://download.pytorch.org/whl/cpu From bb135af07c362236bde418e9fe3db029d1e7ed88 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 16:31:54 -0500 Subject: [PATCH 003/102] Don't compile for CUDA 11, compile for official pytorch 2.6.0 --- .github/workflows/publish.yml | 8 ++++---- README.md | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5dffc0d14..3d67cfbf6 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] - cuda-version: ['11.8.0', '12.3.2'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] + cuda-version: ['12.4.1'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -113,7 +113,7 @@ jobs: run: | pip install --upgrade pip # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools - pip install setuptools==68.0.0 + pip install setuptools==75.8.0 # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable pip install typing-extensions==4.12.2 @@ -149,7 +149,7 @@ jobs: # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 # However this still fails so I'm using a newer version of setuptools - pip install setuptools==68.0.0 + pip install setuptools==75.8.0 pip install ninja packaging wheel export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH diff --git a/README.md b/README.md index 033dba410..9f57bd56c 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ flash_attn_interface.flash_attn_func() ## Installation and features **Requirements:** - CUDA toolkit or ROCm toolkit -- PyTorch 1.12 and above. +- PyTorch 2.1 and above. - `packaging` Python package (`pip install packaging`) - `ninja` Python package (`pip install ninja`) * - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. @@ -98,7 +98,7 @@ MAX_JOBS=4 pip install flash-attn --no-build-isolation ### NVIDIA CUDA Support **Requirements:** -- CUDA 11.7 and above. +- CUDA 12.0 and above. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) From 979702c87a8713a8e0a5e9fee122b90d2ef13be5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 16:34:02 -0500 Subject: [PATCH 004/102] Bump to v2.7.4 --- flash_attn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 07d16cd0f..094b3233d 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.3" +__version__ = "2.7.4" from flash_attn.flash_attn_interface import ( flash_attn_func, From 5231d95fe13733fb534c01895f7ea88c6a6c7793 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 16:42:56 -0500 Subject: [PATCH 005/102] Drop Pytorch 2.1 --- .github/workflows/publish.yml | 11 +++-------- README.md | 2 +- flash_attn/__init__.py | 2 +- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3d67cfbf6..6f227d1ab 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] + torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] cuda-version: ['12.4.1'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -53,12 +53,7 @@ jobs: cxx11_abi: ['FALSE', 'TRUE'] exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.2 does not support Python 3.12 - - torch-version: '2.1.2' - python-version: '3.12' # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.1.2' - python-version: '3.13' - torch-version: '2.2.2' python-version: '3.13' - torch-version: '2.3.1' @@ -122,8 +117,8 @@ jobs: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then diff --git a/README.md b/README.md index 9f57bd56c..aa545ceb0 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ flash_attn_interface.flash_attn_func() ## Installation and features **Requirements:** - CUDA toolkit or ROCm toolkit -- PyTorch 2.1 and above. +- PyTorch 2.2 and above. - `packaging` Python package (`pip install packaging`) - `ninja` Python package (`pip install ninja`) * - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 094b3233d..db131242d 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.4" +__version__ = "2.7.4.post1" from flash_attn.flash_attn_interface import ( flash_attn_func, From 454ce31594aaf0978e394ff9a21635b6f6ce56c4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 18:01:58 -0500 Subject: [PATCH 006/102] [FA3] Compile with nvcc 12.8 instead of 12.3 --- README.md | 2 +- hopper/flash_fwd_launch_template.h | 3 +- hopper/setup.py | 47 +++++++++++++++++------------- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index aa545ceb0..c5d68536d 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Currently released: Requirements: H100 / H800 GPU, CUDA >= 12.3. -For now, we highly recommend CUDA 12.3 for best performance. +We highly recommend CUDA 12.8 for best performance. To install: ```sh diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 16701f160..57d64d6a7 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -191,7 +191,8 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) + static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { diff --git a/hopper/setup.py b/hopper/setup.py index d95be9ad4..0104819c6 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -333,22 +333,19 @@ def open_url(url): return urllib.request.urlopen(request, timeout=300) -def download_and_copy(name, src_path, dst_path, version, url_func): +def download_and_copy(name, src_func, dst_path, version, url_func): if is_offline_build(): return flashattn_cache_path = get_flashattn_cache_path() base_dir = os.path.dirname(__file__) system = platform.system() - try: - arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] - except KeyError: - arch = platform.machine() + arch = platform.machine() + arch = {"arm64": "aarch64"}.get(arch, arch) supported = {"Linux": "linux", "Darwin": "linux"} url = url_func(supported[system], arch, version) + src_path = src_func(supported[system], arch, version) tmp_path = os.path.join(flashattn_cache_path, "nvidia", name) # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path - platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux" - src_path = src_path(platform_name, version) if callable(src_path) else src_path src_path = os.path.join(tmp_path, src_path) download = not os.path.exists(src_path) if download: @@ -364,11 +361,12 @@ def download_and_copy(name, src_path, dst_path, version, url_func): def nvcc_threads_args(): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" + nvcc_threads = os.getenv("NVCC_THREADS") or "2" return ["--threads", nvcc_threads] -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +# NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.8.61"} exe_extension = sysconfig.get_config_var("EXE") @@ -389,24 +387,31 @@ def nvcc_threads_args(): if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") - if bare_metal_version != Version("12.3"): # nvcc 12.3 gives the best perf currently + if bare_metal_version != Version("12.8"): # nvcc 12.8 gives the best perf currently download_and_copy( - name="nvcc", src_path=f"bin", dst_path="bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: - ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") - (*version.split('.')))) + name="nvcc", + # src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) download_and_copy( - name="nvcc", src_path=f"nvvm/bin", dst_path="bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: - ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") - (*version.split('.')))) + name="nvcc", + # src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", + dst_path="nvvm/bin", + version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) base_dir = os.path.dirname(__file__) ctk_path_new = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin") nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc - os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] + # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc + # os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] os.environ["PYTORCH_NVCC"] = nvcc_path_new # Make nvcc executable, sometimes after the copy it loses its permissions os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC) From 803f609aa1c2b7c0f0ddea3a0e7e9fdeaa77e071 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 21:44:20 -0500 Subject: [PATCH 007/102] Fix comment in assert --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index dbbf2f8f8..3af51566b 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -68,7 +68,7 @@ struct CollectiveMainloopFwdSm90 { // Leaving this option here for reference. static constexpr bool Mma0_is_RS = false; // We can have Mma1 (P @ V) with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is enabled"); + static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is disabled"); static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); From 02541ac9e8382f4d8e17f1f2ba0d7de2c792390c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 21:48:03 -0500 Subject: [PATCH 008/102] [CE] Assert logit_scale > 0 --- flash_attn/ops/triton/cross_entropy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 7b0315b97..1b5a415b7 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -166,6 +166,7 @@ def forward( if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: labels = F.pad(labels, (0, 1))[..., :-1] assert labels.data_ptr() % 16 == 0 + assert logit_scale > 0.0 n_rows, n_cols = logits.shape assert labels.shape == (n_rows,) world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) From 2a204125ae71d2010bd3c9634d72a81c63967f3b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 3 Feb 2025 00:19:25 -0500 Subject: [PATCH 009/102] Implement HeadDim_V != HeadDim_QK, support hdimQK=192, hdimV=128 --- hopper/benchmark_attn.py | 34 +- hopper/epilogue_fwd.hpp | 44 +- hopper/flash.h | 7 +- hopper/flash_api.cpp | 119 ++- hopper/flash_fwd_combine.cu | 3 + hopper/flash_fwd_combine_launch_template.h | 2 +- hopper/flash_fwd_kernel_sm90.h | 27 +- hopper/flash_fwd_launch_template.h | 20 +- hopper/generate_kernels.py | 22 +- .../flash_fwd_hdim128_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim128_bf16_paged_sm90.cu | 2 +- ...ash_fwd_hdim128_bf16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim128_bf16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim128_bf16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim128_bf16_paged_split_sm90.cu | 2 +- ...d_hdim128_bf16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim128_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_bf16_sm80.cu | 4 +- .../flash_fwd_hdim128_bf16_sm90.cu | 2 +- ...h_fwd_hdim128_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim128_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim128_bf16_split_sm90.cu | 2 +- ...ash_fwd_hdim128_bf16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim128_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_paged_sm90.cu | 2 +- ...ash_fwd_hdim128_e4m3_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim128_e4m3_paged_split_sm90.cu | 2 +- ...d_hdim128_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_sm90.cu | 2 +- ...h_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_split_sm90.cu | 2 +- ...ash_fwd_hdim128_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim128_fp16_paged_sm90.cu | 2 +- ...ash_fwd_hdim128_fp16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim128_fp16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim128_fp16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim128_fp16_paged_split_sm90.cu | 2 +- ...d_hdim128_fp16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim128_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_sm80.cu | 4 +- .../flash_fwd_hdim128_fp16_sm90.cu | 2 +- ...h_fwd_hdim128_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim128_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim128_fp16_split_sm90.cu | 2 +- ...ash_fwd_hdim128_fp16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim128_fp16_split_softcap_sm90.cu | 2 +- ...flash_fwd_hdim192_128_bf16_packgqa_sm90.cu | 9 + .../flash_fwd_hdim192_128_bf16_paged_sm90.cu | 9 + ...fwd_hdim192_128_bf16_paged_softcap_sm90.cu | 9 + ...h_fwd_hdim192_128_bf16_paged_split_sm90.cu | 9 + ...im192_128_bf16_paged_split_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_bf16_sm90.cu | 9 + ...d_hdim192_128_bf16_softcap_packgqa_sm90.cu | 9 + ...flash_fwd_hdim192_128_bf16_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_bf16_split_sm90.cu | 9 + ...fwd_hdim192_128_bf16_split_softcap_sm90.cu | 9 + ...flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu | 9 + .../flash_fwd_hdim192_128_e4m3_paged_sm90.cu | 9 + ...fwd_hdim192_128_e4m3_paged_softcap_sm90.cu | 9 + ...h_fwd_hdim192_128_e4m3_paged_split_sm90.cu | 9 + ...im192_128_e4m3_paged_split_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_e4m3_sm90.cu | 9 + ...d_hdim192_128_e4m3_softcap_packgqa_sm90.cu | 9 + ...flash_fwd_hdim192_128_e4m3_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_e4m3_split_sm90.cu | 9 + ...fwd_hdim192_128_e4m3_split_softcap_sm90.cu | 9 + ...flash_fwd_hdim192_128_fp16_packgqa_sm90.cu | 9 + .../flash_fwd_hdim192_128_fp16_paged_sm90.cu | 9 + ...fwd_hdim192_128_fp16_paged_softcap_sm90.cu | 9 + ...h_fwd_hdim192_128_fp16_paged_split_sm90.cu | 9 + ...im192_128_fp16_paged_split_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_fp16_sm90.cu | 9 + ...d_hdim192_128_fp16_softcap_packgqa_sm90.cu | 9 + ...flash_fwd_hdim192_128_fp16_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_fp16_split_sm90.cu | 9 + ...fwd_hdim192_128_fp16_split_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim192_bf16_paged_sm90.cu | 2 +- ...ash_fwd_hdim192_bf16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim192_bf16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim192_bf16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim192_bf16_paged_split_sm90.cu | 2 +- ...d_hdim192_bf16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim192_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_bf16_sm80.cu | 4 +- .../flash_fwd_hdim192_bf16_sm90.cu | 2 +- ...h_fwd_hdim192_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim192_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim192_bf16_split_sm90.cu | 2 +- ...ash_fwd_hdim192_bf16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim192_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_paged_sm90.cu | 2 +- ...ash_fwd_hdim192_e4m3_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim192_e4m3_paged_split_sm90.cu | 2 +- ...d_hdim192_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_sm90.cu | 2 +- ...h_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_split_sm90.cu | 2 +- ...ash_fwd_hdim192_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim192_fp16_paged_sm90.cu | 2 +- ...ash_fwd_hdim192_fp16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim192_fp16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim192_fp16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim192_fp16_paged_split_sm90.cu | 2 +- ...d_hdim192_fp16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim192_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_sm80.cu | 4 +- .../flash_fwd_hdim192_fp16_sm90.cu | 2 +- ...h_fwd_hdim192_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim192_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim192_fp16_split_sm90.cu | 2 +- ...ash_fwd_hdim192_fp16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim192_fp16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim256_bf16_paged_sm90.cu | 2 +- ...ash_fwd_hdim256_bf16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim256_bf16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim256_bf16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim256_bf16_paged_split_sm90.cu | 2 +- ...d_hdim256_bf16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim256_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_sm80.cu | 4 +- .../flash_fwd_hdim256_bf16_sm90.cu | 2 +- ...h_fwd_hdim256_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim256_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim256_bf16_split_sm90.cu | 2 +- ...ash_fwd_hdim256_bf16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim256_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_paged_sm90.cu | 2 +- ...ash_fwd_hdim256_e4m3_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim256_e4m3_paged_split_sm90.cu | 2 +- ...d_hdim256_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_sm90.cu | 2 +- ...h_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_split_sm90.cu | 2 +- ...ash_fwd_hdim256_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim256_fp16_paged_sm90.cu | 2 +- ...ash_fwd_hdim256_fp16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim256_fp16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim256_fp16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim256_fp16_paged_split_sm90.cu | 2 +- ...d_hdim256_fp16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim256_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_sm80.cu | 4 +- .../flash_fwd_hdim256_fp16_sm90.cu | 2 +- ...h_fwd_hdim256_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim256_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim256_fp16_split_sm90.cu | 2 +- ...ash_fwd_hdim256_fp16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim256_fp16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_paged_sm90.cu | 2 +- ...lash_fwd_hdim64_bf16_paged_softcap_sm80.cu | 4 +- ...lash_fwd_hdim64_bf16_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_paged_split_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_paged_split_sm90.cu | 2 +- ...wd_hdim64_bf16_paged_split_softcap_sm80.cu | 4 +- ...wd_hdim64_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_sm90.cu | 2 +- ...sh_fwd_hdim64_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_split_sm90.cu | 2 +- ...lash_fwd_hdim64_bf16_split_softcap_sm80.cu | 4 +- ...lash_fwd_hdim64_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_paged_sm90.cu | 2 +- ...lash_fwd_hdim64_e4m3_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_paged_split_sm90.cu | 2 +- ...wd_hdim64_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_sm90.cu | 2 +- ...sh_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_split_sm90.cu | 2 +- ...lash_fwd_hdim64_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_paged_sm90.cu | 2 +- ...lash_fwd_hdim64_fp16_paged_softcap_sm80.cu | 4 +- ...lash_fwd_hdim64_fp16_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_paged_split_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_paged_split_sm90.cu | 2 +- ...wd_hdim64_fp16_paged_split_softcap_sm80.cu | 4 +- ...wd_hdim64_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_sm90.cu | 2 +- ...sh_fwd_hdim64_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_split_sm90.cu | 2 +- ...lash_fwd_hdim64_fp16_split_softcap_sm80.cu | 4 +- ...lash_fwd_hdim64_fp16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_paged_sm90.cu | 2 +- ...lash_fwd_hdim96_bf16_paged_softcap_sm80.cu | 4 +- ...lash_fwd_hdim96_bf16_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_paged_split_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_paged_split_sm90.cu | 2 +- ...wd_hdim96_bf16_paged_split_softcap_sm80.cu | 4 +- ...wd_hdim96_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_sm90.cu | 2 +- ...sh_fwd_hdim96_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_split_sm90.cu | 2 +- ...lash_fwd_hdim96_bf16_split_softcap_sm80.cu | 4 +- ...lash_fwd_hdim96_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_paged_sm90.cu | 2 +- ...lash_fwd_hdim96_e4m3_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_paged_split_sm90.cu | 2 +- ...wd_hdim96_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_sm90.cu | 2 +- ...sh_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_split_sm90.cu | 2 +- ...lash_fwd_hdim96_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_paged_sm90.cu | 2 +- ...lash_fwd_hdim96_fp16_paged_softcap_sm80.cu | 4 +- ...lash_fwd_hdim96_fp16_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_paged_split_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_paged_split_sm90.cu | 2 +- ...wd_hdim96_fp16_paged_split_softcap_sm80.cu | 4 +- ...wd_hdim96_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_sm90.cu | 2 +- ...sh_fwd_hdim96_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_split_sm90.cu | 2 +- ...lash_fwd_hdim96_fp16_split_softcap_sm80.cu | 4 +- ...lash_fwd_hdim96_fp16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdimall_bf16_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_bf16_paged_sm90.cu | 1 + ...ash_fwd_hdimall_bf16_paged_softcap_sm90.cu | 1 + ...flash_fwd_hdimall_bf16_paged_split_sm90.cu | 1 + ...d_hdimall_bf16_paged_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_bf16_sm90.cu | 1 + ...h_fwd_hdimall_bf16_softcap_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_bf16_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_bf16_split_sm90.cu | 1 + ...ash_fwd_hdimall_bf16_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_paged_sm90.cu | 1 + ...ash_fwd_hdimall_e4m3_paged_softcap_sm90.cu | 1 + ...flash_fwd_hdimall_e4m3_paged_split_sm90.cu | 1 + ...d_hdimall_e4m3_paged_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_sm90.cu | 1 + ...h_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_split_sm90.cu | 1 + ...ash_fwd_hdimall_e4m3_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_paged_sm90.cu | 1 + ...ash_fwd_hdimall_fp16_paged_softcap_sm90.cu | 1 + ...flash_fwd_hdimall_fp16_paged_split_sm90.cu | 1 + ...d_hdimall_fp16_paged_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_sm90.cu | 1 + ...h_fwd_hdimall_fp16_softcap_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_split_sm90.cu | 1 + ...ash_fwd_hdimall_fp16_split_softcap_sm90.cu | 1 + hopper/mainloop_fwd_sm80.hpp | 15 +- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 89 +- hopper/paged_kv.h | 62 +- hopper/setup.py | 4 +- hopper/test_flash_attn.py | 850 +++++++++--------- hopper/test_util.py | 9 +- hopper/tile_size.h | 6 +- 306 files changed, 1312 insertions(+), 921 deletions(-) create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 5f7522a8a..e61cea9e6 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -56,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) -def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)): +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 else: @@ -67,7 +67,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size= col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2 + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) def convert_to_cudnn_type(torch_type): @@ -263,7 +263,7 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128]: # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -for headdim in [128]: +for headdim in [192]: nheads = dim // headdim # headdim = 64 # batch_size = 64 @@ -272,6 +272,8 @@ def run(*args, **kwargs): # headdim = 128 nheads_kv = nheads # nheads_kv = nheads // 4 + headdim_v = headdim + # headdim_v = 128 for batch_size, seqlen in bs_seqlen_vals: num_splits = 1 @@ -285,15 +287,15 @@ def run(*args, **kwargs): # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) - v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() v_fa3 = v if not V_colmajor else v_colmajor # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) - # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) - g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) - o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen) b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2) @@ -320,14 +322,14 @@ def run(*args, **kwargs): for causal in [False, True]: # for causal in [False]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") - nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size) + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) # _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: if not varlen: m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') @@ -343,7 +345,7 @@ def run(*args, **kwargs): repeats=repeats, verbose=False, desc='Fav2') time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]] time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark @@ -356,7 +358,7 @@ def run(*args, **kwargs): # # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean @@ -380,7 +382,7 @@ def run(*args, **kwargs): # nFLOPS_matmul = nFLOPS # nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1] # m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) if not varlen: _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, @@ -396,11 +398,11 @@ def run(*args, **kwargs): # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS') # if causal: @@ -409,7 +411,7 @@ def run(*args, **kwargs): print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') # benchmark_forward(torch.square, k) # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 0f9160602..d8f2c15c9 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -20,11 +20,11 @@ namespace flash { using namespace cute; -template struct CollectiveEpilogueFwd { - using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = TileShape_MNK_PV_; using ClusterShape = ClusterShape_; using Element = Element_; using ArchTag = ArchTag_; @@ -37,21 +37,21 @@ struct CollectiveEpilogueFwd { static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); + static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times // we need to call divmod. - static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads); + // static constexpr int kBlockKGmem = kHeadDimV % 128 == 0 ? 128 : (kHeadDimV % 64 == 0 ? 64 : 32); + // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDimV / kGmemElemsPerStore, NumEpilogueThreads); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); @@ -65,15 +65,15 @@ struct CollectiveEpilogueFwd { Layout>>{})); // Val layout, 8 or 16 vals per store using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK{}))); + decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); + using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); - using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) @@ -109,7 +109,7 @@ struct CollectiveEpilogueFwd { GmemTiledCopyOTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), SmemLayoutOTMA{}, - select<0, 2>(TileShape_MNK{}), + select<0, 1>(TileShape_MNK_PV{}), _1{})), // no mcast for O std::nullptr_t >; @@ -148,7 +148,7 @@ struct CollectiveEpilogueFwd { Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); TMA_O tma_store_O = [&]{ if constexpr (Use_TMA_O) { - return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast } else { return nullptr; } @@ -243,14 +243,14 @@ struct CollectiveEpilogueFwd { // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); // (MMA,MMA_M,MMA_K) - Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); static_assert(decltype(size<0, 0>(taccOcO))::value == 2); static_assert(decltype(size<0, 1>(taccOcO))::value == 2); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } @@ -267,7 +267,7 @@ struct CollectiveEpilogueFwd { // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) @@ -287,7 +287,7 @@ struct CollectiveEpilogueFwd { } } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } if constexpr (Use_smem) { GmemTiledCopyO gmem_tiled_copy_O; @@ -305,7 +305,7 @@ struct CollectiveEpilogueFwd { } if constexpr (!PackGQA) { // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } @@ -361,7 +361,7 @@ struct CollectiveEpilogueFwd { int thread_idx, cute::tuple const& block_coord ) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); auto [m_block, bidh, bidb, split_idx] = block_coord; flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; @@ -391,12 +391,12 @@ struct CollectiveEpilogueFwd { GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); if constexpr (!PackGQA) { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_fragment_like(tOgO); cute::clear(tOrO); @@ -406,7 +406,7 @@ struct CollectiveEpilogueFwd { ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); cute::clear(tOrO); PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); diff --git a/hopper/flash.h b/hopper/flash.h index 4559a1352..9f8cb1bca 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -65,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params { int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; int total_q, total_k, total_knew; int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim // The scaling factors for the kernel. float scale_softmax; @@ -197,9 +198,9 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 82643d9ff..94fcf5d78 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -271,36 +271,48 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP16."); @@ -309,19 +321,25 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { #ifndef FLASHATTENTION_DISABLE_FP8 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + } else { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP8."); @@ -339,28 +357,34 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively // so that kBlockM is smaller and we have more parallelism. if (params.is_fp32) { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { + } else if (params.dv <= 128) { run_mha_fwd_combine_(params, stream); - } else { + } else if (params.dv <= 256) { run_mha_fwd_combine_(params, stream); + } else { + run_mha_fwd_combine_(params, stream); } } else if (params.is_bf16) { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { + } else if (params.dv <= 128) { run_mha_fwd_combine_(params, stream); - } else { + } else if (params.dv <= 256) { run_mha_fwd_combine_(params, stream); + } else { + run_mha_fwd_combine_(params, stream); } } else { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { + } else if (params.dv <= 128) { run_mha_fwd_combine_(params, stream); - } else { + } else if (params.dv <= 256) { run_mha_fwd_combine_(params, stream); + } else { + run_mha_fwd_combine_(params, stream); } } #else @@ -378,7 +402,7 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) { // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif @@ -392,10 +416,10 @@ inline int get_num_splits(Flash_fwd_params const& params) { // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit - auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); @@ -460,10 +484,10 @@ inline int round_up_headdim(int head_size) { std::vector mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &out_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional &cu_seqlens_q_, // b+1 std::optional &cu_seqlens_k_, // b+1 std::optional &cu_seqlens_k_new_, // b+1 @@ -551,6 +575,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int num_heads = q.size(-2); int const head_size = q.size(-1); + int const head_size_v = v.size(-1); int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); @@ -564,6 +589,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const max_headdim = get_max_headdim(); TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + TORCH_CHECK(head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128, "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]"); + TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM // TODO: check this @@ -583,15 +612,15 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (!paged_KV) { if (!is_varlen_k) { CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } } else { CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); } @@ -610,6 +639,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); auto opts = q.options(); auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; @@ -620,16 +650,19 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); if (!is_varlen_q) { - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); } else { - CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); } } else { - out = torch::empty_like(q, opts.dtype(out_type)); + out = !is_varlen_q + ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) + : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = round_up_headdim(head_size_v); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -667,6 +700,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.total_k = total_k; params.sink_token_length = sink_token_length; params.b_k = batch_size_k; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; if (paged_KV) { params.page_table = page_table.data_ptr(); @@ -702,10 +737,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); if (!is_varlen_k_new) { CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); } else { CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); } params.seqlen_knew = seqlen_k_new; @@ -772,12 +807,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (params.num_splits > 1) { TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); if (!is_varlen_q) { - out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); params.oaccum_batch_stride = out_accum.stride(1); params.lseaccum_batch_stride = softmax_lse_accum.stride(1); } else { - out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); } params.is_fp32 = false; @@ -1258,7 +1293,7 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x const int seqlen = sizes[2]; const int num_heads = sizes[3]; const int head_size_og = sizes[4]; - TORCH_CHECK(head_size_og <= 256, "FlashAttention combine only supports head dimension at most 256"); + TORCH_CHECK(head_size_og <= 512, "FlashAttention combine only supports head dimension at most 512"); TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); @@ -1307,7 +1342,7 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x params.b = batch_size; params.h = num_heads; params.seqlen_q = seqlen; - params.d = head_size; + params.dv = head_size; params.num_splits = num_splits; params.oaccum_split_stride = out_partial_padded.stride(0); params.oaccum_row_stride = out_partial_padded.stride(2); diff --git a/hopper/flash_fwd_combine.cu b/hopper/flash_fwd_combine.cu index 5b7d9eed6..57392ee75 100644 --- a/hopper/flash_fwd_combine.cu +++ b/hopper/flash_fwd_combine.cu @@ -6,11 +6,14 @@ template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 33e66c21f..5cbed2b0c 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -24,7 +24,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { typename CombineKernel::Arguments args { static_cast(params.oaccum_ptr), - {!Varlen ? params.seqlen_q : params.total_q, params.d, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial + {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial static_cast(params.softmax_lseaccum_ptr), {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index e5411042d..05ce4d0ae 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -45,10 +45,11 @@ class FlashAttnFwdSm90 { static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; static constexpr bool PackGQA = CollectiveMainloop::PackGQA; static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; // Mainloop derived types - using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; using TiledMma0 = typename CollectiveMainloop::TiledMma0; using TiledMma1 = typename CollectiveMainloop::TiledMma1; using ArchTag = typename CollectiveMainloop::ArchTag; @@ -176,7 +177,7 @@ class FlashAttnFwdSm90 { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; @@ -222,6 +223,11 @@ class FlashAttnFwdSm90 { pipeline_params_k.producer_arv_count = NumProducerThreads; } + PipelineParamsV pipeline_params_v = pipeline_params_k; + if constexpr (Use_TMA_KV && !SameHeadDim) { + pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + } + MainloopPipelineK pipeline_k = [&] { if constexpr (Use_TMA_KV) { return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); @@ -234,9 +240,9 @@ class FlashAttnFwdSm90 { if constexpr (!Transpose_V) { static_assert(is_same_v); if constexpr (Use_TMA_KV) { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k, ClusterShape{}); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); } else { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); } } else { PipelineParamsV pipeline_params_v; @@ -256,11 +262,11 @@ class FlashAttnFwdSm90 { // However, the thread role isn't used in the pipeline implementation. MainloopPipelineVt pipeline_vt = [&] { if constexpr (Use_TMA_KV) { - pipeline_params_k.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k, ClusterShape{}); + pipeline_params_v.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_v, ClusterShape{}); } else { - pipeline_params_k.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k); + pipeline_params_v.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_v); } }(); @@ -272,6 +278,9 @@ class FlashAttnFwdSm90 { pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; pipeline_params_kv_new.num_consumers = NumMmaThreads; auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + if constexpr (!SameHeadDim) { + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + } auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); CollectiveMainloop collective_mainloop; @@ -357,7 +366,7 @@ class FlashAttnFwdSm90 { work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 1>(TileShape_MNK_PV{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; // If there's tanh softcap, the scaling will be done before tanh. auto block_coord = work_tile_info.get_block_coord(params.scheduler); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 57d64d6a7..3f4bea96e 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -23,7 +23,7 @@ using namespace cute; -template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { @@ -35,8 +35,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); - static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); static constexpr bool Mma1_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); @@ -46,13 +46,14 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); using TileShape_MNK = cute::Shape, Int, Int>; + using TileShape_MNK_PV = cute::Shape, Int, Int>; using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t(params.v_ptr), + params.dv, // headdim_v v_strides, // stride_V static_cast(params.knew_ptr), {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new @@ -179,7 +181,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; @@ -189,7 +191,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; @@ -197,7 +199,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + run_flash_fwd(params, stream); }); }); }); diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index e741c1382..7a5eb47d0 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -38,7 +38,7 @@ KERNEL_IMPL_TEMPLATE_FWD_SM90 = """#include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} -template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif """ @@ -46,8 +46,8 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} -template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif """ @@ -85,6 +85,7 @@ class Kernel: sm: int dtype: str head_dim: int + head_dim_v: int split: bool paged_kv: bool softcap: bool @@ -98,14 +99,15 @@ def template(self) -> str: # Always enable PackGQA for PagedKV or Split to reduce compilation packgqa = self.packgqa or self.paged_kv or self.split return KERNEL_IMPL_TEMPLATE_FWD_SM90.format( - ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], + HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower() ) else: # Always enable PackGQA for Sm8x to reduce compilation return KERNEL_IMPL_TEMPLATE_FWD_SM8x.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(True).lower() ) @@ -117,13 +119,13 @@ def template(self) -> str: ) else: return KERNEL_IMPL_TEMPLATE_BWD_SM8x.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, SOFTCAP=str(self.softcap).lower() ) @property def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" + return f"flash_{self.direction}_hdim{self.head_dim}{f'_{self.head_dim_v}' if self.head_dim_v != self.head_dim else ''}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" def get_all_kernels() -> List[Kernel]: @@ -133,9 +135,11 @@ def get_all_kernels() -> List[Kernel]: if packgqa and (sm < 90 or (sm >= 90 and (paged_kv or split))): continue if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x: - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 192: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu index 18879eff6..affc7a4dd 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu index 35c0ad78f..7e13614bf 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu index 7a39869a0..670041341 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu index fb7ba5cae..f315fbb45 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu index 296ec9e91..bde3024a4 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu index 8cffb6de8..2724463e6 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu index 12d564ce3..a38a1d5cf 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu index 845b1fa5d..284eeba18 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu index 25fbfda38..0c40ddba8 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu index 1130ca747..cc89c4d5d 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu index 502bc1d17..3a236b712 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu index 537e42ba5..8449104c5 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu index 2255e7949..b152b90ba 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu index 086f55b35..8cc4fed17 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu index 54590eebb..1db3f1e6d 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu index af322d1d1..9b3e294f1 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu index 3e83398e7..07bd687fc 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu index 3f917d26a..5f44833b1 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu index 87c78f289..9f95ca29f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu index e56b64c3d..ad97737d4 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu index 8202bfadd..d77d37ec0 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu index ee7439b27..ae05c7ce5 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu index 812239ef5..bc52a9f35 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu index 74e52315b..480d485d0 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu index fe0bff6a1..d3da5f4e6 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu index 55df1a666..1c1c2d820 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu index 03a9c61e4..371d933e3 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu index 67ba153c6..7491148dc 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu index 9f7bcec9e..d04159a62 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu index 7116702f3..28ad6c149 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu index 04f18ac0f..7afb267e3 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu index c7c7c9e69..69758584c 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu index b4ea8bc33..3be45956b 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu index ec99965c9..698095dad 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu index d1dd96452..16d443a9a 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu index 83274ca3f..1e8f6af71 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu index 80e9eb0e2..4ec688861 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu index fbbc273b7..670b5952d 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu index f4f4829f3..b9778dc92 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu index c768a89fd..446e917c7 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu index 89c2db39e..fd62a2c54 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu index 5b87286ae..0a397f4ac 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu index 750609782..4d3c553e2 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu index d3b7b0f87..77621846f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu index 4d8625cd6..7d217ac27 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu index f6f129c55..0b6430abc 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu new file mode 100644 index 000000000..ea1e266f8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu new file mode 100644 index 000000000..2d7488fef --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu new file mode 100644 index 000000000..8718571e3 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu new file mode 100644 index 000000000..f7dfc18fc --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..935f5a0fe --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu new file mode 100644 index 000000000..3f4d858ff --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..54d720efe --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu new file mode 100644 index 000000000..b9b93af4f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu new file mode 100644 index 000000000..39d9167b9 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu new file mode 100644 index 000000000..0f8645801 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu new file mode 100644 index 000000000..bd6f4df8f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu new file mode 100644 index 000000000..1824b86c6 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu new file mode 100644 index 000000000..87dd01725 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu new file mode 100644 index 000000000..6594d5601 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..d7dc84ebc --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu new file mode 100644 index 000000000..b9d6e54cb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..a8c47652e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu new file mode 100644 index 000000000..32d17c766 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu new file mode 100644 index 000000000..365017c25 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu new file mode 100644 index 000000000..82cfdf040 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu new file mode 100644 index 000000000..f3254936a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu new file mode 100644 index 000000000..931a6dbf8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu new file mode 100644 index 000000000..5c8877a75 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu new file mode 100644 index 000000000..1e230ab08 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..03716c862 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu new file mode 100644 index 000000000..54c66c955 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..e5e0ec47d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu new file mode 100644 index 000000000..e4411b5db --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu new file mode 100644 index 000000000..157ed06dd --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu new file mode 100644 index 000000000..7ef5adc9e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu index 96243edf0..bf8386b82 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu index a51a89458..cbc6f9884 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu index 515d88a11..d5aa15b5c 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu index e5a154c18..b8593612d 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu index 2bd860c77..a03514d91 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu index 6e1d80378..df547749e 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu index 942685e14..1ddb19162 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu index d6050520e..cefffcd21 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu index 7ee500a80..3d4333b9e 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu index 1f9d8bfd5..35a2abef8 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu index 0313ad1b2..99e34ac0b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu index 8d87eb21f..ed1cf22d5 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu index 081bb31b1..4527d9a27 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu index a9b5aa0de..41fcf8001 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu index d465545ef..704cbcb33 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu index 68c571455..e0ea08215 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu index e1d656e5a..a9c00408a 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu index 57d1c73d8..1497e7aa8 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu index 5104d4398..c66ea9bac 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu index cbc61f27e..a7e472b47 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu index f08ba1459..9f090aeed 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu index e413758de..2205168a6 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu index c8205c160..2a01898b5 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu index f0db959e0..888e241a9 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu index 249cae97f..2a6bde7a3 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu index 14b073deb..3d315187b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu index 8152dbaa6..3c3d09380 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu index d0b0df027..4ca103566 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu index 24f3e128d..16debf277 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu index 6eabe0ee2..43c261571 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu index 5c780da81..d9d483838 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu index 5a9436601..70543998d 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu index 9815dd135..c30c7e3b8 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu index 66fc2cb8a..7ae26e69c 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu index 2ceddd8ca..155b5a539 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu index 4c64bc61c..3e6173c31 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu index 6ad1a1529..e1e3191a2 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu index f0ee8c015..8272ecb76 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu index 4a9583196..74606c393 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu index 2b65a88f0..89a58502b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu index e324a9326..b13373806 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu index a8be65709..1335fad7f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu index 1ad82d7ed..18c31bdfc 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu index 75f53ee4f..18a5603cf 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu index 09f765263..4e99c7db0 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu index e5299154c..82f8204aa 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu index 364579e1b..cb851a771 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu index a5f821bec..ae2871c16 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu index 364bd2b3a..ed24fbffe 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu index 3d2e337e1..ffca9c7f8 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu index 310c4a5c3..57a06bd6e 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu index 96f5bbf3a..ccdcf21e4 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu index 7d3131bd5..c2bc77877 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu index 7715a5253..6bba953fc 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu index 686bdfa5c..25c96174c 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu index 97fdc0094..f172239e5 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu index 25a90d3be..9dde6adb0 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu index 4c91ee5bc..2317adef8 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu index ef12a584c..b9b3b7486 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu index e4e746f9d..c57a5a30a 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu index 99924af52..4f59a6aea 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu index 705582b9f..2c2de1574 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu index 7e9690120..0dbd062c7 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu index 058eca375..bee54c702 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu index 679066d54..c02e68334 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu index e4ce6f9aa..02b50b98b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu index 03eff4c6f..6599de63b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu index 26df5e592..a1cdc775c 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu index 57de7421d..6d01be60f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu index 53974f3e6..968bbf36f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu index 24e1f6356..d564a6221 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu index a2fc325da..cb5bccc17 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu index 2c1f5f56f..146a7bc34 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu index 7cbdff3e8..a195e0931 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu index b81bf0b99..045fc71be 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu index 88a00e912..a31da2edd 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu index c28edfd8f..7382b58a2 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu index dbcd16330..87ca31ce9 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu index 63620ec90..60f4d6ebb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu index d8c11ee6a..e0d5d318b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu index 4af31d0bf..dec7db046 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu index c7a04dc47..7b71f4352 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu index 9bca3a1c5..08fc989af 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu index acd0fa660..2cc8b5b86 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu index a38430fb3..644e26846 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu index 03bb0516f..1ebcec8b3 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu index 8ea90bd41..780ade7f6 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu index f91443264..bfcffe2a3 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu index e7e1cecd1..ba4ba78ad 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu index 18b79da92..f04260ba4 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu index 1c1c9470d..33c78e530 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu index 6cadc2641..838842092 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu index 4b650f53c..4134d7d80 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu index 29cb3fe18..11e3503b0 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu index 2612bc9c9..67e39bd73 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu index 4c5fae060..c37844daa 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu index c0b58521b..f0c40e2f8 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu index 0a0588472..3ed969490 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu index b42119971..4a16aae66 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu index 7f337595b..b5b5fc26b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu index c4c35a18c..3b29be627 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu index 9ea549e11..5f1c298c4 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu index 8ffc852e3..64895643d 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu index 7143da2f7..dd508590d 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu index 4f7cd4f8e..8411b6fcc 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu index 5a9bb1420..b5b4f4077 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu index dc9b71a5b..e608da04b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu index 4c5440436..c69b78ac3 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu index d988a48f9..170cdb5cb 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu index c6ae246e7..ef0d1e921 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu index 761a62556..6a7fc29dd 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu index a74d7c2c3..faeb6c487 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu index 6d48fb099..655258d51 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu index 0e49f26aa..4bd8ad8f2 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu index f780a8eb7..657820f28 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu index 948c8b17c..cb0955d1a 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu index 519783851..357b64e83 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu index d5392ef3b..c12079258 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu index 06086d408..21687f893 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu index a15ab4c60..4df8ed64d 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu index 7038c0ad7..b601195d7 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu index 9a805fd3e..ced475318 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu index b23cb43e7..03090f73c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu index c18f470fc..d6fe1559c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu index d61b04a07..7b5ae4a56 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu index 1d33fe12e..6c603b4dc 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu index 03ac4d2f8..26d25fc19 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu index 7b031a490..05a0baf18 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu index 77dbc5812..3a4577653 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu index 6bae5faa5..9b80bae51 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu index 30f666a73..f6810efaf 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu index 358e813ec..98c018893 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu index f5df3f502..a10dfaca7 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu index f16185c3a..b912a8144 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu index 796e4d63a..8603c396e 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu index 6eeb97741..dc55dbc66 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu index aa1d2cd05..ef4884497 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu index 5a92ebddd..b1c0ead6e 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu index 78c390e5e..5d76d0fff 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu index 2b5aaff0d..44ea823d2 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu index f0fa3ac63..30fe62350 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu index 0d9407b2c..6eb12dc80 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu index 223b6783e..b806fc9d5 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu index 2f49d5f5a..8f0a26da0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu index 9661156d8..6de2819a1 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu index b5f6d7f87..16927295b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu index 82b827e18..084130720 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu index 042dd0cc7..7d4dcdc29 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu index 4712aed6c..b4dfbf7f8 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu index 8295033de..1fa048752 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu index 21c43e6db..e0b6a75e6 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu index d3317ad62..e257b42f7 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu index 86218988c..f97ab4733 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu index 7a6450373..cee43ef94 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu index 34c1a3d3f..0442e1f94 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu index 96affd254..bc71fa9e7 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu index 489717ff2..b61dd7188 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu index 69917aa1e..f47e1f5cd 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu index 3e3cc66f6..215752f1b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu index e5f53e49c..207afc792 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu index 0899aa898..6c38c0833 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu index 22f4cf6b1..dc2eb35dc 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu index d601d694d..f04e8bca6 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu index 1c5ba9b00..2697f6910 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu index 8073b677a..e7a98b2e6 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu index 857be3592..98fb39c86 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu index 6931ffa27..cb938ad93 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu index 84facb47e..e2dc45c79 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu index 878d160ff..64f99c05a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu index e5561f7d6..3fdbbf23b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu index 30474d354..ffe202ee3 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu index 074f7232f..42740f022 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu index 734abb7b0..829929980 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu index 285e7ef52..d6a330432 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu index d552e45db..39c774e6f 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu index 64ca02345..bc54be11e 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu index 3d8bb7c27..a68790500 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu index 6fab8802c..3bca3065c 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu index 1fb30696d..985692b9f 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu index af9b88d9a..3c99cb6b5 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu index 5f9794a98..cf77a1ae8 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu index c906649ac..f9a46a44d 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu index 2d7ac26e2..9b4dbbba5 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu index 171f28e9c..da5373fd1 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu index 8b659e832..e8ed21cda 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu index c84d02b6d..f7de8fa20 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_paged_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu index 6aaf7d12f..64e5ce4a3 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu index 117121414..44619cce5 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu index 617572308..a05973582 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu index 2aac1970b..daea288fe 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_sm90.cu" #include "flash_fwd_hdim128_bf16_sm90.cu" #include "flash_fwd_hdim192_bf16_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_sm90.cu" #include "flash_fwd_hdim256_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu index be0c5af08..62640192c 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu index fd5893c59..79b0d52fa 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu index bcde9c945..333406cb4 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_split_sm90.cu" #include "flash_fwd_hdim128_bf16_split_sm90.cu" #include "flash_fwd_hdim192_bf16_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" #include "flash_fwd_hdim256_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu index 160eb3a18..b6c1fb54c 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu index 28819a690..abf0b10e4 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu index 933ad9827..22b310e5a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_paged_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu index a934f7d99..f9eed0732 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu index 8475e878a..b91c7f85a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu index dd1405b17..a6b215bfd 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu index 7e7d806c6..ddec44c68 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_sm90.cu" #include "flash_fwd_hdim128_e4m3_sm90.cu" #include "flash_fwd_hdim192_e4m3_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_sm90.cu" #include "flash_fwd_hdim256_e4m3_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu index f973a4e41..81601b9ec 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu index 30390838d..ae9a362c1 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu index 0b629bd2b..163ee761b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu index 818c7fafb..ba2d427dd 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu index 6652824d0..34d176348 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu index 05d11e2e2..326a2ea90 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_paged_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu index b638138eb..a9e032a07 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu index 3619a2175..d7cc300b8 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu index 3a408ceac..fa4de4e29 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu index eec11be91..cb3455866 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_sm90.cu" #include "flash_fwd_hdim128_fp16_sm90.cu" #include "flash_fwd_hdim192_fp16_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_sm90.cu" #include "flash_fwd_hdim256_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu index ca2a1e1b8..5dbd70ec5 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu index 8cf31a8a8..9a97b9604 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu index 5ee7ace63..5aacbf026 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_split_sm90.cu" #include "flash_fwd_hdim128_fp16_split_sm90.cu" #include "flash_fwd_hdim192_fp16_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" #include "flash_fwd_hdim256_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu index 4da0ee704..cfaabd990 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index e43904518..2d2ba06f2 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -22,7 +22,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm80 { @@ -30,6 +30,7 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kStages = Stages; static_assert(kStages > 0, "kStages must be greater than 0"); using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -177,6 +178,7 @@ struct CollectiveMainloopFwdSm80 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -218,6 +220,7 @@ struct CollectiveMainloopFwdSm80 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -272,7 +275,7 @@ struct CollectiveMainloopFwdSm80 { // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, @@ -430,11 +433,11 @@ struct CollectiveMainloopFwdSm80 { } cute::cp_async_fence(); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k ); @@ -730,11 +733,11 @@ struct CollectiveMainloopFwdSm80 { params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 3af51566b..da5f902ea 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -27,7 +27,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm90 { @@ -35,6 +35,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -53,6 +54,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Use_TMA_KV = !PagedKV; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); + static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; using SeqlenInfo_t = flash::SeqlenInfoQKNewK; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -84,9 +86,9 @@ struct CollectiveMainloopFwdSm90 { std::conditional_t< !Mma1_is_RS, decltype(cute::GMMA::ss_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()), + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()) + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()) >{}, AtomLayoutMNK{})); @@ -107,25 +109,25 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVt = decltype(tile_to_shape( SmemLayoutAtomVt{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVtMma = decltype(tile_to_shape( SmemLayoutAtomVtMma{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); // Only used if we're using cp.async to load V using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK{})), Int>()); using SmemLayoutVCpAsync = decltype(tile_to_shape( SmemLayoutAtomVCpAsync{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + make_shape(shape<1>(TileShape_MNK{}), Int{}, Int{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); @@ -135,26 +137,26 @@ struct CollectiveMainloopFwdSm90 { // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. // For FP16/BF16 we don't do any transposing. - static_assert(!Transpose_V || (kHeadDim % 32 == 0 && kBlockN % 32 == 0)); - static constexpr bool kHeadDim_multiple_64 = kHeadDim % 64 == 0; - // Either kHeadDim is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), + static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); + static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0; + // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). - static_assert(!Transpose_V || (kHeadDim_multiple_64 || kBlockN % 64 == 0)); - using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; - using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; + static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0)); + using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; + using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; using LDSM_value_shape = Shape<_2, _2, _1, _4>; using LDSM_value_stride = Stride<_1, _2, _16, _4>; - using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; + using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; using S2RTiledCopyVt = decltype(make_tiled_copy( Copy_Atom{}, Layout{}, Layout{})); - using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; - using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; + using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; + using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; using STSM_value_shape = Shape<_1, _4, _2, _2>; using STSM_value_stride = Stride<_0, _1, _4, _8>; using STSM_divide_shape = Shape<_8, _16>; - // These will not permute the columns of V (the kHeadDim dimension) but incur bank conflicts + // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. // using STSM_value_shape = Shape<_2, _4, _1, _2>; @@ -168,14 +170,15 @@ struct CollectiveMainloopFwdSm90 { using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); // We use CpAsync for K and V if PagedKV and AppendKV, since TMA doesn't work there + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will // load twice from the same row. - static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); @@ -221,14 +224,13 @@ struct CollectiveMainloopFwdSm90 { GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); - static_assert(TmaTransactionBytesK == TmaTransactionBytesV); using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; using MainloopPipelineK = std::conditional_t>; @@ -294,6 +296,7 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -335,6 +338,7 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -388,12 +392,14 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_K), select<1, 0, 2, 3>(args.stride_V)); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), + make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), + select<1, 0, 2, 3>(args.stride_V)); TMA_V tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); TMA_K tma_load_K_new = make_tma_copy_B_sm90( @@ -402,12 +408,14 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), select<1, 0, 2, 3>(args.shape_K_new), select<1, 0, 2, 3>(args.stride_V_new)); + Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), + make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), + select<1, 0, 2, 3>(args.stride_V_new)); TMA_V tma_load_V_new = make_tma_copy( GmemTiledCopyKV{}, cute::conditional_return(mVnew, mV), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); @@ -429,7 +437,7 @@ struct CollectiveMainloopFwdSm90 { // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, @@ -555,12 +563,13 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K))(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) @@ -573,11 +582,11 @@ struct CollectiveMainloopFwdSm90 { Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k ); @@ -1210,7 +1219,7 @@ struct CollectiveMainloopFwdSm90 { Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K_new))(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) @@ -1306,7 +1315,7 @@ struct CollectiveMainloopFwdSm90 { int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{})); // (N, K_v, _) static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); @@ -1317,11 +1326,11 @@ struct CollectiveMainloopFwdSm90 { params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); @@ -1347,6 +1356,12 @@ struct CollectiveMainloopFwdSm90 { Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); #pragma unroll for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); } + Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{})); // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v) + Tensor tVcV = cute::conditional_return(tKcK, gmem_thr_copy_kv.partition_D(cV)); + Tensor tVpV_ = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; } + Tensor tVpV = cute::conditional_return(tKpK, tVpV_); auto store_K = [&] (int const n_block, auto const& smem_pipe_read) { int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); @@ -1392,7 +1407,7 @@ struct CollectiveMainloopFwdSm90 { Tensor tVgV_cur = tVgV(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tKcK, tKpK, n_limit); + gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit); } else { paged_kv_manager.store_V(n_block, tVsV_cur); } diff --git a/hopper/paged_kv.h b/hopper/paged_kv.h index 0f710e549..9431f384f 100644 --- a/hopper/paged_kv.h +++ b/hopper/paged_kv.h @@ -14,7 +14,7 @@ namespace flash { using namespace cute; -template +template struct PagedKVManager { // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), // load_page_table(2), load_K(2), load_V(1), etc. @@ -23,14 +23,17 @@ struct PagedKVManager { // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for // rotary where we want each thread to have at least 2 loads per row. + static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV); + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); + // We use CpAsync for K and V if PagedKV, since TMA doesn't work there static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // In the case of PackGQA, this reduces the number of times we need to call divmod. - static_assert(kHeadDim % LoadsPerRow_LB == 0, "Headdim must be a multiple of LoadsPerRow_LB"); - static constexpr int kBytePerRow = kHeadDim / LoadsPerRow_LB * sizeof(Element); + static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, "Headdim and HeaddimV must be a multiple of LoadsPerRow_LB"); + static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); @@ -59,6 +62,8 @@ struct PagedKVManager { using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); + using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortVpV = decltype(make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{})); // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, // since those require int64_t arithmetic. We optimize by having threads split this work. @@ -66,6 +71,7 @@ struct PagedKVManager { // that each thread needs to load for the case of hdim 128 and kBlockN = 176. // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. + static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{}))); static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); using TensorPageOffset = decltype(make_tensor>(Shape>{})); using TensorKVPtr = decltype(make_tensor(Shape>{})); @@ -79,15 +85,15 @@ struct PagedKVManager { TensorPageTable mPageTable; TensorKV mK_paged, mV_paged; TensortKpK tKpK; + TensortVpV tVpV; TensorPageOffset tPrPageOffset; TensorKVPtr tPrVPtr; - CUTLASS_DEVICE PagedKVManager(int const* const ptr_page_table, ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, - Element* const ptr_V, StrideKV const &stride_V, + Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, cutlass::FastDivmod const &page_size_divmod, int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k ) @@ -100,13 +106,19 @@ struct PagedKVManager { { mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); - mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_K, stride_V)(_, _, bidh, _); + auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K)); + mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _); tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); #pragma unroll for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } + Tensor tVpV_ = make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + #pragma unroll + for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_K); } + tVpV = cute::conditional_return(tKpK, tVpV_); }; template @@ -200,27 +212,27 @@ struct PagedKVManager { // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); - int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVsV); ++m) { // Faster to rely on the cp.async to clear smem that are out of bound, // rather than calling cute::clear directly. // We have to be careful not to write to smem past `kBlockN` if !EvenN. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked - if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKcK(_0{}, m, _0{})) < kBlockN) { - bool const should_load = !Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) { + bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); } } } @@ -269,24 +281,24 @@ struct PagedKVManager { if constexpr (KV_Same_Iter) { compute_V_ptr(); } // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); GmemTiledCopyKVStore gmem_tiled_copy_kv_store; - int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVrV); ++m) { - bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); if (should_load) { #pragma unroll for (int k = 0; k < size<2>(tVrV); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - if (tKpK(_0{}, k)) { + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tVpV(_0{}, k)) { cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); } } diff --git a/hopper/setup.py b/hopper/setup.py index 0104819c6..db8990255 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -182,10 +182,10 @@ def sanitize_flags(flags): # to make this work on Windows too. nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' cuda_compile_rule_sm80 = ['rule cuda_compile_sm80'] + cuda_compile_rule[1:] + [ - f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' ] cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [ - f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' ] cuda_compile_rule.append( f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 1fe43e21f..d0590b5f1 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -113,88 +113,89 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4) - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - # window_size = (-1, -1) if not local else (16, 0) - if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] - if V_colmajor: - v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - sink_token_length=sink_token_length, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - sink_token_length=sink_token_length, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # exp_sum = s_tmp.sum(-1) - # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) - # lse_ref = torch.logsumexp(qk, dim=-1) - - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out, lse = flash_attn_func( - q, - k, - v, + for dv in [128, d] if d > 128 and d <= 192 else [d]: + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + sink_token_length=sink_token_length, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, causal=causal, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, - pack_gqa=pack_gqa, - num_splits=num_splits + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + sink_token_length=sink_token_length, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: g = torch.randn_like(out) @@ -320,132 +321,133 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4).detach().requires_grad_() - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] - query_padding_mask = generate_random_padding_mask( - seqlen_q, batch_size, device, mode="random", zero_lengths=False - ) - key_padding_mask = generate_random_padding_mask( - seqlen_k, batch_size, device, mode="random", zero_lengths=True - ) - - def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): - if add_unused: - another_mask = generate_random_padding_mask(max_seq_len, bs, device) - attn_mask = torch.logical_and(padding_mask, another_mask) - unused_mask = torch.logical_xor( - torch.logical_or(padding_mask, another_mask), attn_mask - ) + for dv in [128, d] if d > 128 and d <= 192 else [d]: + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: - attn_mask = padding_mask - unused_mask = None - return attn_mask, unused_mask - - query_padding_mask, query_unused_mask = _gen_unused_masks( - query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device - ) - key_padding_mask, key_unused_mask = _gen_unused_masks( - key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device - ) - - ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, - query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) - q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) - if query_unused_mask is not None: - q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out_unpad, lse = flash_attn_varlen_func( + ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, causal=causal, - q_descale=q_descale, - k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - out = output_pad_fn(out_unpad) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if query_unused_mask is not None: - out.masked_fill_(q_zero_masking, 0.0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") - # Check that FlashAttention's numerical error is at most 3x the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, + max_seqlen_k, + causal=causal, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + softcap=softcap, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: @@ -557,7 +559,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -614,261 +617,262 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input( - output_unpad, indices_q, batch_size, seqlen_q - ) - else: - query_padding_mask = None - q_unpad = q - cu_seqlens_q, max_seqlen_q = None, None - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + for dv in [128, d] if d > 128 and d <= 192 else [d]: + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + query_padding_mask = None + q_unpad = q + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() - cu_seqlens_k_new = None - key_new_padding_mask = None - if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: # k & v are also varlen - key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") - k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) - v_unpad, *rest = unpad_input(v, key_new_padding_mask) + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v else: - k_unpad, v_unpad = k, v - else: - k, v, k_unpad, v_unpad = None, None, None, None - if page_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - page_table = None - else: - ( - k_cache, - v_cache, - page_table, - k_cache_paged, - v_cache_paged, - num_blocks, - ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, device, dtype_ref - ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) - else: - cache_leftpad = None - if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] - else: - cache_batch_idx = None - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - if not new_kv: - key_padding_mask = arange < cache_seqlens_expanded - else: - k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new - key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens - if has_leftpad: - key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) - ) - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) - if rotary_dim > 0: - angle = ( - torch.rand( - seqlen_k if page_size is None else num_blocks * page_size, - rotary_dim // 2, - device=device, + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype_ref ) - * 2 - * math.pi + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, ) - cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - if causal or local: - q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=seqlen_q, + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens ) - # q_ro = q - k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + window_size=window_size, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None ) - else: - cos, sin = None, None - q_ro, k_ro = q, k - # k_cache[:, 64:] = -1 - k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() - v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() - if new_kv: - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + num_splits=num_splits, + return_softmax_lse=True ) - k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") - v_to_update = rearrange(v, "b s ... -> (b s) ...") if varlen_q: - k_to_update = k_to_update[indices_k] - v_to_update = v_to_update[indices_k] - k_cache_ref[update_mask] = k_to_update - v_cache_ref[update_mask] = v_to_update - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - window_size=window_size, - key_leftpad=cache_leftpad, - ) - out_pt, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - window_size=window_size, - upcast=False, - reorder_ops=True, - key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None - ) - q = q.to(dtype) - q_unpad = q_unpad.to(dtype) if varlen_q else None - k_cache = k_cache.to(dtype) - v_cache = v_cache.to(dtype) - k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None - v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None - k = k.to(dtype) if k is not None else None - v = v.to(dtype) if v is not None else None - k_unpad = k_unpad.to(dtype) if k_unpad is not None else None - v_unpad = v_unpad.to(dtype) if v_unpad is not None else None - cos = cos.to(dtype) if cos is not None else None - sin = sin.to(dtype) if sin is not None else None - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: - if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] - ) - else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) - else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype): +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn( num_blocks, page_size, nheads_k, d, device=device, dtype=dtype ) v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype ) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), @@ -990,12 +994,12 @@ def attention_combine_ref(out_partial, lse_partial): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float32]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -@pytest.mark.parametrize("d", [64, 96, 128, 192, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) # @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024, 2048]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) # @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) # @pytest.mark.parametrize("seqlen", [15]) -@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 155]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) # @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) # @pytest.mark.parametrize("num_splits", [128]) def test_flash_attn_combine(num_splits, seqlen, d, dtype): diff --git a/hopper/test_util.py b/hopper/test_util.py index 54eb195eb..cbf441031 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -37,15 +37,16 @@ def generate_qkv( Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d_v) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked @@ -208,7 +209,7 @@ def attention_ref( Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads, head_dim) - v: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) @@ -221,7 +222,7 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. Output: - output: (batch_size, seqlen_q, nheads, head_dim) + output: (batch_size, seqlen_q, nheads, head_dim_v) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 127f518bb..66ab1a7fd 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -8,7 +8,7 @@ // Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( - int headdim, bool is_causal, bool is_local, int element_size=2, + int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { @@ -22,7 +22,7 @@ constexpr std::tuple tile_size_fwd_sm90( // {128, 192, false, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS } else if (headdim <= 192) { - return {128, paged_kv || is_local ? 96 : 112, true, true}; // 128 x 112 hits the limit of smem + return {128, paged_kv || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } @@ -43,7 +43,7 @@ constexpr std::tuple tile_size_fwd_sm90( // Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} constexpr std::tuple tile_size_fwd_sm8x( - bool sm86_or_89, int headdim, bool is_causal, bool is_local, int element_size=2, + bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool paged_kv=false, bool varlen_and_split=false, bool softcap=false, bool append_kv=false) { if (element_size == 2) { From 6d199aa20721fbb51340aff6ec19d70cb03063b9 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 3 Feb 2025 20:06:33 -0500 Subject: [PATCH 010/102] Fix shape_O in epilogue params when kHeadDimV != kHeadDim --- hopper/flash_fwd_launch_template.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 3f4bea96e..de17b39c9 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -126,7 +126,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(!Split ? params.o_ptr : params.oaccum_ptr), - {seqlen_q, params.d, params.h, batch_q, params.num_splits}, // shape_O + {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O {!Split ? params.o_row_stride : params.oaccum_row_stride, _1{}, !Split ? params.o_head_stride : params.oaccum_head_stride, From 86bcd0552ff5e817c23d58e3b476e1185dfd2965 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 3 Feb 2025 20:12:13 -0500 Subject: [PATCH 011/102] Remove old combine.h --- hopper/combine.h | 248 ----------------------------------------------- 1 file changed, 248 deletions(-) delete mode 100644 hopper/combine.h diff --git a/hopper/combine.h b/hopper/combine.h deleted file mode 100644 index c26f7ea56..000000000 --- a/hopper/combine.h +++ /dev/null @@ -1,248 +0,0 @@ - -#pragma once - -#include - -#include -#include "cutlass/layout/layout.h" -#include -#include - -#include "kernel_traits.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SharedStorageLSE { - cute::array_aligned> smem_lse; - cute::array_aligned> smem_valid_splits; -}; - -// DONT use Kernel_traits here to avoid redundant compilation. -// template -template -__global__ void combine_attn_seqk_parallel(Params const params) { - // using Element = typename Kernel_traits::OutputType; - // using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = int64_t; // Kernel_traits::index_t - constexpr int kMaxSplits = 1 << Log_max_splits; - // constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNThreads = 128; //Kernel_traits::kNThreads; - - static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); - static_assert(kNThreads == 128, "We assume that each block has 128 threads"); - - // Shared memory. - // kBlockM + 1 instead of kBlockM to reduce bank conflicts. - //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1]; - extern __shared__ char smem_[]; - using SharedStorage = SharedStorageLSE, Int>, Shape>>; - SharedStorage &shared_storage = - *reinterpret_cast(smem_); - Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape, Int>{}); - Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape>{}); - - // The thread and block index. - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - - const index_t lse_size = params.b * params.h * params.seqlen_q; - //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); - - const index_t row_offset_lse = bidx * kBlockM; - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), - Shape, Int>{}, - make_stride(lse_size, _1{})); - - // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. - // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. - Layout flat_layout = make_layout(lse_size); - Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); - auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); - Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); - Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); - - Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); - - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; - - // Read the LSE values from gmem and store them in shared memory, then transpose them. - constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadLSE + tidx / kBlockM; - const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; - if (row < kMaxSplits) { sLSE(row,col) = lse; } - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } - } - __syncthreads(); - - // Reduce along the kBlockM dimension to determine valid splits (store in SMEM) - // One thread per split. Know NumThreads = 128 >= NumMaxSplits - if (tidx < kMaxSplits) { - bool is_valid_split = false; - #pragma unroll - for (int col = 0; col < kBlockM; ++col) { - if(sLSE(tidx,col) != -INFINITY) { - is_valid_split = true; - } - } - sValidSplits(tidx) = is_valid_split; - } - __syncthreads(); - // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } - - Tensor lse_accum = make_tensor(Shape>{}); - constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); - // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits - // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, - // kBlockM rows, so each time we load we can load 128 / kBlockM rows). - // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; - // static_assert(kThreadsPerSplit <= 32); - static_assert(kRowsPerLoadTranspose <= 32); - static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY; - - } - //return; - - // Compute the logsumexp of the LSE along the split dimension. - ElementAccum lse_max = lse_accum(0); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } - MaxOp max_op; - lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf - float lse_sum = expf(lse_accum(0) - lse_max); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } - SumOp sum_op; - lse_sum = Allreduce::run(lse_sum, sum_op); - // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise - // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { - if (params.unpadded_lse) { - const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; - if (lse_offset < lse_size) { - gLSE_unpadded(lse_offset) = lse_logsum; - } - } else { - gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; - } - } - //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum); - - // Store the scales exp(lse - lse_logsum) in shared memory. - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); } - } - __syncthreads(); - - const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape, Int>{}, - Stride, _1>{}); - constexpr int kBlockN = kNThreads / kBlockM; - using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; - using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store - GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); - Tensor tOrO = make_tensor(shape(tOgOaccum)); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - clear(tOrO); - - // Predicates - Tensor cOaccum = make_identity_tensor(Shape, Int>{}); - //if (cute::thread0()) print_tensor (cOaccum); - // Repeat the partitioning with identity layouts - Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); - Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } - } - // Load Oaccum in then scale and accumulate to O - for (int split = 0; split < params.num_splits; ++split) { - // DONT copy in Oaccum if lse(split) = -inf for all kBlockM. - if(sValidSplits(split)) { - flash::copy( - gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM - ); - #pragma unroll - for (int m = 0; m < size<1>(tOrOaccum); ++m) { - int row = get<0>(tOcOaccum(0, m, 0)); - ElementAccum lse_scale = sLSE(split,row); - if (lse_scale != 0.f) { - #pragma unroll - for (int k = 0; k < size<2>(tOrOaccum); ++k) { - #pragma unroll - for (int i = 0; i < size<0>(tOrOaccum); ++i) { - tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); - //tOrO(i, m, k) += tOrOaccum(i, m, k); - } - } - } - //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); } - } - } - tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; - } - //if (cute::thread0()) { print_tensor(tOrO); } - - Tensor rO = flash::convert_type(tOrO); - // Write to gO - #pragma unroll - for (int m = 0; m < size<1>(rO); ++m) { - const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); - //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); - if (idx < params.b * params.h * params.seqlen_q) { - //print ("final2\n"); - const int batch_idx = idx / (params.h * params.seqlen_q); - const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; - // The index to the rows of Q - const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; - auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride - + head_idx * params.o_head_stride + row * params.o_row_stride; - #pragma unroll - for (int k = 0; k < size<2>(rO); ++k) { - if (Is_even_K || tOpOaccum(k)) { - const int col = get<1>(tOcOaccum(0, m, k)); - Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), - Shape(rO))::value>>{}, Stride<_1>{}); - // TODO: Should check if this is using vectorized store, but it seems pretty fast - copy(rO(_, m, k), gO); - //if (cute::thread0()) { print ("final\n"); print_tensor(gO); } - // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } - // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); - } - } - } - } -} - -} // namespace flash From e3b2400a31e1a094411102dfd474b3c582a42305 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 4 Feb 2025 01:31:38 -0500 Subject: [PATCH 012/102] Fix loading paged V when kHeadDimV != kHeadDim --- hopper/paged_kv.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/paged_kv.h b/hopper/paged_kv.h index 9431f384f..80ee61b9a 100644 --- a/hopper/paged_kv.h +++ b/hopper/paged_kv.h @@ -117,7 +117,7 @@ struct PagedKVManager { Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); #pragma unroll - for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_K); } + for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); } tVpV = cute::conditional_return(tKpK, tVpV_); }; From 9e07d6d3cfc3a5ab3ea134af70e3d879d855aa70 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 4 Feb 2025 02:26:10 -0500 Subject: [PATCH 013/102] Fix shape_V for storing new KV when kHeadDimV != kHeadDim --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index da5f902ea..0a1bf98a1 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -1216,7 +1216,8 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K_new))(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); + Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) @@ -1311,7 +1312,8 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) From f0f25239bd0c5a39c0b481cc5686a835f6c746f5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 4 Feb 2025 02:28:16 -0500 Subject: [PATCH 014/102] Implement the case of LargeHeadDimV --- hopper/epilogue_fwd.hpp | 19 ++- hopper/flash_fwd_kernel_sm90.h | 35 +++- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 200 +++++++++++++++++++++-- hopper/named_barrier.hpp | 2 + 4 files changed, 222 insertions(+), 34 deletions(-) diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index d8f2c15c9..1c13988eb 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -40,6 +40,8 @@ struct CollectiveEpilogueFwd { static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); + static constexpr bool LargeHeadDimV = kHeadDimV > 256; + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) @@ -239,6 +241,7 @@ struct CollectiveEpilogueFwd { bool is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); @@ -254,14 +257,16 @@ struct CollectiveEpilogueFwd { Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } - if constexpr (!PackGQA) { - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); - if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + if (!LargeHeadDimV || warp_group_idx == 0) { + if constexpr (!PackGQA) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + } + } else { + PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } - } else { - PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } // Step 3: Write O from smem -> gmem diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 05ce4d0ae..5e1dceb09 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -46,11 +46,12 @@ class FlashAttnFwdSm90 { static constexpr bool PackGQA = CollectiveMainloop::PackGQA; static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; + static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; + static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; // Mainloop derived types using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; - using TiledMma0 = typename CollectiveMainloop::TiledMma0; using TiledMma1 = typename CollectiveMainloop::TiledMma1; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; @@ -69,8 +70,8 @@ class FlashAttnFwdSm90 { using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma0{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma0{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma1{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma1{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -217,15 +218,18 @@ class FlashAttnFwdSm90 { if constexpr (Use_TMA_KV) { pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; pipeline_params_k.is_leader = warp_group_thread_idx == 0; - pipeline_params_k.num_consumers = NumMmaThreads; + pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; } else { - pipeline_params_k.consumer_arv_count = NumMmaThreads; + pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; pipeline_params_k.producer_arv_count = NumProducerThreads; } PipelineParamsV pipeline_params_v = pipeline_params_k; if constexpr (Use_TMA_KV && !SameHeadDim) { pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + if constexpr (LargeHeadDimV) { pipeline_params_v.num_consumers = NumMmaThreads; } + } else { + if constexpr (LargeHeadDimV) { pipeline_params_v.consumer_arv_count = NumMmaThreads; } } MainloopPipelineK pipeline_k = [&] { @@ -378,7 +382,7 @@ class FlashAttnFwdSm90 { float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; softmax_scale_log2 *= q_descale * k_descale; } - flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); + flash::Softmax softmax(softmax_scale_log2); SeqlenInfo_t seqlen_info{ bidb, @@ -404,9 +408,22 @@ class FlashAttnFwdSm90 { // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } } } - bool tile_valid = collective_mainloop.mma( - params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, - tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + bool tile_valid; + if constexpr (!LargeHeadDimV) { + tile_valid = collective_mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { // mma1_only might not compile if !LargeHeadDimV + if (warp_group_idx == 1) { + tile_valid = collective_mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { + tile_valid = collective_mainloop.mma1_only( + params.mainloop, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); + } + } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1, diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 0a1bf98a1..67f645e60 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -55,6 +55,7 @@ struct CollectiveMainloopFwdSm90 { static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; + static constexpr bool LargeHeadDimV = kHeadDimV > 256; using SeqlenInfo_t = flash::SeqlenInfoQKNewK; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -66,6 +67,10 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); + static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); + static_assert(!LargeHeadDimV || !Mma1_is_RS, "Mma1 must be SS for large Headdim_V"); + // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. static constexpr bool Mma0_is_RS = false; @@ -74,26 +79,34 @@ struct CollectiveMainloopFwdSm90 { static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); - using AtomLayoutMNK = Layout, _1, _1>>; + using AtomLayoutQK = Layout, _1, _1>>; using TiledMma0 = decltype(cute::make_tiled_mma( std::conditional_t< !Mma0_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, - AtomLayoutMNK{})); + AtomLayoutQK{})); + using AtomLayoutPV = std::conditional_t< + !LargeHeadDimV, + AtomLayoutQK, + Layout, _1>> + >; + using TileShapeAtomPV = Shape, Int, Int>; using TiledMma1 = decltype(cute::make_tiled_mma( std::conditional_t< !Mma1_is_RS, decltype(cute::GMMA::ss_op_selector()), + TileShapeAtomPV, GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector()) + TileShapeAtomPV, GMMA::Major::K, MmaMajorV>()) >{}, - AtomLayoutMNK{})); + AtomLayoutPV{})); - static constexpr int NumMmaThreads = size(TiledMma0{}); + static constexpr int NumMmaThreadsMma0 = size(TiledMma0{}); + static constexpr int NumMmaThreads = size(TiledMma1{}); static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaThreadsMma0 % cutlass::NumThreadsPerWarpGroup == 0); static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -133,6 +146,9 @@ struct CollectiveMainloopFwdSm90 { decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + // Only for LargeHeadDimV where WG0 sends WG1 the scales + using SmemLayoutScale = cute::Layout, Int>>; + using SmemCopyAtomP = Copy_Atom; // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. @@ -251,6 +267,7 @@ struct CollectiveMainloopFwdSm90 { static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". @@ -266,8 +283,19 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentK> smem_k; SmemP_t smem_p; }; + struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemP_t smem_p; + SmemScale_t smem_scale; + }; - using TensorStorageNoTranspose = std::conditional_t; + using TensorStorageNoTranspose = std::conditional_t< + Mma1_is_RS, + TensorStorageWithoutPNoTranspose, + std::conditional_t + >; static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); @@ -277,14 +305,16 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemScale_t smem_scale; }; using TensorStorage = std::conditional_t; // These are tuned for speed. They don't affect correctness. - static constexpr bool UseSchedulerBarrier = IntraWGOverlap + static constexpr bool UseSchedulerBarrier = (IntraWGOverlap ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) - : NumMmaWarpGroups == 2; + : NumMmaWarpGroups == 2) + && !LargeHeadDimV; static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor); // Host side kernel arguments @@ -699,7 +729,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Use_TMA_Q) { // Wait for the MMA warpgroups to signal that smem_q is ready if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0 + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { @@ -708,7 +738,7 @@ struct CollectiveMainloopFwdSm90 { tQgQ, tQsQ); } } else { // Load Q with cp.async - cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0 + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; @@ -830,13 +860,19 @@ struct CollectiveMainloopFwdSm90 { CUTLASS_DEVICE void mma_init() { + int warp_group_idx = flash::canonical_warp_group_idx_nosync(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if (!LargeHeadDimV || warp_group_idx == 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + if (LargeHeadDimV && warp_group_idx > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } if constexpr (UseSchedulerBarrier) { // We have NamedBarrier for up to 3 WGs static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); // WG1 needs the very first signal to start - if (flash::canonical_warp_group_idx_nosync() == 1) { + if (warp_group_idx == 1) { cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); } } @@ -883,6 +919,13 @@ struct CollectiveMainloopFwdSm90 { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); } }(); + Tensor sScale = [&] { + if constexpr (LargeHeadDimV) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + } else { // won't be used, just a placeholder + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutScale{}); + } + }(); if constexpr (!Mma0_is_RS) { static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and @@ -891,7 +934,7 @@ struct CollectiveMainloopFwdSm90 { size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); } - constexpr int MmaWarpGroups = size(TiledMma0{}) / cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaWarpGroups = size(TiledMma1{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); @@ -911,6 +954,21 @@ struct CollectiveMainloopFwdSm90 { Tensor tOsP = wg_mma1.partition_fragment_A(sP); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); + // For storing scales to smem, only used when LargeHeadDimV + auto thread_mma1 = tiled_mma1.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma1.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto store_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + if (get<1>(taccOcO_row(_0{})) == 0) { + sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi); + } + } + }; + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -947,7 +1005,7 @@ struct CollectiveMainloopFwdSm90 { } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; - using Rotary_t = Rotary; + using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); @@ -970,7 +1028,7 @@ struct CollectiveMainloopFwdSm90 { } // SMEM fence to make sure the rotated Q is visible to GMMA cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); } else { barrier_Q.wait(work_idx % 2); } @@ -996,6 +1054,8 @@ struct CollectiveMainloopFwdSm90 { mask.template apply(tSrS, m_block, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); + // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f + softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); @@ -1003,9 +1063,15 @@ struct CollectiveMainloopFwdSm90 { convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!Mma1_is_RS) { + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); cutlass::arch::fence_view_async_shared(); __syncwarp(); // Only need syncwarp since each warp is using its own P values for Mma1 + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } } --n_block; @@ -1027,17 +1093,24 @@ struct CollectiveMainloopFwdSm90 { scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); cute::copy(softmax.template max_get_scale(tSrS), scores_scale); + if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } softmax.template online_softmax(tSrS); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } if constexpr (!Mma1_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!Mma1_is_RS) { cutlass::arch::fence_view_async_shared(); __syncwarp(); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } } }; @@ -1077,12 +1150,17 @@ struct CollectiveMainloopFwdSm90 { // } } // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } consumer_wait(pipeline_v, smem_pipe_read); flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang softmax.rescale_o(tOrO, scores_scale); @@ -1158,7 +1236,7 @@ struct CollectiveMainloopFwdSm90 { } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); softmax.rescale_o(tOrO, scores_scale); @@ -1168,6 +1246,92 @@ struct CollectiveMainloopFwdSm90 { return true; } + template + CUTLASS_DEVICE bool + mma1_only(Params const& params, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + static constexpr int MmaWarpGroups = size(TiledMma1{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMma1 tiled_mma1; + auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate "fragments/descriptors" + Tensor tOrV = wg_mma1.partition_fragment_B(sV); + Tensor tOsP = wg_mma1.partition_fragment_A(sP); + + // For load scales to smem, pretend thread_idx is thread_idx % 128 + auto thread_mma1 = tiled_mma1.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); + Tensor taccOcO = thread_mma1.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto load_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage); + } + }; + + clear(tOrO); + // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + + typename Softmax::TensorT scores_scale; + + int n_block = n_block_max - 1; + pipeline_v.consumer_wait(smem_pipe_read); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + flash::gemm(tiled_mma1, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + --n_block; + + for (; n_block >= n_block_min; --n_block) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + softmax.rescale_o(tOrO, scores_scale); + ++smem_pipe_read; + auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); + pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + flash::gemm(tiled_mma1, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + }; + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + // if (thread_idx == 128) { print_tensor(scores_scale); } + // if (thread_idx == 128) { print_tensor(sScale); } + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + return true; + } + CUTLASS_DEVICE cute::tuple get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, int m_block, int bidb, int split_idx=0, int num_splits=1) { diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp index f77ea7782..8d07f6aa2 100644 --- a/hopper/named_barrier.hpp +++ b/hopper/named_barrier.hpp @@ -57,6 +57,8 @@ enum class FwdNamedBarriers { WarpSchedulerWG3 = 6, AppendKV = 7, QueryRotated = 8, + PFull = 9, + PEmpty = 6, // HACK: PEmpty is only used when we don't have 3 WGs }; enum class BwdNamedBarriers { From 4c8819d8c68e8021cb82cf5b2df38c5eb5340531 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 13:55:44 -0500 Subject: [PATCH 015/102] Rename Mma0->MmaQK, Mma1->MmaPV, use Cluster only if hdimV >= 192 --- hopper/flash_fwd_kernel_sm90.h | 16 +-- hopper/flash_fwd_launch_template.h | 6 +- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 162 +++++++++++------------ hopper/tile_size.h | 7 +- 4 files changed, 96 insertions(+), 95 deletions(-) diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 5e1dceb09..aad099bd3 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -52,7 +52,7 @@ class FlashAttnFwdSm90 { // Mainloop derived types using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; - using TiledMma1 = typename CollectiveMainloop::TiledMma1; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; @@ -70,8 +70,8 @@ class FlashAttnFwdSm90 { using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma1{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma1{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -354,7 +354,7 @@ class FlashAttnFwdSm90 { TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); // Initialize matmul objects. - TiledMma1 tiled_mma1; + TiledMmaPV tiled_mma_pv; PipelineState smem_pipe_read; PipelineState smem_pipe_read_new; @@ -370,7 +370,7 @@ class FlashAttnFwdSm90 { work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 1>(TileShape_MNK_PV{})); + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; // If there's tanh softcap, the scaling will be done before tanh. auto block_coord = work_tile_info.get_block_coord(params.scheduler); @@ -413,20 +413,20 @@ class FlashAttnFwdSm90 { tile_valid = collective_mainloop.mma( params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); - } else { // mma1_only might not compile if !LargeHeadDimV + } else { // mma_pv might not compile if !LargeHeadDimV if (warp_group_idx == 1) { tile_valid = collective_mainloop.mma( params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); } else { - tile_valid = collective_mainloop.mma1_only( + tile_valid = collective_mainloop.mma_pv( params.mainloop, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); } } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, threadIdx.x - MmaThreadOffset, block_coord); } else { // Write 0 to gO and -inf to gLSE. diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index de17b39c9..f8a98a08f 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -39,7 +39,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); - static constexpr bool Mma1_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); + static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); @@ -50,7 +50,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; @@ -194,7 +194,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDimV >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 67f645e60..c4911c359 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -69,20 +69,20 @@ struct CollectiveMainloopFwdSm90 { static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); - static_assert(!LargeHeadDimV || !Mma1_is_RS, "Mma1 must be SS for large Headdim_V"); + static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. - static constexpr bool Mma0_is_RS = false; - // We can have Mma1 (P @ V) with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is disabled"); - static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); - static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); + static constexpr bool MmaQK_is_RS = false; + // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. + static_assert(!(!MmaPV_is_RS && !IntraWGOverlap), "MmaPV must be RS if IntraWGOverlap is disabled"); + static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); + static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); using AtomLayoutQK = Layout, _1, _1>>; - using TiledMma0 = decltype(cute::make_tiled_mma( + using TiledMmaQK = decltype(cute::make_tiled_mma( std::conditional_t< - !Mma0_is_RS, + !MmaQK_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, @@ -93,9 +93,9 @@ struct CollectiveMainloopFwdSm90 { Layout, _1>> >; using TileShapeAtomPV = Shape, Int, Int>; - using TiledMma1 = decltype(cute::make_tiled_mma( + using TiledMmaPV = decltype(cute::make_tiled_mma( std::conditional_t< - !Mma1_is_RS, + !MmaPV_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector{}, AtomLayoutPV{})); - static constexpr int NumMmaThreadsMma0 = size(TiledMma0{}); - static constexpr int NumMmaThreads = size(TiledMma1{}); + static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); + static constexpr int NumMmaThreads = size(TiledMmaPV{}); static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; - static_assert(NumMmaThreadsMma0 % cutlass::NumThreadsPerWarpGroup == 0); + static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -259,14 +259,14 @@ struct CollectiveMainloopFwdSm90 { // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned // and have sQ being position_independent_swizzle_tensor. // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. - static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !Mma0_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); - using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". @@ -292,7 +292,7 @@ struct CollectiveMainloopFwdSm90 { }; using TensorStorageNoTranspose = std::conditional_t< - Mma1_is_RS, + MmaPV_is_RS, TensorStorageWithoutPNoTranspose, std::conditional_t >; @@ -729,7 +729,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Use_TMA_Q) { // Wait for the MMA warpgroups to signal that smem_q is ready if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0 + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { @@ -738,7 +738,7 @@ struct CollectiveMainloopFwdSm90 { tQgQ, tQsQ); } } else { // Load Q with cp.async - cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0 + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; @@ -863,7 +863,7 @@ struct CollectiveMainloopFwdSm90 { int warp_group_idx = flash::canonical_warp_group_idx_nosync(); // Tell producers that smem_q is ready if (!LargeHeadDimV || warp_group_idx == 1) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if (LargeHeadDimV && warp_group_idx > 1) { cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); @@ -912,8 +912,8 @@ struct CollectiveMainloopFwdSm90 { Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); Tensor sP = [&] { - if constexpr (Mma1_is_RS) { - // We might not have smem_p if !Mma1_is_RS1, just use smem_q as a placeholder since we don't use it + if constexpr (MmaPV_is_RS) { + // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); } else { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); @@ -927,36 +927,36 @@ struct CollectiveMainloopFwdSm90 { } }(); - if constexpr (!Mma0_is_RS) { - static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and - stride<0>(typename TiledMma0::BLayout{}) == 0 and - size<0>(typename TiledMma0::ALayout{}) == cutlass::NumThreadsPerWarpGroup and - size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + if constexpr (!MmaQK_is_RS) { + static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and + stride<0>(typename TiledMmaQK::BLayout{}) == 0 and + size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); } - static constexpr int MmaWarpGroups = size(TiledMma1{}) / cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMma0 tiled_mma0; - TiledMma1 tiled_mma1; - auto wg_mma0 = tiled_mma0.get_slice(warp_group_thread_layout(warp_group_idx)); - auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); - auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma0); + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); // Allocate "fragments/descriptors" - Tensor tSrQ = wg_mma0.partition_fragment_A(sQ); - Tensor tSrK = wg_mma0.partition_fragment_B(sK); - Tensor tOrV = wg_mma1.partition_fragment_B(sV); - Tensor tOsP = wg_mma1.partition_fragment_A(sP); + Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ); + Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); // For storing scales to smem, only used when LargeHeadDimV - auto thread_mma1 = tiled_mma1.get_thread_slice(thread_idx); - Tensor taccOcO = thread_mma1.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); auto store_scales = [&](auto& scales, int stage) { @@ -976,13 +976,13 @@ struct CollectiveMainloopFwdSm90 { // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter clear(tOrO); - // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int n_block = n_block_max - 1; - flash::Mask mask( + flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, params.qhead_per_khead_divmod ); @@ -1005,7 +1005,7 @@ struct CollectiveMainloopFwdSm90 { } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; - using Rotary_t = Rotary; + using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); @@ -1028,15 +1028,15 @@ struct CollectiveMainloopFwdSm90 { } // SMEM fence to make sure the rotated Q is visible to GMMA cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); } else { barrier_Q.wait(work_idx % 2); } } - if constexpr (Mma0_is_RS) { + if constexpr (MmaQK_is_RS) { using SmemCopyAtomQ = Copy_Atom; - auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma0); + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ)); @@ -1045,9 +1045,9 @@ struct CollectiveMainloopFwdSm90 { // TODO: check the case where n_block_max <= n_block_min but there are sink tokens if constexpr (IntraWGOverlap) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); scoremod_premask_fn(tSrS); @@ -1058,17 +1058,17 @@ struct CollectiveMainloopFwdSm90 { softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!Mma1_is_RS) { + if constexpr (!MmaPV_is_RS) { if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); } cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); cutlass::arch::fence_view_async_shared(); - __syncwarp(); // Only need syncwarp since each warp is using its own P values for Mma1 + __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); } @@ -1080,13 +1080,13 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Check_inf = decltype(check_inf_type)::value; PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); ++smem_pipe_read; - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } - flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1103,9 +1103,9 @@ struct CollectiveMainloopFwdSm90 { if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); } - if constexpr (!Mma1_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } + if constexpr (!MmaPV_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!Mma1_is_RS) { + if constexpr (!MmaPV_is_RS) { cutlass::arch::fence_view_async_shared(); __syncwarp(); if constexpr (LargeHeadDimV) { @@ -1150,10 +1150,10 @@ struct CollectiveMainloopFwdSm90 { // } } // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); if constexpr (LargeHeadDimV) { @@ -1174,9 +1174,9 @@ struct CollectiveMainloopFwdSm90 { auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warp_scheduler_barrier_arrive(); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1185,14 +1185,14 @@ struct CollectiveMainloopFwdSm90 { Tensor scores_scale = softmax.template max_get_scale(tSrS); softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } consumer_wait(pipeline_v, smem_pipe_read); warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); pipeline_v.consumer_release(smem_pipe_read); // release V ++smem_pipe_read; }; @@ -1236,7 +1236,7 @@ struct CollectiveMainloopFwdSm90 { } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); softmax.rescale_o(tOrO, scores_scale); @@ -1248,16 +1248,16 @@ struct CollectiveMainloopFwdSm90 { template CUTLASS_DEVICE bool - mma1_only(Params const& params, - MainloopPipelineV pipeline_v, - PipelineState& smem_pipe_read, - FrgTensorO& tOrO, - Softmax& softmax, - int const thread_idx, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, - SharedStorage& shared_storage - ) { + mma_pv(Params const& params, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda int const m_block = get<0>(block_coord); @@ -1272,21 +1272,21 @@ struct CollectiveMainloopFwdSm90 { Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); - static constexpr int MmaWarpGroups = size(TiledMma1{}) / cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMma1 tiled_mma1; - auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); + TiledMmaPV tiled_mma_pv; + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); // Allocate "fragments/descriptors" - Tensor tOrV = wg_mma1.partition_fragment_B(sV); - Tensor tOsP = wg_mma1.partition_fragment_A(sP); + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); // For load scales to smem, pretend thread_idx is thread_idx % 128 - auto thread_mma1 = tiled_mma1.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); - Tensor taccOcO = thread_mma1.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); auto load_scales = [&](auto& scales, int stage) { @@ -1298,14 +1298,14 @@ struct CollectiveMainloopFwdSm90 { }; clear(tOrO); - // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; typename Softmax::TensorT scores_scale; int n_block = n_block_max - 1; pipeline_v.consumer_wait(smem_pipe_read); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - flash::gemm(tiled_mma1, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; @@ -1317,7 +1317,7 @@ struct CollectiveMainloopFwdSm90 { ++smem_pipe_read; auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); pipeline_v.consumer_wait(smem_pipe_read, barrier_token); - flash::gemm(tiled_mma1, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V }; diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 66ab1a7fd..997664bcb 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -6,13 +6,14 @@ #include -// Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap} +// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { - return {192, 128, true, true}; + bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 + return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { @@ -20,7 +21,7 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 128) { return {128, is_causal || is_local || paged_kv ? 128 : 176, true, true}; // {128, 192, false, false} and {192, 128, false, true} are quite good too - // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS + // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { return {128, paged_kv || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { From dd876913f435b3349cb15a2d83b7b5af366f0acd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 21:52:35 -0500 Subject: [PATCH 016/102] Pass _1 or _0 to cute::aligned_struct --- hopper/flash_fwd_kernel_sm90.h | 5 ++--- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 17 +++++++---------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index aad099bd3..b6ab92e0b 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -90,7 +90,7 @@ class FlashAttnFwdSm90 { static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _1> { union { struct { cute::array padding_; @@ -100,8 +100,7 @@ class FlashAttnFwdSm90 { typename CollectiveEpilogue::TensorStorage epilogue; }; } tensors; - - struct PipelineStorage : cute::aligned_struct<16> { + struct PipelineStorage : cute::aligned_struct<16, _1> { alignas(16) BarrierQ barrier_Q; alignas(16) cutlass::arch::ClusterBarrier barrier_O; alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index c4911c359..797c88d64 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -92,14 +92,13 @@ struct CollectiveMainloopFwdSm90 { AtomLayoutQK, Layout, _1>> >; - using TileShapeAtomPV = Shape, Int, Int>; using TiledMmaPV = decltype(cute::make_tiled_mma( std::conditional_t< !MmaPV_is_RS, decltype(cute::GMMA::ss_op_selector()), + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector()) + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()) >{}, AtomLayoutPV{})); @@ -259,7 +258,7 @@ struct CollectiveMainloopFwdSm90 { // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned // and have sQ being position_independent_swizzle_tensor. // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. - static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); @@ -271,19 +270,19 @@ struct CollectiveMainloopFwdSm90 { // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". - struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { + struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; }; - struct TensorStorageWithPNoTranspose : cute::aligned_struct { + struct TensorStorageWithPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; SmemP_t smem_p; }; - struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { + struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; @@ -300,7 +299,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment"); - struct TensorStorageTransposeV : cute::aligned_struct { + struct TensorStorageTransposeV : cute::aligned_struct { cute::array_aligned, SmemAlignmentV> smem_v; cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; @@ -1324,8 +1323,6 @@ struct CollectiveMainloopFwdSm90 { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); load_scales(scores_scale, smem_pipe_read.index()); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - // if (thread_idx == 128) { print_tensor(scores_scale); } - // if (thread_idx == 128) { print_tensor(sScale); } softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } ++smem_pipe_read; From ed53b5fc4c3b01a6d98d747f21380e444056e042 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 22:25:17 -0500 Subject: [PATCH 017/102] Fix compilation for FP8 when kHeadDimV != kHeadDim --- hopper/flash_api.cpp | 4 ++++ hopper/flash_fwd_kernel_sm90.h | 22 +++++++++++----------- hopper/flash_fwd_launch_template.h | 2 +- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 94fcf5d78..7fd8dfc3e 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -592,6 +592,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (head_size_v != head_size) { TORCH_CHECK(head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128, "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]"); TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index b6ab92e0b..aeb81977c 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -223,12 +223,13 @@ class FlashAttnFwdSm90 { pipeline_params_k.producer_arv_count = NumProducerThreads; } - PipelineParamsV pipeline_params_v = pipeline_params_k; + static_assert(is_same_v); + PipelineParamsVt pipeline_params_vt = pipeline_params_k; if constexpr (Use_TMA_KV && !SameHeadDim) { - pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; - if constexpr (LargeHeadDimV) { pipeline_params_v.num_consumers = NumMmaThreads; } + pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; } } else { - if constexpr (LargeHeadDimV) { pipeline_params_v.consumer_arv_count = NumMmaThreads; } + if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; } } MainloopPipelineK pipeline_k = [&] { @@ -243,9 +244,9 @@ class FlashAttnFwdSm90 { if constexpr (!Transpose_V) { static_assert(is_same_v); if constexpr (Use_TMA_KV) { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{}); } else { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt); } } else { PipelineParamsV pipeline_params_v; @@ -257,7 +258,6 @@ class FlashAttnFwdSm90 { return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); } }(); - static_assert(is_same_v); // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then // the producer WG will read from pipeline_vt and write to pipeline_v. // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. @@ -265,11 +265,11 @@ class FlashAttnFwdSm90 { // However, the thread role isn't used in the pipeline implementation. MainloopPipelineVt pipeline_vt = [&] { if constexpr (Use_TMA_KV) { - pipeline_params_v.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_v, ClusterShape{}); + pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{}); } else { - pipeline_params_v.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_v); + pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt); } }(); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index f8a98a08f..118ccb26b 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -194,7 +194,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDimV >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { From 4e8496a78179416ea18ae111508dfa4341dc1e37 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 22:55:07 -0500 Subject: [PATCH 018/102] Support Qv --- hopper/flash.h | 5 + hopper/flash_api.cpp | 25 +++++ hopper/flash_attn_interface.py | 29 +++-- hopper/flash_fwd_kernel_sm90.h | 5 + hopper/flash_fwd_launch_template.h | 19 ++-- hopper/mainloop_fwd_sm80.hpp | 2 + hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 133 +++++++++++++++++++++-- hopper/test_flash_attn.py | 59 ++++++++-- hopper/test_util.py | 22 +++- 9 files changed, 260 insertions(+), 39 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index 9f8cb1bca..9cce795b7 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -104,6 +104,11 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride; index_t vnew_head_stride; + void *__restrict__ qv_ptr; + index_t qv_batch_stride; + index_t qv_row_stride; + index_t qv_head_stride; + // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 7fd8dfc3e..54ec78bce 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -487,6 +487,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional &cu_seqlens_q_, // b+1 std::optional &cu_seqlens_k_, // b+1 @@ -765,6 +766,30 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } + if (q_v_.has_value()) { + TORCH_CHECK(false, "q_v should be None for now"); + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 5f1e4899c..adee1a0ff 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -22,6 +22,7 @@ def _flash_attn_forward( v, k_new, v_new, + qv, out, cu_seqlens_q, cu_seqlens_k, @@ -64,6 +65,7 @@ def _flash_attn_forward( v, k_new, v_new, + qv, out, cu_seqlens_q, cu_seqlens_k, @@ -239,6 +241,7 @@ def forward( v, softmax_scale, causal, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -249,13 +252,14 @@ def forward( sm_margin=0, ): if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new + qv, # qv None, # out None, None, None, # cu_seqlens_q/k/k_new None, None, # seqused_q/k @@ -311,7 +315,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -330,6 +334,7 @@ def forward( max_seqlen_k, softmax_scale, causal, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -340,13 +345,14 @@ def forward( sm_margin=0, ): if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new + qv, # qv None, # out cu_seqlens_q, cu_seqlens_k, @@ -411,7 +417,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -478,6 +484,7 @@ def flash_attn_func( v, softmax_scale=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -538,6 +545,7 @@ def flash_attn_func( v, softmax_scale, causal, + qv, q_descale, k_descale, v_descale, window_size, sink_token_length, @@ -561,6 +569,7 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -582,6 +591,7 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale, causal, + qv, q_descale, k_descale, v_descale, window_size, sink_token_length, @@ -603,6 +613,7 @@ def flash_attn_with_kvcache( v_cache, k=None, v=None, + qv=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, @@ -673,11 +684,12 @@ def flash_attn_with_kvcache( k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no _table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. @@ -714,7 +726,7 @@ def flash_attn_with_kvcache( assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device @@ -726,6 +738,7 @@ def flash_attn_with_kvcache( v_cache, k, v, + qv, None, # out cu_seqlens_q, None, # cu_seqlens_k diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index aeb81977c..c7fec6df5 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -40,6 +40,7 @@ class FlashAttnFwdSm90 { static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool HasQv = CollectiveMainloop::HasQv; static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; @@ -102,6 +103,7 @@ class FlashAttnFwdSm90 { } tensors; struct PipelineStorage : cute::aligned_struct<16, _1> { alignas(16) BarrierQ barrier_Q; + alignas(16) BarrierQ barrier_Qv; alignas(16) cutlass::arch::ClusterBarrier barrier_O; alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; @@ -206,6 +208,9 @@ class FlashAttnFwdSm90 { if (warp_idx == 0 && lane_predicate) { shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + } shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); } diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 118ccb26b..b4f80a04e 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -24,7 +24,7 @@ using namespace cute; template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); @@ -50,7 +50,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; @@ -101,6 +101,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new static_cast(params.vnew_ptr), {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new + static_cast(params.qv_ptr), + {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv static_cast(params.rotary_cos_ptr), {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter {params.rotary_dim / 2, _1{}}, // stride_rotary_cos @@ -195,11 +197,14 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and false; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 2d2ba06f2..0fb32c7a9 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -185,6 +185,8 @@ struct CollectiveMainloopFwdSm80 { StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 797c88d64..1834f200c 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -28,14 +28,15 @@ namespace flash { using namespace cute; template + bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool PagedKV_, bool AppendKV_, bool HasQv_, + bool MmaPV_is_RS, bool IntraWGOverlap, bool PackGQA_, bool Split_, bool V_colmajor_> struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; + using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -46,6 +47,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Varlen = Varlen_; static constexpr bool PagedKV = PagedKV_; static constexpr bool AppendKV = AppendKV_; + static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool V_colmajor = V_colmajor_; @@ -70,6 +72,7 @@ struct CollectiveMainloopFwdSm90 { static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); + static_assert(!(HasQv && !IntraWGOverlap), "HasQv requires IntraWGOverlap"); // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. @@ -101,6 +104,9 @@ struct CollectiveMainloopFwdSm90 { TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()) >{}, AtomLayoutPV{})); + using TiledMmaQV = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutQK{})); static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); static constexpr int NumMmaThreads = size(TiledMmaPV{}); @@ -134,6 +140,16 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); + using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{}))); + using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutVMmaQV = decltype(tile_to_shape( + SmemLayoutAtomVMmaQV{}, + make_shape(shape<1>(TileShape_MNK_QV{}), shape<2>(TileShape_MNK_QV{}), Int{}))); + static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); + // Only used if we're using cp.async to load V using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int>()); @@ -242,10 +258,19 @@ struct CollectiveMainloopFwdSm90 { select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Qv_ = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{})); + using TMA_Qv = std::conditional_t; + // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesQv = static_cast(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v / 8); using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; using MainloopPipelineK = std::conditional_t>; @@ -261,12 +286,14 @@ struct CollectiveMainloopFwdSm90 { static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; + using SmemQv_t = std::conditional_t, cute::array_aligned, SmemAlignmentQv>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". @@ -274,18 +301,21 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; }; struct TensorStorageWithPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; SmemP_t smem_p; }; struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; SmemP_t smem_p; SmemScale_t smem_scale; }; @@ -304,6 +334,7 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; SmemScale_t smem_scale; }; @@ -332,6 +363,8 @@ struct CollectiveMainloopFwdSm90 { StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -374,6 +407,10 @@ struct CollectiveMainloopFwdSm90 { StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideV const stride_Qv; + ShapeQPacked const shape_Qv_packed; + StrideQPacked const stride_Qv_packed; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -390,6 +427,7 @@ struct CollectiveMainloopFwdSm90 { TMA_V tma_load_V; TMA_K tma_load_K_new; TMA_V tma_load_V_new; + TMA_Qv tma_load_Qv; float const softmax_scale_log2; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; @@ -446,6 +484,20 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutVt{}), select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); + Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); + TMA_Qv tma_load_Qv = [&] { + if constexpr (HasQv) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQv, + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{}); // no mcast for Qv + } else { + return nullptr; + } + }(); // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); auto const shape_Q_packed = cute::conditional_return( @@ -456,6 +508,14 @@ struct CollectiveMainloopFwdSm90 { args.stride_Q, make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) ); + auto const shape_Qv_packed = cute::conditional_return( + shape_Qv, + make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) + ); + auto const stride_Qv_packed = cute::conditional_return( + args.stride_Qv, + make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv)) + ); if (get<1>(args.shape_rotary) > 0) { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } @@ -468,12 +528,13 @@ struct CollectiveMainloopFwdSm90 { return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, cutlass::FastDivmod(int(get<0>(args.shape_K))), cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), - tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, + tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, @@ -490,6 +551,9 @@ struct CollectiveMainloopFwdSm90 { static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA_Q) { cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + if constexpr (HasQv) { + cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor()); + } } if constexpr (Use_TMA_KV) { cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); @@ -546,7 +610,11 @@ struct CollectiveMainloopFwdSm90 { int &work_idx ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; + // some of these are captured in lambda so can't use structured binding + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen || Split) { @@ -578,6 +646,7 @@ struct CollectiveMainloopFwdSm90 { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); } }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); int const thread_idx = threadIdx.x % NumProducerThreads; int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; @@ -610,6 +679,19 @@ struct CollectiveMainloopFwdSm90 { auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) + auto [tQvgQv, tQvsQv] = [&] { + if constexpr (HasQv) { + auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); + Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) + auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); + Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) + Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) + return cute::make_tuple(tQvgQv, tQvsQv); + } else { + return cute::make_tuple(nullptr, nullptr); + } + }(); using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( @@ -735,6 +817,11 @@ struct CollectiveMainloopFwdSm90 { shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), tQgQ, tQsQ); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); + copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQvgQv, tQvsQv); + } } } else { // Load Q with cp.async cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); @@ -745,6 +832,15 @@ struct CollectiveMainloopFwdSm90 { auto &barrier_Q = shared_storage.pipelines.barrier_Q; cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); barrier_Q.arrive(); + if constexpr (HasQv) { + Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); + using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); + barrier_Qv.arrive(); + } } // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem @@ -925,6 +1021,8 @@ struct CollectiveMainloopFwdSm90 { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutScale{}); } }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); + Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{}); if constexpr (!MmaQK_is_RS) { static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and @@ -940,8 +1038,10 @@ struct CollectiveMainloopFwdSm90 { int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); TiledMmaQK tiled_mma_qk; TiledMmaPV tiled_mma_pv; + TiledMmaQV tiled_mma_qv; auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx)); auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); @@ -951,6 +1051,8 @@ struct CollectiveMainloopFwdSm90 { Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv); + Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); // For storing scales to smem, only used when LargeHeadDimV @@ -1049,6 +1151,11 @@ struct CollectiveMainloopFwdSm90 { flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } scoremod_premask_fn(tSrS); mask.template apply(tSrS, m_block, n_block); @@ -1084,18 +1191,28 @@ struct CollectiveMainloopFwdSm90 { warp_scheduler_barrier_sync(); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + if constexpr(!HasQv) { + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + } flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K + if constexpr (HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); cute::copy(softmax.template max_get_scale(tSrS), scores_scale); if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } softmax.template online_softmax(tSrS); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + } if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } @@ -1151,7 +1268,7 @@ struct CollectiveMainloopFwdSm90 { // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - consumer_wait(pipeline_v, smem_pipe_read); + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index d0590b5f1..6d5d8f8e2 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -50,6 +50,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -96,7 +98,7 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): # sink_token_length = 0 if not local else 4 sink_token_length = 0 if not local else 0 @@ -121,6 +123,10 @@ def test_flash_attn_output( q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) # window_size = (-1, -1) if not local else (16, 0) @@ -129,6 +135,7 @@ def test_flash_attn_output( else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None if V_colmajor: v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() out_ref, attn_ref = attention_ref( @@ -138,6 +145,7 @@ def test_flash_attn_output( None, None, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, @@ -150,6 +158,7 @@ def test_flash_attn_output( None, None, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, @@ -160,6 +169,8 @@ def test_flash_attn_output( ) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) @@ -180,6 +191,7 @@ def test_flash_attn_output( k, v, causal=causal, + qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, @@ -197,7 +209,7 @@ def test_flash_attn_output( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: g = torch.randn_like(out) do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) # import flash_attn_3_cuda @@ -249,7 +261,7 @@ def test_flash_attn_output( # breakpoint() - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) @@ -264,6 +276,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -308,7 +322,7 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype ): device = "cuda" # set seed @@ -329,6 +343,10 @@ def test_flash_attn_varlen_output( q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) if dtype == torch.float8_e4m3fn: @@ -336,6 +354,7 @@ def test_flash_attn_varlen_output( else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None query_padding_mask = generate_random_padding_mask( seqlen_q, batch_size, device, mode="random", zero_lengths=False ) @@ -366,6 +385,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q_unpad, k_unpad, v_unpad, + qv_unpad, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -375,10 +395,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q, k, v, + qv, output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] out_ref, attn_ref = attention_ref( @@ -388,6 +409,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): query_padding_mask, key_padding_mask, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap @@ -399,6 +421,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): query_padding_mask, key_padding_mask, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, @@ -431,6 +454,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): max_seqlen_q, max_seqlen_k, causal=causal, + qv=qv_unpad, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, @@ -450,7 +474,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda @@ -518,7 +542,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) @@ -554,7 +578,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [True]) +# @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -572,8 +596,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): (3, 799), (64, 2048), (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), + # (1, 128 * 1024), + # (16, 128 * 1024), (128, 128), (256, 512), # To test appending KV with more than 1 block (2048, 3577), # Enough tile to test persistent scheduler @@ -617,17 +641,25 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - for dv in [128, d] if d > 128 and d <= 192 else [d]: + dv_vals = [128, d] if d > 128 and d <= 192 else [d] + has_qv_vals = [False] + for dv, has_qv in itertools.product(dv_vals, has_qv_vals): q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None if varlen_q: query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None else: query_padding_mask = None q_unpad = q + qv_unpad = qv cu_seqlens_q, max_seqlen_q = None, None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) @@ -755,6 +787,7 @@ def test_flash_attn_kvcache( query_padding_mask, key_padding_mask, causal=causal, + qv=qv, window_size=window_size, key_leftpad=cache_leftpad, ) @@ -765,6 +798,7 @@ def test_flash_attn_kvcache( query_padding_mask, key_padding_mask, causal=causal, + qv=qv, window_size=window_size, upcast=False, reorder_ops=True, @@ -781,6 +815,8 @@ def test_flash_attn_kvcache( v = v.to(dtype) if v is not None else None k_unpad = k_unpad.to(dtype) if k_unpad is not None else None v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None cos = cos.to(dtype) if cos is not None else None sin = sin.to(dtype) if sin is not None else None out, lse, *rest = flash_attn_with_kvcache( @@ -789,6 +825,7 @@ def test_flash_attn_kvcache( v_cache if page_size is None else v_cache_paged, k if not new_kv or not varlen_q else k_unpad, v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, rotary_cos=cos, rotary_sin=sin, cache_seqlens=cache_seqlens, diff --git a/hopper/test_util.py b/hopper/test_util.py index cbf441031..b7ea3d3b7 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -30,7 +30,7 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", def generate_qkv( - q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False, + q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, query_unused_mask=None, key_unused_mask=None, ): """ @@ -58,6 +58,7 @@ def generate_qkv( output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( @@ -68,6 +69,7 @@ def generate_qkv( output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( @@ -135,6 +137,7 @@ def generate_qkv( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -144,6 +147,7 @@ def generate_qkv( q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), + qv.detach() if qv is not None else None, output_pad_fn, dq_pad_fn, dk_pad_fn, @@ -197,6 +201,7 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), # -1 means infinite window size sink_token_length=0, @@ -210,6 +215,7 @@ def attention_ref( q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads, head_dim) v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) @@ -230,9 +236,11 @@ def attention_ref( dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None if q_descale is not None: - q_descale = repeat(q_descale, "b h -> b (h g)", g = q.shape[2] // k.shape[2]) - q = (q.float() * rearrange(q_descale, "b h -> b 1 h 1")).to(dtype=q.dtype) + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g = q.shape[2] // k.shape[2]).to(dtype=q.dtype) + q = q.float() * q_descale + qv = qv.float() * q_descale if qv is not None else None if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: @@ -241,10 +249,14 @@ def attention_ref( k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) if softcap > 0: scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: From 893a22ab5703ab3d61eda256f3a9a73a66b4444c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 23:04:20 -0500 Subject: [PATCH 019/102] Test varlen_q=True by default for kvcache --- hopper/test_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 6d5d8f8e2..e9cd8c9d6 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -578,7 +578,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) From 5fab938555597b5e6150b16b190415d3420b1c67 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Feb 2025 01:25:09 -0500 Subject: [PATCH 020/102] Fix num_splits heuristic being called before get_pack_gqa --- hopper/flash_api.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 54ec78bce..6820f9341 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -715,8 +715,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.page_size = page_size; params.num_pages = num_pages; - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + // get_num_splits need params.pack_gqa to decide + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; if (k_new_.has_value()) { at::Tensor k_new, v_new; From 5fc5ebf82b27adc47ffb364a3e0c654fc266a321 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Feb 2025 16:21:44 -0500 Subject: [PATCH 021/102] Fix num_splits heuristic again when PackGQA --- hopper/flash_api.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 6820f9341..402e1a6aa 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -429,7 +429,8 @@ inline int get_num_splits(Flash_fwd_params const& params) { : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; - return num_splits_heuristic(params.b * (!params.pack_gqa ? params.h : params.h_k) * num_m_blocks, params.num_sm, num_n_blocks, 128); + // Always enable PackGQA for Split + return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.num_sm, num_n_blocks, 128); // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k, // params.num_sm, num_n_blocks, 128, params.d_rounded); #endif @@ -715,9 +716,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.page_size = page_size; params.num_pages = num_pages; - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - // get_num_splits need params.pack_gqa to decide params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); if (k_new_.has_value()) { at::Tensor k_new, v_new; From 5378bc3204bf9a2d959f1c66fe2f9bf60d582b43 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Feb 2025 16:30:51 -0500 Subject: [PATCH 022/102] Tile fwd_combine kernel along headdim, don't need kBlockM > 128 --- hopper/flash.h | 2 +- hopper/flash_api.cpp | 18 +----- hopper/flash_fwd_combine.cu | 6 -- hopper/flash_fwd_combine_kernel.h | 65 +++++++++++++--------- hopper/flash_fwd_combine_launch_template.h | 23 ++++---- 5 files changed, 55 insertions(+), 59 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index 9cce795b7..8e95f5ff7 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -207,5 +207,5 @@ template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 402e1a6aa..7dad5b9c7 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -359,32 +359,20 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_fp32) { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 128) { - run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 256) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } else if (params.is_bf16) { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 128) { - run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 256) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } else { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 128) { - run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 256) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } #else diff --git a/hopper/flash_fwd_combine.cu b/hopper/flash_fwd_combine.cu index 57392ee75..a1725cf2a 100644 --- a/hopper/flash_fwd_combine.cu +++ b/hopper/flash_fwd_combine.cu @@ -5,15 +5,9 @@ template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index aaec31e58..20685a156 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -40,11 +40,11 @@ class FlashAttnFwdCombine { static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int kBlockM = get<0>(TileShape_MK{}); - static constexpr int kHeadDim = get<1>(TileShape_MK{}); + static constexpr int kBlockK = get<1>(TileShape_MK{}); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemCopyAtom = std::conditional_t< @@ -98,8 +98,8 @@ class FlashAttnFwdCombine { Stride, _1>>{})); using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); - using SmemLayoutO = Layout, Int, Int>, - Stride, _1, Int>>; + using SmemLayoutO = Layout, Int, Int>, + Stride, _1, Int>>; // We want each column (kMaxSplits) to be processed by threads in the same warp. // To reduce the number of shuffles, we want as few threads on the same column as possible. @@ -194,7 +194,8 @@ class FlashAttnFwdCombine { Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); int const thread_idx = threadIdx.x; - int const m_block = blockIdx.x; + int const k_block = blockIdx.x; + int const m_block = blockIdx.y; int const batch = !Varlen ? 0 : blockIdx.y; int const num_splits = get<1>(params.shape_LSE_partial); flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; @@ -254,7 +255,8 @@ class FlashAttnFwdCombine { Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); - Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), + params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) // Precompute these values to avoid recomputing them in the loop Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); @@ -271,7 +273,7 @@ class FlashAttnFwdCombine { tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); tObidb[m] = 0; } - tOrOptr[m] = &mOpartial(tOmidx(m), _0{}, _0{}, tObidh(m), tObidb(m)); + tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m), tObidb(m)); if (idx >= max_idx) { tObidb[m] = -1; } @@ -280,7 +282,7 @@ class FlashAttnFwdCombine { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); if constexpr (!(Is_even_K)) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial); } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; } } Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); @@ -358,26 +360,36 @@ class FlashAttnFwdCombine { // Store the scales exp(lse - lse_logsum) back to smem cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); - // Step 5: store final LSE back to gmem - auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + // Store max_valid_split to smem #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { - if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); - int idx = m_block * kBlockM + mi; - if (idx < max_idx) { - int m_idx, bidh, bidb; - if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); - } else { - bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; + if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } + } + } + + // Step 5: store final LSE back to gmem + if (k_block == 0) { + auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh, bidb; + if constexpr (!Varlen) { + bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + bidb = 0; + } + // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); + mLSE(m_idx, bidh, bidb) = lse_sum(m); } - // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); - mLSE(m_idx, bidh, bidb) = lse_sum(m); } - if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } } } @@ -427,8 +439,9 @@ class FlashAttnFwdCombine { // Step 7: Write the final O to gmem Tensor rO = make_tensor_like(tOrO); flash::convert_type_out(tOrO, rO); - auto shape_O = select<0, 1, 3, 4>(params.shape_O_partial); - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O)), shape_O, params.stride_O); + auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), + shape_O, params.stride_O); Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 5cbed2b0c..101f894b2 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -16,9 +16,9 @@ using namespace cute; -template +template void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { - using TileShape_MK = cute::Shape, Int>; + using TileShape_MK = cute::Shape, Int>; using CombineKernel = flash::FlashAttnFwdCombine; @@ -37,8 +37,9 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); + int num_blocks_k = cute::ceil_div(params.dv, kBlockK); int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM); - dim3 grid_m(num_blocks_m, !Varlen ? 1 : params.b); + dim3 grid_m(num_blocks_k, num_blocks_m, !Varlen ? 1 : params.b); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { @@ -48,27 +49,27 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream) { // We want kBlockM to be as small as possible to maximize parallelism. // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). - static_assert(kHeadDim % 32 == 0, "kHeadDim must be a multiple of 32"); - static constexpr int kBlockM = kHeadDim % 128 == 0 ? 8 : (kHeadDim % 64 == 0 ? 16 : 32); + static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); + static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); BOOL_SWITCH(params.seqused_q != nullptr, Varlen, [&] { if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. if (params.num_splits <= 16) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); return; } } if (params.num_splits <= 32) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else if (params.num_splits <= 64) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else if (params.num_splits <= 128) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } }); } From db8ca796092463a38db8faf1089bde4f29def745 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 13:02:28 -0500 Subject: [PATCH 023/102] Use bf16 instead of fp16 in benchmark_gemm.py --- benchmarks/benchmark_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index df0d56b8f..3f3639e0b 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -26,7 +26,7 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs torch.manual_seed(0) repeats = 30 -dtype = torch.float16 +dtype = torch.bfloat16 device = 'cuda' verbose = False m, n = 8192, 8192 From 982c480c57c1b9a8e8ec3f70358957c69355f47a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 15:52:15 -0500 Subject: [PATCH 024/102] Update Cutlass to 3.7 --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index c506e1678..b78588d16 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit c506e16788cb08416a4a57e11a9067beeee29420 +Subproject commit b78588d1630aa6643bf021613717bafb705df4ef From 58ebfa5865516c7fb4ad83783501c802484260bb Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 16:02:41 -0500 Subject: [PATCH 025/102] Use nvcc 12.6 but ptxas 12.8 --- hopper/benchmark_attn.py | 8 ++++---- hopper/setup.py | 23 ++++++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index e61cea9e6..6dc253e00 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -261,9 +261,9 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128, 192]: # for headdim in [64, 96, 128, 192, 256]: # for headdim in [64, 96, 128]: -# for headdim in [64, 128, 256]: +for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -for headdim in [192]: +# for headdim in [128]: nheads = dim // headdim # headdim = 64 # batch_size = 64 @@ -276,7 +276,7 @@ def run(*args, **kwargs): # headdim_v = 128 for batch_size, seqlen in bs_seqlen_vals: - num_splits = 1 + num_splits = 0 window_size = (-1, -1) # window_size = (seqlen // 2 - 1, 0) sink_token_length = 0 @@ -320,7 +320,7 @@ def run(*args, **kwargs): page_table = None for causal in [False, True]: - # for causal in [False]: + # for causal in [True]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: diff --git a/hopper/setup.py b/hopper/setup.py index db8990255..1fb22acae 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -366,7 +366,7 @@ def nvcc_threads_args(): # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.8.61"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61", "cicc": "12.8.61"} exe_extension = sysconfig.get_config_var("EXE") @@ -387,10 +387,12 @@ def nvcc_threads_args(): if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") - if bare_metal_version != Version("12.8"): # nvcc 12.8 gives the best perf currently + # ptxas 12.8 gives the best perf currently + # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 + # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. + if bare_metal_version != Version("12.8"): download_and_copy( name="nvcc", - # src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", dst_path="bin", version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], @@ -398,11 +400,18 @@ def nvcc_threads_args(): f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) download_and_copy( - name="nvcc", - # src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", + name="ptxas", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) + download_and_copy( + name="cicc", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", dst_path="nvvm/bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + version=NVIDIA_TOOLCHAIN_VERSION["cicc"], url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) @@ -411,7 +420,7 @@ def nvcc_threads_args(): nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc - # os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] + os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] os.environ["PYTORCH_NVCC"] = nvcc_path_new # Make nvcc executable, sometimes after the copy it loses its permissions os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC) From ed435c6b364288b3a98a1ec26975adfa9f645f6b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 16:11:22 -0500 Subject: [PATCH 026/102] cicc uses the same version as ptxas --- hopper/setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hopper/setup.py b/hopper/setup.py index 1fb22acae..30063dd93 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -366,7 +366,7 @@ def nvcc_threads_args(): # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61", "cicc": "12.8.61"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61"} exe_extension = sysconfig.get_config_var("EXE") @@ -408,10 +408,10 @@ def nvcc_threads_args(): f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) download_and_copy( - name="cicc", + name="ptxas", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", dst_path="nvvm/bin", - version=NVIDIA_TOOLCHAIN_VERSION["cicc"], + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) From 86688236356ad19f560a698525c47f99b06531f2 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 16:32:59 -0500 Subject: [PATCH 027/102] Split hdimdiff into a separate translation unit --- hopper/generate_kernels.py | 11 ++++++++++- .../flash_fwd_hdim64_512_bf16_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_paged_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_paged_split_sm90.cu | 9 +++++++++ ...sh_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu | 9 +++++++++ .../instantiations/flash_fwd_hdim64_512_bf16_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_split_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_paged_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_paged_split_sm90.cu | 9 +++++++++ ...sh_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu | 9 +++++++++ .../instantiations/flash_fwd_hdim64_512_fp16_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_split_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdimall_bf16_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_paged_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_paged_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_paged_split_sm90.cu | 1 - ...flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu | 1 - hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_split_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_split_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_paged_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_paged_split_sm90.cu | 1 - ...flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu | 1 - hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_split_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_split_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_paged_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_paged_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_paged_split_sm90.cu | 1 - ...flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu | 1 - hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_split_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_split_softcap_sm90.cu | 1 - .../flash_fwd_hdimdiff_bf16_packgqa_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_paged_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_paged_split_sm90.cu | 6 ++++++ ...lash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu | 6 ++++++ hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_split_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_paged_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu | 5 +++++ ...lash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu | 5 +++++ hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_softcap_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_split_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_fp16_packgqa_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_paged_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_paged_split_sm90.cu | 6 ++++++ ...lash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu | 6 ++++++ hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_split_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu | 6 ++++++ hopper/setup.py | 2 +- 82 files changed, 361 insertions(+), 32 deletions(-) create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index 7a5eb47d0..19a6e90d3 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -138,6 +138,8 @@ def get_all_kernels() -> List[Kernel]: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") if sm == 90 and head_dim == 192: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") @@ -146,11 +148,18 @@ def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM): if sm < 90: continue - kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm] + # Same hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim == k.head_dim_v] if len(kernels) > 0: filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) yield KERNEL_BATCH(template, filename) + # Different hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim != k.head_dim_v] + if len(kernels) > 0: + filename = f"flash_fwd_hdimdiff_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" + template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) + yield KERNEL_BATCH(template, filename) def batch_softcap(kernels_all) -> List[KERNEL_BATCH]: diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu new file mode 100644 index 000000000..2f4ceaaed --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu new file mode 100644 index 000000000..5fd59af34 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu new file mode 100644 index 000000000..e0f885b0f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu new file mode 100644 index 000000000..6dcda0196 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..5d20be6d2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu new file mode 100644 index 000000000..47463a715 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..622b5533c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu new file mode 100644 index 000000000..c83f44722 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu new file mode 100644 index 000000000..5c9130f86 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu new file mode 100644 index 000000000..a152022cb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu new file mode 100644 index 000000000..ef05aa203 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu new file mode 100644 index 000000000..19fe6d94f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu new file mode 100644 index 000000000..6eb2d3d13 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu new file mode 100644 index 000000000..ffbc99821 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..3d35075b4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu new file mode 100644 index 000000000..c2af33cf5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..e07547c92 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu new file mode 100644 index 000000000..1a04eb01f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu new file mode 100644 index 000000000..da9afc115 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu new file mode 100644 index 000000000..5e63a1551 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu index e8ed21cda..8b659e832 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu index f7de8fa20..c84d02b6d 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_paged_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu index 64e5ce4a3..6aaf7d12f 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu index 44619cce5..117121414 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu index a05973582..617572308 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu index daea288fe..2aac1970b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_sm90.cu" #include "flash_fwd_hdim128_bf16_sm90.cu" #include "flash_fwd_hdim192_bf16_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_sm90.cu" #include "flash_fwd_hdim256_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu index 62640192c..be0c5af08 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu index 79b0d52fa..fd5893c59 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu index 333406cb4..bcde9c945 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_split_sm90.cu" #include "flash_fwd_hdim128_bf16_split_sm90.cu" #include "flash_fwd_hdim192_bf16_split_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" #include "flash_fwd_hdim256_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu index b6c1fb54c..160eb3a18 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu index abf0b10e4..28819a690 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu index 22b310e5a..933ad9827 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_paged_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu index f9eed0732..a934f7d99 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu index b91c7f85a..8475e878a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu index a6b215bfd..dd1405b17 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu index ddec44c68..7e7d806c6 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_sm90.cu" #include "flash_fwd_hdim128_e4m3_sm90.cu" #include "flash_fwd_hdim192_e4m3_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_sm90.cu" #include "flash_fwd_hdim256_e4m3_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu index 81601b9ec..f973a4e41 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu index ae9a362c1..30390838d 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu index 163ee761b..0b629bd2b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu index ba2d427dd..818c7fafb 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu index 34d176348..6652824d0 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu index 326a2ea90..05d11e2e2 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_paged_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu index a9e032a07..b638138eb 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu index d7cc300b8..3619a2175 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu index fa4de4e29..3a408ceac 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu index cb3455866..eec11be91 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_sm90.cu" #include "flash_fwd_hdim128_fp16_sm90.cu" #include "flash_fwd_hdim192_fp16_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_sm90.cu" #include "flash_fwd_hdim256_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu index 5dbd70ec5..ca2a1e1b8 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu index 9a97b9604..8cf31a8a8 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu index 5aacbf026..5ee7ace63 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_split_sm90.cu" #include "flash_fwd_hdim128_fp16_split_sm90.cu" #include "flash_fwd_hdim192_fp16_split_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" #include "flash_fwd_hdim256_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu index cfaabd990..4da0ee704 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu new file mode 100644 index 000000000..cc3a8a7c9 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu new file mode 100644 index 000000000..d6d6df0d4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu new file mode 100644 index 000000000..bd85f7608 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu new file mode 100644 index 000000000..733511adb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..c62ccf28d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu new file mode 100644 index 000000000..b7e51551a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..0dbd00454 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu new file mode 100644 index 000000000..51a143712 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu new file mode 100644 index 000000000..24a64e8e4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu new file mode 100644 index 000000000..50c78f3d5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu new file mode 100644 index 000000000..526a51fb7 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu new file mode 100644 index 000000000..4e5d9cc4f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu new file mode 100644 index 000000000..f553af139 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu new file mode 100644 index 000000000..aa2a8260d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..bbc4449ba --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu new file mode 100644 index 000000000..02ca85ad6 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..d090fde97 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu new file mode 100644 index 000000000..d48f60ad7 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu new file mode 100644 index 000000000..9dda19d1c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu new file mode 100644 index 000000000..f3e51fc9e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu new file mode 100644 index 000000000..453282a4f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu new file mode 100644 index 000000000..72736d8ef --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu new file mode 100644 index 000000000..97895aa70 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu new file mode 100644 index 000000000..423c42221 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..98c895721 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu new file mode 100644 index 000000000..69108d025 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..da39ba273 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu new file mode 100644 index 000000000..be6496d19 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu new file mode 100644 index 000000000..a5a809090 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu new file mode 100644 index 000000000..62fe14256 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/setup.py b/hopper/setup.py index 30063dd93..560ddcc1c 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -470,7 +470,7 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all"] + HEAD_DIMENSIONS_FWD = ["all", "diff"] HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) From b2fc79d17526ab56d7561091441a62f241056a4b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 16:39:17 -0500 Subject: [PATCH 028/102] Update benchmark script --- hopper/benchmark_attn.py | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 6dc253e00..5d1f53692 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -242,28 +242,13 @@ def run(*args, **kwargs): time_f = {} time_b = {} -# tflops_matmul = {} -# m, n = 8192, 8192 -# for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]: -# a = torch.randn(m, k, device=device, dtype=dtype) -# b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2) -# nFLOPS_matmul = 2 * m * n * k -# m5 = time_fwd(torch.matmul, a, b, desc='cuBLAS') -# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') -# tflops_matmul[k] = nFLOPS_matmul / m5.mean * 1e-12 -# # import pickle -# # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp: -# # with open(f'flash3_matmul_tflops_h100.plk', 'wb') as fp: -# # pickle.dump(tflops_matmul, fp, protocol=pickle.HIGHEST_PROTOCOL) -# exit(0) - # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192]: # for headdim in [64, 96, 128, 192, 256]: # for headdim in [64, 96, 128]: -for headdim in [64, 128, 256]: +# for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -# for headdim in [128]: +for headdim in [128]: nheads = dim // headdim # headdim = 64 # batch_size = 64 @@ -297,10 +282,6 @@ def run(*args, **kwargs): g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) - a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen) - b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2) - # x = torch.randn(batch_size * seqlen, 4096, device=device, dtype=dtype) - # w = torch.randn(4096 * 2, 4096, device=device, dtype=dtype).transpose(-1, -2) if varlen: q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]] cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q @@ -377,11 +358,6 @@ def run(*args, **kwargs): m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean - # time.sleep(1) - # m5 = time_fwd(torch.bmm, a, b, desc='cuBLAS', repeats=repeats, verbose=False) - # nFLOPS_matmul = nFLOPS - # nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1] - # m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS') if dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) if not varlen: From c09154572015e803123a5c875e7548cef423cd90 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 18:32:23 -0500 Subject: [PATCH 029/102] Update Cutlass to 3.8 --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index b78588d16..833f6990e 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit b78588d1630aa6643bf021613717bafb705df4ef +Subproject commit 833f6990e031b48b4cd2fcf55e0849c51ef6bac2 From 5e39b100b421e104c3dca3011353e9889e8839ea Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 18:47:23 -0500 Subject: [PATCH 030/102] Adjust tile size for hdim 64 --- hopper/tile_size.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 997664bcb..5d0bd6e26 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -13,7 +13,11 @@ constexpr std::tuple tile_size_fwd_sm90( if (element_size == 2) { if (headdim <= 64) { bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 - return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; + // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; + // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why + // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 + // Switch to tile size 192 x 192 for now + return {same_hdim ? 192 : 64, same_hdim ? 192 : 64, false, true}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { From 1a7f4dfa9e51f6a90177a3244a5bc9c687894cdd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 19:01:26 -0500 Subject: [PATCH 031/102] Adjust ninja build file --- hopper/setup.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/hopper/setup.py b/hopper/setup.py index 560ddcc1c..f638558a0 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -150,6 +150,8 @@ def sanitize_flags(flags): flags.append(f'cuda_post_cflags_sm80 = {" ".join(cuda_post_cflags_sm80)}') cuda_post_cflags_sm80_sm90 = cuda_post_cflags + ['-gencode', 'arch=compute_80,code=sm_80'] flags.append(f'cuda_post_cflags_sm80_sm90 = {" ".join(cuda_post_cflags_sm80_sm90)}') + cuda_post_cflags_sm100 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_100a,code=sm_100a' for s in cuda_post_cflags] + flags.append(f'cuda_post_cflags_sm100 = {" ".join(cuda_post_cflags_sm100)}') flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') flags.append(f'ldflags = {" ".join(ldflags)}') @@ -187,6 +189,9 @@ def sanitize_flags(flags): cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [ f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' ] + cuda_compile_rule_sm100 = ['rule cuda_compile_sm100'] + cuda_compile_rule[1:] + [ + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm100' + ] cuda_compile_rule.append( f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') @@ -199,6 +204,8 @@ def sanitize_flags(flags): rule = 'cuda_compile' elif source_file.endswith('_sm80.cu'): rule = 'cuda_compile_sm80' + elif source_file.endswith('_sm100.cu'): + rule = 'cuda_compile_sm100' else: rule = 'cuda_compile_sm80_sm90' else: @@ -244,6 +251,7 @@ def sanitize_flags(flags): blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80_sm90) # type: ignore[possibly-undefined] + blocks.append(cuda_compile_rule_sm100) # type: ignore[possibly-undefined] blocks += [devlink_rule, link_rule, build, devlink, link, default] content = "\n\n".join("\n".join(b) for b in blocks) # Ninja requires a new lines at the end of the .ninja file From 15cf7ee4357d1880b8ba5b1356fdea03f6ee5df9 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 11 Feb 2025 15:59:33 -0500 Subject: [PATCH 032/102] Rename collective_mainloop -> mainloop, move tile_scheduler variable --- hopper/flash_bwd_kernel_sm80.h | 15 +++---- hopper/flash_bwd_kernel_sm90.h | 31 ++++++------- hopper/flash_fwd_kernel_sm80.h | 14 +++--- hopper/flash_fwd_kernel_sm90.h | 35 ++++++++------- hopper/flash_fwd_launch_template.h | 2 +- hopper/utils.h | 70 ++++++++++++++++++++++++++++++ 6 files changed, 116 insertions(+), 51 deletions(-) diff --git a/hopper/flash_bwd_kernel_sm80.h b/hopper/flash_bwd_kernel_sm80.h index b4fe26285..aaec00dbe 100644 --- a/hopper/flash_bwd_kernel_sm80.h +++ b/hopper/flash_bwd_kernel_sm80.h @@ -133,8 +133,8 @@ class FlashAttnBwdSm80 { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue; + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); // Initialize matmul objects. @@ -155,15 +155,14 @@ class FlashAttnBwdSm80 { // dK and dV output accumulator. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - bool tile_valid = collective_mainloop.mma( - params.mainloop, tdKrdK, tdVrdV, threadIdx.x, block_coord, - shared_storage); + bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x, + block_coord, shared_storage); scheduler.prefetch_next_work(params.scheduler, work_tile_info); if (tile_valid) { - collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, - threadIdx.x, block_coord); + epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + threadIdx.x, block_coord); } else { - collective_epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); } } diff --git a/hopper/flash_bwd_kernel_sm90.h b/hopper/flash_bwd_kernel_sm90.h index 7aa32a846..b93a02191 100644 --- a/hopper/flash_bwd_kernel_sm90.h +++ b/hopper/flash_bwd_kernel_sm90.h @@ -195,8 +195,8 @@ class FlashAttnBwdSm90 { PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers}; MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return(pipeline_params, pipeline_params_dO), ClusterShape{}); - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue; + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { @@ -206,6 +206,8 @@ class FlashAttnBwdSm90 { __syncthreads(); } + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); @@ -213,8 +215,6 @@ class FlashAttnBwdSm90 { if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO PipelineState smem_pipe_write = cutlass::make_producer_start_state(); PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state(); - - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { @@ -224,32 +224,29 @@ class FlashAttnBwdSm90 { auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { scheduler.prefetch_next_work(params.scheduler, work_tile_info); }; - collective_mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, - smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord); + mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, + smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord); } - collective_mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do); + mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do); } else if (warp_idx_in_warpgroup == 1) { - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; cute::tuple block_coord = {n_block, bidh, bidb}; - collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord); + mainloop.store_dq(params.mainloop, shared_storage, block_coord); } } } else { // Consumer cutlass::arch::warpgroup_reg_alloc(); - - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); // Initialize matmul objects. TiledMmadKV tiled_mma_dKV; PipelineState smem_pipe_read; PipelineState_dO smem_pipe_read_do; - collective_mainloop.mma_init(); + mainloop.mma_init(); scheduler.init_consumer(); int work_idx = 0; @@ -264,18 +261,18 @@ class FlashAttnBwdSm90 { // dK and dV output accumulator. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - bool tile_valid = collective_mainloop.mma( + bool tile_valid = mainloop.mma( params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do, tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage); if (tile_valid) { - collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, - threadIdx.x - NumCopyThreads, block_coord); + epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + threadIdx.x - NumCopyThreads, block_coord); } else { - collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); } } - collective_epilogue.store_tail(); + epilogue.store_tail(); } } diff --git a/hopper/flash_fwd_kernel_sm80.h b/hopper/flash_fwd_kernel_sm80.h index a2f550478..71071d722 100644 --- a/hopper/flash_fwd_kernel_sm80.h +++ b/hopper/flash_fwd_kernel_sm80.h @@ -151,8 +151,8 @@ class FlashAttnFwdSm80 { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue; + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); // Initialize matmul objects. @@ -189,23 +189,23 @@ class FlashAttnFwdSm80 { params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, }; if constexpr (AppendKV) { - bool tile_new_valid = collective_mainloop.store_kv_new( + bool tile_new_valid = mainloop.store_kv_new( params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord); if (tile_new_valid) { __syncthreads(); } } - bool tile_valid = collective_mainloop.mma( + bool tile_valid = mainloop.mma( params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord, shared_storage); scheduler.prefetch_next_work(params.scheduler, work_tile_info); if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, - threadIdx.x, block_coord); + epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, + threadIdx.x, block_coord); } else { // Write 0 to gO and -inf to gLSE. // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will // not use the value of O if LSE is -inf. - collective_epilogue.template store_zero(params.epilogue, threadIdx.x, block_coord); + epilogue.template store_zero(params.epilogue, threadIdx.x, block_coord); } } diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index c7fec6df5..9cfb2d9e5 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -291,8 +291,8 @@ class FlashAttnFwdSm90 { } auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue; + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { @@ -302,6 +302,8 @@ class FlashAttnFwdSm90 { __syncthreads(); } + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); @@ -312,8 +314,6 @@ class FlashAttnFwdSm90 { PipelineState smem_pipe_write = cutlass::make_producer_start_state(); PipelineState smem_pipe_write_new = cutlass::make_producer_start_state(); int work_idx = 0; - - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; if constexpr (SingleProducerWarp) { @@ -336,7 +336,7 @@ class FlashAttnFwdSm90 { params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, }; if constexpr (AppendKV) { - bool tile_new_valid = collective_mainloop.load_kv_new( + bool tile_new_valid = mainloop.load_kv_new( params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx); if (tile_new_valid) { @@ -349,14 +349,13 @@ class FlashAttnFwdSm90 { scheduler.prefetch_next_work(params.scheduler, work_tile_info); }; // pipeline_vt won't be used if we don't need to transpose V. - collective_mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, + mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx); } - collective_mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); + mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); } else { // Consumer cutlass::arch::warpgroup_reg_alloc(); - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); // Initialize matmul objects. TiledMmaPV tiled_mma_pv; @@ -366,7 +365,7 @@ class FlashAttnFwdSm90 { // (like in Cutlass's gemm) because the read and release pipeline states are always the same. scheduler.init_consumer(); - collective_mainloop.mma_init(); + mainloop.mma_init(); int work_idx = 0; CUTLASS_PRAGMA_NO_UNROLL @@ -397,7 +396,7 @@ class FlashAttnFwdSm90 { params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, }; if constexpr (AppendKV) { - bool tile_new_valid = collective_mainloop.store_kv_new( + bool tile_new_valid = mainloop.store_kv_new( params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new, threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord); if (tile_new_valid) { @@ -414,33 +413,33 @@ class FlashAttnFwdSm90 { } bool tile_valid; if constexpr (!LargeHeadDimV) { - tile_valid = collective_mainloop.mma( + tile_valid = mainloop.mma( params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); } else { // mma_pv might not compile if !LargeHeadDimV if (warp_group_idx == 1) { - tile_valid = collective_mainloop.mma( + tile_valid = mainloop.mma( params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); } else { - tile_valid = collective_mainloop.mma_pv( + tile_valid = mainloop.mma_pv( params.mainloop, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); } } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, - threadIdx.x - MmaThreadOffset, block_coord); + epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, + threadIdx.x - MmaThreadOffset, block_coord); } else { // Write 0 to gO and -inf to gLSE. // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will // not use the value of O if LSE is -inf. - collective_epilogue.template store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); - // collective_epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); + epilogue.template store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); + // epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); } } - collective_epilogue.store_tail(); + epilogue.store_tail(); } } diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index b4f80a04e..71eabc2a1 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -196,7 +196,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and false; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { diff --git a/hopper/utils.h b/hopper/utils.h index fa8938c85..e14ca1574 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -354,6 +354,69 @@ CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Ten } } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_sm100(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + static constexpr int rA = decltype(rank(tA))::value; + static constexpr int rB = decltype(rank(tB))::value; + static constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + if constexpr (zero_init) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -562,6 +625,13 @@ CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + CUTLASS_DEVICE int canonical_warp_group_idx_nosync() { return threadIdx.x / cutlass::NumThreadsPerWarpGroup; From 9f313c7073ffa4b10d6daea86003e0f76764f134 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 12 Feb 2025 07:10:05 -0500 Subject: [PATCH 033/102] Move functions getting number of m/n blocks to a separate file --- hopper/block.h | 89 ++++++++++++++++++++++++ hopper/mainloop_bwd_sm80.hpp | 31 +++------ hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 37 ++++------ hopper/mainloop_fwd_sm80.hpp | 56 +++------------ hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 70 ++++++------------- 5 files changed, 140 insertions(+), 143 deletions(-) create mode 100644 hopper/block.h diff --git a/hopper/block.h b/hopper/block.h new file mode 100644 index 000000000..d06744c3b --- /dev/null +++ b/hopper/block.h @@ -0,0 +1,89 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +template +struct BlockMN { + + static + CUTLASS_DEVICE + cute::tuple get_n_block_min_max( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const bidb, int const split_idx, int const num_splits, + int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + + int const seqlen_k = seqlen_info.seqlen_k; + int const seqlen_q = seqlen_info.seqlen_q; + int n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal || Is_local) { + int m_idx_max = (m_block + 1) * kBlockM; + // TODO: check off-by-1 error + if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } + n_block_max = std::min(n_block_max, + cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN)); + } + int n_block_min = 0; + if constexpr (Is_local) { + int m_idx_min = m_block * kBlockM; + if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } + n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN); + } + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + if constexpr (Split) { + int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits); + n_block_min = n_block_min + split_idx * num_n_blocks_per_split; + n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); + } + // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + return {n_block_min, n_block_max}; + } + + static + CUTLASS_DEVICE + cute::tuple get_n_block_k_new_min_max( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const bidb, int const split_idx, int const num_splits, + int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + + auto [n_block_min, n_block_max] = get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, num_splits, + window_size_left, window_size_right, qhead_per_khead_divmod); + int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); + int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); + int const n_block_new_min = idx_k_new_min / kBlockN; + int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; + // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} + return {n_block_new_min, n_block_new_max}; + } + + static + CUTLASS_DEVICE + cute::tuple get_m_block_min_max( + SeqlenInfo_t const& seqlen_info, + int const n_block, int const bidb, + int const window_size_left, int const window_size_right, int const sink_token_length) { + + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + int m_block_max = cute::ceil_div(seqlen_q, kBlockM); + if constexpr (Is_local) { + if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) { + m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM)); + } + } + int m_block_min = 0; + if constexpr (Is_causal || Is_local) { + m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM); + } + return {m_block_min, m_block_max}; + } + +}; + +} // namespace flash diff --git a/hopper/mainloop_bwd_sm80.hpp b/hopper/mainloop_bwd_sm80.hpp index e7b3d2dea..eb0503c93 100644 --- a/hopper/mainloop_bwd_sm80.hpp +++ b/hopper/mainloop_bwd_sm80.hpp @@ -13,6 +13,7 @@ #include "seqlen.h" #include "mask.h" +#include "mask.h" #include "softmax.h" #include "utils.h" @@ -38,7 +39,6 @@ struct CollectiveMainloopBwdSm80 { static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; - using SeqlenInfo_t = flash::SeqlenInfoQK(TileShape_MNK{}))>; static constexpr int NumMmaWarps = NumMmaWarpGroups * cutlass::NumWarpsPerWarpGroup; static constexpr bool SdP_swapAB = SdP_swapAB_; @@ -51,6 +51,9 @@ struct CollectiveMainloopBwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + using SeqlenInfo_t = flash::SeqlenInfoQK; + using BlockMN_t = flash::BlockMN; + static_assert(ArchTag::kMinComputeCapability >= 80); static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; @@ -362,26 +365,6 @@ struct CollectiveMainloopBwdSm80 { args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; } - CUTLASS_DEVICE - cute::tuple get_m_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int n_block, int bidb) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - int m_block_max = cute::ceil_div(seqlen_q, kBlockM); - if constexpr (Is_local) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - if (n_block >= cute::ceil_div(params.sink_token_length, kBlockN)) { - m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); - } - } - int m_block_min = 0; - if constexpr (Is_causal || Is_local) { - m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM); - } - return {m_block_min, m_block_max}; - } - template CUTLASS_DEVICE bool mma(Params const& params, @@ -400,7 +383,9 @@ struct CollectiveMainloopBwdSm80 { bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; - auto m_block_min_max = get_m_block_min_max(params, seqlen_info, n_block, bidb); + auto m_block_min_max = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, + params.window_size_left, params.window_size_right, params.sink_token_length); int const m_block_min = get<0>(m_block_min_max); int const m_block_max = get<1>(m_block_min_max); // It's possible to have m_block_max <= m_block_min. Exit early @@ -861,7 +846,7 @@ struct CollectiveMainloopBwdSm80 { tdKrdK, tdKrdS, tdKrQ, tdKsdSt, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, cute::conditional_return<(kStages > 1)>(nullptr, load_dO_next)); } - if constexpr (kStages == 1) { + if constexpr (kStages == 1) { __syncthreads(); do_mma_dQ(load_Q_next); } diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 393a6e581..e3b296068 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -17,6 +17,7 @@ #include "named_barrier.hpp" #include "seqlen.h" +#include "block.h" #include "mask.h" #include "softmax.h" #include "utils.h" @@ -48,7 +49,6 @@ struct CollectiveMainloopBwdSm90 { static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; - using SeqlenInfo_t = flash::SeqlenInfoQK(TileShape_MNK{}))>; static constexpr bool SdP_swapAB = SdP_swapAB_; static constexpr bool dKV_swapAB = dKV_swapAB_; @@ -60,6 +60,9 @@ struct CollectiveMainloopBwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + using SeqlenInfo_t = flash::SeqlenInfoQK; + using BlockMN_t = flash::BlockMN; + static_assert(ArchTag::kMinComputeCapability >= 90); static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); @@ -406,26 +409,6 @@ struct CollectiveMainloopBwdSm90 { cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); } - CUTLASS_DEVICE - cute::tuple get_m_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int n_block, int bidb) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - int m_block_max = cute::ceil_div(seqlen_q, kBlockM); - if constexpr (Is_local) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - if (n_block >= cute::ceil_div(params.sink_token_length, kBlockN)) { - m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); - } - } - int m_block_min = 0; - if constexpr (Is_causal || Is_local) { - m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM); - } - return {m_block_min, m_block_max}; - } - template CUTLASS_DEVICE void load(Params const& params, @@ -443,7 +426,9 @@ struct CollectiveMainloopBwdSm90 { bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; - auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, + params.window_size_left, params.window_size_right, params.sink_token_length); // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { @@ -609,7 +594,9 @@ struct CollectiveMainloopBwdSm90 { bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; - auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, params.window_size_left, + params.window_size_right, params.sink_token_length); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return; } @@ -697,7 +684,9 @@ struct CollectiveMainloopBwdSm90 { bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; - auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, params.window_size_left, + params.window_size_right, params.sink_token_length); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return false; } diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 0fb32c7a9..909654d34 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -12,6 +12,7 @@ #include "cute/tensor.hpp" #include "seqlen.h" +#include "block.h" #include "mask.h" #include "pack_gqa.h" #include "paged_kv.h" @@ -44,7 +45,6 @@ struct CollectiveMainloopFwdSm80 { static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool Transpose_V = Is_FP8; - using SeqlenInfo_t = flash::SeqlenInfoQKNewK; static_assert(ArchTag::kMinComputeCapability >= 80); @@ -54,6 +54,9 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + using BlockMN_t = flash::BlockMN; + using MMA_Atom_Arch = std::conditional_t< ArchTag::kMinComputeCapability >= 80, std::conditional_t< @@ -295,36 +298,6 @@ struct CollectiveMainloopFwdSm80 { args.seqused_q, args.seqused_k, args.leftpad_k}; } - CUTLASS_DEVICE - cute::tuple get_n_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int m_block, int bidb, int split_idx=0, int num_splits=1) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const seqlen_k = seqlen_info.seqlen_k; - int const seqlen_q = seqlen_info.seqlen_q; - int n_block_max = cute::ceil_div(seqlen_k, kBlockN); - if constexpr (Is_causal || Is_local) { - int m_idx_max = (m_block + 1) * kBlockM; - if (PackGQA) { m_idx_max = params.qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } - n_block_max = std::min(n_block_max, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + params.window_size_right, kBlockN)); - } - int n_block_min = 0; - if constexpr (Is_local) { - int m_idx_min = m_block * kBlockM; - if (PackGQA) { m_idx_min = params.qhead_per_khead_divmod.divide(m_idx_min); } - n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - params.window_size_left) / kBlockN); - } - // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - if constexpr (Split) { - int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits); - n_block_min = n_block_min + split_idx * num_n_blocks_per_split; - n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); - } - // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - return {n_block_min, n_block_max}; - } - template CUTLASS_DEVICE bool mma(Params const& params, @@ -345,7 +318,9 @@ struct CollectiveMainloopFwdSm80 { int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto n_block_min_max = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); int const n_block_min = get<0>(n_block_min_max); int const n_block_max = get<1>(n_block_min_max); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier @@ -679,19 +654,6 @@ struct CollectiveMainloopFwdSm80 { return true; } - CUTLASS_DEVICE - cute::tuple get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int m_block, int bidb, int split_idx=0, int num_splits=1) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, num_splits); - int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); - int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); - int const n_block_new_min = idx_k_new_min / kBlockN; - int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; - // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} - return {n_block_new_min, n_block_new_max}; - } - template CUTLASS_DEVICE bool store_kv_new(Params const& params, @@ -701,7 +663,9 @@ struct CollectiveMainloopFwdSm80 { cute::tuple block_coord ) { auto [m_block, bidh, bidb, split_idx] = block_coord; - auto n_block_new_min_max = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); int const n_block_new_min = get<0>(n_block_new_min_max); int const n_block_new_max = get<1>(n_block_new_min_max); if (n_block_new_max <= n_block_new_min) { return false; } diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 1834f200c..4f2e7a35a 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -16,6 +16,7 @@ #include "named_barrier.hpp" #include "seqlen.h" +#include "block.h" #include "mask.h" #include "pack_gqa.h" #include "paged_kv.h" @@ -58,7 +59,6 @@ struct CollectiveMainloopFwdSm90 { static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; static constexpr bool LargeHeadDimV = kHeadDimV > 256; - using SeqlenInfo_t = flash::SeqlenInfoQKNewK; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -69,6 +69,9 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + using BlockMN_t = flash::BlockMN; + static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); @@ -565,37 +568,6 @@ struct CollectiveMainloopFwdSm90 { } } - CUTLASS_DEVICE - cute::tuple get_n_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int m_block, int bidb, int split_idx=0, int num_splits=1) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const seqlen_k = seqlen_info.seqlen_k; - int const seqlen_q = seqlen_info.seqlen_q; - int n_block_max = cute::ceil_div(seqlen_k, kBlockN); - if constexpr (Is_causal || Is_local) { - int m_idx_max = (m_block + 1) * kBlockM; - // TODO: check off-by-1 error - if (PackGQA) { m_idx_max = params.qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } - n_block_max = std::min(n_block_max, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + params.window_size_right, kBlockN)); - } - int n_block_min = 0; - if constexpr (Is_local) { - int m_idx_min = m_block * kBlockM; - if (PackGQA) { m_idx_min = params.qhead_per_khead_divmod.divide(m_idx_min); } - n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - params.window_size_left) / kBlockN); - } - // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - if constexpr (Split) { - int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits); - n_block_min = n_block_min + split_idx * num_n_blocks_per_split; - n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); - } - // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - return {n_block_min, n_block_max}; - } - template CUTLASS_DEVICE void load(Params const& params, @@ -615,7 +587,9 @@ struct CollectiveMainloopFwdSm90 { int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { @@ -997,7 +971,9 @@ struct CollectiveMainloopFwdSm90 { int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -1379,7 +1355,9 @@ struct CollectiveMainloopFwdSm90 { int const m_block = get<0>(block_coord); int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -1446,19 +1424,6 @@ struct CollectiveMainloopFwdSm90 { return true; } - CUTLASS_DEVICE - cute::tuple get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int m_block, int bidb, int split_idx=0, int num_splits=1) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, num_splits); - int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); - int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); - int const n_block_new_min = idx_k_new_min / kBlockN; - int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; - // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} - return {n_block_new_min, n_block_new_max}; - } - template CUTLASS_DEVICE bool load_kv_new(Params const& params, @@ -1472,7 +1437,10 @@ struct CollectiveMainloopFwdSm90 { ) { auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + if (n_block_new_max <= n_block_new_min) { return false; } Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); @@ -1571,7 +1539,9 @@ struct CollectiveMainloopFwdSm90 { cute::tuple block_coord ) { auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); if (n_block_new_max <= n_block_new_min) { return false; } // as_position_independent_swizzle_tensor makes address calculation easier From eafd53c2f1f6efc2e4816eb18f5c79a2463eb6c0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 12 Feb 2025 07:21:10 -0500 Subject: [PATCH 034/102] Update cutlass 3.8 to fix error w cudaGetDriverEntryPointByVersion --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index 833f6990e..e9627ce55 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit 833f6990e031b48b4cd2fcf55e0849c51ef6bac2 +Subproject commit e9627ce55b42fd2599f58cd4396da9380954def0 From fa445ff6c215026438cca496a97242b8269aa428 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 12 Feb 2025 07:50:45 -0500 Subject: [PATCH 035/102] Fix FP8 test --- hopper/test_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hopper/test_util.py b/hopper/test_util.py index b7ea3d3b7..8c10e2d5d 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -238,9 +238,9 @@ def attention_ref( q, k, v = q.float(), k.float(), v.float() qv = qv.float() if qv is not None else None if q_descale is not None: - q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g = q.shape[2] // k.shape[2]).to(dtype=q.dtype) - q = q.float() * q_descale - qv = qv.float() * q_descale if qv is not None else None + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: From a09abcd32d3cae4d83b313446e887f38d02b799f Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Sun, 16 Feb 2025 02:16:32 +0100 Subject: [PATCH 036/102] make seqused optional on top level interface (#1497) --- hopper/benchmark_attn.py | 4 ++-- hopper/flash_attn_interface.py | 4 ++-- hopper/test_flash_attn.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 5d1f53692..36f0bf6d0 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -355,7 +355,7 @@ def run(*args, **kwargs): m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: - m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean if dtype != torch.float8_e4m3fn and headdim == headdim_v: @@ -364,7 +364,7 @@ def run(*args, **kwargs): _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') else: - _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean # time.sleep(1) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index adee1a0ff..78cfe1cb9 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -563,10 +563,10 @@ def flash_attn_varlen_func( v, cu_seqlens_q, cu_seqlens_k, - seqused_q, - seqused_k, max_seqlen_q, max_seqlen_k, + seqused_q=None, + seqused_k=None, softmax_scale=None, causal=False, qv=None, diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index e9cd8c9d6..ddd687f1f 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -450,9 +450,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): v_unpad, cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, causal=causal, qv=qv_unpad, q_descale=q_descale, From 40cbd529e4ef4c09abc923ab6166b30cda841550 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 18 Feb 2025 10:10:31 -0500 Subject: [PATCH 037/102] Temporarily change package name of FA3 to allow FA2 & FA3 install --- hopper/setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hopper/setup.py b/hopper/setup.py index f638558a0..6798de67a 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -33,7 +33,7 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) -PACKAGE_NAME = "flash_attn" +PACKAGE_NAME = "flash_attn_3" BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" @@ -390,7 +390,7 @@ def nvcc_threads_args(): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - check_if_cuda_home_none("flash_attn") + check_if_cuda_home_none(PACKAGE_NAME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") From 91917b406bcf5b87dc88d67e4a37b3e80adf7d25 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 18 Feb 2025 13:41:09 -0500 Subject: [PATCH 038/102] Update benchmark_split_kv.py to work w new API --- hopper/benchmark_split_kv.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/hopper/benchmark_split_kv.py b/hopper/benchmark_split_kv.py index d3d83590a..c54b51824 100644 --- a/hopper/benchmark_split_kv.py +++ b/hopper/benchmark_split_kv.py @@ -38,7 +38,7 @@ def main(): ).multi_processor_count max_splits = 129 - check_all_splits = False + check_all_splits = True causal = True # causal = False @@ -139,7 +139,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=False, + pack_gqa=False, num_splits=1, ) * 1000. * 1000. @@ -151,9 +151,9 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=0, - max_seqlen_k_hint=context_seqlen + # max_seqlen_k_hint=context_seqlen ) * 1000. * 1000. if check_all_splits: @@ -170,7 +170,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=False, + pack_gqa=False, num_splits=num_splits ) * 1000. * 1000. @@ -181,7 +181,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=False, + pack_gqa=False, num_splits=num_splits ) @@ -192,7 +192,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=False, + pack_gqa=False, num_splits=1 ) @@ -220,7 +220,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=num_splits ) * 1000. * 1000. @@ -231,7 +231,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=num_splits ) @@ -242,7 +242,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=1 ) @@ -271,11 +271,11 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, # num_splits=num_splits_select, # num_splits=1, num_splits=0, - max_seqlen_k_hint=context_seqlen + # max_seqlen_k_hint=context_seqlen ) * 1000. * 1000. fa3_fastest_splitk_time_gqa = timeit( @@ -286,7 +286,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=fa3_fastest_num_splits_gqa ) * 1000. * 1000. @@ -322,4 +322,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From ea3ecea97a1393c092863330aff9a162bb5ce443 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 18 Feb 2025 14:24:13 -0500 Subject: [PATCH 039/102] Add tp_degree to benchmark_split_kv --- hopper/benchmark_split_kv.py | 36 +++++++++++++++++++++--------------- hopper/epilogue_bwd.hpp | 2 +- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/hopper/benchmark_split_kv.py b/hopper/benchmark_split_kv.py index c54b51824..f3c8af917 100644 --- a/hopper/benchmark_split_kv.py +++ b/hopper/benchmark_split_kv.py @@ -18,13 +18,13 @@ def timeit(fn, *args, **kwargs): # Warmup for _ in range(5): fn(*args, **kwargs) - + # Benchmark using PyTorch Timer t = benchmark.Timer( stmt='fn(*args, **kwargs)', globals={'fn': fn, 'args': args, 'kwargs': kwargs} ) - + # Measure execution time measurement = t.timeit(20) # Runs the function 20 times # measurement = t.blocked_autorange(min_run_time=1) @@ -44,8 +44,9 @@ def main(): # causal = False # dtype=torch.float16 dtype=torch.bfloat16 + tp_degree = 1 - torch.manual_seed(42) + torch.manual_seed(42) model_configs = [ # ("Gemma-2-2B", 8, 4, 256), @@ -56,6 +57,7 @@ def main(): # ("Qwen-2.5-7B", 28, 4, 128), # ("Llama-3.1-8B", 32, 8, 128), ("Llama-3.1-70B", 64, 8, 128), + # ("Mistral Large", 96, 8, 128), # ("Llama-3.1-405B", 128, 8, 128), # ("Llama-3.2-1B", 32, 8, 64), # ("Llama-3.2-3B", 24, 8, 128), @@ -66,28 +68,32 @@ def main(): all_batch_configs.extend(itertools.product( # [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen - [4096, 16384, 65536], # context_seqlen - # [131072], # context_seqlen + # [4096, 16384, 65536], # context_seqlen + [131072], # context_seqlen # [i for i in range(1, (num_sms) + 1)], # num_requests [1, 4, 8, 16], # num_requests # [1], # num_requests - [1, 4, 8, 16], # query_seqlen - # [1], # query_seqlen + # [1, 4, 8, 16], # query_seqlen + [1], # query_seqlen )) num_caches = max(reqs for _, reqs, _ in all_batch_configs) cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs) for model_name, nheads_q, nheads_kv, headdim in model_configs: + assert nheads_kv % tp_degree == 0 + print(f"***{model_name}***") + print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}, TP:{tp_degree}") + nheads_q //= tp_degree + nheads_kv //= tp_degree + k_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype ) v_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype ) - print(f"***{model_name}***") - print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}") - + if check_all_splits is False: print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}") @@ -157,10 +163,10 @@ def main(): ) * 1000. * 1000. if check_all_splits: - + fa3_fastest_num_splits = 0 fa3_fastest_splitk_time = float("inf") - + for num_splits in range(1, max_splits): t = timeit( flash_attn_interface.flash_attn_with_kvcache, @@ -257,7 +263,7 @@ def main(): if t < fa3_fastest_splitk_time_gqa: fa3_fastest_splitk_time_gqa = t fa3_fastest_num_splits_gqa = num_splits - + efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa # remeasure to smooth anomalies @@ -288,7 +294,7 @@ def main(): causal=causal, pack_gqa=True, num_splits=fa3_fastest_num_splits_gqa - ) * 1000. * 1000. + ) * 1000. * 1000. if check_all_splits is True: print( @@ -308,7 +314,7 @@ def main(): # f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, " f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, " f"EFF:{efficiency:.2f}, " - f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}" + f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}" ) if check_all_splits is False: diff --git a/hopper/epilogue_bwd.hpp b/hopper/epilogue_bwd.hpp index f99dfe918..811d0d1f1 100644 --- a/hopper/epilogue_bwd.hpp +++ b/hopper/epilogue_bwd.hpp @@ -238,7 +238,7 @@ struct CollectiveEpilogueBwd { Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K) Tensor tdKVrdV = make_fragment_like(tdKVgdV); Tensor tdKVrdK = make_fragment_like(tdKVgdK); - Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdV))); From 74dfa43c8d22f46999f5a9554faa72c30d81fe64 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 18 Feb 2025 22:23:15 -0500 Subject: [PATCH 040/102] Fix divide by 0 in causal tile_scheduler for large seqlen --- hopper/flash_fwd_combine_kernel.h | 4 ++-- hopper/flash_fwd_combine_launch_template.h | 2 +- hopper/tile_scheduler.hpp | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 20685a156..8957ae41a 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -194,8 +194,8 @@ class FlashAttnFwdCombine { Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); int const thread_idx = threadIdx.x; - int const k_block = blockIdx.x; - int const m_block = blockIdx.y; + int const m_block = blockIdx.x; + int const k_block = blockIdx.y; int const batch = !Varlen ? 0 : blockIdx.y; int const num_splits = get<1>(params.shape_LSE_partial); flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 101f894b2..eb7dd404c 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -39,7 +39,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); int num_blocks_k = cute::ceil_div(params.dv, kBlockK); int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM); - dim3 grid_m(num_blocks_k, num_blocks_m, !Varlen ? 1 : params.b); + dim3 grid_m(num_blocks_m, num_blocks_k, !Varlen ? 1 : params.b); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 0b74d0e1f..e67abf89a 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -236,7 +236,8 @@ class DynamicPersistentTileScheduler { int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead - int const swizzle = (1 << cutlass::find_log2(size_l2 / size_one_kv_head)) * (PackGQA ? 1 : args.qhead_per_khead); + // Need to be careful about the case where only one head will fit + int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << cutlass::find_log2(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); // If we're in the last section (called residual), we don't want to divide by // swizzle. Instead we want to divide by the remainder. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; From b36ad4ef767d2d5536ff8af2e3f720ae4eba731c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 19 Feb 2025 02:08:07 -0500 Subject: [PATCH 041/102] Use split for super long sequences that don't fit into L2 --- hopper/flash_api.cpp | 3 ++- hopper/heuristics.h | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 7dad5b9c7..e400c63d5 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -417,8 +417,9 @@ inline int get_num_splits(Flash_fwd_params const& params) { : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; + int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); // Always enable PackGQA for Split - return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.num_sm, num_n_blocks, 128); + return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k, // params.num_sm, num_n_blocks, 128, params.d_rounded); #endif diff --git a/hopper/heuristics.h b/hopper/heuristics.h index 8e7b4a314..03fd391ff 100644 --- a/hopper/heuristics.h +++ b/hopper/heuristics.h @@ -22,9 +22,20 @@ inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, in // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. -inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { +inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split - if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + // However, in the case of super long seqlen where each head of KV doesn't even fit into + // L2 (we assume conservatively that L2 size is 50MB), we want to split. + if (batch_nheads_mblocks >= 0.8f * num_SMs) { + int const size_l2 = 50 * 1024 * 1024; + // Only split if there are enough queries to go over the KV at least twice + // Don't split if causal + if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) { + return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits); + } else { + return 1; + } + } // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. if (num_n_blocks <= 4) { return 1; } max_splits = std::min({max_splits, num_SMs, num_n_blocks}); From ecdb528dea98904bcf6aa7b436a38f1e2e4cbd71 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 22 Feb 2025 16:04:58 -0500 Subject: [PATCH 042/102] Make rotary test optional in FA3 --- hopper/test_flash_attn.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index ddd687f1f..16cfb2384 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -7,7 +7,10 @@ import torch.nn.functional as F from einops import rearrange, repeat -from flash_attn.layers.rotary import apply_rotary_emb +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None from padding import pad_input, unpad_input from test_util import ( @@ -570,7 +573,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) # @pytest.mark.parametrize("rotary_interleaved", [True]) -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if not DISABLE_APPENDKV else [0.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) # @pytest.mark.parametrize("page_size", [None]) From 06e34f62d18d3a721bc515d4b331a46d5d4c8c09 Mon Sep 17 00:00:00 2001 From: Ted Zadouri Date: Sat, 22 Feb 2025 21:24:44 -0500 Subject: [PATCH 043/102] Enable MLA flag in FA3 (rope=64, latent=512) (#1504) * Enable MLA flag in FA3 (rope=64, latent=512) * updated HasQv in flash_fwd_launch_template.h --- hopper/flash_api.cpp | 24 ++++++++++++++++++++---- hopper/flash_fwd_launch_template.h | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index e400c63d5..4e3737663 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -271,7 +271,14 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { + if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } + else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_(params, stream); } @@ -294,7 +301,14 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { + if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } + else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_(params, stream); } @@ -581,7 +595,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (head_size_v != head_size) { - TORCH_CHECK(head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128, "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]"); + TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + (head_size <= 64 && head_size_v <= 512), + "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " + "or (Q/K <= 64 and V <= 512)."); TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); if (head_size_v > 256) { TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, @@ -758,7 +775,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } if (q_v_.has_value()) { - TORCH_CHECK(false, "q_v should be None for now"); TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 71eabc2a1..15f439296 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -198,7 +198,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and false; + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { From 6aed835dd9ba0184db43712d73e40b7dec34878d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 24 Feb 2025 01:48:05 -0500 Subject: [PATCH 044/102] Add simple script to benchmark MLA decode --- hopper/benchmark_mla_decode.py | 61 ++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 hopper/benchmark_mla_decode.py diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py new file mode 100644 index 000000000..ed6e903fd --- /dev/null +++ b/hopper/benchmark_mla_decode.py @@ -0,0 +1,61 @@ +import torch + +from triton.testing import do_bench, do_bench_cudagraph + +from einops import rearrange + +from flash_attn_interface import flash_attn_with_kvcache + +try: + from flash_attn.utils.benchmark import pytorch_profiler +except ImportError: + pytorch_profiler = None + +device = "cuda" +dtype = torch.bfloat16 +seqlen = 64 * 1024 +nheads = 16 +nheads_kv = 1 +headdim = 64 +headdim_v = 512 +has_qv = True +seqlen_q = 1 +# page_size = None +page_size = 1 + +torch.manual_seed(0) + +batch_size = 4 +cache_seqlens = torch.tensor([seqlen - 1] * batch_size, device=device, dtype=torch.int) +# cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32) +# cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) +# cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int) +# cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) + +num_splits = 0 +q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) * 3 +v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) +k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) * 3 +if page_size is not None: + assert seqlen % page_size == 0 + k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) +else: + page_table = None +qv = torch.randn(batch_size, 1, nheads, headdim_v, dtype=dtype, device=device) if has_qv else None + +# Time in ms +fn = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True) +t0 = do_bench(fn, warmup=1, rep=10) +with torch.cuda.stream(torch.cuda.Stream()): + t1 = do_bench_cudagraph(fn, rep=10) + +mem_io = cache_seqlens.sum().item() * nheads_kv * (headdim + headdim_v) * 2 +ideal_h100_time = mem_io / 3.35e12 * 1e6 +print(f"Time: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s") +print(f"Time w CUDA Graph: {t1 * 1e3:.0f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s") +print(f"Ideal time: {ideal_h100_time:.0f} us") + +if pytorch_profiler is not None: + pytorch_profiler(flash_attn_with_kvcache, q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=False) From 6752d62aa4196fe27cda621e80bcf8a10e03b206 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 24 Feb 2025 03:37:05 -0500 Subject: [PATCH 045/102] Add dynamic splits --- hopper/block.h | 9 +- hopper/epilogue_bwd.hpp | 4 +- hopper/epilogue_fwd.hpp | 2 + hopper/flash.h | 4 + hopper/flash_api.cpp | 29 +++-- hopper/flash_fwd_combine_kernel.h | 11 +- hopper/flash_fwd_combine_launch_template.h | 4 +- hopper/flash_fwd_launch_template.h | 12 +- hopper/flash_prepare_scheduler.cu | 126 +++++++++++++++++++++ hopper/heuristics.h | 6 +- hopper/setup.py | 1 + hopper/test_flash_attn.py | 10 +- hopper/tile_scheduler.hpp | 108 ++++++++++++------ hopper/utils.h | 13 +++ 14 files changed, 278 insertions(+), 61 deletions(-) create mode 100644 hopper/flash_prepare_scheduler.cu diff --git a/hopper/block.h b/hopper/block.h index d06744c3b..eda7eaa1c 100644 --- a/hopper/block.h +++ b/hopper/block.h @@ -35,9 +35,14 @@ struct BlockMN { } // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } if constexpr (Split) { - int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits); - n_block_min = n_block_min + split_idx * num_n_blocks_per_split; + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + int split_idx_actual = split_idx & 0x0000FFFF; + int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual); + n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split; n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); } } // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } return {n_block_min, n_block_max}; diff --git a/hopper/epilogue_bwd.hpp b/hopper/epilogue_bwd.hpp index 811d0d1f1..9362b0404 100644 --- a/hopper/epilogue_bwd.hpp +++ b/hopper/epilogue_bwd.hpp @@ -4,8 +4,8 @@ #pragma once -#include -#include +#include "cutlass/cutlass.h" +#include "cutlass/barrier.h" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 1c13988eb..f3815ea73 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -200,6 +200,7 @@ struct CollectiveEpilogueFwd { ) { auto [m_block, bidh, bidb, split_idx] = block_coord; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); @@ -368,6 +369,7 @@ struct CollectiveEpilogueFwd { ) { static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); auto [m_block, bidh, bidb, split_idx] = block_coord; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; diff --git a/hopper/flash.h b/hopper/flash.h index 8e95f5ff7..d9f007dfb 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -150,6 +150,9 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; + int * __restrict__ num_m_blocks_ptr; + int * __restrict__ num_n_blocks_ptr; + int * __restrict__ num_splits_dynamic_ptr; int arch; int num_sm; @@ -205,6 +208,7 @@ struct Flash_bwd_params : public Flash_fwd_params { template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 4e3737663..805513e14 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -433,9 +433,11 @@ inline int get_num_splits(Flash_fwd_params const& params) { int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); // Always enable PackGQA for Split - return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); - // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k, - // params.num_sm, num_n_blocks, 128, params.d_rounded); + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (!varlen ? params.b : 1) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); #endif } @@ -861,14 +863,21 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.lseaccum_head_stride = softmax_lse_accum.stride(-2); } - at::Tensor tile_count_semaphore; + at::Tensor tile_count_semaphore, num_m_n_blocks_splits; // We don't use the persistent scheduler if Split and not Varlen bool const persistent_scheduler = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); if (persistent_scheduler) { - tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32)); + tile_count_semaphore = torch::empty({1}, opts.dtype(torch::kInt32)); + if (!is_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + if (is_varlen) { + num_m_n_blocks_splits = torch::empty({batch_size * 3}, opts.dtype(torch::kInt32)); + params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr(); + params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr() + batch_size; + params.num_splits_dynamic_ptr = num_m_n_blocks_splits.data_ptr() + batch_size * 2; + } } else { params.tile_count_semaphore = nullptr; } @@ -935,11 +944,13 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 // and seqlen = total_q, and don't need to dispatch to Varlen there. + // However, with dynamic split, each row needs to know which batch it belongs to + // to read the number of splits, so we just use the varlen version of combine kernel. // if (is_varlen_q && !seqused_q_.has_value()) { - if (is_varlen_q) { - params.b = 1; - params.seqlen_q = total_q; - } + // if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } run_mha_fwd_combine(params, stream); } } else if (total_q > 0 && num_heads_k > 0) { diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 8957ae41a..8e9146d18 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -128,7 +128,6 @@ class FlashAttnFwdCombine { static constexpr int SharedStorageSize = sizeof(SharedStorage); - // Device side arguments struct Arguments { ElementPartial const* ptr_O_partial; @@ -143,6 +142,7 @@ class FlashAttnFwdCombine { StrideLSE const stride_LSE; int const* cu_seqlens = nullptr; int const* seqused = nullptr; + int const* num_splits_dynamic_ptr = nullptr; }; // Kernel entry point API @@ -160,6 +160,7 @@ class FlashAttnFwdCombine { cutlass::FastDivmod seqlen_divmod, head_divmod; int const* cu_seqlens = nullptr; int const* seqused = nullptr; + int const* num_splits_dynamic_ptr = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -180,7 +181,8 @@ class FlashAttnFwdCombine { args.stride_LSE, cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), args.cu_seqlens, - args.seqused + args.seqused, + args.num_splits_dynamic_ptr }; } @@ -196,7 +198,7 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; - int const batch = !Varlen ? 0 : blockIdx.y; + int const batch = !Varlen ? 0 : blockIdx.z; int const num_splits = get<1>(params.shape_LSE_partial); flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; @@ -229,12 +231,13 @@ class FlashAttnFwdCombine { bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); bidb = 0; } + int num_splits_actual = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[!Varlen ? bidb : batch] : num_splits; Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh, bidb); #pragma unroll for (int s = 0; s < size<1>(tLSEcLSE); ++s) { int si = get<0>(tLSEcLSE(_0{}, s, _0{})); // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);} - if (si < num_splits) { + if (si < num_splits_actual) { cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m)); } else { cute::fill(tLSEsLSE(_, s, m), -INFINITY); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index eb7dd404c..e4ac21fd0 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -33,7 +33,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); @@ -55,7 +55,7 @@ void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream) { // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); - BOOL_SWITCH(params.seqused_q != nullptr, Varlen, [&] { + BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. if (params.num_splits <= 16) { run_flash_fwd_combine(params, stream); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 15f439296..6b80af44c 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -68,7 +68,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better // since we'll avoid launching a bunch of thread blocks that immediately exit. // On Sm80, noncausal persistent seems a bit slower. - using Scheduler = std::conditional_t= 90 ? (Split && !Varlen) : !((Is_causal && !Varlen) || (Varlen && Split)), SchedulerSingleTile, SchedulerPersistent>; + static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split)); + using Scheduler = std::conditional_t; using AttnKernel = std::conditional_t< Arch >= 90, flash::enable_sm90_or_later>, @@ -148,9 +149,16 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.h / params.h_k, params.seqlen_q, params.seqlen_k, params.d, sizeof(Element), - params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q + params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, + // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, + params.num_splits_dynamic_ptr, }; + if constexpr (Varlen && UsePersistentScheduler) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN); + CHECK_CUDA_KERNEL_LAUNCH(); + } + int device; CHECK_CUDA(cudaGetDevice(&device)); typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu new file mode 100644 index 000000000..e108347ec --- /dev/null +++ b/hopper/flash_prepare_scheduler.cu @@ -0,0 +1,126 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include "cutlass/fast_math.h" +#include "cutlass/barrier.h" +#include "cutlass/arch/barrier.h" + +#include "flash.h" + +namespace flash { + +__global__ void prepare_varlen_num_blocks_kernel( + int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, + int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, + int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, + cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, + int* const tile_count_semaphore, int* const num_m_blocks_ptr, int* const num_n_blocks_ptr, + int* const num_splits_dynamic_ptr) { + + static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; + // Assume that there's only one block in the grid + __shared__ int smem[1]; + + if (threadIdx.x == 0) { smem[0] = 0; } + __syncthreads(); + + if (threadIdx.x == 0) { *tile_count_semaphore = 0; } + + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + + auto get_num_m_blocks = [&](int bidb_start) { + int batch_idx = lane + bidb_start; + int seqlen; + if (seqused_q) { + seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; + } else if (cu_seqlens_q) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_q_static; + } + seqlen *= qhead_per_khead; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; + }; + + auto get_num_n_blocks = [&](int bidb_start) { + int batch_idx = lane + bidb_start; + int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; + int seqlen; + if (seqused_k) { + seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0; + } else if (cu_seqlens_k) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_k_static; + } + int seqlen_new; + if (cu_seqlens_k_new) { + int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0; + int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1); + seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new; + } else { + seqlen_new = seqlen_k_new_static; + } + // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); } + seqlen = seqlen - leftpad_k + seqlen_new; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0; + }; + + int total_blocks = 0; + int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; + for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp) { + int num_m_blocks = get_num_m_blocks(bidb_start); + int num_n_blocks = get_num_n_blocks(bidb_start); + if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { + num_m_blocks_ptr[bidb_start + lane] = num_m_blocks; + num_n_blocks_ptr[bidb_start + lane] = num_n_blocks; + // printf("idx = %d, num_m = %d, num_n = %d\n", bidb_start + lane, num_m_blocks, num_n_blocks); + } + total_blocks += num_m_blocks * num_n_blocks; + } + + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(smem, total_blocks); } + __syncthreads(); + total_blocks = smem[0]; + // 20% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.2f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp) { + bool is_valid = bidb_start + lane < num_batch && lane < kNumBatchPerWarp; + int num_n_blocks = is_valid ? num_n_blocks_ptr[bidb_start + lane] : 0; + int num_split_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + if (is_valid) { + num_splits_dynamic_ptr[bidb_start + lane] = num_split_dynamic; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_split_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_split_dynamic); + } + } + +} + +} // flash + +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, + int blockM, int blockN) { + int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); + flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 256 /*block*/, 0, stream>>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, params.num_m_blocks_ptr, params.num_n_blocks_ptr, + params.num_splits_dynamic_ptr); +} diff --git a/hopper/heuristics.h b/hopper/heuristics.h index 03fd391ff..868d4ad59 100644 --- a/hopper/heuristics.h +++ b/hopper/heuristics.h @@ -22,11 +22,11 @@ inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, in // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. -inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { +inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split // However, in the case of super long seqlen where each head of KV doesn't even fit into // L2 (we assume conservatively that L2 size is 50MB), we want to split. - if (batch_nheads_mblocks >= 0.8f * num_SMs) { + if (total_mblocks >= 0.8f * num_SMs) { int const size_l2 = 50 * 1024 * 1024; // Only split if there are enough queries to go over the KV at least twice // Don't split if causal @@ -43,7 +43,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n std::vector efficiency; efficiency.reserve(max_splits); for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float n_waves = float(total_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } diff --git a/hopper/setup.py b/hopper/setup.py index 6798de67a..433c3bb3a 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -506,6 +506,7 @@ def nvcc_threads_args(): ) if not DISABLE_SPLIT: sources += ["flash_fwd_combine.cu"] + sources += ["flash_prepare_scheduler.cu"] nvcc_flags = [ "-O3", "-std=c++17", diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 16cfb2384..abd9046ef 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -587,8 +587,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -645,9 +645,9 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else [d] - has_qv_vals = [False] - for dv, has_qv in itertools.product(dv_vals, has_qv_vals): + dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + for dv in dv_vals: + has_qv = d == 64 and dv == 512 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index e67abf89a..5272c361a 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -8,6 +8,7 @@ #include "cutlass/arch/barrier.h" #include "named_barrier.hpp" +#include "utils.h" namespace flash { @@ -23,6 +24,8 @@ struct TileSchedulerArguments { int* const tile_count_semaphore = nullptr; int* const cu_seqlens = nullptr; int* const seqused = nullptr; + // int* const num_m_blocks_ptr = nullptr; + int* const num_splits_dynamic_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -341,7 +344,6 @@ class DynamicPersistentTileScheduler { }; - template class VarlenDynamicPersistentTileScheduler { @@ -365,6 +367,8 @@ class VarlenDynamicPersistentTileScheduler { int* const tile_count_semaphore; int* const cu_seqlens; int* const seqused; + // int* const num_m_blocks_ptr; + int* const num_splits_dynamic_ptr; }; static Params @@ -372,10 +376,15 @@ class VarlenDynamicPersistentTileScheduler { // If Split, for the purpose of scheduling, we pretend that instead there are // (args.num_splits * args.num_head) number of heads. assert(args.tile_count_semaphore != nullptr); - return {args.num_head * (!Split ? 1 : args.num_splits), args.num_batch, + assert(!Split || args.num_splits_dynamic_ptr != nullptr); + assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx + assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, cutlass::FastDivmod(!Split ? 1 : args.num_splits), - args.tile_count_semaphore, args.cu_seqlens, args.seqused}; + args.tile_count_semaphore, args.cu_seqlens, args.seqused, + // args.num_m_blocks_ptr, args.num_splits_dynamic_ptr}; + args.num_splits_dynamic_ptr}; } static dim3 @@ -399,8 +408,18 @@ class VarlenDynamicPersistentTileScheduler { if constexpr (!Split) { return {block, bidh, bidb, 0 /*split_idx*/}; } else { - int split_idx; - int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh); + uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; + int bidh_actual = reinterpret_cast(bidh_actual_u); + // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx + uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); + int split_idx = reinterpret_cast(split_idx_u); + // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + // if (threadIdx.x == 128) { + // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); + // } return {block, bidh_actual, bidb, split_idx}; } } @@ -412,43 +431,53 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE WorkTileInfo tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { - auto prefix_sum = [](int val) { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { - int32_t partial_sum = __shfl_up_sync(0xffffffff, val, i); - if (lane >= i) { val += partial_sum; } + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + auto get_num_m_blocks = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlock) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } } - return val; + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlock) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; }; - auto get_num_m_blocks = [&](int bidb_start) { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - int seqlen; - if (params.seqused) { - seqlen = lane + bidb_start < params.num_batch ? params.seqused[lane + bidb_start] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = lane + bidb_start <= params.num_batch ? params.cu_seqlens[lane + bidb_start] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = params.seqlen; - } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - return lane + bidb_start < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; + auto get_num_splits = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? (!Split ? 1 : params.num_splits_dynamic_ptr[batch_idx]) + : 0; }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane + int num_splits = get_num_splits(current_work.bidb); + int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; // Cumulative number of blocks for the next 31 batches - int num_m_blocks_cumulative = prefix_sum(num_m_blocks); + int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); // Total number of blocks for the next 31 batches int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); - int group_end_tile = current_work.tile_idx - current_work.block - current_work.bidh * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + // Only the lower 16 bits are the actual bidh + int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + if constexpr (Split) { + int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + } int bidb = current_work.bidb; // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); // } + // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); } while (group_end_tile <= next_tile_idx) { bidb += cutlass::NumThreadsPerWarp - 1; if (bidb >= params.num_batch) { @@ -458,7 +487,9 @@ class VarlenDynamicPersistentTileScheduler { return {next_tile_idx, 0, 0, params.num_batch}; } num_m_blocks = get_num_m_blocks(bidb); - num_m_blocks_cumulative = prefix_sum(num_m_blocks); + num_splits = get_num_splits(bidb); + num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; + num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); group_end_tile += m_blocks_in_group * params.num_head; // if (blockIdx.x <= 9 && threadIdx.x == 0) { @@ -469,13 +500,26 @@ class VarlenDynamicPersistentTileScheduler { // The next problem to process is the first one that does not have ending tile position // that is greater than or equal to tile index. int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); + // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); } bidb += batch_idx_in_group; num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); + if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; int bidh = mh_block / num_m_blocks; int block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("blockIdx.x = %d, threadIdx.x = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } return {next_tile_idx, block, bidh, bidb}; } diff --git a/hopper/utils.h b/hopper/utils.h index e14ca1574..f821b19a4 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -625,6 +625,19 @@ CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTE_DEVICE T warp_prefix_sum(T val) { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { + T partial_sum = __shfl_up_sync(0xffffffff, val, i); + if (lane >= i) { val += partial_sum; } + } + return val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template CUTE_DEVICE T warp_uniform(T a) { return __shfl_sync(0xffffffff, a, 0); From cdda5bfdd75c891e81dca228929d1b2a8fb02fab Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 24 Feb 2025 03:38:21 -0500 Subject: [PATCH 046/102] Update to Cutlass 3.8.0 tag --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index e9627ce55..afa177220 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit e9627ce55b42fd2599f58cd4396da9380954def0 +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 From 9505c7436eab3d9469c9d3646cfe19f8e3d27c7b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 24 Feb 2025 11:48:50 -0500 Subject: [PATCH 047/102] Adjust seqlen_q in MLA decode benchmark script --- hopper/benchmark_mla_decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index ed6e903fd..e8a773e71 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -19,7 +19,7 @@ headdim = 64 headdim_v = 512 has_qv = True -seqlen_q = 1 +seqlen_q = 4 # page_size = None page_size = 1 @@ -43,7 +43,7 @@ "(b s) -> b s", s=seqlen // page_size) else: page_table = None -qv = torch.randn(batch_size, 1, nheads, headdim_v, dtype=dtype, device=device) if has_qv else None +qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, dtype=dtype, device=device) if has_qv else None # Time in ms fn = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True) From 3b5047d2ce742848f45d44b143d511f211eba2d2 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 24 Feb 2025 22:54:05 -0500 Subject: [PATCH 048/102] Fix loop in prepare_scheduler.cu (h/t Jay Shah) Only affects the case where batch size > 256 --- hopper/flash_prepare_scheduler.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index e108347ec..0f4c1963f 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -76,7 +76,8 @@ __global__ void prepare_varlen_num_blocks_kernel( int total_blocks = 0; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; - for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp) { + int num_warps = blockDim.x / cutlass::NumThreadsPerWarp; + for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp * num_warps) { int num_m_blocks = get_num_m_blocks(bidb_start); int num_n_blocks = get_num_n_blocks(bidb_start); if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { @@ -98,7 +99,7 @@ __global__ void prepare_varlen_num_blocks_kernel( // 20% margin int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.2f * float(num_head) / float(num_sm))); // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM - for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp) { + for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp * num_warps) { bool is_valid = bidb_start + lane < num_batch && lane < kNumBatchPerWarp; int num_n_blocks = is_valid ? num_n_blocks_ptr[bidb_start + lane] : 0; int num_split_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); From dec83a10c4e91938ffe4344da22324b9e53f979f Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Fri, 28 Feb 2025 23:54:59 +0800 Subject: [PATCH 049/102] fix: add "typename" prior to dependent type name (#1517) This project uses c++17 which still has this requirement. Signed-off-by: Jiang, Zhiwei --- csrc/flash_attn/src/flash_fwd_kernel.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 1ba07da15..d492c87b5 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -362,7 +362,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } @@ -424,7 +424,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } @@ -922,7 +922,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); @@ -987,7 +987,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } From 08f4c802c450708a86a92b226cba5663be81aead Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 28 Feb 2025 14:48:26 -0500 Subject: [PATCH 050/102] Add FLOPS to MLA decode benchmark --- hopper/benchmark_mla_decode.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index e8a773e71..58224a0e9 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -14,12 +14,12 @@ device = "cuda" dtype = torch.bfloat16 seqlen = 64 * 1024 -nheads = 16 +nheads = 128 nheads_kv = 1 headdim = 64 headdim_v = 512 has_qv = True -seqlen_q = 4 +seqlen_q = 1 # page_size = None page_size = 1 @@ -33,9 +33,9 @@ # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) num_splits = 0 -q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) * 3 +q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) -k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) * 3 +k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) if page_size is not None: assert seqlen % page_size == 0 k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] @@ -52,9 +52,12 @@ t1 = do_bench_cudagraph(fn, rep=10) mem_io = cache_seqlens.sum().item() * nheads_kv * (headdim + headdim_v) * 2 -ideal_h100_time = mem_io / 3.35e12 * 1e6 -print(f"Time: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s") -print(f"Time w CUDA Graph: {t1 * 1e3:.0f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s") +flops = seqlen_q * cache_seqlens.float().sum().item() * nheads * (headdim + headdim_v * 2) * 2 +ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 +ideal_h100_time_flop = flops / 989e12 * 1e6 +ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) +print(f"Time: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") +print(f"Time w CUDA Graph: {t1 * 1e3:.0f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") print(f"Ideal time: {ideal_h100_time:.0f} us") if pytorch_profiler is not None: From 085ce5864a6fee05e1b8cba26143943df91ebb63 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 28 Feb 2025 17:05:24 -0500 Subject: [PATCH 051/102] Change margin in prepare_scheduler.cu from 20% to 10% --- hopper/flash_prepare_scheduler.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 0f4c1963f..c8fe8fc5e 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -96,8 +96,8 @@ __global__ void prepare_varlen_num_blocks_kernel( if (lane == 0) { atomicAdd(smem, total_blocks); } __syncthreads(); total_blocks = smem[0]; - // 20% margin - int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.2f * float(num_head) / float(num_sm))); + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp * num_warps) { bool is_valid = bidb_start + lane < num_batch && lane < kNumBatchPerWarp; From 39e71975642daab365a5a37c959182c93ed5fc8a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 28 Feb 2025 22:42:16 -0500 Subject: [PATCH 052/102] Fix cuda 12.1 build (#1511) Signed-off-by: Lucas Wilkinson --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 4f2e7a35a..3589534c1 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -994,7 +994,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (LargeHeadDimV) { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); } else { // won't be used, just a placeholder - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutScale{}); + return make_tensor(make_smem_ptr(static_cast(nullptr)), SmemLayoutScale{}); } }(); Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); From 20b84d636324f00e53923d555a559e965683d4ba Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 1 Mar 2025 20:13:49 -0500 Subject: [PATCH 053/102] Don't use IntraWGOverlap for hdim 64,512 --- hopper/benchmark_attn.py | 7 ++- hopper/flash_prepare_scheduler.cu | 6 +-- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 55 ++++++++++++++++++------ hopper/test_flash_attn.py | 5 ++- hopper/tile_scheduler.hpp | 5 +++ hopper/tile_size.h | 2 +- 6 files changed, 60 insertions(+), 20 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 36f0bf6d0..4272dab26 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -1,6 +1,7 @@ from collections import namedtuple from functools import partial import math +import os from typing import NamedTuple import torch import torch.nn as nn @@ -34,6 +35,8 @@ triton_attention = None triton_attention = None +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" + def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): # # Warmup @@ -358,7 +361,7 @@ def run(*args, **kwargs): m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean - if dtype != torch.float8_e4m3fn and headdim == headdim_v: + if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: time.sleep(1) if not varlen: _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, @@ -387,7 +390,7 @@ def run(*args, **kwargs): print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') - if dtype != torch.float8_e4m3fn and headdim == headdim_v: + if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') # benchmark_forward(torch.square, k) # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index c8fe8fc5e..9befcf438 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -102,10 +102,10 @@ __global__ void prepare_varlen_num_blocks_kernel( for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp * num_warps) { bool is_valid = bidb_start + lane < num_batch && lane < kNumBatchPerWarp; int num_n_blocks = is_valid ? num_n_blocks_ptr[bidb_start + lane] : 0; - int num_split_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); if (is_valid) { - num_splits_dynamic_ptr[bidb_start + lane] = num_split_dynamic; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_split_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_split_dynamic); + num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } } diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 3589534c1..8a9aed08c 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -75,13 +75,11 @@ struct CollectiveMainloopFwdSm90 { static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); - static_assert(!(HasQv && !IntraWGOverlap), "HasQv requires IntraWGOverlap"); // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. static constexpr bool MmaQK_is_RS = false; // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!MmaPV_is_RS && !IntraWGOverlap), "MmaPV must be RS if IntraWGOverlap is disabled"); static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); @@ -1266,27 +1264,51 @@ struct CollectiveMainloopFwdSm90 { auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; + auto smem_pipe_read_prev = smem_pipe_read; + if constexpr (!Is_first_iter) { ++smem_pipe_read; } Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warp_scheduler_barrier_arrive(); - warpgroup_wait<0>(); - pipeline_k.consumer_release(smem_pipe_read); // release K + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read); // release K + } else { + if constexpr (Is_first_iter) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + } + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + pipeline_k.consumer_release(smem_pipe_read); // release K + warpgroup_wait<0>(); + } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); + if constexpr (LargeHeadDimV && !Is_first_iter) { store_scales(scores_scale, smem_pipe_read_prev.index()); } softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + if constexpr (!MmaPV_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } - consumer_wait(pipeline_v, smem_pipe_read); + if constexpr (!MmaPV_is_RS) { + cutlass::arch::fence_view_async_shared(); + __syncwarp(); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + } + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V - ++smem_pipe_read; }; auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; @@ -1331,8 +1353,14 @@ struct CollectiveMainloopFwdSm90 { cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; } ++work_idx; return true; @@ -1391,15 +1419,16 @@ struct CollectiveMainloopFwdSm90 { } }; - clear(tOrO); + // clear(tOrO); // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; typename Softmax::TensorT scores_scale; int n_block = n_block_max - 1; - pipeline_v.consumer_wait(smem_pipe_read); + // If HasQv, then by the time P is ready, V must be ready as well + if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; @@ -1409,8 +1438,10 @@ struct CollectiveMainloopFwdSm90 { load_scales(scores_scale, smem_pipe_read.index()); softmax.rescale_o(tOrO, scores_scale); ++smem_pipe_read; - auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); - pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + if constexpr (!HasQv) { + auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); + pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + } flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index abd9046ef..dd9a1d0d3 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -118,7 +118,8 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - for dv in [128, d] if d > 128 and d <= 192 else [d]: + dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + for dv in dv_vals: q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -582,7 +583,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [True]) +# @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 5272c361a..b39c7aeb2 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -363,6 +363,7 @@ class VarlenDynamicPersistentTileScheduler { int num_head, num_batch; int const qhead_per_khead; int const seqlen; + cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; int* const cu_seqlens; @@ -381,6 +382,7 @@ class VarlenDynamicPersistentTileScheduler { assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, // args.num_m_blocks_ptr, args.num_splits_dynamic_ptr}; @@ -510,6 +512,9 @@ class VarlenDynamicPersistentTileScheduler { if constexpr (Split) { int bidh_actual = bidh / num_splits; int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 5d0bd6e26..12a4839eb 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -17,7 +17,7 @@ constexpr std::tuple tile_size_fwd_sm90( // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 // Switch to tile size 192 x 192 for now - return {same_hdim ? 192 : 64, same_hdim ? 192 : 64, false, true}; + return {same_hdim ? 192 : 64, same_hdim ? 192 : 64, false, same_hdim}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { From 5458c78e6da05138d76a4f67b5d339ede1b43e9e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 1 Mar 2025 21:03:47 -0500 Subject: [PATCH 054/102] Remove sink token It wasn't working anyway --- hopper/benchmark_attn.py | 7 ++-- hopper/flash.h | 1 - hopper/flash_api.cpp | 4 -- hopper/flash_attn_interface.py | 28 +------------- hopper/flash_bwd_launch_template.h | 2 +- hopper/flash_fwd_launch_template.h | 2 +- hopper/mainloop_bwd_sm80.hpp | 10 ++--- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 14 +++---- hopper/mainloop_fwd_sm80.hpp | 14 ++----- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 48 ++---------------------- hopper/test_flash_attn.py | 6 --- 11 files changed, 26 insertions(+), 110 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 4272dab26..fbca7829a 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -267,7 +267,6 @@ def run(*args, **kwargs): num_splits = 0 window_size = (-1, -1) # window_size = (seqlen // 2 - 1, 0) - sink_token_length = 0 pack_gqa = None # seqlen_q = 64 seqlen_q = seqlen @@ -354,8 +353,8 @@ def run(*args, **kwargs): time.sleep(1) if not varlen: - # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') - m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') @@ -364,7 +363,7 @@ def run(*args, **kwargs): if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: time.sleep(1) if not varlen: - _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, + _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') else: _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, diff --git a/hopper/flash.h b/hopper/flash.h index d9f007dfb..c192830b7 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -133,7 +133,6 @@ struct Flash_fwd_params : public Qkv_params { // Local window size int window_size_left, window_size_right; - int sink_token_length; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 805513e14..624372f8b 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -515,7 +515,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq bool is_causal, int window_size_left, int window_size_right, - int sink_token_length, float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits, @@ -712,7 +711,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq sm_margin); params.total_q = total_q; params.total_k = total_k; - params.sink_token_length = sink_token_length; params.b_k = batch_size_k; params.dv = head_size_v; params.dv_rounded = head_size_v_rounded; @@ -1041,7 +1039,6 @@ std::vector mha_bwd( bool is_causal, int window_size_left, int window_size_right, - int const sink_token_length, float const softcap, bool const deterministic, int const sm_margin) { @@ -1275,7 +1272,6 @@ std::vector mha_bwd( params.total_q = total_q; params.total_k = total_k; params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); - params.sink_token_length = sink_token_length; // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 78cfe1cb9..469266e52 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -42,13 +42,11 @@ def _flash_attn_forward( softmax_scale, causal, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, rotary_interleaved=True, num_splits=1, pack_gqa=None, sm_margin=0): - assert sink_token_length == 0, "sink_token_length not supported yet" q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -86,7 +84,6 @@ def _flash_attn_forward( causal, window_size[0], window_size[1], - sink_token_length, softcap, rotary_interleaved, num_splits, @@ -115,12 +112,10 @@ def _flash_attn_backward( softmax_scale, causal, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, deterministic=False, sm_margin=0, ): - assert sink_token_length == 0, "sink_token_length not supported yet" # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( @@ -143,7 +138,6 @@ def _flash_attn_backward( causal, window_size[0], window_size[1], - sink_token_length, softcap, deterministic, sm_margin, @@ -160,7 +154,6 @@ def forward( causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, deterministic=False, num_heads_q=None, @@ -183,14 +176,13 @@ def forward( softmax_scale, causal=causal, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, sink_token_length=sink_token_length, + window_size=window_size, softcap=softcap, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length ctx.softcap = softcap ctx.deterministic = deterministic ctx.ndim = qkv.dim() @@ -223,7 +215,6 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, ) @@ -244,7 +235,6 @@ def forward( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -270,7 +260,6 @@ def forward( softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, @@ -281,7 +270,6 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin @@ -307,7 +295,6 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -337,7 +324,6 @@ def forward( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -367,7 +353,6 @@ def forward( softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, @@ -380,7 +365,6 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin @@ -409,7 +393,6 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, ctx.sm_margin, @@ -426,7 +409,6 @@ def flash_attn_qkvpacked_func( causal=False, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, deterministic=False, num_heads_q=None, @@ -471,7 +453,6 @@ def flash_attn_qkvpacked_func( causal, q_descale, k_descale, v_descale, window_size, - sink_token_length, softcap, deterministic, num_heads_q, @@ -487,7 +468,6 @@ def flash_attn_func( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -548,7 +528,6 @@ def flash_attn_func( qv, q_descale, k_descale, v_descale, window_size, - sink_token_length, softcap, num_splits, pack_gqa, @@ -572,7 +551,6 @@ def flash_attn_varlen_func( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -594,7 +572,6 @@ def flash_attn_varlen_func( qv, q_descale, k_descale, v_descale, window_size, - sink_token_length, softcap, num_splits, pack_gqa, @@ -629,7 +606,6 @@ def flash_attn_with_kvcache( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window - sink_token_length=0, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, num_splits=0, # Can be tuned for speed @@ -722,7 +698,6 @@ def flash_attn_with_kvcache( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ - assert sink_token_length == 0 assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: @@ -756,7 +731,6 @@ def flash_attn_with_kvcache( softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, rotary_interleaved=rotary_interleaved, num_splits=num_splits, diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index 635228eeb..65d010b46 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -120,7 +120,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_cast(params.dsoftmax_sum), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum params.scale_softmax, - params.window_size_left, params.window_size_right, params.sink_token_length, + params.window_size_left, params.window_size_right, params.softcap, params.b, params.dq_semaphore, diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 6b80af44c..420538178 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -119,7 +119,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { {params.q_descale_batch_stride, params.q_descale_head_stride}, {params.k_descale_batch_stride, params.k_descale_head_stride}, {params.v_descale_batch_stride, params.v_descale_head_stride}, - params.window_size_left, params.window_size_right, params.sink_token_length, + params.window_size_left, params.window_size_right, params.softcap, params.num_splits, params.kv_batch_idx, diff --git a/hopper/mainloop_bwd_sm80.hpp b/hopper/mainloop_bwd_sm80.hpp index eb0503c93..0a79670f4 100644 --- a/hopper/mainloop_bwd_sm80.hpp +++ b/hopper/mainloop_bwd_sm80.hpp @@ -296,7 +296,7 @@ struct CollectiveMainloopBwdSm80 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -328,7 +328,7 @@ struct CollectiveMainloopBwdSm80 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; float const softcap_val; int const num_batch; int *const dq_semaphore; @@ -359,7 +359,7 @@ struct CollectiveMainloopBwdSm80 { args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; @@ -385,7 +385,7 @@ struct CollectiveMainloopBwdSm80 { }; auto m_block_min_max = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, - params.window_size_left, params.window_size_right, params.sink_token_length); + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); int const m_block_min = get<0>(m_block_min_max); int const m_block_max = get<1>(m_block_min_max); // It's possible to have m_block_max <= m_block_min. Exit early @@ -532,7 +532,7 @@ struct CollectiveMainloopBwdSm80 { int const seqlen_k = seqlen_info.seqlen_k; flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.qhead_per_khead_divmod ); diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index e3b296068..71cfb0204 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -310,7 +310,7 @@ struct CollectiveMainloopBwdSm90 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -337,7 +337,7 @@ struct CollectiveMainloopBwdSm90 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -394,7 +394,7 @@ struct CollectiveMainloopBwdSm90 { args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; @@ -428,7 +428,7 @@ struct CollectiveMainloopBwdSm90 { }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, - params.window_size_left, params.window_size_right, params.sink_token_length); + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { @@ -596,7 +596,7 @@ struct CollectiveMainloopBwdSm90 { }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, params.window_size_left, - params.window_size_right, params.sink_token_length); + params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return; } @@ -686,7 +686,7 @@ struct CollectiveMainloopBwdSm90 { }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, params.window_size_left, - params.window_size_right, params.sink_token_length); + params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return false; } @@ -792,7 +792,7 @@ struct CollectiveMainloopBwdSm90 { // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.qhead_per_khead_divmod ); diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 909654d34..84c0fd0e5 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -202,7 +202,7 @@ struct CollectiveMainloopFwdSm80 { float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1, sink_token_length = 0; + int const window_size_left = -1, window_size_right = -1; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; @@ -247,7 +247,7 @@ struct CollectiveMainloopFwdSm80 { float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; @@ -291,7 +291,7 @@ struct CollectiveMainloopFwdSm80 { args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, @@ -541,7 +541,7 @@ struct CollectiveMainloopFwdSm80 { if constexpr (!Share_QV_Smem) { preprocess_Q(); } flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.qhead_per_khead_divmod ); @@ -640,12 +640,6 @@ struct CollectiveMainloopFwdSm80 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); - // } } float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 8a9aed08c..823826d93 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -378,7 +378,7 @@ struct CollectiveMainloopFwdSm90 { float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1, sink_token_length = 0; + int const window_size_left = -1, window_size_right = -1; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; @@ -433,7 +433,7 @@ struct CollectiveMainloopFwdSm90 { float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; @@ -540,7 +540,7 @@ struct CollectiveMainloopFwdSm90 { args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, @@ -848,33 +848,6 @@ struct CollectiveMainloopFwdSm90 { n_block_prev = n_block; if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } } - // if constexpr (Is_local) { - // Disable sink token code for now - if constexpr (false && Is_local) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - #pragma unroll 1 - for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind - ++smem_pipe_write; - if (should_load_KV) { - if constexpr (PagedKV) { - paged_kv_manager.template load_page_table(n_block); - } - if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } - load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - if constexpr (!Transpose_V) { - if constexpr (IntraWGOverlap) { - load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); - } else { - load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - } - } - } - n_block_prev = n_block; - if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } - } - } scheduler_prefetch(); if constexpr (!Transpose_V && IntraWGOverlap) { if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } @@ -1058,7 +1031,7 @@ struct CollectiveMainloopFwdSm90 { int n_block = n_block_max - 1; flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.qhead_per_khead_divmod ); @@ -1118,7 +1091,6 @@ struct CollectiveMainloopFwdSm90 { cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); } - // TODO: check the case where n_block_max <= n_block_min but there are sink tokens if constexpr (IntraWGOverlap) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); @@ -1232,12 +1204,6 @@ struct CollectiveMainloopFwdSm90 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); - // } } // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); @@ -1341,12 +1307,6 @@ struct CollectiveMainloopFwdSm90 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); - // } } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index dd9a1d0d3..54fdab17e 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -103,8 +103,6 @@ def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): - # sink_token_length = 0 if not local else 4 - sink_token_length = 0 if not local else 0 if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") device = "cuda" @@ -152,7 +150,6 @@ def test_flash_attn_output( qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap ) out_pt, attn_pt = attention_ref( @@ -165,7 +162,6 @@ def test_flash_attn_output( qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, upcast=False, reorder_ops=True, @@ -198,7 +194,6 @@ def test_flash_attn_output( qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - sink_token_length=sink_token_length, softcap=softcap, pack_gqa=pack_gqa, num_splits=num_splits @@ -230,7 +225,6 @@ def test_flash_attn_output( # d ** (-0.5), # causal, # window_size[0], window_size[1], - # sink_token_length, # softcap, # deterministic, # 0, # sm_margin From 6865e6014501ee4ce2cb8f8e031f03dac244c8c1 Mon Sep 17 00:00:00 2001 From: xin-w8023 <43900898+xin-w8023@users.noreply.github.com> Date: Sun, 2 Mar 2025 10:18:28 +0800 Subject: [PATCH 055/102] fix: prompt index to type longlong to avoid numerical overflow (#1500) --- csrc/flash_attn/src/flash_bwd_kernel.h | 2 +- csrc/flash_attn/src/flash_bwd_preprocess_kernel.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 8f42f0ae1..50af5f630 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -118,7 +118,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM; diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h index 016a01070..e4875fe3a 100644 --- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -79,7 +79,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM; @@ -205,7 +205,7 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, From 45c48afb2bf0bc148484960346615e4d66365f46 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 3 Mar 2025 23:53:59 -0500 Subject: [PATCH 056/102] Add option for WG1 to use RS MMA but WG2 using SS MMA --- hopper/flash_api.cpp | 12 +++---- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 43 +++++++++++++++++------- hopper/utils.h | 20 ++++++++++- 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 624372f8b..ffe62bf70 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -276,7 +276,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif @@ -301,12 +301,12 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { - if (params.dv > 64 && Arch == 90) { + if (params.d <= 64) { + if (params.dv > 64 && Arch == 90) { return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif @@ -596,10 +596,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (head_size_v != head_size) { - TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || (head_size <= 64 && head_size_v <= 512), "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " - "or (Q/K <= 64 and V <= 512)."); + "or (Q/K <= 64 and V <= 512)."); TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); if (head_size_v > 256) { TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 823826d93..b53e4104e 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -83,6 +83,9 @@ struct CollectiveMainloopFwdSm90 { static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); + // Slightly faster in this case to have WG1 use RS instead of SS to avoid waiting for the P smem write + static constexpr bool MmaPV_use_RS_WG1 = !MmaPV_is_RS && kHeadDim == 64 && kHeadDimV == 512; + using AtomLayoutQK = Layout, _1, _1>>; using TiledMmaQK = decltype(cute::make_tiled_mma( std::conditional_t< @@ -108,6 +111,10 @@ struct CollectiveMainloopFwdSm90 { using TiledMmaQV = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutQK{})); + // For hdim64,512, WG1 can use RS but WG2 must use SS + using TiledMmaPV_RS = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutPV{})); static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); static constexpr int NumMmaThreads = size(TiledMmaPV{}); @@ -128,17 +135,17 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); + Int, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVt = decltype(tile_to_shape( SmemLayoutAtomVt{}, - make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); + Int, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVtMma = decltype(tile_to_shape( SmemLayoutAtomVtMma{}, - make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); using SmemLayoutVMmaQV = decltype(tile_to_shape( SmemLayoutAtomVMmaQV{}, - make_shape(shape<1>(TileShape_MNK_QV{}), shape<2>(TileShape_MNK_QV{}), Int{}))); + make_shape(shape<1>(TileShape_MNK_QV{}), Int{}, Int{}))); static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); // Only used if we're using cp.async to load V @@ -1263,16 +1270,25 @@ struct CollectiveMainloopFwdSm90 { } if constexpr (!MmaPV_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!MmaPV_is_RS) { - cutlass::arch::fence_view_async_shared(); - __syncwarp(); - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + auto arrive_P_write_barrier = [&] { + if constexpr (!MmaPV_is_RS) { + cutlass::arch::fence_view_async_shared(); + __syncwarp(); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } } - } + }; + if constexpr (!MmaPV_use_RS_WG1) { arrive_P_write_barrier(); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr (!MmaPV_use_RS_WG1) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } else { + TiledMmaPV_RS tiled_mma_pv_rs; + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } + if constexpr (MmaPV_use_RS_WG1) { arrive_P_write_barrier(); } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V }; @@ -1385,7 +1401,7 @@ struct CollectiveMainloopFwdSm90 { typename Softmax::TensorT scores_scale; int n_block = n_block_max - 1; - // If HasQv, then by the time P is ready, V must be ready as well + // If HasQv, then by the time P is ready, V must have been ready as well if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); @@ -1393,6 +1409,7 @@ struct CollectiveMainloopFwdSm90 { pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; + #pragma unroll 1 for (; n_block >= n_block_min; --n_block) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); load_scales(scores_scale, smem_pipe_read.index()); diff --git a/hopper/utils.h b/hopper/utils.h index f821b19a4..d9468af55 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -272,9 +272,11 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } + static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA)); + static constexpr int kMaxKIters = 16; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) { if constexpr (!SwapAB) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } else { @@ -282,6 +284,22 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } + // In the case of large kNumKIters, the compiler chooses to store the smem addresses + // in registers, causing spills. This loop forces the compiler to recompute the addresses. + if constexpr (kNumKIters > kMaxKIters) { + // This will always be zero, just a way to force the compiler to recompute the smem + // addresses. This results in USEL instructions. There's probably a better way to do this. + int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1; + CUTLASS_PRAGMA_UNROLL + for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) { + if constexpr (!SwapAB) { + cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC); + } else { + cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } warpgroup_commit_batch(); if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); From 3edf7e0daa62662cd2dd2ec8fd999dd7f254415c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 4 Mar 2025 11:41:25 -0500 Subject: [PATCH 057/102] Add kwargs to _write_ninja_file for compatibility with new torch --- hopper/setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hopper/setup.py b/hopper/setup.py index 433c3bb3a..cf3d23934 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -90,7 +90,9 @@ def _write_ninja_file(path, objects, ldflags, library_target, - with_cuda) -> None: + with_cuda, + **kwargs, # kwargs (ignored) to absorb new flags in torch.utils.cpp_extension + ) -> None: r"""Write a ninja file that does the desired compiling and linking. `path`: Where to write this file From 4f0640d534888c579a448fd89c2d4e064905d798 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 5 Mar 2025 01:40:01 -0500 Subject: [PATCH 058/102] Move writing P to smem as separate function --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 67 +++++++++--------------- 1 file changed, 26 insertions(+), 41 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index b53e4104e..03b812d76 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -1029,10 +1029,6 @@ struct CollectiveMainloopFwdSm90 { pipeline.consumer_wait(smem_pipe_read, barrier_token); }; - // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - clear(tOrO); - // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; - int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int n_block = n_block_max - 1; @@ -1054,6 +1050,21 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } }; + auto write_P_to_smem = [&](auto& tOrP) { + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); + }; + + auto arrive_on_P_write_barrier = [&] { + cutlass::arch::fence_view_async_shared(); + __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + }; + auto &barrier_Q = shared_storage.pipelines.barrier_Q; if constexpr (!AppendKV) { barrier_Q.wait(work_idx % 2); @@ -1098,6 +1109,10 @@ struct CollectiveMainloopFwdSm90 { cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); } + // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + if constexpr (IntraWGOverlap) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); @@ -1121,17 +1136,8 @@ struct CollectiveMainloopFwdSm90 { Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!MmaPV_is_RS) { - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - } - cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); - cutlass::arch::fence_view_async_shared(); - __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } --n_block; // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. @@ -1169,18 +1175,9 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - } - if constexpr (!MmaPV_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!MmaPV_is_RS) { - cutlass::arch::fence_view_async_shared(); - __syncwarp(); - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } }; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking @@ -1265,21 +1262,9 @@ struct CollectiveMainloopFwdSm90 { Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - } - if constexpr (!MmaPV_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } - auto arrive_P_write_barrier = [&] { - if constexpr (!MmaPV_is_RS) { - cutlass::arch::fence_view_async_shared(); - __syncwarp(); - if constexpr (LargeHeadDimV) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - } - } - }; - if constexpr (!MmaPV_use_RS_WG1) { arrive_P_write_barrier(); } + if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); if constexpr (!MmaPV_use_RS_WG1) { @@ -1288,7 +1273,7 @@ struct CollectiveMainloopFwdSm90 { TiledMmaPV_RS tiled_mma_pv_rs; flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } - if constexpr (MmaPV_use_RS_WG1) { arrive_P_write_barrier(); } + if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V }; From d82bbf26924c492064af8b27ab299ff4808d1bf6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 5 Mar 2025 16:51:48 -0500 Subject: [PATCH 059/102] Fix causal scheduler not considering hdim_v != hdim --- hopper/flash_api.cpp | 3 +++ hopper/flash_bwd_launch_template.h | 2 +- hopper/flash_fwd_launch_template.h | 2 +- hopper/heuristics.h | 2 +- hopper/tile_scheduler.hpp | 4 ++-- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index ffe62bf70..5806e7150 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -611,6 +611,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq // TODO: check this if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { is_causal = false; } if (is_causal) { window_size_right = 0; } // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM. @@ -1272,6 +1274,7 @@ std::vector mha_bwd( params.total_q = total_q; params.total_k = total_k; params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); + params.dv = head_size; // We don't support hdim_v being different from hdim_qk for now // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index 65d010b46..76ded0407 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -165,7 +165,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { num_blocks_n, params.h, params.b, 1 /*num_splits*/, params.h / params.h_k, params.seqlen_k, - params.seqlen_q, params.d, sizeof(Element), + params.seqlen_q, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k }; diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 420538178..b08826153 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -148,7 +148,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, params.h / params.h_k, params.seqlen_q, - params.seqlen_k, params.d, sizeof(Element), + params.seqlen_k, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, params.num_splits_dynamic_ptr, diff --git a/hopper/heuristics.h b/hopper/heuristics.h index 868d4ad59..031ea44a0 100644 --- a/hopper/heuristics.h +++ b/hopper/heuristics.h @@ -25,7 +25,7 @@ inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, in inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split // However, in the case of super long seqlen where each head of KV doesn't even fit into - // L2 (we assume conservatively that L2 size is 50MB), we want to split. + // L2 (we assume that L2 size is 50MB), we want to split. if (total_mblocks >= 0.8f * num_SMs) { int const size_l2 = 50 * 1024 * 1024; // Only split if there are enough queries to go over the KV at least twice diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index b39c7aeb2..9d2c83f2c 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -20,7 +20,7 @@ struct TileSchedulerArguments { int const num_blocks, num_head, num_batch, num_splits; int const qhead_per_khead; int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr - int const seqlen_k, headdim, element_size; // Used to calculate L2 swizzling + int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling int* const tile_count_semaphore = nullptr; int* const cu_seqlens = nullptr; int* const seqused = nullptr; @@ -235,7 +235,7 @@ class DynamicPersistentTileScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * args.headdim * args.element_size * 2; + int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size * 2; int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead From 9c036e466a3574fc75fe8a98f242dd6c1235d506 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Mar 2025 15:42:57 -0500 Subject: [PATCH 060/102] Always split fwd_combine_kernel on batch --- hopper/flash_fwd_combine_kernel.h | 54 +++++++++++----------- hopper/flash_fwd_combine_launch_template.h | 4 +- hopper/flash_prepare_scheduler.cu | 5 +- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 8e9146d18..42dac2a69 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -198,17 +198,22 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; - int const batch = !Varlen ? 0 : blockIdx.z; - int const num_splits = get<1>(params.shape_LSE_partial); + int const batch = blockIdx.z; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; - int max_idx = seqlen * get<2>(params.shape_LSE_partial) * get<3>(params.shape_LSE_partial); + int max_idx = seqlen * get<2>(params.shape_LSE_partial); + if constexpr (Varlen) { + if (m_block * kBlockM >= max_idx) { return; } + } cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); // Step 1: load LSE_partial from gmem -> smem - Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), select<1, 0, 2, 3>(params.shape_LSE_partial), select<1, 0, 2, 3>(params.stride_LSE_partial)); // (num_splits, seqlen, head, batch) + Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), + select<1, 0, 2, 3>(params.shape_LSE_partial), + select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head) Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int>{}); GmemTiledCopyLSE gmem_tiled_copy_LSE; auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx); @@ -224,20 +229,18 @@ class FlashAttnFwdCombine { int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); int idx = m_block * kBlockM + mi; if (idx < max_idx) { - int m_idx, bidh, bidb; + int m_idx, bidh; if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + bidh = params.seqlen_divmod.divmod(m_idx, idx); } else { bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; } - int num_splits_actual = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[!Varlen ? bidb : batch] : num_splits; - Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh, bidb); + Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh); #pragma unroll for (int s = 0; s < size<1>(tLSEcLSE); ++s) { int si = get<0>(tLSEcLSE(_0{}, s, _0{})); // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);} - if (si < num_splits_actual) { + if (si < num_splits) { cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m)); } else { cute::fill(tLSEsLSE(_, s, m), -INFINITY); @@ -259,26 +262,24 @@ class FlashAttnFwdCombine { // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), - params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) + params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head) // Precompute these values to avoid recomputing them in the loop Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); Tensor tObidh = make_tensor(make_shape(size<1>(tOcO))); - Tensor tObidb = make_tensor(make_shape(size<1>(tOcO))); Tensor tOrOptr = make_tensor(make_shape(size<1>(tOcO))); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { int mi = get<0>(tOcO(_0{}, m, _0{})); int idx = m_block * kBlockM + mi; if constexpr (!Varlen) { - tObidb[m] = params.head_divmod.divmod(tObidh(m), params.seqlen_divmod.divmod(tOmidx(m), idx)); + tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx); } else { tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); - tObidb[m] = 0; } - tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m), tObidb(m)); + tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m)); if (idx >= max_idx) { - tObidb[m] = -1; + tObidh[m] = -1; } } @@ -294,8 +295,8 @@ class FlashAttnFwdCombine { Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { - if (tObidb(m) >= 0) { - Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}, _0{}).layout()); + if (tObidh(m) >= 0) { + Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout()); Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tOcO); ++k) { @@ -375,22 +376,21 @@ class FlashAttnFwdCombine { // Step 5: store final LSE back to gmem if (k_block == 0) { auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0); #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); int idx = m_block * kBlockM + mi; if (idx < max_idx) { - int m_idx, bidh, bidb; + int m_idx, bidh; if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + bidh = params.seqlen_divmod.divmod(m_idx, idx); } else { bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; } // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); - mLSE(m_idx, bidh, bidb) = lse_sum(m); + mLSE(m_idx, bidh) = lse_sum(m); } } } @@ -423,7 +423,7 @@ class FlashAttnFwdCombine { #pragma unroll for (int m = 0; m < size<1>(tOrOpartial); ++m) { - if (tObidb(m) >= 0 && scale(m) > 0.f) { + if (tObidh(m) >= 0 && scale(m) > 0.f) { #pragma unroll for (int k = 0; k < size<2>(tOrOpartial); ++k) { if (Is_even_K || tOpO(k)) { @@ -444,19 +444,19 @@ class FlashAttnFwdCombine { flash::convert_type_out(tOrO, rO); auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), - shape_O, params.stride_O); + shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0); Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { - if (tObidb(m) >= 0) { + if (tObidh(m) >= 0) { #pragma unroll for (int k = 0; k < size<2>(tOcO); ++k) { int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; if (Is_even_K || tOpO(k)) { - cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m), tObidb(m))); + cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m))); } } } diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index e4ac21fd0..b0472b2c4 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -38,8 +38,8 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); int num_blocks_k = cute::ceil_div(params.dv, kBlockK); - int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM); - dim3 grid_m(num_blocks_m, num_blocks_k, !Varlen ? 1 : params.b); + int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM); + dim3 grid_m(num_blocks_m, num_blocks_k, params.b); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 9befcf438..6fde9084c 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -20,10 +20,11 @@ __global__ void prepare_varlen_num_blocks_kernel( int* const num_splits_dynamic_ptr) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; + static constexpr int kSmemSize = 1; // Assume that there's only one block in the grid - __shared__ int smem[1]; + __shared__ int smem[kSmemSize]; - if (threadIdx.x == 0) { smem[0] = 0; } + if (threadIdx.x < kSmemSize) { smem[threadIdx.x] = 0; } __syncthreads(); if (threadIdx.x == 0) { *tile_count_semaphore = 0; } From 81643fa0ea63908064e26251b573cd315ca434fe Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Mar 2025 14:50:54 -0500 Subject: [PATCH 061/102] For each batch, if num_splits=1, write to O instead of O_partial --- hopper/epilogue_fwd.hpp | 154 ++++++++++++++------- hopper/flash_api.cpp | 12 +- hopper/flash_fwd_combine_kernel.h | 38 +++-- hopper/flash_fwd_combine_launch_template.h | 2 +- hopper/flash_fwd_kernel_sm80.h | 4 +- hopper/flash_fwd_kernel_sm90.h | 5 +- hopper/flash_fwd_launch_template.h | 24 ++-- hopper/flash_prepare_scheduler.cu | 4 +- hopper/test_flash_attn.py | 9 +- hopper/tile_scheduler.hpp | 57 +++++--- 10 files changed, 194 insertions(+), 115 deletions(-) diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index f3815ea73..69102e8c4 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -21,21 +21,24 @@ namespace flash { using namespace cute; template + int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false> struct CollectiveEpilogueFwd { using TileShape_MNK_PV = TileShape_MNK_PV_; using ClusterShape = ClusterShape_; using Element = Element_; + using ElementPartial = float; using ArchTag = ArchTag_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; static constexpr bool Varlen = Varlen_; static constexpr bool PackGQA = PackGQA_; - static constexpr bool Use_smem = sizeof(Element) <= 2; - static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && Use_smem && !PackGQA; + static constexpr bool Split = Split_; + static constexpr bool Use_smem = !(Split && !Varlen); + static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); + static_assert(sizeof(Element) <= 2); static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); @@ -52,8 +55,6 @@ struct CollectiveEpilogueFwd { // we need to call divmod. static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - // static constexpr int kBlockKGmem = kHeadDimV % 128 == 0 ? 128 : (kHeadDimV % 64 == 0 ? 64 : 32); - // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDimV / kGmemElemsPerStore, NumEpilogueThreads); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); @@ -121,8 +122,12 @@ struct CollectiveEpilogueFwd { Element* ptr_O; ShapeO const shape_O; StrideO const stride_O; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; float* ptr_LSE; StrideLSE const stride_LSE; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; int32_t const nheads_kv; int const* cu_seqlens = nullptr; int const* seqused = nullptr; @@ -135,10 +140,16 @@ struct CollectiveEpilogueFwd { StrideO const stride_O; ShapeOPacked const shape_O_packed; StrideOPacked const stride_O_packed; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; + StrideOPacked const stride_O_partial_packed; float* ptr_LSE; StrideLSE const stride_LSE; ShapeLSEPacked const shape_LSE_packed; StrideLSEPacked const stride_LSE_packed; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; + StrideLSEPacked const stride_LSE_partial_packed; cutlass::FastDivmod qhead_per_khead_divmod; TMA_O tma_store_O; int const* cu_seqlens = nullptr; @@ -165,6 +176,10 @@ struct CollectiveEpilogueFwd { args.stride_O, make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) ); + auto const stride_O_partial_packed = cute::conditional_return( + args.stride_O_partial, + make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) + ); // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) auto const shape_LSE_packed = cute::conditional_return( select<0, 2, 3, 4>(args.shape_O), @@ -174,8 +189,14 @@ struct CollectiveEpilogueFwd { args.stride_LSE, make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) ); + auto const stride_LSE_partial_packed = cute::conditional_return( + args.stride_LSE_partial, + make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial)) + ); return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, + args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed, args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, + args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed, cutlass::FastDivmod(qhead_per_khead), tma_store_O, args.cu_seqlens, args.seqused}; } @@ -191,7 +212,7 @@ struct CollectiveEpilogueFwd { template CUTLASS_DEVICE void store(Params const& params, - FrgTensorO const& tOrO, + FrgTensorO& tOrO, FrgTensorLSE const& lse, SharedStorage& shared_storage, TiledMma tiled_mma, @@ -200,13 +221,25 @@ struct CollectiveEpilogueFwd { ) { auto [m_block, bidh, bidb, split_idx] = block_coord; - split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); + static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4); + // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion. + // Otherwise we can permute after conversion. + if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); } Tensor tOrO_out = make_tensor_like(tOrO); flash::convert_type_out(tOrO, tOrO_out); - if constexpr (FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4)) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } + if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } // Make sure all WGs have finished reading V // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that @@ -254,9 +287,12 @@ struct CollectiveEpilogueFwd { Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + using PackGQApartial_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>; - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } if (!LargeHeadDimV || warp_group_idx == 0) { if constexpr (!PackGQA) { @@ -266,7 +302,7 @@ struct CollectiveEpilogueFwd { if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } } } else { - PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } @@ -292,10 +328,10 @@ struct CollectiveEpilogueFwd { } } } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) - // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } - if constexpr (Use_smem) { + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) @@ -322,17 +358,27 @@ struct CollectiveEpilogueFwd { ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row - PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } else { - // We already arrived on barrier_O earlier + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // We already arrived on barrier_O earlier if !Use_smem + if constexpr (Use_smem) { + if constexpr (ArchTag::kMinComputeCapability >= 90) { + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } if constexpr (!PackGQA) { static constexpr int kGmemElemsPerStoreDirect = 2; - cute::Copy_Atom, Element> gmem_copy_direct; + cute::Copy_Atom, ElementPartial> gmem_copy_direct; // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor tOrO_rowcol = make_tensor(tOrO_out.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); - Tensor tOgO = thread_mma.partition_C(gO); + Tensor tOgO = thread_mma.partition_C(gOpartial); Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); @@ -348,7 +394,7 @@ struct CollectiveEpilogueFwd { } } } else { - PackGQAt::store_O_direct(mO, tOrO_out, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } } @@ -360,7 +406,6 @@ struct CollectiveEpilogueFwd { } // Write 0 to output and -inf to LSE - template CUTLASS_DEVICE void store_zero( Params const& params, @@ -369,14 +414,23 @@ struct CollectiveEpilogueFwd { ) { static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); auto [m_block, bidh, bidb, split_idx] = block_coord; - split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); static_assert(kBlockM <= NumEpilogueThreads); @@ -388,35 +442,39 @@ struct CollectiveEpilogueFwd { if (row < seqlen_o * qhead_per_khead) { int m_idx, h_idx; m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); - // mLSE shape shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" + // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; } } } - if constexpr (!Clear_O) { return; } + // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used, + // since it will not use the value of O if LSE is -inf. + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); - if constexpr (!PackGQA) { - Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_fragment_like(tOgO); - cute::clear(tOrO); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM - ); - } else { - // If PackGQA, we split the work of compute O_ptr among threads in the same row - using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; - Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); - cute::clear(tOrO); - PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + if constexpr (!PackGQA) { + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + cute::clear(tOrO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); + cute::clear(tOrO); + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } } } diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 5806e7150..565a9eb55 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -872,15 +872,15 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq tile_count_semaphore = torch::empty({1}, opts.dtype(torch::kInt32)); if (!is_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = tile_count_semaphore.data_ptr(); - if (is_varlen) { - num_m_n_blocks_splits = torch::empty({batch_size * 3}, opts.dtype(torch::kInt32)); - params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr(); - params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr() + batch_size; - params.num_splits_dynamic_ptr = num_m_n_blocks_splits.data_ptr() + batch_size * 2; - } } else { params.tile_count_semaphore = nullptr; } + if (is_varlen) { + num_m_n_blocks_splits = torch::empty({batch_size * 3}, opts.dtype(torch::kInt32)); + params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr(); + params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr() + batch_size; + params.num_splits_dynamic_ptr = num_m_n_blocks_splits.data_ptr() + batch_size * 2; + } if (q_type == at::ScalarType::Float8_e4m3fn) { if (q_descale_.has_value()) { diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 42dac2a69..3e9a3c232 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -130,37 +130,39 @@ class FlashAttnFwdCombine { // Device side arguments struct Arguments { - ElementPartial const* ptr_O_partial; + ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; - float const* ptr_LSE_partial; + float const* const ptr_LSE_partial; ShapeLSEPartial const shape_LSE_partial; StrideLSEPartial const stride_LSE_partial; - Element* ptr_O; + Element* const ptr_O; StrideO const stride_O; - float* ptr_LSE; + float* const ptr_LSE; StrideLSE const stride_LSE; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - int const* num_splits_dynamic_ptr = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int* const semaphore_to_reset = nullptr; }; // Kernel entry point API struct Params { - ElementPartial const* ptr_O_partial; + ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; - float const* ptr_LSE_partial; + float const* const ptr_LSE_partial; ShapeLSEPartial const shape_LSE_partial; StrideLSEPartial const stride_LSE_partial; - Element* ptr_O; + Element* const ptr_O; StrideO const stride_O; - float* ptr_LSE; + float* const ptr_LSE; StrideLSE const stride_LSE; cutlass::FastDivmod seqlen_divmod, head_divmod; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; - int const* num_splits_dynamic_ptr = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int* const semaphore_to_reset = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -182,7 +184,8 @@ class FlashAttnFwdCombine { cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), args.cu_seqlens, args.seqused, - args.num_splits_dynamic_ptr + args.num_splits_dynamic_ptr, + args.semaphore_to_reset }; } @@ -200,6 +203,11 @@ class FlashAttnFwdCombine { int const k_block = blockIdx.y; int const batch = blockIdx.z; int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + + if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { + *params.semaphore_to_reset = 0; + } + if (num_splits <= 1) { return; } flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index b0472b2c4..7cb9b64fd 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -33,7 +33,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); diff --git a/hopper/flash_fwd_kernel_sm80.h b/hopper/flash_fwd_kernel_sm80.h index 71071d722..4c35da4f0 100644 --- a/hopper/flash_fwd_kernel_sm80.h +++ b/hopper/flash_fwd_kernel_sm80.h @@ -203,9 +203,7 @@ class FlashAttnFwdSm80 { threadIdx.x, block_coord); } else { // Write 0 to gO and -inf to gLSE. - // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will - // not use the value of O if LSE is -inf. - epilogue.template store_zero(params.epilogue, threadIdx.x, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); } } diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 9cfb2d9e5..d54a2f53c 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -433,10 +433,7 @@ class FlashAttnFwdSm90 { threadIdx.x - MmaThreadOffset, block_coord); } else { // Write 0 to gO and -inf to gLSE. - // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will - // not use the value of O if LSE is -inf. - epilogue.template store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); - // epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); } } epilogue.store_tail(); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index b08826153..231045567 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -53,7 +53,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t(!Split ? params.o_ptr : params.oaccum_ptr), + static_cast(params.o_ptr), {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O - {!Split ? params.o_row_stride : params.oaccum_row_stride, - _1{}, - !Split ? params.o_head_stride : params.oaccum_head_stride, - !is_varlen_q ? (!Split ? params.o_batch_stride : params.oaccum_batch_stride) : 0, - !Split ? 0 : params.oaccum_split_stride}, // stride_O - static_cast(!Split ? params.softmax_lse_ptr : params.softmax_lseaccum_ptr), - {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, !Split ? 0 : params.h * seqlen_q * batch_q}, // stride_LSE + {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O + static_cast(params.oaccum_ptr), + {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial + static_cast(params.softmax_lse_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE + static_cast(params.softmax_lseaccum_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial params.h_k, params.cu_seqlens_q, params.seqused_q }; @@ -150,11 +150,11 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.seqlen_q, params.seqlen_k, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, - // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, + // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, }; - if constexpr (Varlen && UsePersistentScheduler) { + if constexpr (Varlen) { prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -195,7 +195,7 @@ template || cute::is_same_v; - using T_out = std::conditional_t, float>; + using T_out = std::conditional_t; CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 6fde9084c..8d1b3602b 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -27,7 +27,7 @@ __global__ void prepare_varlen_num_blocks_kernel( if (threadIdx.x < kSmemSize) { smem[threadIdx.x] = 0; } __syncthreads(); - if (threadIdx.x == 0) { *tile_count_semaphore = 0; } + if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; } int lane = threadIdx.x % cutlass::NumThreadsPerWarp; @@ -82,7 +82,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int num_m_blocks = get_num_m_blocks(bidb_start); int num_n_blocks = get_num_n_blocks(bidb_start); if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { - num_m_blocks_ptr[bidb_start + lane] = num_m_blocks; + // num_m_blocks_ptr[bidb_start + lane] = num_m_blocks; num_n_blocks_ptr[bidb_start + lane] = num_n_blocks; // printf("idx = %d, num_m = %d, num_n = %d\n", bidb_start + lane, num_m_blocks, num_n_blocks); } diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 54fdab17e..2ed394324 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -117,6 +117,8 @@ def test_flash_attn_output( nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] for dv in dv_vals: q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: @@ -333,7 +335,10 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - for dv in [128, d] if d > 128 and d <= 192 else [d]: + dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -641,6 +646,8 @@ def test_flash_attn_kvcache( assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] for dv in dv_vals: has_qv = d == 64 and dv == 512 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 9d2c83f2c..a3aa794d6 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -22,10 +22,10 @@ struct TileSchedulerArguments { int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling int* const tile_count_semaphore = nullptr; - int* const cu_seqlens = nullptr; - int* const seqused = nullptr; - // int* const num_m_blocks_ptr = nullptr; - int* const num_splits_dynamic_ptr = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + // int const* const num_m_blocks_ptr = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -43,16 +43,20 @@ class SingleTileScheduler { int const qhead_per_khead; int const seqlen; cutlass::FastDivmod nsplits_divmod; - int* const cu_seqlens; - int* const seqused; + int const* const cu_seqlens; + int const* const seqused; + int const* const num_splits_dynamic_ptr = nullptr; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { + assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); + assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, args.qhead_per_khead, args.seqlen, cutlass::FastDivmod(!Split ? 1 : args.num_splits), - !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused}; + !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, + args.num_splits_dynamic_ptr}; } static dim3 @@ -64,24 +68,18 @@ class SingleTileScheduler { int block_idx = 0; int bidh = 0; int bidb = 0; - bool is_valid_tile = false; + int split_idx = 0; CUTLASS_DEVICE bool is_valid(Params const& params) const { - return is_valid_tile; + return bidb >= 0; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { - if constexpr (!Split) { - return {block_idx, bidh, bidb, 0 /*split_idx*/}; - } else { - int split_idx; - int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); - return {block_idx, bidh_actual, bidb, split_idx}; - } + return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; } }; @@ -93,14 +91,27 @@ class SingleTileScheduler { CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { - WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true}; + WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; + if constexpr (Split) { + int split_idx; + work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); + work_info.split_idx = split_idx; + } + bool is_valid_tile = true; if constexpr (Varlen) { int seqlen = params.seqused ? params.seqused[work_info.bidb] : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - work_info.is_valid_tile = work_info.block_idx * kBlock < seqlen; + is_valid_tile = work_info.block_idx * kBlock < seqlen; + } + if constexpr (Varlen && Split) { + int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; + // Use the top 16 bits to store num_splits + work_info.split_idx |= (num_splits_dynamic << 16); + is_valid_tile &= work_info.split_idx < num_splits_dynamic; } + work_info.bidb = is_valid_tile ? work_info.bidb : -1; return work_info; } @@ -116,7 +127,7 @@ class SingleTileScheduler { CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {-1, -1, -1, false}; + return {0, 0, -1, 0}; } }; @@ -366,10 +377,10 @@ class VarlenDynamicPersistentTileScheduler { cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; - int* const cu_seqlens; - int* const seqused; + int const* const cu_seqlens; + int const* const seqused; // int* const num_m_blocks_ptr; - int* const num_splits_dynamic_ptr; + int const* const num_splits_dynamic_ptr; }; static Params @@ -385,7 +396,7 @@ class VarlenDynamicPersistentTileScheduler { cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, - // args.num_m_blocks_ptr, args.num_splits_dynamic_ptr}; + // args.num_m_blocks_ptr, args.num_splits_dynamic_ptr}; } From 1d30bb4cd31513a1c0e1b66c88f7da2d420699c7 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Mar 2025 22:19:39 -0500 Subject: [PATCH 062/102] Enable TMA when page size is a multiple of kBlockN --- hopper/flash.h | 3 +- hopper/flash_api.cpp | 114 +++++++++++++---------- hopper/flash_fwd_kernel_sm90.h | 7 +- hopper/flash_fwd_launch_template.h | 24 ++--- hopper/mainloop_fwd_sm80.hpp | 10 +- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 77 +++++++++------ hopper/paged_kv.h | 45 ++++++++- hopper/rotary.h | 16 ++-- hopper/tile_size.h | 14 +-- 9 files changed, 197 insertions(+), 113 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index c192830b7..d5d7fa218 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -121,6 +121,7 @@ struct Flash_fwd_params : public Qkv_params { index_t page_table_batch_stride; int page_size; int num_pages; + bool pagedkv_tma; // The dropout probability (probability of keeping an activation). float p_dropout; @@ -205,7 +206,7 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN); template diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 565a9eb55..27bedc1fc 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -263,70 +263,70 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { TORCH_CHECK(params.num_splits >= 1); ARCH_SWITCH(params.arch, Arch, [&] { SPLIT_SWITCH(params.num_splits > 1, Split, [&] { - PAGEDKV_SWITCH(params.page_table, PagedKV, [&] { + PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { - // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation - static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKV || Split; + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation + static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } else { - return run_mha_fwd_(params, stream); + return run_mha_fwd_(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP16."); @@ -335,25 +335,25 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { #ifndef FLASHATTENTION_DISABLE_FP8 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } else { - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP8."); @@ -394,17 +394,25 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif } +inline bool get_pagedkv_tma(Flash_fwd_params const& params) { + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); + return params.page_size % kBlockN == 0; +} + inline bool get_pack_gqa(Flash_fwd_params const& params) { - // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation and binary size. + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. // Has little effect on speed. - if (params.arch < 90 || params.page_table || params.num_splits > 1) { return true; } + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } #ifdef FLASHATTENTION_DISABLE_PACKGQA return false; #else // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif @@ -418,7 +426,7 @@ inline int get_num_splits(Flash_fwd_params const& params) { // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); @@ -569,11 +577,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); } - // This is what we will template on - bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); - #ifdef FLASHATTENTION_DISABLE_VARLEN - TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); - #endif auto const sizes = q.sizes(); const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; @@ -612,7 +615,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } // causal=true is the same as causal=false in this case - if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { is_causal = false; } + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((head_size <= 64 || head_size > 128) || !paged_KV) { + is_causal = false; + } + } if (is_causal) { window_size_right = 0; } // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM. @@ -652,6 +660,19 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_SHAPE(seqused_k, batch_size); } + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); @@ -716,7 +737,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.b_k = batch_size_k; params.dv = head_size_v; params.dv_rounded = head_size_v_rounded; - + if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma + params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); + } if (paged_KV) { params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); @@ -724,11 +747,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.page_size = page_size; params.num_pages = num_pages; - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - - if (k_new_.has_value()) { + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma at::Tensor k_new, v_new; TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); @@ -776,6 +795,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, @@ -799,14 +823,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - params.leftpad_k = static_cast(leftpad_k.data_ptr()); - } - if (rotary_cos_.has_value()) { TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); @@ -925,10 +941,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); #endif #ifdef FLASHATTENTION_DISABLE_PACKGQA - TORCH_CHECK(!params.pack_gqa || params.arch < 90 || params.page_table || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); #endif #ifdef FLASHATTENTION_DISABLE_PAGEDKV - TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV."); + TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); #endif #ifdef FLASHATTENTION_DISABLE_APPENDKV TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index d54a2f53c..1f841da46 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -35,7 +35,6 @@ class FlashAttnFwdSm90 { static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; static constexpr bool Varlen = CollectiveMainloop::Varlen; - static constexpr bool PagedKV = CollectiveMainloop::PagedKV; static constexpr bool Split = CollectiveMainloop::Split; static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; @@ -308,7 +307,7 @@ class FlashAttnFwdSm90 { cutlass::arch::warpgroup_reg_dealloc(); // The pipelines for AppendKV and main attention are different, since e.g. main attention - // might use cp.async to load KV (if PagedKV) while AppendKV always uses TMA to load + // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load // KV_new. Since the pipeline states are different, we have to manually sync to make // sure the two pipelines don't race when accessing smem_k and smem_v. PipelineState smem_pipe_write = cutlass::make_producer_start_state(); @@ -330,7 +329,7 @@ class FlashAttnFwdSm90 { SeqlenInfo_t seqlen_info{ get<2>(block_coord) /*bidb*/, get<0>(params.mainloop.shape_Q), - !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, @@ -390,7 +389,7 @@ class FlashAttnFwdSm90 { SeqlenInfo_t seqlen_info{ bidb, get<0>(params.mainloop.shape_Q), - !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 231045567..ededa4a5e 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -24,7 +24,7 @@ using namespace cute; template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); @@ -35,8 +35,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); - static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); @@ -50,8 +50,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; @@ -91,8 +91,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { {seqlen_q, params.d, params.h, batch_q}, // shape_Q {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q static_cast(params.k_ptr), - {!PagedKV ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, - params.d, params.h_k, !PagedKV ? batch_k : params.num_pages}, // shape_K + {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, + params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), params.dv, // headdim_v @@ -112,7 +112,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.is_rotary_interleaved, params.page_table, // if page_size is not set, avoid dividing by zero - {params.kv_batch_idx ? params.b_k : params.b, !PagedKV ? 0 : params.seqlen_k / params.page_size}, // shape_page_table + {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table {params.page_table_batch_stride, _1{}}, // stride_page_table params.scale_softmax, params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr, @@ -191,7 +191,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; @@ -201,17 +201,17 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + run_flash_fwd(params, stream); }); }); }); diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 84c0fd0e5..a642fc74f 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -415,7 +415,10 @@ struct CollectiveMainloopFwdSm80 { params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { @@ -698,8 +701,11 @@ struct CollectiveMainloopFwdSm80 { params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); static_assert(std::is_same_v); diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 03b812d76..c2f7ff7eb 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -29,7 +29,7 @@ namespace flash { using namespace cute; template struct CollectiveMainloopFwdSm90 { @@ -46,7 +46,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; - static constexpr bool PagedKV = PagedKV_; + static constexpr bool PagedKVNonTMA = PagedKVNonTMA_; static constexpr bool AppendKV = AppendKV_; static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; @@ -54,7 +54,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool V_colmajor = V_colmajor_; static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; static constexpr bool Use_TMA_Q = !PackGQA; - static constexpr bool Use_TMA_KV = !PagedKV; + static constexpr bool Use_TMA_KV = !PagedKVNonTMA; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; @@ -208,7 +208,7 @@ struct CollectiveMainloopFwdSm90 { using GmemTiledCopyQ = cute::SM90_TMA_LOAD; using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); - // We use CpAsync for K and V if PagedKV and AppendKV, since TMA doesn't work there + // We use CpAsync for K and V if PagedKVNonTMA and AppendKV, since TMA doesn't work there static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); @@ -221,7 +221,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); - // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKVNonTMA where // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); using GmemLayoutAtom = Layout, Int>, @@ -360,7 +360,7 @@ struct CollectiveMainloopFwdSm90 { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQK const stride_Q; - Element* const ptr_K; // Not Element const* since we might append to KV cache in-place + Element* const ptr_K; // not Element const* since we might append to KV cache in-place ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; @@ -429,6 +429,7 @@ struct CollectiveMainloopFwdSm90 { ShapePageTable const shape_pagetable; StridePageTable const stride_pagetable; cutlass::FastDivmod page_size_divmod; + cutlass::FastDivmod blockN_per_page_size_divmod; cutlass::FastDivmod qhead_per_khead_divmod; TMA_Q tma_load_Q; TMA_K tma_load_K; @@ -528,6 +529,11 @@ struct CollectiveMainloopFwdSm90 { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } assert(args.num_splits >= 1); + int page_size = !args.ptr_pagetable ? 1 : get<0>(args.shape_K); + if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { + assert(page_size % kBlockN == 0); + assert(!args.leftpad_k); + } // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -540,7 +546,8 @@ struct CollectiveMainloopFwdSm90 { args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, - cutlass::FastDivmod(int(get<0>(args.shape_K))), + cutlass::FastDivmod(page_size), // page_size_divmod + cutlass::FastDivmod(!args.ptr_pagetable ? 1 : cute::ceil_div(page_size, kBlockN)), // blockN_per_page_size_divmod cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), @@ -639,24 +646,24 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); - Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } - Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); - Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k) + Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE) auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); - Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k) + Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k, batch) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) auto [tQvgQv, tQvsQv] = [&] { if constexpr (HasQv) { @@ -672,12 +679,16 @@ struct CollectiveMainloopFwdSm90 { } }(); + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx ); // Set up for transposing V, only used if Transpose_V @@ -729,9 +740,10 @@ struct CollectiveMainloopFwdSm90 { auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { pipeline_k.producer_acquire(smem_pipe_write); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_K_TMA(); copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), - tKgK_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); + tKgK_TMA(_, n_block_idx, bidb_kv_idx), tKsK_TMA(_, smem_pipe_write.index())); } else { constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; paged_kv_manager.template load_K(n_block, sK_pi(_, _, smem_pipe_write.index())); @@ -742,9 +754,10 @@ struct CollectiveMainloopFwdSm90 { auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { auto pipeline_v_load = cute::conditional_return(pipeline_v, pipeline_vt); pipeline_v_load.producer_acquire(smem_pipe_write); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_V_TMA(); copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), - tVgVt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); + tVgVt_TMA(_, n_block_idx, bidb_kv_idx), tVsVt_TMA(_, smem_pipe_write.index())); } else { constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; paged_kv_manager.template load_V(n_block, sVcpasync(_, _, smem_pipe_write.index())); @@ -777,8 +790,10 @@ struct CollectiveMainloopFwdSm90 { bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()); if (should_load_KV) { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.template load_page_table_TMA(n_block); } if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } // if (thread_idx == 0) { printf("Producer: main load, before load_K, index = %d\n", smem_pipe_write.index());} @@ -839,8 +854,10 @@ struct CollectiveMainloopFwdSm90 { PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind ++smem_pipe_write; if (should_load_KV) { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.load_page_table_TMA(n_block); } if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); @@ -1569,12 +1586,16 @@ struct CollectiveMainloopFwdSm90 { params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, bidb_kv_idx // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); @@ -1587,7 +1608,7 @@ struct CollectiveMainloopFwdSm90 { } static_assert(std::is_same_v); - static_assert(!PagedKV || std::is_same_v); + static_assert(!PagedKVNonTMA || std::is_same_v); GmemTiledCopyAppendKV gmem_tiled_copy_kv; auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx); Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) @@ -1611,7 +1632,7 @@ struct CollectiveMainloopFwdSm90 { if (get<1>(params.shape_rotary) <= 0) { pipeline_k_new.consumer_wait(smem_pipe_read); Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index()); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { Tensor tKgK_cur = tKgK(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -1622,15 +1643,15 @@ struct CollectiveMainloopFwdSm90 { } } else { Tensor gK_cur = gK(_, _, n_block); - auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); + auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); pipeline_k_new.consumer_wait(smem_pipe_read); - rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); + rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); } else { auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); pipeline_k_new.consumer_wait(smem_pipe_read); - rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); + rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); } } // Without this sync I'm getting race condition when seqlen_k is large @@ -1646,7 +1667,7 @@ struct CollectiveMainloopFwdSm90 { pipeline_v_new.consumer_wait(smem_pipe_read); int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index()); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { Tensor tVgV_cur = tVgV(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -1661,7 +1682,7 @@ struct CollectiveMainloopFwdSm90 { #pragma unroll 1 for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) { - if constexpr (PagedKV) { paged_kv_manager.template load_page_table(n_block); } + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } store_K(n_block, smem_pipe_read); // if (thread_idx == 0) { printf("Done storing K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } store_V(n_block, smem_pipe_read); diff --git a/hopper/paged_kv.h b/hopper/paged_kv.h index 80ee61b9a..9ea59bcc2 100644 --- a/hopper/paged_kv.h +++ b/hopper/paged_kv.h @@ -78,9 +78,11 @@ struct PagedKVManager { GmemTiledCopyKVCpAsync gmem_tiled_copy_kv; cutlass::FastDivmod const &page_size_divmod; + cutlass::FastDivmod const &blockN_per_page_size_divmod; int const thread_idx; int const seqlen_k; int const leftpad_k; + int const* const ptr_page_table; GmemThrCopyKVCpAsync const gmem_thr_copy_kv; TensorPageTable mPageTable; TensorKV mK_paged, mV_paged; @@ -88,20 +90,27 @@ struct PagedKVManager { TensortVpV tVpV; TensorPageOffset tPrPageOffset; TensorKVPtr tPrVPtr; + int bidb_kv_idx, bidb_kv_idx_prev, n_block_idx, n_block_idx_prev; // Only used for TMA CUTLASS_DEVICE - PagedKVManager(int const* const ptr_page_table, + PagedKVManager(int const* const ptr_page_table_, ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, cutlass::FastDivmod const &page_size_divmod, - int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k + cutlass::FastDivmod const &blockN_per_page_size_divmod, + int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k, + int bidb_kv_idx ) : page_size_divmod(page_size_divmod) + , blockN_per_page_size_divmod(blockN_per_page_size_divmod) , thread_idx(thread_idx) , seqlen_k(seqlen_k) , leftpad_k(leftpad_k) + , ptr_page_table(ptr_page_table_) , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx)) + , bidb_kv_idx(bidb_kv_idx) + , bidb_kv_idx_prev(bidb_kv_idx) { mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); @@ -143,6 +152,38 @@ struct PagedKVManager { if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); } }; + template + CUTLASS_DEVICE + void load_page_table_TMA(const int n_block) { + // We require that page size is a multiple of kBlockN, and there's no leftpad_k + if (ptr_page_table) { + bidb_kv_idx = mPageTable[blockN_per_page_size_divmod.divmod(n_block_idx, n_block)]; + } else { + n_block_idx = n_block; + } + if constexpr (First_iter && !KV_Same_Iter) { + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + } + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_K_TMA() { + return {n_block_idx, bidb_kv_idx}; + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_V_TMA() { + if constexpr (KV_Same_Iter) { + return {n_block_idx, bidb_kv_idx}; + } else { + cute::tuple const indices = {n_block_idx_prev, bidb_kv_idx_prev}; + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + return indices; + } + }; + CUTLASS_DEVICE TensorKVPtr compute_K_ptr() { Tensor tPrKPtr = make_tensor(Shape>{}); diff --git a/hopper/rotary.h b/hopper/rotary.h index 5e30456c2..aa3602cc7 100644 --- a/hopper/rotary.h +++ b/hopper/rotary.h @@ -226,7 +226,7 @@ struct Rotary { // The main bottleneck here is actually instruction cache misses. - // Similar to PagedKV, it's expensive to compute the pointers. + // Similar to PagedKVNonTMA, it's expensive to compute the pointers. // We split the work among threads loading the same row, then __shfl_sync the pointers. static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow); Tensor tPrCosPtr = make_tensor(Shape>{}); @@ -350,7 +350,7 @@ struct Rotary { } }; - template + template CUTLASS_DEVICE void apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim) @@ -377,7 +377,7 @@ struct Rotary { CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2); static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); } @@ -385,7 +385,7 @@ struct Rotary { for (int m = 0; m < size<1>(tKsK); ++m) { int const row = get<0>(tKcK(_0{}, m, _0{})); auto mK_cur_copy = [&] { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); return cute::tiled_divide(mK_cur, Shape>{}); @@ -400,7 +400,7 @@ struct Rotary { Tensor rK = make_fragment_like(tKsK(_, m, k)); cute::copy(tiled_copy_k, tKsK(_, m, k), rK); if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); } - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { cute::copy(tiled_copy_k, rK, tKgK(_, m, k)); } else { int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; @@ -412,7 +412,7 @@ struct Rotary { } }; - template + template CUTLASS_DEVICE void apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim) @@ -439,7 +439,7 @@ struct Rotary { CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont)); static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); } @@ -449,7 +449,7 @@ struct Rotary { for (int m = 0; m < size<1>(tKcK); ++m) { int const row = get<0>(tKcK(_0{}, m, _0{})); Tensor gK_cur_copy = [&] { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); return cute::tiled_divide(mK_cur, Shape>{}); diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 12a4839eb..487c70198 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -9,7 +9,7 @@ // Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, - bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { + bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 @@ -21,13 +21,13 @@ constexpr std::tuple tile_size_fwd_sm90( // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { - return {192, is_local || paged_kv ? 128 : 144, false, true}; + return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv ? 128 : 176, true, true}; + return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; // {128, 192, false, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { - return {128, paged_kv || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem + return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } @@ -37,11 +37,11 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, 128, true, true}; } else if (headdim <= 128) { - return {128, paged_kv ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; + return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; } else if (headdim <= 192) { - return {128, (paged_kv || softcap) && is_local ? 128 : 160, true, true}; + return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; } else { - return {128, is_local ? 64 : 128, true, !paged_kv}; // PagedKV uses more registers so we disabled IntraWGOverlap + return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // PagedKV uses more registers so we disabled IntraWGOverlap } } } From a3a9cc567b44a873938322e81f0f89f3c0a9621a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Mar 2025 22:59:35 -0500 Subject: [PATCH 063/102] Update ptxas to 12.8.93 (i.e. 12.8.1) --- hopper/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/setup.py b/hopper/setup.py index cf3d23934..121266ebd 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -376,7 +376,7 @@ def nvcc_threads_args(): # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.93"} exe_extension = sysconfig.get_config_var("EXE") From 322bec97d411fefad03e85da8e0d9e0dda0469e8 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Mar 2025 23:09:44 -0500 Subject: [PATCH 064/102] Use tile size 192 x 128 for hdim 64 causal --- hopper/tile_size.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 487c70198..2c440c6e2 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -13,11 +13,12 @@ constexpr std::tuple tile_size_fwd_sm90( if (element_size == 2) { if (headdim <= 64) { bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 - // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; + // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 // Switch to tile size 192 x 192 for now - return {same_hdim ? 192 : 64, same_hdim ? 192 : 64, false, same_hdim}; + bool const use_blockN_128 = is_causal || is_local; + return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { From 5639b9d26dac63d912d6815cb4369250f6cef764 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Mar 2025 00:08:46 -0500 Subject: [PATCH 065/102] Update benchmark_mla_decode.py --- hopper/benchmark_attn.py | 11 +++-- hopper/benchmark_mla_decode.py | 76 ++++++++++++++++++++-------------- 2 files changed, 53 insertions(+), 34 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index fbca7829a..62ac2b63c 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -253,6 +253,7 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128, 192, 256]: for headdim in [128]: nheads = dim // headdim + # nheads = 128 # headdim = 64 # batch_size = 64 # seqlen = 512 @@ -260,8 +261,11 @@ def run(*args, **kwargs): # headdim = 128 nheads_kv = nheads # nheads_kv = nheads // 4 + # nheads_kv = 1 headdim_v = headdim - # headdim_v = 128 + # headdim_v = 512 + has_qv = headdim == 64 and headdim_v == 512 + # has_qv = False for batch_size, seqlen in bs_seqlen_vals: num_splits = 0 @@ -278,6 +282,7 @@ def run(*args, **kwargs): q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) @@ -305,7 +310,7 @@ def run(*args, **kwargs): for causal in [False, True]: # for causal in [True]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") - nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size) + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: @@ -354,7 +359,7 @@ def run(*args, **kwargs): time.sleep(1) if not varlen: # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') - m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, qv=qv, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index 58224a0e9..2c90a3390 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -1,4 +1,6 @@ +import time import torch +import torch.nn.functional as F from triton.testing import do_bench, do_bench_cudagraph @@ -13,7 +15,8 @@ device = "cuda" dtype = torch.bfloat16 -seqlen = 64 * 1024 +# seqlen = 64 * 1024 +seqlen = 8192 nheads = 128 nheads_kv = 1 headdim = 64 @@ -21,44 +24,55 @@ has_qv = True seqlen_q = 1 # page_size = None -page_size = 1 +page_size = 64 + +use_bench_cudagraph = False torch.manual_seed(0) -batch_size = 4 -cache_seqlens = torch.tensor([seqlen - 1] * batch_size, device=device, dtype=torch.int) +batch_size = 128 +cache_seqlens = None +# cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) # cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32) # cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) # cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int) # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) -num_splits = 0 -q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) -v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) -k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) -if page_size is not None: - assert seqlen % page_size == 0 - k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] - page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), - "(b s) -> b s", s=seqlen // page_size) -else: - page_table = None -qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, dtype=dtype, device=device) if has_qv else None +for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: +# for seqlen in [s * 1024 for s in [1]]: + cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) + num_splits = 0 + q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) + v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) + k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) + if page_size is not None: + assert seqlen % page_size == 0 + k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, dtype=dtype, device=device) if has_qv else None -# Time in ms -fn = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True) -t0 = do_bench(fn, warmup=1, rep=10) -with torch.cuda.stream(torch.cuda.Stream()): - t1 = do_bench_cudagraph(fn, rep=10) + # Time in ms + fn = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True) + time.sleep(1) # to avoid power throttling + if not use_bench_cudagraph: + t0 = do_bench(fn, warmup=1, rep=10) + else: + with torch.cuda.stream(torch.cuda.Stream()): + t0 = do_bench_cudagraph(fn, rep=10) + # exit(0) -mem_io = cache_seqlens.sum().item() * nheads_kv * (headdim + headdim_v) * 2 -flops = seqlen_q * cache_seqlens.float().sum().item() * nheads * (headdim + headdim_v * 2) * 2 -ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 -ideal_h100_time_flop = flops / 989e12 * 1e6 -ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) -print(f"Time: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") -print(f"Time w CUDA Graph: {t1 * 1e3:.0f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") -print(f"Ideal time: {ideal_h100_time:.0f} us") + total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() + mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + qv.numel() * 4 + flops = seqlen_q * total_seqlen * nheads * (headdim + headdim_v * (2 if has_qv else 1)) * 2 + ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 + ideal_h100_time_flop = flops / 989e12 * 1e6 + ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) + print(f"Time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") + print(f"Ideal time: {ideal_h100_time:.0f} us") -if pytorch_profiler is not None: - pytorch_profiler(flash_attn_with_kvcache, q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=False) + # if pytorch_profiler is not None: + # time.sleep(1) # to avoid power throttling + # pytorch_profiler(fn) From 48b3acbc44e8fd66b804d695f526c2be3586a760 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 11 Mar 2025 17:09:07 -0400 Subject: [PATCH 066/102] Benchmark MHA, GQA, MQA, MLA in the same script --- hopper/benchmark_mla_decode.py | 49 +++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index 2c90a3390..4b65a6339 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -1,3 +1,11 @@ +# Copyright (c) 2025, Ted Zadouri, Tri Dao. + +# We recommend locking GPU clocks before running the benchmark to ensure consistent results. +# This can be done using the following commands (1830 MHz is the clock for H100): +# sudo nvidia-smi -i 0 -pm 1 +# sudo nvidia-smi -i 0 --lock-gpu-clocks 1830,1830 +# See more here: https://github.com/triton-lang/triton/blob/d9f10ebdc5da53f73eb852fde73d8d7d80b679d1/python/triton/testing.py#L487 + import time import torch import torch.nn.functional as F @@ -13,18 +21,19 @@ except ImportError: pytorch_profiler = None +attn_variants = ["mha", "gqa", "mqa", "mla"] +attn_variant = attn_variants[3] device = "cuda" dtype = torch.bfloat16 -# seqlen = 64 * 1024 seqlen = 8192 nheads = 128 -nheads_kv = 1 -headdim = 64 -headdim_v = 512 -has_qv = True +nheads_kv = nheads if attn_variant == "mha" else (min(nheads // 8, 8) if attn_variant == "gqa" else 1) +headdim = 64 if attn_variant == "mla" else 128 +headdim_v = 512 if attn_variant == "mla" else headdim +has_qv = headdim == 64 and headdim_v == 512 seqlen_q = 1 # page_size = None -page_size = 64 +page_size = 64 if attn_variant == "mla" else 128 use_bench_cudagraph = False @@ -35,23 +44,27 @@ # cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) # cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32) # cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) -# cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int) # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) +print(f"{attn_variant.upper()}, nheads_q = {nheads}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") + for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: # for seqlen in [s * 1024 for s in [1]]: cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) num_splits = 0 q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) - v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) - k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) - if page_size is not None: - assert seqlen % page_size == 0 - k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] - page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), - "(b s) -> b s", s=seqlen // page_size) - else: - page_table = None + try: + v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) + k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) + if page_size is not None: + assert seqlen % page_size == 0 + k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + except torch.OutOfMemoryError: + continue qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, dtype=dtype, device=device) if has_qv else None # Time in ms @@ -65,12 +78,12 @@ # exit(0) total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() - mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + qv.numel() * 4 + mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output flops = seqlen_q * total_seqlen * nheads * (headdim + headdim_v * (2 if has_qv else 1)) * 2 ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 ideal_h100_time_flop = flops / 989e12 * 1e6 ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) - print(f"Time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") + print(f"Seqlen = {seqlen}, time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") print(f"Ideal time: {ideal_h100_time:.0f} us") # if pytorch_profiler is not None: From d904855e2dc0ec1c72984b1a9f6eba5cdcee1433 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 11 Mar 2025 17:56:53 -0400 Subject: [PATCH 067/102] Benchmark FlashMLA if it's available --- hopper/benchmark_mla_decode.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index 4b65a6339..294afc0b3 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -16,6 +16,11 @@ from flash_attn_interface import flash_attn_with_kvcache +try: + from flash_mla import flash_mla_with_kvcache, get_mla_metadata +except ImportError: + flash_mla_with_kvcache, get_mla_metadata = None, None + try: from flash_attn.utils.benchmark import pytorch_profiler except ImportError: @@ -26,8 +31,8 @@ device = "cuda" dtype = torch.bfloat16 seqlen = 8192 -nheads = 128 -nheads_kv = nheads if attn_variant == "mha" else (min(nheads // 8, 8) if attn_variant == "gqa" else 1) +nheads_q = 128 +nheads_kv = nheads_q if attn_variant == "mha" else (min(nheads_q // 8, 8) if attn_variant == "gqa" else 1) headdim = 64 if attn_variant == "mla" else 128 headdim_v = 512 if attn_variant == "mla" else headdim has_qv = headdim == 64 and headdim_v == 512 @@ -36,6 +41,7 @@ page_size = 64 if attn_variant == "mla" else 128 use_bench_cudagraph = False +should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None torch.manual_seed(0) @@ -46,13 +52,13 @@ # cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) -print(f"{attn_variant.upper()}, nheads_q = {nheads}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") +print(f"{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: # for seqlen in [s * 1024 for s in [1]]: cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) num_splits = 0 - q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) + q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) try: v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) @@ -65,7 +71,7 @@ page_table = None except torch.OutOfMemoryError: continue - qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, dtype=dtype, device=device) if has_qv else None + qv = torch.randn(batch_size, seqlen_q, nheads_q, headdim_v, dtype=dtype, device=device) if has_qv else None # Time in ms fn = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True) @@ -76,14 +82,28 @@ with torch.cuda.stream(torch.cuda.Stream()): t0 = do_bench_cudagraph(fn, rep=10) # exit(0) + if should_run_flashmla: + # Separate out the preprocessing since this can be done once and reused for all layers + scheduler_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv) + q_concat = torch.concat([q, qv], dim=-1) if has_qv else q + kv_cache_concat = torch.concat([v_cache, k_cache], dim=-1) + fn = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *scheduler_metadata, causal=True) + time.sleep(1) # to avoid power throttling + if not use_bench_cudagraph: + t1 = do_bench(fn, warmup=1, rep=10) + else: + with torch.cuda.stream(torch.cuda.Stream()): + t1 = do_bench_cudagraph(fn, rep=10) total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output - flops = seqlen_q * total_seqlen * nheads * (headdim + headdim_v * (2 if has_qv else 1)) * 2 + flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2 ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 ideal_h100_time_flop = flops / 989e12 * 1e6 ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) - print(f"Seqlen = {seqlen}, time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") + print(f"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") + if should_run_flashmla: + print(f"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.0f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") print(f"Ideal time: {ideal_h100_time:.0f} us") # if pytorch_profiler is not None: From cdaf2de6e95cb05400959b5ab984f66e4c7df317 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 11 Mar 2025 22:44:42 -0400 Subject: [PATCH 068/102] Run all 4 attn variants in benchmark --- hopper/benchmark_mla_decode.py | 159 +++++++++++++++++---------------- 1 file changed, 81 insertions(+), 78 deletions(-) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index 294afc0b3..eabf6efa0 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -26,86 +26,89 @@ except ImportError: pytorch_profiler = None + attn_variants = ["mha", "gqa", "mqa", "mla"] -attn_variant = attn_variants[3] -device = "cuda" -dtype = torch.bfloat16 -seqlen = 8192 -nheads_q = 128 -nheads_kv = nheads_q if attn_variant == "mha" else (min(nheads_q // 8, 8) if attn_variant == "gqa" else 1) -headdim = 64 if attn_variant == "mla" else 128 -headdim_v = 512 if attn_variant == "mla" else headdim -has_qv = headdim == 64 and headdim_v == 512 -seqlen_q = 1 -# page_size = None -page_size = 64 if attn_variant == "mla" else 128 - -use_bench_cudagraph = False -should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None - -torch.manual_seed(0) - -batch_size = 128 -cache_seqlens = None -# cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) -# cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32) -# cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) -# cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) - -print(f"{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") - -for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: -# for seqlen in [s * 1024 for s in [1]]: - cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) - num_splits = 0 - q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) - try: - v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) - k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) - if page_size is not None: - assert seqlen % page_size == 0 - k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] - page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), - "(b s) -> b s", s=seqlen // page_size) - else: - page_table = None - except torch.OutOfMemoryError: - continue - qv = torch.randn(batch_size, seqlen_q, nheads_q, headdim_v, dtype=dtype, device=device) if has_qv else None - - # Time in ms - fn = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True) - time.sleep(1) # to avoid power throttling - if not use_bench_cudagraph: - t0 = do_bench(fn, warmup=1, rep=10) - else: - with torch.cuda.stream(torch.cuda.Stream()): - t0 = do_bench_cudagraph(fn, rep=10) - # exit(0) - if should_run_flashmla: - # Separate out the preprocessing since this can be done once and reused for all layers - scheduler_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv) - q_concat = torch.concat([q, qv], dim=-1) if has_qv else q - kv_cache_concat = torch.concat([v_cache, k_cache], dim=-1) - fn = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *scheduler_metadata, causal=True) +# attn_variant = attn_variants[3] +for attn_variant in attn_variants: + device = "cuda" + dtype = torch.bfloat16 + seqlen = 8192 + nheads_q = 128 + nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else 1) + headdim = 64 if attn_variant == "mla" else 128 + headdim_v = 512 if attn_variant == "mla" else headdim + has_qv = headdim == 64 and headdim_v == 512 + seqlen_q = 1 + # page_size = None + page_size = 64 if attn_variant == "mla" else 128 + + use_bench_cudagraph = False + should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None + + torch.manual_seed(0) + + batch_size = 128 + cache_seqlens = None + # cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) + # cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32) + # cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) + # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) + + print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") + + for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: + # for seqlen in [s * 1024 for s in [8]]: + cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) + num_splits = 0 + q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) + try: + v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) + k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) + if page_size is not None: + assert seqlen % page_size == 0 + k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + except torch.OutOfMemoryError: + continue + qv = torch.randn(batch_size, seqlen_q, nheads_q, headdim_v, dtype=dtype, device=device) if has_qv else None + + # Time in ms + fn = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True) time.sleep(1) # to avoid power throttling if not use_bench_cudagraph: - t1 = do_bench(fn, warmup=1, rep=10) + t0 = do_bench(fn, warmup=1, rep=10) else: with torch.cuda.stream(torch.cuda.Stream()): - t1 = do_bench_cudagraph(fn, rep=10) - - total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() - mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output - flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2 - ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 - ideal_h100_time_flop = flops / 989e12 * 1e6 - ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) - print(f"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") - if should_run_flashmla: - print(f"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.0f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") - print(f"Ideal time: {ideal_h100_time:.0f} us") - - # if pytorch_profiler is not None: - # time.sleep(1) # to avoid power throttling - # pytorch_profiler(fn) + t0 = do_bench_cudagraph(fn, rep=10) + # exit(0) + if should_run_flashmla: + # Separate out the preprocessing since this can be done once and reused for all layers + scheduler_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv) + q_concat = torch.concat([q, qv], dim=-1) if has_qv else q + kv_cache_concat = torch.concat([v_cache, k_cache], dim=-1) + fn = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *scheduler_metadata, causal=True) + time.sleep(1) # to avoid power throttling + if not use_bench_cudagraph: + t1 = do_bench(fn, warmup=1, rep=10) + else: + with torch.cuda.stream(torch.cuda.Stream()): + t1 = do_bench_cudagraph(fn, rep=10) + + total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() + mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output + flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2 + ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 + ideal_h100_time_flop = flops / 989e12 * 1e6 + ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) + print(f"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") + if should_run_flashmla: + print(f"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.0f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") + print(f"Arithmetic intensity: {flops / mem_io:.1f}") + print(f"Ideal time: {ideal_h100_time:.0f} us") + + # if pytorch_profiler is not None: + # time.sleep(1) # to avoid power throttling + # pytorch_profiler(fn) From cf1b80988c31009989123c7d474bdf88e1b91f5d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 12 Mar 2025 16:28:31 -0400 Subject: [PATCH 069/102] Move scheduler.get_next_work to before the epilogue --- hopper/flash_fwd_kernel_sm90.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 1f841da46..c8bfc29b7 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -370,7 +370,8 @@ class FlashAttnFwdSm90 { CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); - work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { + // get_next_work will be called before the epilogue + ) { // Attention output (GEMM-II) accumulator. Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; @@ -426,6 +427,8 @@ class FlashAttnFwdSm90 { tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); } } + // Do this here before the epilogue so that the next tile is ready to go. + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, From 3cf8998e07b05c32c33098f8658222c6456a4fbc Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 12 Mar 2025 16:29:07 -0400 Subject: [PATCH 070/102] Enable Cluster for hdim128 back --- hopper/flash_fwd_launch_template.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index ededa4a5e..0a0d92f59 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -203,8 +203,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { From 6063dc5b90f1084d7edd88abe4e17bc8cdade1dc Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 12 Mar 2025 18:25:04 -0400 Subject: [PATCH 071/102] Move tOrO init in mainloop --- hopper/flash_fwd_kernel_sm90.h | 25 ++++++++++++------------ hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 18 ++++++++--------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index c8bfc29b7..3b02c18ba 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -372,21 +372,8 @@ class FlashAttnFwdSm90 { work_tile_info.is_valid(params.scheduler); // get_next_work will be called before the epilogue ) { - // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); - float softmax_scale_log2 = params.mainloop.softmax_scale_log2; - // If there's tanh softcap, the scaling will be done before tanh. auto block_coord = work_tile_info.get_block_coord(params.scheduler); int const bidb = get<2>(block_coord); - if constexpr (Is_FP8 && !Has_softcap) { - int const bidh = get<1>(block_coord); - int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; - float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; - float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; - softmax_scale_log2 *= q_descale * k_descale; - } - flash::Softmax softmax(softmax_scale_log2); - SeqlenInfo_t seqlen_info{ bidb, get<0>(params.mainloop.shape_Q), @@ -411,6 +398,18 @@ class FlashAttnFwdSm90 { // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } } } + // If there's tanh softcap, the scaling will be done before tanh. + float softmax_scale_log2 = params.mainloop.softmax_scale_log2; + if constexpr (Is_FP8 && !Has_softcap) { + int const bidh = get<1>(block_coord); + int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; + float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; + float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; + softmax_scale_log2 *= q_descale * k_descale; + } + flash::Softmax softmax(softmax_scale_log2); + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); bool tile_valid; if constexpr (!LargeHeadDimV) { tile_valid = mainloop.mma( diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index c2f7ff7eb..6a21078f7 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -353,7 +353,7 @@ struct CollectiveMainloopFwdSm90 { ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) : NumMmaWarpGroups == 2) && !LargeHeadDimV; - static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor); + static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap; // Host side kernel arguments struct Arguments { @@ -1061,8 +1061,8 @@ struct CollectiveMainloopFwdSm90 { float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; softcap_val *= q_descale * k_descale; } - // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - // -inf to e.g. -50.0, which can affect the attention softmax. + // Softcapping needs to happen before masking since if we apply after masking, softcapping + // can turn -inf to e.g. -50.0, which can affect the attention softmax. auto scoremod_premask_fn = [&](auto& tSrS) { if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } }; @@ -1126,10 +1126,6 @@ struct CollectiveMainloopFwdSm90 { cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); } - // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - clear(tOrO); - // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; - if constexpr (IntraWGOverlap) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); @@ -1157,6 +1153,10 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } --n_block; + // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { static constexpr bool Check_inf = decltype(check_inf_type)::value; @@ -1285,10 +1285,10 @@ struct CollectiveMainloopFwdSm90 { if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); if constexpr (!MmaPV_use_RS_WG1) { - flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); } else { TiledMmaPV_RS tiled_mma_pv_rs; - flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } warpgroup_wait<0>(); From 430954a8a173bdf2b757bfb5cd7cca08f2629859 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 12 Mar 2025 18:40:36 -0400 Subject: [PATCH 072/102] Adjust heuristic for get_pagedkv_tma --- hopper/flash_api.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 27bedc1fc..369bee25d 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -398,8 +398,11 @@ inline bool get_pagedkv_tma(Flash_fwd_params const& params) { if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } // This needs to match the kernel configs auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); - return params.page_size % kBlockN == 0; + // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, + // at least for MLA. + return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; } inline bool get_pack_gqa(Flash_fwd_params const& params) { From 000090d02f0398e9087a8823fc1f5242becfac99 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 12 Mar 2025 20:32:26 -0400 Subject: [PATCH 073/102] Enable PDL --- hopper/flash.h | 2 +- hopper/flash_api.cpp | 19 ++++----- hopper/flash_fwd_combine.cu | 12 +++--- hopper/flash_fwd_combine_kernel.h | 5 +++ hopper/flash_fwd_combine_launch_template.h | 45 ++++++++++++---------- hopper/flash_fwd_kernel_sm90.h | 9 +++++ hopper/flash_fwd_launch_template.h | 4 +- hopper/flash_prepare_scheduler.cu | 12 ++++-- hopper/setup.py | 2 +- 9 files changed, 69 insertions(+), 41 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index d5d7fa218..cf1d0d4a0 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -212,4 +212,4 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template -void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 369bee25d..e4a94144e 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -366,27 +366,27 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { #ifndef FLASHATTENTION_DISABLE_SPLIT // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively // so that kBlockM is smaller and we have more parallelism. if (params.is_fp32) { if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } else if (params.is_bf16) { if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } else { if (params.dv <= 64) { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } #else @@ -970,7 +970,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq // params.b = 1; // params.seqlen_q = total_q; // } - run_mha_fwd_combine(params, stream); + run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. @@ -1419,10 +1419,11 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x params.o_row_stride = out.stride(1); params.o_head_stride = out.stride(2); params.o_batch_stride = out.stride(0); + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; if (seqlen > 0 && batch_size > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd_combine(params, stream); + run_mha_fwd_combine(params, stream, false /*enable_pdl*/); } at::Tensor out_padded = out; diff --git a/hopper/flash_fwd_combine.cu b/hopper/flash_fwd_combine.cu index a1725cf2a..3e85a0a21 100644 --- a/hopper/flash_fwd_combine.cu +++ b/hopper/flash_fwd_combine.cu @@ -3,11 +3,11 @@ #include "flash_fwd_combine_launch_template.h" -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 3e9a3c232..a22e05969 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -12,6 +12,8 @@ #include #include +#include "cutlass/arch/grid_dependency_control.h" + #include "seqlen.h" #include "utils.h" @@ -205,6 +207,7 @@ class FlashAttnFwdCombine { int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { + cutlass::arch::wait_on_dependent_grids(); *params.semaphore_to_reset = 0; } if (num_splits <= 1) { return; } @@ -232,6 +235,8 @@ class FlashAttnFwdCombine { // Repeat the partitioning with identity layouts Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE); + cutlass::arch::wait_on_dependent_grids(); + #pragma unroll for (int m = 0; m < size<2>(tLSEcLSE); ++m) { int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 7cb9b64fd..11d422924 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -9,6 +9,7 @@ #include "cutlass/cutlass.h" #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80 #include "cutlass/device_kernel.h" // For device_kernel +#include "cutlass/kernel_launch.h" // For kernel_launch #include "static_switch.h" #include "flash.h" @@ -16,11 +17,12 @@ using namespace cute; -template -void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { +template +void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; using TileShape_MK = cute::Shape, Int>; using CombineKernel = flash::FlashAttnFwdCombine; + IsEvenK, Varlen, Element, ElementPartial, ArchTag>; typename CombineKernel::Arguments args { static_cast(params.oaccum_ptr), @@ -45,31 +47,34 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } - kernel<<>>(kernel_params); + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } template -void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { // We want kBlockM to be as small as possible to maximize parallelism. // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); - BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { - if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. - if (params.num_splits <= 16) { - run_flash_fwd_combine(params, stream); - return; + ARCH_SWITCH(params.arch, Arch, [&] { + BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { + if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. + if (params.num_splits <= 16) { + run_flash_fwd_combine(params, stream, enable_pdl); + return; + } } - } - if (params.num_splits <= 32) { - run_flash_fwd_combine(params, stream); - } else if (params.num_splits <= 64) { - run_flash_fwd_combine(params, stream); - } else if (params.num_splits <= 128) { - run_flash_fwd_combine(params, stream); - } else { - run_flash_fwd_combine(params, stream); - } + if (params.num_splits <= 32) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 64) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 128) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else { + run_flash_fwd_combine(params, stream, enable_pdl); + } + }); }); } diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 3b02c18ba..962283fe2 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -14,6 +14,8 @@ #include #include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/grid_dependency_control.h" + #include "seqlen.h" #include "utils.h" #include "softmax.h" @@ -320,6 +322,8 @@ class FlashAttnFwdSm90 { } if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); } + cutlass::arch::wait_on_dependent_grids(); + // Load Q, K, V for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); @@ -428,6 +432,11 @@ class FlashAttnFwdSm90 { } // Do this here before the epilogue so that the next tile is ready to go. work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); + if constexpr (Split && Varlen) { + if (!work_tile_info.is_valid(params.scheduler)) { // Last tile + cutlass::arch::launch_dependent_grids(); + } + } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 0a0d92f59..4df7eec8c 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -10,6 +10,7 @@ #include "cutlass/device_kernel.h" // For device_kernel #include #include "cutlass/cluster_launch.hpp" +#include "cutlass/kernel_launch.h" #include "static_switch.h" #include "flash.h" @@ -186,7 +187,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } - kernel<<>>(kernel_params); + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, Arch >= 90 && Varlen /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 8d1b3602b..9ba793223 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -6,6 +6,8 @@ #include "cutlass/barrier.h" #include "cutlass/arch/barrier.h" +#include "cutlass/arch/grid_dependency_control.h" + #include "flash.h" namespace flash { @@ -16,7 +18,8 @@ __global__ void prepare_varlen_num_blocks_kernel( int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, - int* const tile_count_semaphore, int* const num_m_blocks_ptr, int* const num_n_blocks_ptr, + int* const tile_count_semaphore, int* const num_n_blocks_ptr, + // int* const num_m_blocks_ptr, int* const num_splits_dynamic_ptr) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; @@ -24,6 +27,9 @@ __global__ void prepare_varlen_num_blocks_kernel( // Assume that there's only one block in the grid __shared__ int smem[kSmemSize]; + // There's only 1 block in the grid, so might as well start launching the main attn kernel + cutlass::arch::launch_dependent_grids(); + if (threadIdx.x < kSmemSize) { smem[threadIdx.x] = 0; } __syncthreads(); @@ -109,7 +115,6 @@ __global__ void prepare_varlen_num_blocks_kernel( // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } } - } } // flash @@ -123,6 +128,7 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo params.seqused_q, params.seqused_k, params.leftpad_k, params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), - params.tile_count_semaphore, params.num_m_blocks_ptr, params.num_n_blocks_ptr, + params.tile_count_semaphore, params.num_n_blocks_ptr, + // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr); } diff --git a/hopper/setup.py b/hopper/setup.py index 121266ebd..f87d809eb 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -520,7 +520,7 @@ def nvcc_threads_args(): # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster "-lineinfo", "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use - # "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL + "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging "-DNDEBUG", # Important, otherwise performance is severely impacted ] From 46e1d4a1c762c08e73eab63a65fba128cf696a3d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 13 Mar 2025 01:38:14 -0400 Subject: [PATCH 074/102] Simplify prepare_varlen_num_blocks_kernel, restrict to batch <= 992 --- hopper/flash.h | 4 +-- hopper/flash_api.cpp | 46 +++++++++++++++++------------- hopper/flash_fwd_launch_template.h | 2 +- hopper/flash_prepare_scheduler.cu | 42 +++++++++++---------------- hopper/tile_scheduler.hpp | 5 ++-- 5 files changed, 48 insertions(+), 51 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index cf1d0d4a0..93b6b5165 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -150,8 +150,8 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; - int * __restrict__ num_m_blocks_ptr; - int * __restrict__ num_n_blocks_ptr; + // int * __restrict__ num_m_blocks_ptr; + // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; int arch; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index e4a94144e..76eb32b86 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -447,7 +447,7 @@ inline int get_num_splits(Flash_fwd_params const& params) { // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending // that batch = 1. - int total_mblocks = (!varlen ? params.b : 1) * params.h_k * num_m_blocks; + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); #endif } @@ -798,6 +798,31 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } + at::Tensor tile_count_semaphore; + // We don't use the persistent scheduler if Split and not Varlen + bool const persistent_scheduler = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel + bool const use_dynamic_split = is_varlen && params.b <= 992; + if (persistent_scheduler || use_dynamic_split) { // This needs to be set before get_num_splits + tile_count_semaphore = torch::empty({int(persistent_scheduler) + int(use_dynamic_split) * batch_size}, opts.dtype(torch::kInt32)); + if (persistent_scheduler) { + if (!is_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + } else { + params.tile_count_semaphore = nullptr; + } + if (use_dynamic_split) { + // params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr(); + // params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr() + batch_size; + params.num_splits_dynamic_ptr = tile_count_semaphore.data_ptr() + 1; + } else { + params.num_splits_dynamic_ptr = nullptr; + } + } + + params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide @@ -882,25 +907,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.lseaccum_head_stride = softmax_lse_accum.stride(-2); } - at::Tensor tile_count_semaphore, num_m_n_blocks_splits; - // We don't use the persistent scheduler if Split and not Varlen - bool const persistent_scheduler = params.arch >= 90 - ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) - : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - if (persistent_scheduler) { - tile_count_semaphore = torch::empty({1}, opts.dtype(torch::kInt32)); - if (!is_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); - } else { - params.tile_count_semaphore = nullptr; - } - if (is_varlen) { - num_m_n_blocks_splits = torch::empty({batch_size * 3}, opts.dtype(torch::kInt32)); - params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr(); - params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr() + batch_size; - params.num_splits_dynamic_ptr = num_m_n_blocks_splits.data_ptr() + batch_size * 2; - } - if (q_type == at::ScalarType::Float8_e4m3fn) { if (q_descale_.has_value()) { auto q_descale = q_descale_.value(); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 4df7eec8c..fe54bd1c0 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -155,7 +155,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.num_splits_dynamic_ptr, }; - if constexpr (Varlen) { + if (Varlen && params.num_splits_dynamic_ptr) { prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN); CHECK_CUDA_KERNEL_LAUNCH(); } diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 9ba793223..df5a19a1f 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -18,19 +18,19 @@ __global__ void prepare_varlen_num_blocks_kernel( int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, - int* const tile_count_semaphore, int* const num_n_blocks_ptr, + int* const tile_count_semaphore, // int* const num_m_blocks_ptr, int* const num_splits_dynamic_ptr) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; // Assume that there's only one block in the grid - __shared__ int smem[kSmemSize]; + __shared__ int total_blocks_smem[kSmemSize]; // There's only 1 block in the grid, so might as well start launching the main attn kernel cutlass::arch::launch_dependent_grids(); - if (threadIdx.x < kSmemSize) { smem[threadIdx.x] = 0; } + if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } __syncthreads(); if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; } @@ -83,37 +83,26 @@ __global__ void prepare_varlen_num_blocks_kernel( int total_blocks = 0; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; - int num_warps = blockDim.x / cutlass::NumThreadsPerWarp; - for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp * num_warps) { - int num_m_blocks = get_num_m_blocks(bidb_start); - int num_n_blocks = get_num_n_blocks(bidb_start); - if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { - // num_m_blocks_ptr[bidb_start + lane] = num_m_blocks; - num_n_blocks_ptr[bidb_start + lane] = num_n_blocks; - // printf("idx = %d, num_m = %d, num_n = %d\n", bidb_start + lane, num_m_blocks, num_n_blocks); - } - total_blocks += num_m_blocks * num_n_blocks; - } + int bidb_start = kNumBatchPerWarp * warp_idx; + int num_m_blocks = get_num_m_blocks(bidb_start); + int num_n_blocks = get_num_n_blocks(bidb_start); + total_blocks += num_m_blocks * num_n_blocks; // Warp sum #pragma unroll for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); } - if (lane == 0) { atomicAdd(smem, total_blocks); } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } __syncthreads(); - total_blocks = smem[0]; + total_blocks = total_blocks_smem[0]; // 10% margin int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM - for (int bidb_start = kNumBatchPerWarp * warp_idx; bidb_start < num_batch; bidb_start += kNumBatchPerWarp * num_warps) { - bool is_valid = bidb_start + lane < num_batch && lane < kNumBatchPerWarp; - int num_n_blocks = is_valid ? num_n_blocks_ptr[bidb_start + lane] : 0; - int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); - if (is_valid) { - num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); - } + int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { + num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } } @@ -121,14 +110,15 @@ __global__ void prepare_varlen_num_blocks_kernel( void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN) { + // Only support batch <= 992 (32 warps, each with 31 batches) int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); - flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 256 /*block*/, 0, stream>>>( + flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( params.seqlen_q, params.seqlen_k, params.seqlen_knew, params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, params.leftpad_k, params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), - params.tile_count_semaphore, params.num_n_blocks_ptr, + params.tile_count_semaphore, // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr); } diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index a3aa794d6..f71324272 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -388,7 +388,6 @@ class VarlenDynamicPersistentTileScheduler { // If Split, for the purpose of scheduling, we pretend that instead there are // (args.num_splits * args.num_head) number of heads. assert(args.tile_count_semaphore != nullptr); - assert(!Split || args.num_splits_dynamic_ptr != nullptr); assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits return {args.num_head, args.num_batch, @@ -468,7 +467,9 @@ class VarlenDynamicPersistentTileScheduler { auto get_num_splits = [&] (int bidb_start) { int batch_idx = lane + bidb_start; return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? (!Split ? 1 : params.num_splits_dynamic_ptr[batch_idx]) + ? (!Split ? 1 : (params.num_splits_dynamic_ptr + ? params.num_splits_dynamic_ptr[batch_idx] + : params.nsplits_divmod.divisor)) : 0; }; From 897c84539a9009bac832093d55883010d0da25ff Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 14 Mar 2025 00:38:03 -0400 Subject: [PATCH 075/102] Fix: num_splits_dynamic_ptr needs to be set before get_num_splits --- hopper/flash_api.cpp | 31 +++++++++++++++++-------------- hopper/flash_prepare_scheduler.cu | 3 +-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 76eb32b86..8bb80604a 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -798,17 +798,26 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } - at::Tensor tile_count_semaphore; + // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel + bool const use_dynamic_split = is_varlen && params.b <= 992; + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic // We don't use the persistent scheduler if Split and not Varlen - bool const persistent_scheduler = params.arch >= 90 + bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= 992; - if (persistent_scheduler || use_dynamic_split) { // This needs to be set before get_num_splits - tile_count_semaphore = torch::empty({int(persistent_scheduler) + int(use_dynamic_split) * batch_size}, opts.dtype(torch::kInt32)); - if (persistent_scheduler) { - if (!is_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + if (scheduler_needs_semaphore || use_dynamic_split) { // This needs to be set before get_num_splits + tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * batch_size}, opts.dtype(torch::kInt32)); + if (scheduler_needs_semaphore) { + if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = tile_count_semaphore.data_ptr(); } else { params.tile_count_semaphore = nullptr; @@ -822,12 +831,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } - - params.pagedkv_tma = get_pagedkv_tma(params); - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index df5a19a1f..d1b2a4f2a 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -81,13 +81,12 @@ __global__ void prepare_varlen_num_blocks_kernel( ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0; }; - int total_blocks = 0; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; int bidb_start = kNumBatchPerWarp * warp_idx; int num_m_blocks = get_num_m_blocks(bidb_start); int num_n_blocks = get_num_n_blocks(bidb_start); - total_blocks += num_m_blocks * num_n_blocks; + int total_blocks = num_m_blocks * num_n_blocks; // Warp sum #pragma unroll for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { From 90f27a29dd1db73b474112854730a7894b8c7f9b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 15 Mar 2025 15:54:58 -0400 Subject: [PATCH 076/102] Loop on num_splits instead of parameterizing it in kvcache test --- hopper/test_flash_attn.py | 167 ++++++++++++++++++++------------------ 1 file changed, 87 insertions(+), 80 deletions(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 2ed394324..3098d4e30 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -559,8 +559,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("num_splits", [1] + ([0] if not DISABLE_SPLIT else [])) -# @pytest.mark.parametrize("num_splits", [1]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) @@ -623,7 +621,6 @@ def test_flash_attn_kvcache( local, new_kv, mha_type, - num_splits, dtype, ): if page_size is not None and seqlen_k % page_size != 0: @@ -825,88 +822,98 @@ def test_flash_attn_kvcache( qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None cos = cos.to(dtype) if cos is not None else None sin = sin.to(dtype) if sin is not None else None - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - qv=qv if not varlen_q else qv_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + for num_splits in num_splits_vals: if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] - ) + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) - else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype): From fa60e7cc97300b4b26721983df580a7da7a8ebea Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 15 Mar 2025 16:41:29 -0400 Subject: [PATCH 077/102] Add option to precompute scheduler metadata --- hopper/benchmark_attn.py | 5 +- hopper/cuda_check.h | 19 ++++ hopper/flash.h | 3 +- hopper/flash_api.cpp | 151 ++++++++++++++++++++++++++--- hopper/flash_attn_interface.py | 47 ++++++++- hopper/flash_fwd_launch_template.h | 7 +- hopper/flash_prepare_scheduler.cu | 9 +- hopper/test_flash_attn.py | 18 +++- hopper/utils.h | 12 +-- 9 files changed, 235 insertions(+), 36 deletions(-) create mode 100644 hopper/cuda_check.h diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 62ac2b63c..33e5d2827 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -56,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) # # return time_f[1].mean # return time_f[1] - return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=3, rep=repeats) * 1e-3) def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): @@ -404,7 +404,8 @@ def run(*args, **kwargs): # import pickle # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp: - # with open(f'flash3_attn_time_h100_fa3_20241208.plk', 'wb') as fp: + # with open(f'flash3_attn_time_h100_fa3_20250313.plk', 'wb') as fp: + # # with open(f'flash3_attn_time_h100_fa3_fp8_20250313.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp: # pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/hopper/cuda_check.h b/hopper/cuda_check.h new file mode 100644 index 000000000..b5e63aef7 --- /dev/null +++ b/hopper/cuda_check.h @@ -0,0 +1,19 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) diff --git a/hopper/flash.h b/hopper/flash.h index 93b6b5165..69562d488 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -153,6 +153,7 @@ struct Flash_fwd_params : public Qkv_params { // int * __restrict__ num_m_blocks_ptr; // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; + bool skip_scheduler_metadata_computation; int arch; int num_sm; @@ -208,7 +209,7 @@ struct Flash_bwd_params : public Flash_fwd_params { template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN); +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 8bb80604a..0251c6c4e 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -15,6 +15,7 @@ #include "static_switch.h" #include "tile_size.h" #include "heuristics.h" +#include "cuda_check.h" // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909 // This is so that we can pass in torch.dtype as a parameter to the function. @@ -490,6 +491,127 @@ inline int round_up_headdim(int head_size) { return 256; } +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available +at::Tensor +mha_fwd_get_scheduler_metadata( + int batch_size, + int max_seqlen_q, + int max_seqlen_k, + int num_heads, + int num_heads_k, + int headdim, + int headdim_v, + at::ScalarType qkv_dtype, + const at::Tensor &seqused_k, // b + std::optional &cu_seqlens_q_, // b+1 + std::optional &cu_seqlens_k_, // b+1 + std::optional &cu_seqlens_k_new_, // b+1 + std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional &leftpad_k_, // b + std::optional page_size, + int max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int window_size_left, + int window_size_right, + bool has_softcap, + int num_splits, + std::optional pack_gqa_, + int const sm_margin + ) { + + TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // Reset the parameters + Flash_fwd_params params{}; + params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16; + params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn; + params.b = batch_size; + params.seqlen_q = max_seqlen_q; + params.seqlen_k = max_seqlen_k; + params.h = num_heads; + params.h_k = num_heads_k; + params.d = headdim; + params.dv = headdim_v; + params.d_rounded = round_up_headdim(headdim); + params.dv_rounded = round_up_headdim(headdim_v); + params.seqlen_knew = max_seqlen_k_new; + + bool const is_varlen_q = cu_seqlens_q_.has_value(); + params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr() : nullptr; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr() : nullptr; + params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr() : nullptr; + params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr; + params.seqused_k = seqused_k.data_ptr(); + params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr() : nullptr; + params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; + if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; + params.softcap = has_softcap ? 1.0f : 0.0f; + + params.page_size = page_size.has_value() ? page_size.value() : 1; + params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); + + bool const use_dynamic_split = params.b <= 992; + params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + bool is_varlen = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; + + auto opts = seqused_k.options(); + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; + if (scheduler_needs_semaphore || use_dynamic_split) { + tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); + if (scheduler_needs_semaphore) { + if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + } else { + params.tile_count_semaphore = nullptr; + } + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + } + + if (params.num_splits_dynamic_ptr) { + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + return tile_count_semaphore; +} + // b: batch_size // b_k: batch_size_k // s_q: seqlen_q @@ -528,6 +650,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int window_size_right, float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional &scheduler_metadata_, // (b + 1) int num_splits, std::optional pack_gqa_, int const sm_margin @@ -814,21 +937,24 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - if (scheduler_needs_semaphore || use_dynamic_split) { // This needs to be set before get_num_splits - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * batch_size}, opts.dtype(torch::kInt32)); - if (scheduler_needs_semaphore) { - if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + if (scheduler_needs_semaphore || use_dynamic_split) { + int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); + if (scheduler_metadata_.has_value()) { + at::Tensor scheduler_metadata = scheduler_metadata_.value(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); + tile_count_semaphore = scheduler_metadata; } else { - params.tile_count_semaphore = nullptr; + tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); } - if (use_dynamic_split) { - // params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr(); - // params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr() + batch_size; - params.num_splits_dynamic_ptr = tile_count_semaphore.data_ptr() + 1; - } else { - params.num_splits_dynamic_ptr = nullptr; + if (scheduler_needs_semaphore && !use_dynamic_split) { + tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing } + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; + params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; } if (q_v_.has_value()) { @@ -1449,4 +1575,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fwd", &mha_fwd, "Forward pass"); m.def("bwd", &mha_bwd, "Backward pass"); m.def("fwd_combine", &mha_combine, "Combine partial attention outputs"); + m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass"); } diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 469266e52..92b84096f 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -44,6 +44,7 @@ def _flash_attn_forward( window_size=(-1, -1), softcap=0.0, rotary_interleaved=True, + scheduler_metadata=None, num_splits=1, pack_gqa=None, sm_margin=0): @@ -86,11 +87,12 @@ def _flash_attn_forward( window_size[1], softcap, rotary_interleaved, + scheduler_metadata, num_splits, pack_gqa, sm_margin, ) - return (out, softmax_lse, *rest) + return out, softmax_lse, *rest def _flash_attn_backward( @@ -608,6 +610,7 @@ def flash_attn_with_kvcache( window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, + scheduler_metadata=None, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication @@ -733,9 +736,51 @@ def flash_attn_with_kvcache( window_size=window_size, softcap=softcap, rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out + + +def get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication +): + cache_seqlens = maybe_contiguous(cache_seqlens) + if headdim_v is None: + headdim_v = headdim + scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, + qkv_dtype, + cache_seqlens, + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_leftpad, + page_size, + max_seqlen_k_new, + causal, + window_size[0], window_size[1], + has_softcap, + num_splits, + pack_gqa, + sm_margin, + ) + return scheduler_metadata diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index fe54bd1c0..006920493 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -155,8 +155,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.num_splits_dynamic_ptr, }; - if (Varlen && params.num_splits_dynamic_ptr) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN); + if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -188,7 +188,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // kernel<<>>(kernel_params); - cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, Arch >= 90 && Varlen /*launch_with_pdl*/); + cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, + Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index d1b2a4f2a..7093fff32 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -20,7 +20,8 @@ __global__ void prepare_varlen_num_blocks_kernel( cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, int* const tile_count_semaphore, // int* const num_m_blocks_ptr, - int* const num_splits_dynamic_ptr) { + int* const num_splits_dynamic_ptr, + bool enable_pdl) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; @@ -28,7 +29,7 @@ __global__ void prepare_varlen_num_blocks_kernel( __shared__ int total_blocks_smem[kSmemSize]; // There's only 1 block in the grid, so might as well start launching the main attn kernel - cutlass::arch::launch_dependent_grids(); + if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } __syncthreads(); @@ -108,7 +109,7 @@ __global__ void prepare_varlen_num_blocks_kernel( } // flash void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, - int blockM, int blockN) { + int blockM, int blockN, bool enable_pdl) { // Only support batch <= 992 (32 warps, each with 31 batches) int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( @@ -119,5 +120,5 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), params.tile_count_semaphore, // params.num_m_blocks_ptr, - params.num_splits_dynamic_ptr); + params.num_splits_dynamic_ptr, enable_pdl); } diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 3098d4e30..a29ec8e9a 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -19,7 +19,8 @@ generate_random_padding_mask, ) -from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine, flash_attn_with_kvcache +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" @@ -825,13 +826,25 @@ def test_flash_attn_kvcache( k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] - for num_splits in num_splits_vals: + precompute_metadata_vals = [False, True] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): if page_size is None: k_cache.copy_(k_cache_saved) v_cache.copy_(v_cache_saved) else: k_cache_paged.copy_(k_cache_saved) v_cache_paged.copy_(v_cache_saved) + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, + cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + max_seqlen_k_new=seqlen_new, page_size=page_size, + causal=causal, window_size=window_size, + num_splits=num_splits + ) + else: + scheduler_metadata = None out, lse, *rest = flash_attn_with_kvcache( q if not varlen_q else q_unpad, k_cache if page_size is None else k_cache_paged, @@ -851,6 +864,7 @@ def test_flash_attn_kvcache( causal=causal, window_size=window_size, rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, num_splits=num_splits, return_softmax_lse=True ) diff --git a/hopper/utils.h b/hopper/utils.h index d9468af55..3f76ea66e 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -21,17 +21,7 @@ #include #include - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while(0) - -#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) +#include "cuda_check.h" namespace flash { From 6c87fac478de8ba7d6d43cc064b3bd0f701ae6eb Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 15 Mar 2025 16:43:06 -0400 Subject: [PATCH 078/102] Update MLA decode benchmark to use get_scheduler_metadata --- hopper/benchmark_mla_decode.py | 54 +++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index eabf6efa0..9b7c05708 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -14,7 +14,7 @@ from einops import rearrange -from flash_attn_interface import flash_attn_with_kvcache +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata try: from flash_mla import flash_mla_with_kvcache, get_mla_metadata @@ -27,22 +27,25 @@ pytorch_profiler = None +device = "cuda" +dtype = torch.bfloat16 +seqlen = 8192 +seqlen_q = 1 +# nheads_q = 16 +nheads_q = 128 + +use_bench_cudagraph = False + attn_variants = ["mha", "gqa", "mqa", "mla"] -# attn_variant = attn_variants[3] for attn_variant in attn_variants: - device = "cuda" - dtype = torch.bfloat16 - seqlen = 8192 - nheads_q = 128 +# for attn_variant in attn_variants[3:]: nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else 1) headdim = 64 if attn_variant == "mla" else 128 headdim_v = 512 if attn_variant == "mla" else headdim has_qv = headdim == 64 and headdim_v == 512 - seqlen_q = 1 # page_size = None page_size = 64 if attn_variant == "mla" else 128 - use_bench_cudagraph = False should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None torch.manual_seed(0) @@ -57,7 +60,7 @@ print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: - # for seqlen in [s * 1024 for s in [8]]: + # for seqlen in [s * 1024 for s in [1]]: cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) num_splits = 0 q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) @@ -75,27 +78,35 @@ continue qv = torch.randn(batch_size, seqlen_q, nheads_q, headdim_v, dtype=dtype, device=device) if has_qv else None - # Time in ms - fn = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True) + # Precomputing this saves ~2us + scheduler_metadata = get_scheduler_metadata( + batch_size, seqlen_q, seqlen, nheads_q, nheads_kv, headdim, + cache_seqlens, q.dtype, headdim_v=headdim_v, page_size=page_size, causal=True + ) + # scheduler_metadata = None + fn0 = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True, scheduler_metadata=scheduler_metadata) time.sleep(1) # to avoid power throttling + # Time in ms if not use_bench_cudagraph: - t0 = do_bench(fn, warmup=1, rep=10) + t0 = do_bench(fn0, warmup=1, rep=10) else: + torch.cuda.synchronize() # Gotta wait, otherwise e.g. k_cache might not be ready with torch.cuda.stream(torch.cuda.Stream()): - t0 = do_bench_cudagraph(fn, rep=10) + t0 = do_bench_cudagraph(fn0, rep=10) # exit(0) if should_run_flashmla: # Separate out the preprocessing since this can be done once and reused for all layers - scheduler_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv) + mla_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv) q_concat = torch.concat([q, qv], dim=-1) if has_qv else q kv_cache_concat = torch.concat([v_cache, k_cache], dim=-1) - fn = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *scheduler_metadata, causal=True) + fn1 = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *mla_metadata, causal=True) time.sleep(1) # to avoid power throttling if not use_bench_cudagraph: - t1 = do_bench(fn, warmup=1, rep=10) + t1 = do_bench(fn1, warmup=1, rep=10) else: + torch.cuda.synchronize() # Gotta wait, otherwise e.g. k_cache might not be ready with torch.cuda.stream(torch.cuda.Stream()): - t1 = do_bench_cudagraph(fn, rep=10) + t1 = do_bench_cudagraph(fn1, rep=10) total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output @@ -103,12 +114,15 @@ ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 ideal_h100_time_flop = flops / 989e12 * 1e6 ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) - print(f"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.0f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") + print(f"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.1f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") if should_run_flashmla: - print(f"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.0f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") + print(f"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.1f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") print(f"Arithmetic intensity: {flops / mem_io:.1f}") print(f"Ideal time: {ideal_h100_time:.0f} us") # if pytorch_profiler is not None: # time.sleep(1) # to avoid power throttling - # pytorch_profiler(fn) + # pytorch_profiler(fn0) + # if should_run_flashmla: + # time.sleep(1) # to avoid power throttling + # pytorch_profiler(fn1) From 4b5eeab1222ab8faab3024f408e90d1f6563eae1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 15 Mar 2025 17:15:34 -0400 Subject: [PATCH 079/102] Fix FP8 test to quantize KV cache for reference impl as well --- hopper/test_flash_attn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index a29ec8e9a..be27f14f6 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -695,7 +695,7 @@ def test_flash_attn_kvcache( v_cache_paged, num_blocks, ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype_ref + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref ) cache_seqlens = torch.randint( 0 if new_kv else 1, @@ -930,14 +930,14 @@ def test_flash_attn_kvcache( assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype): +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype - ) + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", From 27f501dbe011f4371bff938fe7e09311ab3002fa Mon Sep 17 00:00:00 2001 From: schung-amd Date: Sat, 15 Mar 2025 19:23:11 -0400 Subject: [PATCH 080/102] Dynamic autotune configs for devices with warp size != 32 (#1534) Generate a list of autotune configs based on device warp size to avoid triton error if maximum threads per block is exceeded. --- flash_attn/ops/triton/layer_norm.py | 31 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index addffe1f1..0d122aa08 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -15,6 +15,19 @@ import triton import triton.language as tl +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs=[] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block=1024 + # Default to warp size 32 if not defined by device + warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + warp_count=1 + while warp_count*warp_size <= max_threads_per_block: + configs.append(triton.Config({}, num_warps=warp_count)) + warp_count*=2 + return configs def layer_norm_ref( x, @@ -126,14 +139,7 @@ def rms_norm_ref( @triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], + configs=triton_autotune_configs(), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @@ -393,14 +399,7 @@ def _layer_norm_fwd( @triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], + configs=triton_autotune_configs(), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) From 7ae5f8c8fe0c518ec0039352c07118c83bd33f1f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 21 Mar 2025 01:51:04 -0700 Subject: [PATCH 081/102] Add option for rotary_seqlens --- hopper/flash.h | 1 + hopper/flash_api.cpp | 8 ++++++++ hopper/flash_attn_interface.py | 9 +++++++-- hopper/flash_fwd_kernel_sm80.h | 1 + hopper/flash_fwd_kernel_sm90.h | 2 ++ hopper/flash_fwd_launch_template.h | 2 +- hopper/mainloop_fwd_sm80.hpp | 12 +++++++----- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 12 +++++++----- hopper/seqlen.h | 6 ++++-- hopper/setup.py | 3 ++- hopper/test_flash_attn.py | 14 ++++++++++---- 11 files changed, 50 insertions(+), 20 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index 69562d488..91fb5c812 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -112,6 +112,7 @@ struct Flash_fwd_params : public Qkv_params { // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; + int *__restrict__ seqlens_rotary; // The indices to index into the KV cache. int * __restrict__ kv_batch_idx; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 0251c6c4e..c79869483 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -641,6 +641,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq std::optional &leftpad_k_, // b std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &seqlens_rotary_, // b std::optional &q_descale_, // (b, h_k), not (b, h) std::optional &k_descale_, // (b, h_k) std::optional &v_descale_, // (b, h_k) @@ -1002,6 +1003,13 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.rotary_cos_ptr = rotary_cos.data_ptr(); params.rotary_sin_ptr = rotary_sin.data_ptr(); params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } } else { params.rotary_dim = 0; } diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 92b84096f..59a5517ce 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -36,6 +36,7 @@ def _flash_attn_forward( leftpad_k, rotary_cos, rotary_sin, + seqlens_rotary, q_descale, k_descale, v_descale, @@ -58,6 +59,7 @@ def _flash_attn_forward( maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k) ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] + seqlens_rotary = maybe_contiguous(seqlens_rotary) out, softmax_lse, *rest = flash_attn_3_cuda.fwd( q, k, @@ -78,6 +80,7 @@ def _flash_attn_forward( leftpad_k, rotary_cos, rotary_sin, + seqlens_rotary, q_descale, k_descale, v_descale, @@ -257,7 +260,7 @@ def forward( None, None, # seqused_q/k None, None, # max_seqlen_q/k None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, # rotary_cos/sin + None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, @@ -350,7 +353,7 @@ def forward( max_seqlen_q, max_seqlen_k, None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, # rotary_cos/sin + None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, @@ -602,6 +605,7 @@ def flash_attn_with_kvcache( cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, @@ -730,6 +734,7 @@ def flash_attn_with_kvcache( cache_leftpad, rotary_cos, rotary_sin, + rotary_seqlens, q_descale, k_descale, v_descale, softmax_scale, causal=causal, diff --git a/hopper/flash_fwd_kernel_sm80.h b/hopper/flash_fwd_kernel_sm80.h index 4c35da4f0..b308d2d1b 100644 --- a/hopper/flash_fwd_kernel_sm80.h +++ b/hopper/flash_fwd_kernel_sm80.h @@ -187,6 +187,7 @@ class FlashAttnFwdSm80 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 962283fe2..47b3817cd 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -337,6 +337,7 @@ class FlashAttnFwdSm90 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.load_kv_new( @@ -385,6 +386,7 @@ class FlashAttnFwdSm90 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 006920493..452fd61b7 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -126,7 +126,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.kv_batch_idx, params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, - params.leftpad_k, + params.leftpad_k, params.seqlens_rotary }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.o_ptr), diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index a642fc74f..905be872d 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -212,6 +212,7 @@ struct CollectiveMainloopFwdSm80 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; // Device side kernel params @@ -256,6 +257,7 @@ struct CollectiveMainloopFwdSm80 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; static Params @@ -295,7 +297,7 @@ struct CollectiveMainloopFwdSm80 { !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k}; + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; } template @@ -472,11 +474,11 @@ struct CollectiveMainloopFwdSm80 { flash::cp_async_wait(); } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_q, + seqlen_info.seqlen_rotary); int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = cute::conditional_return( @@ -689,12 +691,12 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_k_new, + seqlen_info.seqlen_rotary); using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 6a21078f7..65d447da0 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -395,6 +395,7 @@ struct CollectiveMainloopFwdSm90 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; // Device side kernel params @@ -450,6 +451,7 @@ struct CollectiveMainloopFwdSm90 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const *const seqlens_rotary = nullptr; }; static Params @@ -558,7 +560,7 @@ struct CollectiveMainloopFwdSm90 { !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k}; + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -1087,11 +1089,11 @@ struct CollectiveMainloopFwdSm90 { barrier_Q.wait(work_idx % 2); } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_q, + seqlen_info.seqlen_rotary); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; if (params.is_rotary_interleaved) { @@ -1579,12 +1581,12 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_k_new, + seqlen_info.seqlen_rotary); // This is used to index into the batch dimension of mK and mV int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; diff --git a/hopper/seqlen.h b/hopper/seqlen.h index 21a747128..5547238b3 100644 --- a/hopper/seqlen.h +++ b/hopper/seqlen.h @@ -64,12 +64,13 @@ struct SeqlenInfoQKNewK { int const leftpad_k; int const offset_q, offset_k, offset_k_new; - int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k; + int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary; CUTLASS_DEVICE SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, - int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k + int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k, + int const* const seqlens_rotary ) : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) @@ -85,6 +86,7 @@ struct SeqlenInfoQKNewK { ? 0 : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) + , seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb]) { } diff --git a/hopper/setup.py b/hopper/setup.py index f87d809eb..d9f4bad4c 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -377,6 +377,7 @@ def nvcc_threads_args(): # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.93"} + exe_extension = sysconfig.get_config_var("EXE") @@ -518,7 +519,7 @@ def nvcc_threads_args(): # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", # printing out number of registers "--resource-usage", # printing out number of registers # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster - "-lineinfo", + "-lineinfo", # TODO: disable this for release to reduce binary size "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index be27f14f6..fb014f719 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -564,12 +564,13 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) # @pytest.mark.parametrize("new_kv", [True]) -# @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(False, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) # @pytest.mark.parametrize("rotary_interleaved", [True]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) @@ -617,6 +618,7 @@ def test_flash_attn_kvcache( page_size, rotary_fraction, rotary_interleaved, + has_rotary_seqlens, seqlen_new_eq_seqlen_q, causal, local, @@ -630,6 +632,8 @@ def test_flash_attn_kvcache( pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) @@ -733,6 +737,7 @@ def test_flash_attn_kvcache( key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 if rotary_dim > 0: angle = ( torch.rand( @@ -747,7 +752,7 @@ def test_flash_attn_kvcache( sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) if causal or local: q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved ) else: q_ro = rearrange( @@ -755,7 +760,7 @@ def test_flash_attn_kvcache( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, - seqlen_offsets=cache_seqlens, + seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved, ), "b 1 (s h) d -> b s h d", @@ -763,7 +768,7 @@ def test_flash_attn_kvcache( ) # q_ro = q k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved ) else: cos, sin = None, None @@ -861,6 +866,7 @@ def test_flash_attn_kvcache( cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, causal=causal, window_size=window_size, rotary_interleaved=rotary_interleaved, From fef4fcf2b0391aac7a7af486b6a870723d1e3a0a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 21 Mar 2025 22:12:10 -0400 Subject: [PATCH 082/102] Use StreamkBarrier0/1 barriers instead of TileCountSmemEmpty/Full --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 2 +- hopper/named_barrier.hpp | 32 ++++++++++-------------- hopper/tile_scheduler.hpp | 20 +++++++-------- 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 65d447da0..b72906941 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -780,7 +780,7 @@ struct CollectiveMainloopFwdSm90 { pipeline_v.producer_commit(smem_pipe_write); // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized // before calling. Without this we get race conditions. - cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, cutlass::arch::ReservedNamedBarriers::TransposeBarrier /*id*/); pipeline_vt.consumer_release(smem_pipe_read); }; diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp index 8d07f6aa2..a7dfb6439 100644 --- a/hopper/named_barrier.hpp +++ b/hopper/named_barrier.hpp @@ -49,30 +49,24 @@ static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNa enum class FwdNamedBarriers { QueryEmpty = 0, - ProducerWG = 1, - TileCountSmemEmpty = 2, - TileCountSmemFull = 3, - WarpSchedulerWG1 = 4, - WarpSchedulerWG2 = 5, - WarpSchedulerWG3 = 6, - AppendKV = 7, - QueryRotated = 8, - PFull = 9, - PEmpty = 6, // HACK: PEmpty is only used when we don't have 3 WGs + WarpSchedulerWG1 = 1, + WarpSchedulerWG2 = 2, + WarpSchedulerWG3 = 3, + AppendKV = 4, + QueryRotated = 5, + PFull = 6, + PEmpty = 7, }; enum class BwdNamedBarriers { KVEmpty = 0, PdS = 1, - // This needs to match FwdNamedBarriers::TileCountSmemEmpty since TileScheduler uses it - TileCountSmemEmpty = 2, - TileCountSmemFull = 3, - dQEmptyWG1 = 4, - dQEmptyWG2 = 5, - dQEmptyWG3 = 6, - dQFullWG1 = 7, - dQFullWG2 = 8, - dQFullWG3 = 9, + dQEmptyWG1 = 2, + dQEmptyWG2 = 3, + dQEmptyWG3 = 4, + dQFullWG1 = 5, + dQFullWG2 = 6, + dQFullWG3 = 7, }; } // flash diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index f71324272..344a5c03d 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -320,7 +320,7 @@ class DynamicPersistentTileScheduler { void init_consumer() const { if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty } } @@ -339,16 +339,16 @@ class DynamicPersistentTileScheduler { if constexpr (IsProducerWarp) { // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % NumProducerThreads == 0) { *tile_count_smem = current_work.tile_idx; } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return {new_tile_idx}; } else { - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int tile_idx = *tile_count_smem; - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty return {tile_idx}; } } @@ -550,7 +550,7 @@ class VarlenDynamicPersistentTileScheduler { if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { return get_next_work(params, {0, 0, 0, 0}); @@ -580,16 +580,16 @@ class VarlenDynamicPersistentTileScheduler { int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int4 work_info = *work_info_smem; - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; } } From b1951a4e0126021657d1e2bcc05d934f9ebf90e3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 22 Mar 2025 11:54:51 -0400 Subject: [PATCH 083/102] Update Cutlass to 3.9 --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index afa177220..62750a2b7 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit 62750a2b75c802660e4894434dc55e839f322277 From df11fcae2635b85e22e720ceab5d75f279c918d4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 22 Mar 2025 16:10:08 -0400 Subject: [PATCH 084/102] Support hdim 64,256 --- hopper/flash_api.cpp | 15 ++++++++------- hopper/flash_fwd_launch_template.h | 2 +- hopper/generate_kernels.py | 1 + .../flash_fwd_hdim128_bf16_sm100.cu | 9 +++++++++ .../flash_fwd_hdim64_256_bf16_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_bf16_paged_sm90.cu | 9 +++++++++ ...lash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_bf16_paged_split_sm90.cu | 9 +++++++++ ...wd_hdim64_256_bf16_paged_split_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_bf16_sm90.cu | 9 +++++++++ ...sh_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_bf16_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_bf16_split_sm90.cu | 9 +++++++++ ...lash_fwd_hdim64_256_bf16_split_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_fp16_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_fp16_paged_sm90.cu | 9 +++++++++ ...lash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_fp16_paged_split_sm90.cu | 9 +++++++++ ...wd_hdim64_256_fp16_paged_split_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_fp16_sm90.cu | 9 +++++++++ ...sh_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_fp16_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_256_fp16_split_sm90.cu | 9 +++++++++ ...lash_fwd_hdim64_256_fp16_split_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdimdiff_bf16_packgqa_sm90.cu | 1 + .../flash_fwd_hdimdiff_bf16_paged_sm90.cu | 1 + .../flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu | 1 + .../flash_fwd_hdimdiff_bf16_paged_split_sm90.cu | 1 + ..._fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimdiff_bf16_sm90.cu | 1 + ...lash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu | 1 + .../flash_fwd_hdimdiff_bf16_softcap_sm90.cu | 1 + .../flash_fwd_hdimdiff_bf16_split_sm90.cu | 1 + .../flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimdiff_fp16_packgqa_sm90.cu | 1 + .../flash_fwd_hdimdiff_fp16_paged_sm90.cu | 1 + .../flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu | 1 + .../flash_fwd_hdimdiff_fp16_paged_split_sm90.cu | 1 + ..._fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimdiff_fp16_sm90.cu | 1 + ...lash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu | 1 + .../flash_fwd_hdimdiff_fp16_softcap_sm90.cu | 1 + .../flash_fwd_hdimdiff_fp16_split_sm90.cu | 1 + .../flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu | 1 + hopper/test_flash_attn.py | 8 ++++---- hopper/tile_size.h | 13 +++++++++---- 46 files changed, 232 insertions(+), 16 deletions(-) create mode 100644 hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index c79869483..58bc49da4 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -273,10 +273,11 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { - if (params.dv > 64 && Arch == 90) { + if (params.dv > 256 && Arch == 90) { return run_mha_fwd_(params, stream); - } - else { + } else if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { return run_mha_fwd_(params, stream); } } @@ -303,10 +304,11 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { - if (params.dv > 64 && Arch == 90) { + if (params.dv > 256 && Arch == 90) { return run_mha_fwd_(params, stream); - } - else { + } else if (params.dv > 64 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { return run_mha_fwd_(params, stream); } } @@ -1501,7 +1503,6 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x const int seqlen = sizes[2]; const int num_heads = sizes[3]; const int head_size_og = sizes[4]; - TORCH_CHECK(head_size_og <= 512, "FlashAttention combine only supports head dimension at most 512"); TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 452fd61b7..e9297e1b7 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -208,7 +208,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV == 512; + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index 19a6e90d3..b91a5b128 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -139,6 +139,7 @@ def get_all_kernels() -> List[Kernel]: if sm == 90 and head_dim == 192: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=256, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu new file mode 100644 index 000000000..4fb8f71d0 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<100, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu new file mode 100644 index 000000000..8d037153c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu new file mode 100644 index 000000000..c62e0b8d8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu new file mode 100644 index 000000000..5e22d67f7 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu new file mode 100644 index 000000000..1e005b3f0 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..96c4f55af --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu new file mode 100644 index 000000000..8a92fe291 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..f47cb3266 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu new file mode 100644 index 000000000..1915feb04 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu new file mode 100644 index 000000000..fbc157766 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu new file mode 100644 index 000000000..88445691f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu new file mode 100644 index 000000000..f7d051a34 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu new file mode 100644 index 000000000..c83c1741d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu new file mode 100644 index 000000000..2e06c89a8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu new file mode 100644 index 000000000..46479ec15 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 000000000..18681ec42 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu new file mode 100644 index 000000000..d2245aa13 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 000000000..022cdd395 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu new file mode 100644 index 000000000..67a324d52 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu new file mode 100644 index 000000000..664f88dbf --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu new file mode 100644 index 000000000..6bd6b9ab3 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu index cc3a8a7c9..ddd8bf07c 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu index d6d6df0d4..c9494c4f1 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_paged_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu index bd85f7608..4b2ec583c 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu index 733511adb..306722d45 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu index c62ccf28d..e44b2d246 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu index b7e51551a..d52417dae 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_sm90.cu" #include "flash_fwd_hdim64_512_bf16_sm90.cu" #include "flash_fwd_hdim192_128_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu index 0dbd00454..6428c461a 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu index 51a143712..d0df6306e 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu index 24a64e8e4..e116d3ea7 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_split_sm90.cu" #include "flash_fwd_hdim64_512_bf16_split_sm90.cu" #include "flash_fwd_hdim192_128_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu index 50c78f3d5..bededf4a7 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu index 453282a4f..ea5310279 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu index 72736d8ef..10d86e5e9 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_paged_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu index 97895aa70..375197ef7 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu index 423c42221..4fc4831cf 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu index 98c895721..a3d94a163 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu index 69108d025..9663103ae 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_sm90.cu" #include "flash_fwd_hdim64_512_fp16_sm90.cu" #include "flash_fwd_hdim192_128_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu index da39ba273..b7d2b07ca 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu index be6496d19..471b5abaa 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu index a5a809090..10f72182f 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_split_sm90.cu" #include "flash_fwd_hdim64_512_fp16_split_sm90.cu" #include "flash_fwd_hdim192_128_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu index 62fe14256..54db60c23 100644 --- a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu @@ -2,5 +2,6 @@ // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" +#include "flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index fb014f719..d68384c83 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -117,7 +117,7 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] for dv in dv_vals: @@ -336,7 +336,7 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] for dv in dv_vals: @@ -647,11 +647,11 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([512, d] if d <= 64 else [d]) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] for dv in dv_vals: - has_qv = d == 64 and dv == 512 + has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 2c440c6e2..4414b53ac 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -12,13 +12,18 @@ constexpr std::tuple tile_size_fwd_sm90( bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { - bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 - // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; - return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim}; + if (headdim_v == 512) { + return {64, 64, false, false}; + } else if (headdim_v == 256) { + return {128, 112, true, false}; + } else { + // Switch to tile size 192 x 192 for now + bool const use_blockN_128 = is_causal || is_local; + return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; + } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { From f6a294a2442666bcfced83405bd44e57bce2595d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 22 Mar 2025 17:15:16 -0400 Subject: [PATCH 085/102] Update benchmark with GLA --- hopper/benchmark_mla_decode.py | 21 +++++++++++---------- hopper/flash_api.cpp | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py index 9b7c05708..99b1b7a32 100644 --- a/hopper/benchmark_mla_decode.py +++ b/hopper/benchmark_mla_decode.py @@ -36,15 +36,15 @@ use_bench_cudagraph = False -attn_variants = ["mha", "gqa", "mqa", "mla"] -for attn_variant in attn_variants: -# for attn_variant in attn_variants[3:]: - nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else 1) - headdim = 64 if attn_variant == "mla" else 128 - headdim_v = 512 if attn_variant == "mla" else headdim - has_qv = headdim == 64 and headdim_v == 512 +attn_variants = ["mha", "gqa", "mqa", "mla", "gla"] +# for attn_variant in attn_variants: +for attn_variant in attn_variants[3:5]: + nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else (1 if attn_variant == "mla" else 2)) + headdim = 64 if attn_variant in ["mla", "gla"] else 128 + headdim_v = 512 if attn_variant == "mla" else (256 if attn_variant == "gla" else headdim) + has_qv = headdim == 64 and headdim_v > 64 # page_size = None - page_size = 64 if attn_variant == "mla" else 128 + page_size = 64 if attn_variant in ["mla", "gla"] else 128 should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None @@ -60,7 +60,7 @@ print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: - # for seqlen in [s * 1024 for s in [1]]: + # for seqlen in [s * 1024 for s in [8]]: cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) num_splits = 0 q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) @@ -84,6 +84,7 @@ cache_seqlens, q.dtype, headdim_v=headdim_v, page_size=page_size, causal=True ) # scheduler_metadata = None + # breakpoint() fn0 = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True, scheduler_metadata=scheduler_metadata) time.sleep(1) # to avoid power throttling # Time in ms @@ -109,7 +110,7 @@ t1 = do_bench_cudagraph(fn1, rep=10) total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() - mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last time is for the output + mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last term is for the output flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2 ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 ideal_h100_time_flop = flops / 989e12 * 1e6 diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 58bc49da4..ef715d38b 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -493,6 +493,16 @@ inline int round_up_headdim(int head_size) { return 256; } +inline int round_up_headdimv(int head_size) { + if (head_size <= 64) { return 64; } + if (head_size <= 96) { return 96; } + if (head_size <= 128) { return 128; } + if (head_size <= 192) { return 192; } + if (head_size <= 256) { return 256; } + return 512; +} + + // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available at::Tensor mha_fwd_get_scheduler_metadata( @@ -537,7 +547,7 @@ mha_fwd_get_scheduler_metadata( params.d = headdim; params.dv = headdim_v; params.d_rounded = round_up_headdim(headdim); - params.dv_rounded = round_up_headdim(headdim_v); + params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); params.seqlen_knew = max_seqlen_k_new; bool const is_varlen_q = cu_seqlens_q_.has_value(); @@ -827,7 +837,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); - int const head_size_v_rounded = round_up_headdim(head_size_v); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); From 29ef580560761838c0e9e82bc0e98d04ba75f949 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 22 Mar 2025 17:46:12 -0400 Subject: [PATCH 086/102] Adjust warp scheduler sync for HasQv case --- hopper/flash_api.cpp | 1 - hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index ef715d38b..6773ee7c1 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -502,7 +502,6 @@ inline int round_up_headdimv(int head_size) { return 512; } - // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available at::Tensor mha_fwd_get_scheduler_metadata( diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index b72906941..be0d79a26 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -1258,8 +1258,8 @@ struct CollectiveMainloopFwdSm90 { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - warp_scheduler_barrier_arrive(); if constexpr (!HasQv) { + warp_scheduler_barrier_arrive(); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); // release K } else { @@ -1267,7 +1267,9 @@ struct CollectiveMainloopFwdSm90 { shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); } consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K warpgroup_wait<0>(); } From 2f9ef0879a0935c3ca852f7a6a7b7a9c24f41e96 Mon Sep 17 00:00:00 2001 From: "Ye (Charlotte) Qi" Date: Tue, 25 Mar 2025 06:41:44 -0700 Subject: [PATCH 087/102] num_head -> args.num_head (#1552) Signed-off-by: Ye (Charlotte) Qi --- hopper/tile_scheduler.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 344a5c03d..1e4f14201 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -388,7 +388,7 @@ class VarlenDynamicPersistentTileScheduler { // If Split, for the purpose of scheduling, we pretend that instead there are // (args.num_splits * args.num_head) number of heads. assert(args.tile_count_semaphore != nullptr); - assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx + assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, From 1a58058a6da83bd7baaf4c512e8a1abe0240bb77 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 29 Mar 2025 01:29:01 -0400 Subject: [PATCH 088/102] Fix zeroing out the scheduler semaphore when reusing metadata --- hopper/flash_api.cpp | 4 + hopper/test_flash_attn.py | 172 +++++++++++++++++++------------------- 2 files changed, 91 insertions(+), 85 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 6773ee7c1..b82b10b78 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -1124,7 +1124,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq // params.b = 1; // params.seqlen_q = total_q; // } + // This will zero out the semaphore if needed run_mha_fwd_combine(params, stream, true /*enable_pdl*/); + } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { + // need to zero out the semaphore in this case + tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index d68384c83..4d20ff8af 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -833,12 +833,6 @@ def test_flash_attn_kvcache( num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): - if page_size is None: - k_cache.copy_(k_cache_saved) - v_cache.copy_(v_cache_saved) - else: - k_cache_paged.copy_(k_cache_saved) - v_cache_paged.copy_(v_cache_saved) if precompute_metadata: scheduler_metadata = get_scheduler_metadata( batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, @@ -850,90 +844,98 @@ def test_flash_attn_kvcache( ) else: scheduler_metadata = None - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - qv=qv if not varlen_q else qv_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - rotary_seqlens=rotary_seqlens, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - scheduler_metadata=scheduler_metadata, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] - ) + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) - else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): From 2dd8078adc1d9b74e315ee99718c0dea0de8eeb6 Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Tue, 1 Apr 2025 03:44:32 +0200 Subject: [PATCH 089/102] fix deprecation warning for newer torch versions (#1565) --- flash_attn/ops/fused_dense.py | 2 +- flash_attn/ops/triton/layer_norm.py | 4 +++- flash_attn/ops/triton/mlp.py | 2 +- flash_attn/utils/torch.py | 21 +++++++++++++++++++++ 4 files changed, 26 insertions(+), 3 deletions(-) create mode 100644 flash_attn/utils/torch.py diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 1e45b8e60..6b4033d13 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -11,9 +11,9 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup +from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd from flash_attn.utils.distributed import ( all_gather_raw, diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index 0d122aa08..f073c827c 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -10,11 +10,13 @@ import torch import torch.nn.functional as F -from torch.cuda.amp import custom_fwd, custom_bwd import triton import triton.language as tl +from flash_attn.utils.torch import custom_fwd, custom_bwd + + def triton_autotune_configs(): # Return configs with a valid warp count for the current device configs=[] diff --git a/flash_attn/ops/triton/mlp.py b/flash_attn/ops/triton/mlp.py index b795310f1..059f4f8a5 100644 --- a/flash_attn/ops/triton/mlp.py +++ b/flash_attn/ops/triton/mlp.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd +from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act diff --git a/flash_attn/utils/torch.py b/flash_attn/utils/torch.py new file mode 100644 index 000000000..98cbf9a27 --- /dev/null +++ b/flash_attn/utils/torch.py @@ -0,0 +1,21 @@ +import torch +from typing import Callable + + +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated) From 7ff1b621112ba8b538e2fc6a316f2a6b6f22e518 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 6 Apr 2025 22:41:59 -0400 Subject: [PATCH 090/102] Don't use FusedDense anymore to simplify code --- flash_attn/modules/mha.py | 47 +++++------------------- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 2 +- 2 files changed, 11 insertions(+), 38 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 77640c2b2..2c0a4f1b8 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -23,9 +23,9 @@ flash_attn_with_kvcache = None try: - from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear + from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear except ImportError: - FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None + ColumnParallelLinear, RowParallelLinear = None, None, None try: from flash_attn.layers.rotary import RotaryEmbedding @@ -341,13 +341,6 @@ def forward(self, q, kv, causal=None, key_padding_mask=None): return output -class LinearResidual(nn.Linear): - """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input), input - - def _update_kv_cache(kv, inference_params, layer_idx): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" # Pre-allocate memory for key-values for inference. @@ -452,13 +445,6 @@ def __init__( device=device, ) - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - linear_cls = nn.Linear if not fused_bias_fc else FusedDense - linear_resid_cls = ( - LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) - ) - wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls inner_attn_cls = ( partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn @@ -470,10 +456,10 @@ def __init__( else CrossAttention ) if not self.cross_attn: - self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wqkv = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) else: - self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) - self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wq = nn.Linear(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wkv = nn.Linear(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) if self.dwconv: if self.num_heads_kv == self.num_heads: self.dwconv_qkv = nn.Conv1d( @@ -492,7 +478,7 @@ def __init__( self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): dtype = self.out_proj.weight.dtype if dtype is None else dtype @@ -646,10 +632,7 @@ def forward( batch, seqlen = x.shape[:2] if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) + qkv = self.Wqkv(x) if self.dwconv: qkv = rearrange( self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" @@ -680,21 +663,11 @@ def forward( ) else: if self.cross_attn: - if not self.return_residual: - q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) - kv = self.Wkv(x_kv if x_kv is not None else x) - else: - if x_kv is not None: - kv, x_kv = self.Wkv(x_kv) - else: - kv, x = self.Wkv(x) - q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + kv = self.Wkv(x_kv if x_kv is not None else x) else: assert self.num_heads_kv != self.num_heads - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) + qkv = self.Wqkv(x) q = qkv[..., : self.num_heads * self.head_dim] kv = qkv[..., self.num_heads * self.head_dim :] q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index be0d79a26..68988862e 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -1658,7 +1658,7 @@ struct CollectiveMainloopFwdSm90 { rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); } } - // Without this sync I'm getting race condition when seqlen_k is large + // Without this fence I'm getting race condition when seqlen_k is large cutlass::arch::fence_view_async_shared(); // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized // before calling. From aa04de66e22fb1810eeede8ba736ccd895f16274 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 7 Apr 2025 18:39:52 -0400 Subject: [PATCH 091/102] Fix FA3 qkvpacked interface --- hopper/flash_attn_interface.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 59a5517ce..9e8d6908e 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -174,17 +174,26 @@ def forward( num_heads_k = (qkv.shape[2] - num_heads_q) // 2 assert num_heads_k * 2 + num_heads_q == qkv.shape[2] q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) - out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( + out, softmax_lse, *rest = _flash_attn_forward( q, k, v, + None, None, # k_new, v_new + None, # qv + None, # out + None, None, None, # cu_seqlens_q/k/k_new + None, None, # seqused_q/k + None, None, # max_seqlen_q/k + None, None, None, # page_table, kv_batch_idx, leftpad_k, + None, None, None, # rotary_cos/sin, seqlens_rotary + q_descale, k_descale, v_descale, softmax_scale, causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.save_for_backward(q, k, v, out, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size @@ -214,6 +223,9 @@ def backward(ctx, dout, *args): v, out, softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, From 2afa43cdab1e173f81408c37a7457aadf3bda895 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 8 Apr 2025 12:41:26 -0400 Subject: [PATCH 092/102] Launch more thread blocks in layer_norm_bwd --- flash_attn/ops/triton/layer_norm.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index f073c827c..0427e957e 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -637,7 +637,9 @@ def _layer_norm_bwd( BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) _db = ( torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) @@ -1020,12 +1022,12 @@ def forward( norm_bias, eps, residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), residual_dtype=residual_dtype, is_rms_norm=is_rms_norm, ) y = y.reshape(x_shape_og) - dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype linear_weight = linear_weight.to(dtype) linear_bias = linear_bias.to(dtype) if linear_bias is not None else None out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) From 9f2d2ae3b843bfea602dbb2893b7c00f6b099824 Mon Sep 17 00:00:00 2001 From: jayhshah Date: Tue, 8 Apr 2025 22:18:35 -0700 Subject: [PATCH 093/102] check valid tile before storing num_splits in split_idx (#1578) --- hopper/tile_scheduler.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 1e4f14201..53651d5c8 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -107,9 +107,9 @@ class SingleTileScheduler { } if constexpr (Varlen && Split) { int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; + is_valid_tile &= work_info.split_idx < num_splits_dynamic; // Use the top 16 bits to store num_splits work_info.split_idx |= (num_splits_dynamic << 16); - is_valid_tile &= work_info.split_idx < num_splits_dynamic; } work_info.bidb = is_valid_tile ? work_info.bidb : -1; return work_info; From d836a6bf09bf3838c6e71c9cf675b3708fea0d71 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 9 Apr 2025 14:39:14 -0400 Subject: [PATCH 094/102] Tune rotary kernel to use 2 warps if rotary_dim <= 64 --- flash_attn/ops/triton/rotary.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index 0ee56d647..560c75d00 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -38,8 +38,8 @@ def rotary_kernel( BLOCK_M: tl.constexpr, ): pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) + pid_head = tl.program_id(axis=1) + pid_batch = tl.program_id(axis=2) rotary_dim_half = rotary_dim // 2 if not IS_VARLEN: @@ -193,7 +193,7 @@ def apply_rotary( if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) ) - grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), nheads, batch) # noqa BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4) # Need this, otherwise Triton tries to launch from cuda:0 and we get @@ -223,5 +223,6 @@ def apply_rotary( interleaved, conjugate, BLOCK_M, + num_warps=2 if rotary_dim <= 64 else 4, ) return output From 80461564ebdc312dad268d28c216c8df4a41a982 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 9 Apr 2025 22:26:19 +0000 Subject: [PATCH 095/102] update api Signed-off-by: Lucas Wilkinson --- hopper/flash_api_torch_lib.cpp | 2 ++ vllm_flash_attn/flash_attn_interface.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index f3f6a18b2..a2006f3c4 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -38,6 +38,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq std::optional &leftpad_k_, // b std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &seqlens_rotary_, // b std::optional &q_descale_, // (b, h_k), not (b, h) std::optional &k_descale_, // (b, h_k) std::optional &v_descale_, // (b, h_k) @@ -104,6 +105,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? leftpad_k," " Tensor? rotary_cos," " Tensor? rotary_sin," + " Tensor? seqlens_rotary," " Tensor? q_descale," " Tensor? k_descale," " Tensor? v_descale," diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 6c524f9ed..30a160785 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -22,6 +22,7 @@ FA3_UNAVAILABLE_REASON = None FA3_AVAILABLE = True except ImportError as e: + raise e FA3_UNAVAILABLE_REASON = str(e) FA3_AVAILABLE = False @@ -262,7 +263,7 @@ def flash_attn_varlen_func( block_table, None, # kv_batch_idx None, # leftpad_k - None, None, # rotary_cos, rotary_sin + None, None, None, # rotary_cos, rotary_sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal, @@ -448,7 +449,7 @@ def flash_attn_with_kvcache( block_table, cache_batch_idx, # kv_batch_idx None, # leftpad_k - None, None, # rotary_cos, rotary_sin + None, None, None, # rotary_cos, rotary_sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal, From 70cd6257608134a2602c146acc8ffbd3e15845aa Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 11 Apr 2025 04:15:31 +0000 Subject: [PATCH 096/102] single wg for decode Signed-off-by: Lucas Wilkinson --- hopper/flash_api.cpp | 3 ++- hopper/flash_fwd_launch_template.h | 28 +++++++++++++++------------- hopper/tile_size.h | 8 ++++++-- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 5a595840a..acd84f48c 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -431,7 +431,8 @@ inline int get_num_splits(Flash_fwd_params const& params) { // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + bool use_one_mma_wg = params.seqlen_q * (!params.pack_gqa ? 1 : params.h / params.h_k) <= 64; + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index e9297e1b7..dee564961 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -26,7 +26,7 @@ using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor, bool Use_one_mma_wg> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -36,7 +36,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg); static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); @@ -203,17 +203,19 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { - // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; - BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + BOOL_SWITCH(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k) <= 64, Use_one_mma_wg, [&] { + // Only needed here to decide if we should use cluster + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128; + + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 4414b53ac..b87a83aff 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -9,7 +9,7 @@ // Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, - bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { + bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false, bool use_one_mma_wg=false) { if (element_size == 2) { if (headdim <= 64) { // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; @@ -29,7 +29,11 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; + if (use_one_mma_wg) { + return {64, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; + } else { + return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; + } // {128, 192, false, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { From 65c54add66e02cbae66535b125eb8f009fdbd887 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 11 Apr 2025 04:34:27 +0000 Subject: [PATCH 097/102] disable masking for pure decode Signed-off-by: Lucas Wilkinson --- hopper/flash_fwd_launch_template.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index dee564961..dab84880b 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -199,7 +199,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; using T_out = std::conditional_t; - CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { + CAUSAL_LOCAL_SWITCH(params.is_causal && params.seqlen_q > 1, params.is_local, Is_causal, Is_local, [&] { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { From 62c987bbd6b1671865778595570627d5eea2b28e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 14 Apr 2025 16:08:12 +0000 Subject: [PATCH 098/102] Seperate out `get_n_block_min_max` Signed-off-by: Lucas Wilkinson --- hopper/epilogue_fwd.hpp | 45 +- hopper/flash_fwd_kernel_sm80.h | 30 +- hopper/flash_fwd_kernel_sm90.h | 57 +- hopper/flash_fwd_launch_template.h | 33 +- hopper/mainloop_fwd_sm80.hpp | 72 +-- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 107 ++-- hopper/tile_scheduler.hpp | 663 +++++++++++++++++------ 7 files changed, 653 insertions(+), 354 deletions(-) diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 69102e8c4..357eb794c 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -217,18 +217,16 @@ struct CollectiveEpilogueFwd { SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, - cute::tuple const& block_coord + BlockCoord const& block_coord ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; - int num_splits = get<4>(params.shape_O_packed); - if constexpr (Split && Varlen) { - uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits - int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); - num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; - split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx - } - bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + int const m_block = block_coord.m_block; + int const bidh = block_coord.bidh; + int const bidb = block_coord.bidb; + int const peer_id = block_coord.peer_id; + int const num_peers = block_coord.num_peers; + + bool const is_split = !Split ? false : (!Varlen ? true : num_peers > 1); Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); @@ -292,7 +290,7 @@ struct CollectiveEpilogueFwd { Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), params.shape_LSE_packed, - !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : peer_id); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } if (!LargeHeadDimV || warp_group_idx == 0) { if constexpr (!PackGQA) { @@ -308,7 +306,7 @@ struct CollectiveEpilogueFwd { // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { - Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); + Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, peer_id); Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) @@ -361,7 +359,7 @@ struct CollectiveEpilogueFwd { PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } else { - Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, peer_id); Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) // We already arrived on barrier_O earlier if !Use_smem if constexpr (Use_smem) { @@ -410,18 +408,17 @@ struct CollectiveEpilogueFwd { store_zero( Params const& params, int thread_idx, - cute::tuple const& block_coord + BlockCoord const& block_coord ) { + int const m_block = block_coord.m_block; + int const bidh = block_coord.bidh; + int const bidb = block_coord.bidb; + int const peer_id = block_coord.peer_id; + int const num_peers = block_coord.num_peers; + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); - auto [m_block, bidh, bidb, split_idx] = block_coord; - int num_splits = get<4>(params.shape_O_packed); - if constexpr (Split && Varlen) { - uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits - int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); - num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; - split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx - } - bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + + bool const is_split = !Split ? false : (!Varlen ? true : num_peers > 1); flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; @@ -430,7 +427,7 @@ struct CollectiveEpilogueFwd { int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), params.shape_LSE_packed, - !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : peer_id); Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); static_assert(kBlockM <= NumEpilogueThreads); diff --git a/hopper/flash_fwd_kernel_sm80.h b/hopper/flash_fwd_kernel_sm80.h index b308d2d1b..e5bbbfdb8 100644 --- a/hopper/flash_fwd_kernel_sm80.h +++ b/hopper/flash_fwd_kernel_sm80.h @@ -154,7 +154,9 @@ class FlashAttnFwdSm80 { CollectiveMainloop mainloop; CollectiveEpilogue epilogue; - TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); + TileScheduler scheduler( + reinterpret_cast(&shared_storage.smem_scheduler), + params.scheduler); // Initialize matmul objects. TiledMma tiled_mma; @@ -162,14 +164,17 @@ class FlashAttnFwdSm80 { int warp_idx = cutlass::canonical_warp_idx_sync(); CUTLASS_PRAGMA_NO_UNROLL - for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { + for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work() : scheduler.template get_initial_work(); + scheduler.is_valid(work_tile_info); + work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(work_tile_info) : scheduler.template get_next_work(work_tile_info)) { // Attention output (GEMM-II) accumulator. Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; // If there's tanh softcap, the scaling will be done before tanh. - auto block_coord = work_tile_info.get_block_coord(params.scheduler); + + auto block_coord = scheduler.get_block_coord(work_tile_info); + auto seqlen_info = scheduler.get_seqlen_info(work_tile_info); + int const bidb = get<2>(block_coord); if constexpr (Is_FP8 && !Has_softcap) { int const bidh = get<1>(block_coord); @@ -180,24 +185,17 @@ class FlashAttnFwdSm80 { } flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); - SeqlenInfo_t seqlen_info{ - bidb, - get<0>(params.mainloop.shape_Q), - !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), - get<0>(params.mainloop.shape_K_new), - params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, - params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, - params.mainloop.seqlens_rotary - }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( - params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord); + params.mainloop, threadIdx.x, shared_storage, seqlen_info, + // upcast + static_cast>(block_coord)); if (tile_new_valid) { __syncthreads(); } } bool tile_valid = mainloop.mma( params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord, shared_storage); - scheduler.prefetch_next_work(params.scheduler, work_tile_info); + scheduler.prefetch_next_work(work_tile_info); if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 47b3817cd..07d750b7c 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -303,7 +303,9 @@ class FlashAttnFwdSm90 { __syncthreads(); } - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + TileScheduler scheduler( + reinterpret_cast(&shared_storage.pipelines.smem_scheduler), + params.scheduler); if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); @@ -325,20 +327,13 @@ class FlashAttnFwdSm90 { cutlass::arch::wait_on_dependent_grids(); // Load Q, K, V - for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); - work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { - - auto block_coord = work_tile_info.get_block_coord(params.scheduler); - SeqlenInfo_t seqlen_info{ - get<2>(block_coord) /*bidb*/, - get<0>(params.mainloop.shape_Q), - !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), - get<0>(params.mainloop.shape_K_new), - params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, - params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, - params.mainloop.seqlens_rotary - }; + for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work() : scheduler.template get_initial_work(); + scheduler.is_valid(work_tile_info); + work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work(work_tile_info) : scheduler.template get_next_work(work_tile_info)) { + + auto block_coord = scheduler.get_block_coord(work_tile_info); + auto seqlen_info = scheduler.get_seqlen_info(work_tile_info); + if constexpr (AppendKV) { bool tile_new_valid = mainloop.load_kv_new( params.mainloop, pipeline_k_new, pipeline_v_new, @@ -349,8 +344,8 @@ class FlashAttnFwdSm90 { // if (threadIdx.x == 0) { printf("Producer: After sync\n"); } } } - auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { - scheduler.prefetch_next_work(params.scheduler, work_tile_info); + auto scheduler_prefetch = [&scheduler, &work_tile_info]() { + scheduler.prefetch_next_work(work_tile_info); }; // pipeline_vt won't be used if we don't need to transpose V. mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, @@ -373,21 +368,13 @@ class FlashAttnFwdSm90 { int work_idx = 0; CUTLASS_PRAGMA_NO_UNROLL - for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); - work_tile_info.is_valid(params.scheduler); + for (auto work_tile_info = scheduler.template get_initial_work(); + scheduler.is_valid(work_tile_info); // get_next_work will be called before the epilogue ) { - auto block_coord = work_tile_info.get_block_coord(params.scheduler); - int const bidb = get<2>(block_coord); - SeqlenInfo_t seqlen_info{ - bidb, - get<0>(params.mainloop.shape_Q), - !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), - get<0>(params.mainloop.shape_K_new), - params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, - params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, - params.mainloop.seqlens_rotary - }; + auto block_coord = scheduler.get_block_coord(work_tile_info); + auto seqlen_info = scheduler.get_seqlen_info(work_tile_info); + if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new, @@ -407,11 +394,11 @@ class FlashAttnFwdSm90 { // If there's tanh softcap, the scaling will be done before tanh. float softmax_scale_log2 = params.mainloop.softmax_scale_log2; if constexpr (Is_FP8 && !Has_softcap) { - int const bidh = get<1>(block_coord); + int const bidh = block_coord.bidh; + int const bidb = block_coord.bidb; int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; - float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; - softmax_scale_log2 *= q_descale * k_descale; + softmax_scale_log2 = params.mainloop.softmax_scale_log2 * q_descale; } flash::Softmax softmax(softmax_scale_log2); // Attention output (GEMM-II) accumulator. @@ -433,9 +420,9 @@ class FlashAttnFwdSm90 { } } // Do this here before the epilogue so that the next tile is ready to go. - work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); + work_tile_info = scheduler.template get_next_work(work_tile_info); if constexpr (Split && Varlen) { - if (!work_tile_info.is_valid(params.scheduler)) { // Last tile + if (!scheduler.is_valid(work_tile_info)) { // Last tile cutlass::arch::launch_dependent_grids(); } } diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index dab84880b..d05a65cdf 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -46,25 +46,26 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; using TileShape_MNK = cute::Shape, Int, Int>; using TileShape_MNK_PV = cute::Shape, Int, Int>; using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/>, std::conditional_t, - flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> + flash::StaticPersistentTileScheduler, + flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> > >; - using SchedulerSingleTile = flash::SingleTileScheduler; + using SchedulerSingleTile = flash::SingleTileScheduler; // If Split then we probably don't have enough work for PersistentScheduler to be useful. // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better // since we'll avoid launching a bunch of thread blocks that immediately exit. @@ -122,11 +123,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { {params.v_descale_batch_stride, params.v_descale_head_stride}, params.window_size_left, params.window_size_right, params.softcap, - params.num_splits, params.kv_batch_idx, - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, - params.seqused_q, params.seqused_k, - params.leftpad_k, params.seqlens_rotary + is_varlen_k_new, is_varlen_q, is_varlen_k }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.o_ptr), @@ -147,12 +145,21 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{})); typename flash::TileSchedulerArguments scheduler_args { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, - params.h / params.h_k, + qhead_per_khead, params.seqlen_q, params.seqlen_k, params.d, params.dv, sizeof(Element), - params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, - // params.num_m_blocks_ptr, + params.tile_count_semaphore, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, + params.leftpad_k, params.seqlens_rotary, + {seqlen_q, params.d, params.h, batch_q}, // shape_Q + {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, + params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K + {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new + params.page_table, + {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table params.num_splits_dynamic_ptr, + params.window_size_left, params.window_size_right }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 905be872d..c3952a749 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -23,7 +23,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm80 { @@ -44,7 +44,6 @@ struct CollectiveMainloopFwdSm80 { static constexpr bool AppendKV = AppendKV_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; - static constexpr bool Transpose_V = Is_FP8; static_assert(ArchTag::kMinComputeCapability >= 80); @@ -54,8 +53,6 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - using SeqlenInfo_t = flash::SeqlenInfoQKNewK; - using BlockMN_t = flash::BlockMN; using MMA_Atom_Arch = std::conditional_t< ArchTag::kMinComputeCapability >= 80, @@ -204,15 +201,10 @@ struct CollectiveMainloopFwdSm80 { StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; int const window_size_left = -1, window_size_right = -1; float const softcap_val; - int const num_splits; int const* const kv_batch_idx = nullptr; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const* const seqlens_rotary = nullptr; + bool const is_varlen_knew = false; + bool const is_varlen_q = false; + bool const is_varlen_k = false; }; // Device side kernel params @@ -249,17 +241,15 @@ struct CollectiveMainloopFwdSm80 { StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; int const window_size_left, window_size_right; - int const num_splits; int const* const kv_batch_idx = nullptr; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const* const seqlens_rotary = nullptr; + bool const is_varlen_knew = false; + bool const is_varlen_q = false; + bool const is_varlen_k = false; }; + using SeqlenInfo_t = SeqlenInfo_t_; + using BlockMN_t = flash::BlockMN; + static Params to_underlying_arguments(Arguments const& args) { // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) @@ -294,10 +284,9 @@ struct CollectiveMainloopFwdSm80 { args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.window_size_left, args.window_size_right, - !Split ? 1 : args.num_splits, args.kv_batch_idx, - args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; + args.is_varlen_knew, args.is_varlen_q, args.is_varlen_k + }; } template @@ -307,7 +296,8 @@ struct CollectiveMainloopFwdSm80 { Softmax& softmax, int const thread_idx, SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, + // (m_block, bidh, bidb, n_block_min, n_block_max) + BlockCoord const& block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); @@ -315,16 +305,13 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda - int const m_block = get<0>(block_coord); - int const bidh = get<1>(block_coord); - int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); + int const m_block = block_coord.m_block; + int const bidh = block_coord.bidh; int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto n_block_min_max = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - int const n_block_min = get<0>(n_block_min_max); - int const n_block_max = get<1>(n_block_min_max); + int const bidb = block_coord.bidb; + int const n_block_min = block_coord.n_block_min; + int const n_block_max = block_coord.n_block_max; + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -335,8 +322,8 @@ struct CollectiveMainloopFwdSm80 { Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); - bool const is_varlen_q = Varlen && params.cu_seqlens_q; - bool const is_varlen_k = Varlen && params.cu_seqlens_k; + bool const is_varlen_q = Varlen && params.is_varlen_q; + bool const is_varlen_k = Varlen && params.is_varlen_k; int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); @@ -659,14 +646,14 @@ struct CollectiveMainloopFwdSm80 { int const thread_idx, SharedStorage &shared_storage, SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord + BlockCoord const& block_coord ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - int const n_block_new_min = get<0>(n_block_new_min_max); - int const n_block_new_max = get<1>(n_block_new_min_max); + int const bidh = block_coord.bidh; + int const bidb = block_coord.bidb; + int const n_block_new_min = block_coord.n_block_new_min; + int const n_block_new_max = block_coord.n_block_new_max; + bool const is_varlen_k_new = Varlen && params.is_varlen_knew; + if (n_block_new_max <= n_block_new_min) { return false; } Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); @@ -675,7 +662,6 @@ struct CollectiveMainloopFwdSm80 { int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; - bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; Tensor mKnew = make_tensor(make_gmem_ptr(params.ptr_K_new), params.shape_K_new, params.stride_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor mVnew = make_tensor(make_gmem_ptr(params.ptr_V_new), params.shape_K_new, params.stride_V_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 68988862e..e706b7955 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -28,7 +28,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm90 { @@ -69,8 +69,8 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - using SeqlenInfo_t = flash::SeqlenInfoQKNewK; - using BlockMN_t = flash::BlockMN; + using SeqlenInfo_t = SeqlenInfo_t_; + using BlockMN_t = flash::BlockMN; static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); @@ -387,15 +387,10 @@ struct CollectiveMainloopFwdSm90 { StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; int const window_size_left = -1, window_size_right = -1; float const softcap_val; - int const num_splits; int const* const kv_batch_idx = nullptr; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const* const seqlens_rotary = nullptr; + bool const is_varlen_knew = false; + bool const is_varlen_q = false; + bool const is_varlen_k = false; }; // Device side kernel params @@ -443,15 +438,10 @@ struct CollectiveMainloopFwdSm90 { StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; int const window_size_left, window_size_right; - int const num_splits; int const* const kv_batch_idx = nullptr; - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const *const seqlens_rotary = nullptr; + bool const is_varlen_knew = false; + bool const is_varlen_q = false; + bool const is_varlen_k = false; }; static Params @@ -557,10 +547,9 @@ struct CollectiveMainloopFwdSm90 { args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.window_size_left, args.window_size_right, - !Split ? 1 : args.num_splits, args.kv_batch_idx, - args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; + args.is_varlen_knew, args.is_varlen_q, args.is_varlen_k + }; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -592,18 +581,17 @@ struct CollectiveMainloopFwdSm90 { SharedStorage &shared_storage, SchedulerPrefetch const& scheduler_prefetch, SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, + BlockCoord const& block_coord, int &work_idx ) { // some of these are captured in lambda so can't use structured binding - int const m_block = get<0>(block_coord); - int const bidh = get<1>(block_coord); - int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + int const m_block = block_coord.m_block; + int const bidh = block_coord.bidh; + int const bidb = block_coord.bidb; + int const n_block_min = block_coord.n_block_min; + int const n_block_max = block_coord.n_block_max; + // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { @@ -645,8 +633,8 @@ struct CollectiveMainloopFwdSm90 { constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - bool const is_varlen_q = Varlen && params.cu_seqlens_q; - bool const is_varlen_k = Varlen && params.cu_seqlens_k; + bool const is_varlen_q = Varlen && params.is_varlen_q; + bool const is_varlen_k = Varlen && params.is_varlen_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); @@ -955,22 +943,20 @@ struct CollectiveMainloopFwdSm90 { int const thread_idx, int &work_idx, SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, + BlockCoord const& block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); - // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda - int const m_block = get<0>(block_coord); - int const bidh = get<1>(block_coord); - int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); + int const m_block = block_coord.m_block; + int const bidb = block_coord.bidb; + int const bidh = block_coord.bidh; int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + int const n_block_min = block_coord.n_block_min; + int const n_block_max = block_coord.n_block_max; + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -1357,17 +1343,16 @@ struct CollectiveMainloopFwdSm90 { Softmax& softmax, int const thread_idx, SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, + BlockCoord const& block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda - int const m_block = get<0>(block_coord); - int const bidb = get<2>(block_coord); - int const split_idx = get<3>(block_coord); - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + int const m_block = block_coord.m_block; + int const bidb = block_coord.bidb; + int const n_block_min = block_coord.n_block_min; + int const n_block_max = block_coord.n_block_max; + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -1446,14 +1431,14 @@ struct CollectiveMainloopFwdSm90 { PipelineState& smem_pipe_write, SharedStorage &shared_storage, SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, + BlockCoord const& block_coord, int const work_idx ) { - - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + int const bidh = block_coord.bidh; + int const bidb = block_coord.bidb; + int const n_block_new_min = block_coord.n_block_new_min; + int const n_block_new_max = block_coord.n_block_new_max; + bool const is_varlen_k_new = Varlen && params.is_varlen_knew; if (n_block_new_max <= n_block_new_min) { return false; } @@ -1474,7 +1459,6 @@ struct CollectiveMainloopFwdSm90 { constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); @@ -1550,12 +1534,13 @@ struct CollectiveMainloopFwdSm90 { int const thread_idx, SharedStorage &shared_storage, SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord + BlockCoord const& block_coord ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + int const bidh = block_coord.bidh; + int const bidb = block_coord.bidb; + int const n_block_new_min = block_coord.n_block_new_min; + int const n_block_new_max = block_coord.n_block_new_max; + if (n_block_new_max <= n_block_new_min) { return false; } // as_position_independent_swizzle_tensor makes address calculation easier @@ -1572,7 +1557,7 @@ struct CollectiveMainloopFwdSm90 { int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; - bool const is_varlen_k = Varlen && params.cu_seqlens_k; + bool const is_varlen_k = Varlen && params.is_varlen_k; Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 53651d5c8..326952b7b 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -8,12 +8,17 @@ #include "cutlass/arch/barrier.h" #include "named_barrier.hpp" +#include "block.h" #include "utils.h" +#include "seqlen.h" namespace flash { /////////////////////////////////////////////////////////////////////////////// +using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) +using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) + // Host side kernel arguments struct TileSchedulerArguments { // num_head is num_head_q if not PackGQA, else num_head_k @@ -22,41 +27,92 @@ struct TileSchedulerArguments { int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling int* const tile_count_semaphore = nullptr; - int const* const cu_seqlens = nullptr; - int const* const seqused = nullptr; - // int const* const num_m_blocks_ptr = nullptr; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; + ShapeQKV const shape_Q; + ShapeQKV const shape_K; + ShapeQKV const shape_K_new; + int const* const ptr_pagetable = nullptr; + ShapePageTable const shape_pagetable; int const* const num_splits_dynamic_ptr = nullptr; + int const window_size_left = -1; + int const window_size_right = 0; +}; + +template +struct BlockCoord {}; + +template<> +struct BlockCoord { + int const m_block = -1; + int const bidh = -1; + int const bidb = -1; + int const n_block_min = 0; + int const n_block_max = 0; + int const peer_id = 0; // Where to write the partial results / split_k index + int const num_peers = 0; // Number of peers / num_splits +}; + +template<> +struct BlockCoord: public BlockCoord { + int const n_block_new_min = 0; + int const n_block_new_max = 0; }; /////////////////////////////////////////////////////////////////////////////// -template +template class SingleTileScheduler { public: - using SharedStorage = int; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + using BlockMN_t = flash::BlockMN; // Device side kernel params struct Params { int const num_blocks, num_head, num_batch, num_splits; - int const qhead_per_khead; + cutlass::FastDivmod qhead_per_khead; int const seqlen; cutlass::FastDivmod nsplits_divmod; - int const* const cu_seqlens; - int const* const seqused; + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; + ShapeQKV shape_Q; + ShapeQKV shape_K; + ShapeQKV shape_K_new; + int const* const ptr_pagetable = nullptr; + ShapePageTable shape_pagetable; int const* const num_splits_dynamic_ptr = nullptr; + int const window_size_left; + int const window_size_right; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); - assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits + assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits in VarlenDynamic return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, - args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(args.qhead_per_khead), + args.seqlen, cutlass::FastDivmod(!Split ? 1 : args.num_splits), - !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, - args.num_splits_dynamic_ptr}; + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, + args.shape_Q, args.shape_K, args.shape_K_new, + args.ptr_pagetable, + args.shape_pagetable, + args.num_splits_dynamic_ptr, + args.window_size_left, args.window_size_right}; } static dim3 @@ -64,55 +120,93 @@ class SingleTileScheduler { return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; } + Params const& params; struct WorkTileInfo { - int block_idx = 0; - int bidh = 0; - int bidb = 0; - int split_idx = 0; - - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - return bidb >= 0; - } + BlockCoord block_coord; + SeqlenInfo_t seqlen_info; + }; - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; - } + CUTLASS_DEVICE + bool + is_valid(WorkTileInfo const& work_tile) const { + return work_tile.block_coord.bidb >= 0; + } - }; + CUTLASS_DEVICE + BlockCoord + get_block_coord(WorkTileInfo const& work_tile) const { + return work_tile.block_coord; + } + + CUTLASS_DEVICE + SeqlenInfo_t + get_seqlen_info(WorkTileInfo const& work_tile) const { + return work_tile.seqlen_info; + } CUTLASS_DEVICE - SingleTileScheduler(SharedStorage* const smem_scheduler) { } + SingleTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : params(params) { } template CUTLASS_DEVICE WorkTileInfo - get_initial_work(Params const& params) const { - WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; + get_initial_work() const { + int const m_block = blockIdx.x; + int bidb = blockIdx.z; + int bidh = blockIdx.y; + int peer_id = 0; + int num_peers = 1; + + SeqlenInfo_t seqlen_info{ + bidb, + get<0>(params.shape_Q), + !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), + get<0>(params.shape_K_new), + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.seqlens_rotary + }; + + bool is_valid_tile = true; if constexpr (Split) { - int split_idx; - work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); - work_info.split_idx = split_idx; + peer_id = params.nsplits_divmod.divmod(peer_id, bidh); + num_peers = params.nsplits_divmod.divisor; } - bool is_valid_tile = true; if constexpr (Varlen) { - int seqlen = params.seqused - ? params.seqused[work_info.bidb] - : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - is_valid_tile = work_info.block_idx * kBlock < seqlen; + int seqlen_q_ = params.seqused_q + ? params.seqused_q[bidb] + : (params.cu_seqlens_q ? params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb] : params.seqlen); + if constexpr (PackGQA) { seqlen_q_ *= params.qhead_per_khead_divmod.divisor; } + is_valid_tile = m_block * kBlock < seqlen_q_; + + if constexpr (Split) { + int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[bidb] : num_peers; + is_valid_tile &= peer_id < num_splits_dynamic; + num_peers = num_splits_dynamic; + } } - if constexpr (Varlen && Split) { - int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; - is_valid_tile &= work_info.split_idx < num_splits_dynamic; - // Use the top 16 bits to store num_splits - work_info.split_idx |= (num_splits_dynamic << 16); + if (!is_valid_tile) { bidb = -1; } + + // Calculate n_block_min/max based on causality and local window + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, peer_id & 0xFFFF /* Get actual peer_id */, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + + if constexpr (AppendKV) { + auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, peer_id & 0xFFFF /* Get actual peer_id */, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + return { + {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, + get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}, + seqlen_info + }; + } else { + return { + {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}, + seqlen_info + }; } - work_info.bidb = is_valid_tile ? work_info.bidb : -1; - return work_info; } CUTLASS_DEVICE @@ -121,38 +215,69 @@ class SingleTileScheduler { CUTLASS_DEVICE void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + prefetch_next_work(WorkTileInfo& current_work) const { + } template CUTLASS_DEVICE WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {0, 0, -1, 0}; + get_next_work(WorkTileInfo const& current_work) const { + return { BlockCoord{}, {} }; } }; /////////////////////////////////////////////////////////////////////////////// -template +template class StaticPersistentTileScheduler { public: using SharedStorage = int; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + using BlockMN_t = flash::BlockMN; // Device side kernel params struct Params { int total_blocks; cutlass::FastDivmod m_block_divmod, head_divmod; cutlass::FastDivmod nsplits_divmod; + int num_splits; // Static number of splits + cutlass::FastDivmod qhead_per_khead_divmod; + ShapeQKV shape_Q; + ShapeQKV shape_K; + ShapeQKV shape_K_new; + int const* cu_seqlens_q = nullptr; // Assuming null for non-varlen static + int const* cu_seqlens_k = nullptr; + int const* cu_seqlens_k_new = nullptr; + int const* seqused_q = nullptr; + int const* seqused_k = nullptr; + int const* leftpad_k = nullptr; + int const* seqlens_rotary = nullptr; + int const* ptr_pagetable = nullptr; + int window_size_left; + int window_size_right; + // Params specific to this scheduler + int num_m_blocks; // Needed for BlockMN_t }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { - return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), - cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), - cutlass::FastDivmod(!Split ? 1 : args.num_splits)}; + int num_splits_val = !Split ? 1 : args.num_splits; + return {args.num_blocks * args.num_head * args.num_batch * num_splits_val, + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * num_splits_val), + cutlass::FastDivmod(num_splits_val), + num_splits_val, + cutlass::FastDivmod(args.qhead_per_khead), + args.shape_Q, args.shape_K, args.shape_K_new, + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, + args.ptr_pagetable, + args.window_size_left, args.window_size_right, + args.num_blocks + }; } static dim3 @@ -162,34 +287,80 @@ class StaticPersistentTileScheduler { struct WorkTileInfo { int tile_idx; + }; + + Params const& params; + + CUTLASS_DEVICE + bool + is_valid(WorkTileInfo const& work_tile) const { + return work_tile.tile_idx < params.total_blocks; + } - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - return tile_idx < params.total_blocks; + CUTLASS_DEVICE + BlockCoord + get_block_coord(WorkTileInfo const& work_tile) const { + int m_block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_tile.tile_idx)); + int peer_id = 0; + int num_peers = 1; + if constexpr (Split) { + num_peers = params.nsplits_divmod.divisor; + bidh = params.nsplits_divmod.divmod(peer_id, bidh); } - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - int block, bidh, bidb; - bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); - int split_idx = 0; - if constexpr (Split) { - bidh = params.nsplits_divmod.divmod(split_idx, bidh); - } - return {block, bidh, bidb, split_idx}; + SeqlenInfo_t seqlen_info{ + bidb, + get<0>(params.shape_Q), + !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), + get<0>(params.shape_K_new), + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.seqlens_rotary + }; + + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, peer_id, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + + if constexpr (AppendKV) { + auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, peer_id, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, + get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}; + } else { + return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}; } + } + + CUTLASS_DEVICE + SeqlenInfo_t + get_seqlen_info(WorkTileInfo const& work_tile) const { + // Recompute block coord parts needed for SeqlenInfo + int m_block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_tile.tile_idx)); + // No need for split info here as SeqlenInfo is per batch item (bidb) + + return SeqlenInfo_t { + bidb, + get<0>(params.shape_Q), + !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), + get<0>(params.shape_K_new), + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.seqlens_rotary + }; + } - }; CUTLASS_DEVICE - StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {}; + StaticPersistentTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : params(params) {}; template CUTLASS_DEVICE WorkTileInfo - get_initial_work(Params const& params) const { + get_initial_work() const { return {int(blockIdx.x)}; } @@ -199,19 +370,19 @@ class StaticPersistentTileScheduler { CUTLASS_DEVICE void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + prefetch_next_work(WorkTileInfo& current_work) const { + } template CUTLASS_DEVICE WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { + get_next_work(WorkTileInfo const& current_work) const { return {current_work.tile_idx + int(gridDim.x)}; } }; -template +template class DynamicPersistentTileScheduler { // This scheduler targets the causal (or local) case where each tile takes different @@ -228,6 +399,9 @@ class DynamicPersistentTileScheduler { public: using SharedStorage = int; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + using BlockMN_t = flash::BlockMN; protected: SharedStorage* const tile_count_smem; @@ -242,6 +416,28 @@ class DynamicPersistentTileScheduler { cutlass::FastDivmod const l2_minor_residual_divmod; int const num_hb_quotient; int* const tile_count_semaphore; + int const num_splits; // Static number of splits + cutlass::FastDivmod qhead_per_khead_divmod; + ShapeQKV shape_Q; + ShapeQKV shape_K; + ShapeQKV shape_K_new; + int const* cu_seqlens_q = nullptr; // Assuming null for non-varlen dynamic + int const* cu_seqlens_k = nullptr; + int const* cu_seqlens_k_new = nullptr; + int const* seqused_q = nullptr; + int const* seqused_k = nullptr; + int const* leftpad_k = nullptr; + int const* seqlens_rotary = nullptr; + int const* ptr_pagetable = nullptr; + int window_size_left; + int window_size_right; + // Params needed for L2 swizzling calculation + int const seqlen_k; + int const headdim; + int const headdim_v; + int const element_size; + // Num M blocks + int num_m_blocks; }; static Params @@ -256,6 +452,7 @@ class DynamicPersistentTileScheduler { // swizzle. Instead we want to divide by the remainder. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits); + int const num_splits_val = !Split ? 1 : args.num_splits; // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder); assert(args.tile_count_semaphore != nullptr); return {num_split_blocks * args.num_head * args.num_batch, @@ -264,7 +461,17 @@ class DynamicPersistentTileScheduler { // don't divide by 0 cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), (args.num_head * args.num_batch) / swizzle, - args.tile_count_semaphore}; + args.tile_count_semaphore, + num_splits_val, + cutlass::FastDivmod(args.qhead_per_khead), + args.shape_Q, args.shape_K, args.shape_K_new, + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, + args.ptr_pagetable, + args.window_size_left, args.window_size_right, + args.seqlen_k, args.headdim, args.headdim_v, args.element_size, + args.num_blocks + }; } static dim3 @@ -274,45 +481,97 @@ class DynamicPersistentTileScheduler { struct WorkTileInfo { int tile_idx; + }; + + Params const& params; - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - return tile_idx < params.total_blocks; + CUTLASS_DEVICE + bool + is_valid(WorkTileInfo const& work_tile) const { + return work_tile.tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + BlockCoord + get_block_coord(WorkTileInfo const& work_tile) const { + int m_block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, work_tile.tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + m_block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + m_block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + int peer_id = 0; + int num_peers = 1; + if constexpr (Split) { + num_peers = params.num_splits; // Static splits for Dynamic scheduler + peer_id = params.m_block_divmod.divmod(m_block, m_block); // This uses m_block_divmod's divisor (num_m_blocks) + } + // Longest-processing-time-first means we process m_blocks in reverse order + m_block = params.m_block_divmod.divisor - 1 - m_block; + + SeqlenInfo_t seqlen_info{ + bidb, + get<0>(params.shape_Q), + !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), + get<0>(params.shape_K_new), + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.seqlens_rotary + }; - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - int block, bidh, bidb; - int l2_mod, bidhb, bidhb_residual; - bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); - // If we're in the last section (called residual), we don't want to divide by - // swizzle. Instead we want to divide by the remainder. - if (bidhb < params.num_hb_quotient) { - block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); - } else { - block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); - } - bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); - int split_idx = 0; - if constexpr (Split) { - split_idx = params.m_block_divmod.divmod(block, block); - } - // Longest-processing-time-first - block = params.m_block_divmod.divisor - 1 - block; - return {block, bidh, bidb, split_idx}; + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, peer_id, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + + if constexpr (AppendKV) { + auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, peer_id, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, + get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}; + } else { + return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}; } + } + + CUTLASS_DEVICE + SeqlenInfo_t + get_seqlen_info(WorkTileInfo const& work_tile) const { + // Recompute block coord parts needed for SeqlenInfo + int m_block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, work_tile.tile_idx); + if (bidhb < params.num_hb_quotient) { + m_block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + m_block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + + return SeqlenInfo_t{ + bidb, + get<0>(params.shape_Q), + !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), + get<0>(params.shape_K_new), + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.seqlens_rotary + }; + } - }; CUTLASS_DEVICE - DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; + DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : tile_count_smem(smem_scheduler), params(params) {}; template CUTLASS_DEVICE WorkTileInfo - get_initial_work(Params const& params) const { + get_initial_work() const { return {int(blockIdx.x)}; } @@ -326,7 +585,7 @@ class DynamicPersistentTileScheduler { CUTLASS_DEVICE void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + prefetch_next_work(WorkTileInfo& current_work) const { if (threadIdx.x % NumProducerThreads == 0) { current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); } @@ -335,16 +594,17 @@ class DynamicPersistentTileScheduler { template CUTLASS_DEVICE WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { + get_next_work(WorkTileInfo const& current_work) const { if constexpr (IsProducerWarp) { // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); + WorkTileInfo work_info = tile_idx_to_work_tile(new_tile_idx, current_work); flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % NumProducerThreads == 0) { *tile_count_smem = current_work.tile_idx; } flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull - return {new_tile_idx}; + return work_info; } else { flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int tile_idx = *tile_count_smem; @@ -355,7 +615,10 @@ class DynamicPersistentTileScheduler { }; -template +template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -363,6 +626,9 @@ class VarlenDynamicPersistentTileScheduler { public: using SharedStorage = int4; + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + using BlockMN_t = flash::BlockMN; protected: SharedStorage* const work_info_smem; @@ -372,15 +638,32 @@ class VarlenDynamicPersistentTileScheduler { // Device side kernel params struct Params { int num_head, num_batch; - int const qhead_per_khead; - int const seqlen; + cutlass::FastDivmod qhead_per_khead_divmod; + int const seqlen; // Max seqlen cutlass::FastDivmod head_divmod; - cutlass::FastDivmod nsplits_divmod; + cutlass::FastDivmod nsplits_divmod; // Static num splits divisor + int const num_splits; // Static num splits int* const tile_count_semaphore; - int const* const cu_seqlens; - int const* const seqused; - // int* const num_m_blocks_ptr; - int const* const num_splits_dynamic_ptr; + // Sequence length info (needed for tile_idx_to_work_tile and SeqlenInfo) + int const* const cu_seqlens_q = nullptr; + int const* const cu_seqlens_k = nullptr; + int const* const cu_seqlens_k_new = nullptr; + int const* const seqused_q = nullptr; + int const* const seqused_k = nullptr; + int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; + // Shape info (needed for SeqlenInfo) + ShapeQKV const shape_Q; + ShapeQKV const shape_K; + ShapeQKV const shape_K_new; + // Paged attention table (needed for SeqlenInfo) + int const* const ptr_pagetable = nullptr; + ShapePageTable const shape_pagetable; + // Dynamic splits for Varlen + int const* const num_splits_dynamic_ptr = nullptr; + // Window sizes for local/causal attention (needed for BlockMN_t) + int const window_size_left; + int const window_size_right; }; static Params @@ -390,13 +673,22 @@ class VarlenDynamicPersistentTileScheduler { assert(args.tile_count_semaphore != nullptr); assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + int const num_splits_val = !Split ? 1 : args.num_splits; return {args.num_head, args.num_batch, - args.qhead_per_khead, args.seqlen, + cutlass::FastDivmod(args.qhead_per_khead), + args.seqlen, cutlass::FastDivmod(args.num_head), - cutlass::FastDivmod(!Split ? 1 : args.num_splits), - args.tile_count_semaphore, args.cu_seqlens, args.seqused, - // args.num_m_blocks_ptr, - args.num_splits_dynamic_ptr}; + cutlass::FastDivmod(num_splits_val), + num_splits_val, + args.tile_count_semaphore, + args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, + args.shape_Q, args.shape_K, args.shape_K_new, + args.ptr_pagetable, + args.shape_pagetable, + args.num_splits_dynamic_ptr, + args.window_size_left, args.window_size_right + }; } static dim3 @@ -406,61 +698,108 @@ class VarlenDynamicPersistentTileScheduler { struct WorkTileInfo { int tile_idx, block, bidh, bidb; + }; + + Params const& params; + + CUTLASS_DEVICE + bool + is_valid(WorkTileInfo const& work_tile) const { + // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, work_tile.bidb, params.num_batch); } + return work_tile.bidb >= 0 && work_tile.bidb < params.num_batch; + } + + CUTLASS_DEVICE + BlockCoord + get_block_coord(WorkTileInfo const& work_tile) const { + int m_block = work_tile.block; + int bidh_in = work_tile.bidh; // This might be packed + int bidb = work_tile.bidb; + int peer_id = 0; + int num_peers = 1; + int bidh_actual = bidh_in; - CUTLASS_DEVICE - bool - is_valid(Params const& params) const { - // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); } - return bidb < params.num_batch; + if constexpr (Split) { + // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_in); + uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; + bidh_actual = reinterpret_cast(bidh_actual_u); + // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx + // Extract split_idx (lower 8 bits of upper 16) and num_splits (upper 8 bits of upper 16) + uint32_t split_idx_u = (bidh_packed >> 16) & 0xFF; + uint32_t num_peers_u = (bidh_packed >> 24) & 0xFF; + peer_id = reinterpret_cast(split_idx_u); + num_peers = reinterpret_cast(num_peers_u); } - CUTLASS_DEVICE - cute::tuple - get_block_coord(Params const& params) const { - if constexpr (!Split) { - return {block, bidh, bidb, 0 /*split_idx*/}; - } else { - // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh); - uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; - int bidh_actual = reinterpret_cast(bidh_actual_u); - // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx - uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); - int split_idx = reinterpret_cast(split_idx_u); - // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); - // if (threadIdx.x == 128) { - // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); - // } - return {block, bidh_actual, bidb, split_idx}; - } + SeqlenInfo_t seqlen_info { + bidb, + get<0>(params.shape_Q), + !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), + get<0>(params.shape_K_new), + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.seqlens_rotary + }; + + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, peer_id & 0xFFFF /* actual peer id*/, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + + if constexpr (AppendKV) { + auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, peer_id & 0xFFFF /* actual peer id*/, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + return {m_block, bidh_actual, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, + get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}; + } else { + return {m_block, bidh_actual, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}; } - }; + } + + CUTLASS_DEVICE + SeqlenInfo_t + get_seqlen_info(WorkTileInfo const& work_tile) const { + // Extract bidb needed for SeqlenInfo + int bidb = work_tile.bidb; + + return SeqlenInfo_t { + bidb, + get<0>(params.shape_Q), + !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), + get<0>(params.shape_K_new), + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.seqlens_rotary + }; + } + CUTLASS_DEVICE - VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {}; + VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : work_info_smem(smem_scheduler), params(params) {}; CUTLASS_DEVICE WorkTileInfo - tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { + tile_idx_to_work_tile(int next_tile_idx, WorkTileInfo const& current_work) const { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); - if (seqlen > kBlock) { - if (params.seqused) { - seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead_divmod.divisor); + if (seqlen > get<0>(TileShape_MNK{})) { + if (params.seqused_q) { + seqlen = batch_idx < params.num_batch ? params.seqused_q[batch_idx] : 0; + } else if (params.cu_seqlens_q) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens_q[batch_idx] : 0; int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); seqlen = next_cu_seqlen - cur_cu_seqlen; } else { seqlen = params.seqlen; } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead_divmod.divisor; } } return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; + ? cute::ceil_div(seqlen, get<0>(TileShape_MNK{})) : 0; // ? params.num_m_blocks_ptr[batch_idx] : 0; }; @@ -544,16 +883,16 @@ class VarlenDynamicPersistentTileScheduler { template CUTLASS_DEVICE WorkTileInfo - get_initial_work(Params const& params) const { + get_initial_work() const { if constexpr (IsProducerWarp) { - WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0}); + WorkTileInfo work_info = tile_idx_to_work_tile(int(blockIdx.x), {0, 0, 0, 0}); if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { - return get_next_work(params, {0, 0, 0, 0}); + return get_next_work({0, 0, 0, 0}); } } @@ -565,7 +904,7 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE void - prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + prefetch_next_work(WorkTileInfo& current_work) const { if (threadIdx.x % NumProducerThreads == 0) { current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); } @@ -574,12 +913,12 @@ class VarlenDynamicPersistentTileScheduler { template CUTLASS_DEVICE WorkTileInfo - get_next_work(Params const& params, WorkTileInfo const& current_work) const { + get_next_work(WorkTileInfo const& current_work) const { if constexpr (IsProducerWarp) { // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0 int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; - work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); + work_info = tile_idx_to_work_tile(new_tile_idx, work_info); flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); From 57635db03a3c430962ca42985c54c744b5248655 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 14 Apr 2025 19:50:20 +0000 Subject: [PATCH 099/102] add `TileSchedulerCommon` Signed-off-by: Lucas Wilkinson --- hopper/flash_fwd_launch_template.h | 2 +- hopper/tile_scheduler.hpp | 389 +++++++++-------------------- 2 files changed, 122 insertions(+), 269 deletions(-) diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index d05a65cdf..f78b37143 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -65,7 +65,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> > >; - using SchedulerSingleTile = flash::SingleTileScheduler; + using SchedulerSingleTile = flash::SingleTileScheduler; // If Split then we probably don't have enough work for PersistentScheduler to be useful. // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better // since we'll avoid launching a bunch of thread blocks that immediately exit. diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 326952b7b..ad15c4a7f 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -66,21 +66,16 @@ struct BlockCoord: public BlockCoord { /////////////////////////////////////////////////////////////////////////////// -template -class SingleTileScheduler { - -public: - using SharedStorage = int; +template +struct TileSchedulerCommon { static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); using BlockMN_t = flash::BlockMN; - // Device side kernel params struct Params { int const num_blocks, num_head, num_batch, num_splits; cutlass::FastDivmod qhead_per_khead; int const seqlen; - cutlass::FastDivmod nsplits_divmod; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const cu_seqlens_k_new = nullptr; @@ -98,6 +93,11 @@ class SingleTileScheduler { int const window_size_right; }; + Params const& params; + + CUTLASS_DEVICE + TileSchedulerCommon(Params const& params) : params(params) { } + static Params to_underlying_arguments(TileSchedulerArguments const& args) { assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); @@ -105,7 +105,6 @@ class SingleTileScheduler { return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, cutlass::FastDivmod(args.qhead_per_khead), args.seqlen, - cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, args.shape_Q, args.shape_K, args.shape_K_new, @@ -114,13 +113,66 @@ class SingleTileScheduler { args.num_splits_dynamic_ptr, args.window_size_left, args.window_size_right}; } + CUTLASS_DEVICE + SeqlenInfo_t + create_seqlen_info(int bidb) { + return { + bidb, + get<0>(params.shape_Q), + !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), + get<0>(params.shape_K_new), + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.seqlens_rotary + }; + }; + + CUTLASS_DEVICE + BlockCoord + create_block_coord_split(SeqlenInfo_t const& seqlen_info, int m_block, int bidh, int bidb, int peer_id, int num_peers) { + // Calculate n_block_min/max based on causality and local window + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, peer_id, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + + if constexpr (AppendKV) { + auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, peer_id, num_peers, + params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, + get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}; + } else { + return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}; + } + }; +}; + +template +class SingleTileScheduler: public TileSchedulerCommon { + +public: + using SharedStorage = int; + using Super = TileSchedulerCommon; + + struct Params: public Super::Params { + cutlass::FastDivmod nsplits_divmod; + }; + + using Super::create_seqlen_info; + using Super::create_block_coord_split; + + Params const& params; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + return {Super::to_underlying_arguments(args), cutlass::FastDivmod(!Split ? 1 : args.num_splits)}; + } static dim3 get_grid_shape(Params const& params, int num_sm) { return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; } - Params const& params; struct WorkTileInfo { BlockCoord block_coord; SeqlenInfo_t seqlen_info; @@ -145,7 +197,7 @@ class SingleTileScheduler { } CUTLASS_DEVICE - SingleTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : params(params) { } + SingleTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : Super(params), params(params) { } template CUTLASS_DEVICE @@ -157,15 +209,7 @@ class SingleTileScheduler { int peer_id = 0; int num_peers = 1; - SeqlenInfo_t seqlen_info{ - bidb, - get<0>(params.shape_Q), - !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), - get<0>(params.shape_K_new), - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.seqlens_rotary - }; + SeqlenInfo_t seqlen_info = create_seqlen_info(bidb); bool is_valid_tile = true; if constexpr (Split) { @@ -177,7 +221,7 @@ class SingleTileScheduler { ? params.seqused_q[bidb] : (params.cu_seqlens_q ? params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb] : params.seqlen); if constexpr (PackGQA) { seqlen_q_ *= params.qhead_per_khead_divmod.divisor; } - is_valid_tile = m_block * kBlock < seqlen_q_; + is_valid_tile = m_block * size<0>(TileShape_MNK{}) < seqlen_q_; if constexpr (Split) { int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[bidb] : num_peers; @@ -185,28 +229,12 @@ class SingleTileScheduler { num_peers = num_splits_dynamic; } } - if (!is_valid_tile) { bidb = -1; } + if (!is_valid_tile) return {BlockCoord{}, seqlen_info}; - // Calculate n_block_min/max based on causality and local window - auto n_block_min_max = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, peer_id & 0xFFFF /* Get actual peer_id */, num_peers, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - - if constexpr (AppendKV) { - auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, peer_id & 0xFFFF /* Get actual peer_id */, num_peers, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - return { - {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, - get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}, - seqlen_info - }; - } else { - return { - {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}, - seqlen_info - }; - } + return { + create_block_coord_split(seqlen_info, m_block, bidh, bidb, peer_id, num_peers), + seqlen_info + }; } CUTLASS_DEVICE @@ -230,67 +258,44 @@ class SingleTileScheduler { /////////////////////////////////////////////////////////////////////////////// template -class StaticPersistentTileScheduler { +class StaticPersistentTileScheduler: public TileSchedulerCommon { public: - using SharedStorage = int; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - using BlockMN_t = flash::BlockMN; + using Super = TileSchedulerCommon; // Device side kernel params - struct Params { + struct Params: public Super::Params { int total_blocks; cutlass::FastDivmod m_block_divmod, head_divmod; cutlass::FastDivmod nsplits_divmod; - int num_splits; // Static number of splits - cutlass::FastDivmod qhead_per_khead_divmod; - ShapeQKV shape_Q; - ShapeQKV shape_K; - ShapeQKV shape_K_new; - int const* cu_seqlens_q = nullptr; // Assuming null for non-varlen static - int const* cu_seqlens_k = nullptr; - int const* cu_seqlens_k_new = nullptr; - int const* seqused_q = nullptr; - int const* seqused_k = nullptr; - int const* leftpad_k = nullptr; - int const* seqlens_rotary = nullptr; - int const* ptr_pagetable = nullptr; - int window_size_left; - int window_size_right; - // Params specific to this scheduler - int num_m_blocks; // Needed for BlockMN_t }; + Params const& params; + + using Super::create_seqlen_info; + using Super::create_block_coord_split; + static Params to_underlying_arguments(TileSchedulerArguments const& args) { int num_splits_val = !Split ? 1 : args.num_splits; - return {args.num_blocks * args.num_head * args.num_batch * num_splits_val, + return { Super::to_underlying_arguments(args), + args.num_blocks * args.num_head * args.num_batch * num_splits_val, cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * num_splits_val), - cutlass::FastDivmod(num_splits_val), - num_splits_val, - cutlass::FastDivmod(args.qhead_per_khead), - args.shape_Q, args.shape_K, args.shape_K_new, - args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, - args.ptr_pagetable, - args.window_size_left, args.window_size_right, - args.num_blocks + cutlass::FastDivmod(num_splits_val) }; } static dim3 get_grid_shape(Params const& params, int num_sm) { - return {uint32_t(num_sm)}; + // This scheduler assumes that the grid shape is fixed and known at compile time. + return {uint32_t(num_sm), 1u, 1u}; // Persistent kernel uses SM count as grid dim } struct WorkTileInfo { int tile_idx; }; - Params const& params; - CUTLASS_DEVICE bool is_valid(WorkTileInfo const& work_tile) const { @@ -306,56 +311,27 @@ class StaticPersistentTileScheduler { int num_peers = 1; if constexpr (Split) { num_peers = params.nsplits_divmod.divisor; - bidh = params.nsplits_divmod.divmod(peer_id, bidh); + if constexpr (Varlen) { + int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[bidb] : num_peers; + if (peer_id >= num_splits_dynamic) return {}; // Invalid tile for this split + num_peers = num_splits_dynamic; + } } - SeqlenInfo_t seqlen_info{ - bidb, - get<0>(params.shape_Q), - !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), - get<0>(params.shape_K_new), - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.seqlens_rotary - }; - - auto n_block_min_max = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, peer_id, num_peers, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - - if constexpr (AppendKV) { - auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, peer_id, num_peers, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, - get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}; - } else { - return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}; - } + SeqlenInfo_t seqlen_info = create_seqlen_info(bidb); + return create_block_coord_split(seqlen_info, m_block, bidh, bidb, peer_id, num_peers); } CUTLASS_DEVICE SeqlenInfo_t get_seqlen_info(WorkTileInfo const& work_tile) const { - // Recompute block coord parts needed for SeqlenInfo - int m_block, bidh, bidb; bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_tile.tile_idx)); - // No need for split info here as SeqlenInfo is per batch item (bidb) - - return SeqlenInfo_t { - bidb, - get<0>(params.shape_Q), - !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), - get<0>(params.shape_K_new), - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.seqlens_rotary - }; + return create_seqlen_info(bidb); } - CUTLASS_DEVICE - StaticPersistentTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : params(params) {}; + StaticPersistentTileScheduler(SharedStorage* const smem_scheduler, Params const& params) + : Super(params), params(params) {} template CUTLASS_DEVICE @@ -383,7 +359,7 @@ class StaticPersistentTileScheduler { }; template -class DynamicPersistentTileScheduler { +class DynamicPersistentTileScheduler: public TileSchedulerCommon { // This scheduler targets the causal (or local) case where each tile takes different // amount of time. We use longest-processing-time-first scheduling: @@ -395,51 +371,34 @@ class DynamicPersistentTileScheduler { // size of K & V and the L2 cache size. static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); - static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; + static constexpr int NumThreads = NumMmaThreads; // Assuming non-warp-specialized usage for now -public: - using SharedStorage = int; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - using BlockMN_t = flash::BlockMN; + using Super = TileSchedulerCommon; protected: SharedStorage* const tile_count_smem; public: - // Device side kernel params - struct Params { + struct Params : public Super::Params { int const total_blocks; cutlass::FastDivmod const m_block_divmod, head_divmod; cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; cutlass::FastDivmod const l2_minor_residual_divmod; int const num_hb_quotient; int* const tile_count_semaphore; - int const num_splits; // Static number of splits - cutlass::FastDivmod qhead_per_khead_divmod; - ShapeQKV shape_Q; - ShapeQKV shape_K; - ShapeQKV shape_K_new; - int const* cu_seqlens_q = nullptr; // Assuming null for non-varlen dynamic - int const* cu_seqlens_k = nullptr; - int const* cu_seqlens_k_new = nullptr; - int const* seqused_q = nullptr; - int const* seqused_k = nullptr; - int const* leftpad_k = nullptr; - int const* seqlens_rotary = nullptr; - int const* ptr_pagetable = nullptr; - int window_size_left; - int window_size_right; // Params needed for L2 swizzling calculation int const seqlen_k; int const headdim; int const headdim_v; int const element_size; - // Num M blocks - int num_m_blocks; }; + Params const& params; + + using Super::create_seqlen_info; + using Super::create_block_coord_split; + static Params to_underlying_arguments(TileSchedulerArguments const& args) { int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size * 2; @@ -455,22 +414,15 @@ class DynamicPersistentTileScheduler { int const num_splits_val = !Split ? 1 : args.num_splits; // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder); assert(args.tile_count_semaphore != nullptr); - return {num_split_blocks * args.num_head * args.num_batch, + return {Super::to_underlying_arguments(args), + num_split_blocks * args.num_head * args.num_batch, cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks), // don't divide by 0 cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), (args.num_head * args.num_batch) / swizzle, args.tile_count_semaphore, - num_splits_val, - cutlass::FastDivmod(args.qhead_per_khead), - args.shape_Q, args.shape_K, args.shape_K_new, - args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, - args.ptr_pagetable, - args.window_size_left, args.window_size_right, - args.seqlen_k, args.headdim, args.headdim_v, args.element_size, - args.num_blocks + args.seqlen_k, args.headdim, args.headdim_v, args.element_size }; } @@ -483,8 +435,6 @@ class DynamicPersistentTileScheduler { int tile_idx; }; - Params const& params; - CUTLASS_DEVICE bool is_valid(WorkTileInfo const& work_tile) const { @@ -514,35 +464,13 @@ class DynamicPersistentTileScheduler { // Longest-processing-time-first means we process m_blocks in reverse order m_block = params.m_block_divmod.divisor - 1 - m_block; - SeqlenInfo_t seqlen_info{ - bidb, - get<0>(params.shape_Q), - !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), - get<0>(params.shape_K_new), - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.seqlens_rotary - }; - - auto n_block_min_max = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, peer_id, num_peers, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - - if constexpr (AppendKV) { - auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, peer_id, num_peers, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, - get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}; - } else { - return {m_block, bidh, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}; - } + return create_block_coord_split(seqlen_info, m_block, bidh, bidb, peer_id, num_peers); } + CUTLASS_DEVICE SeqlenInfo_t get_seqlen_info(WorkTileInfo const& work_tile) const { - // Recompute block coord parts needed for SeqlenInfo int m_block, bidh, bidb; int l2_mod, bidhb, bidhb_residual; bidhb = params.l2_major_divmod.divmod(l2_mod, work_tile.tile_idx); @@ -553,20 +481,14 @@ class DynamicPersistentTileScheduler { } bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); - return SeqlenInfo_t{ - bidb, - get<0>(params.shape_Q), - !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), - get<0>(params.shape_K_new), - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.seqlens_rotary - }; + return create_seqlen_info(bidb); } - CUTLASS_DEVICE - DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : tile_count_smem(smem_scheduler), params(params) {}; + DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler, Params const& params) + : Super(params), params(params), tile_count_smem(smem_scheduler) { + } + template CUTLASS_DEVICE @@ -619,51 +541,24 @@ template -class VarlenDynamicPersistentTileScheduler { +class VarlenDynamicPersistentTileScheduler: public TileSchedulerCommon { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; -public: - using SharedStorage = int4; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - using BlockMN_t = flash::BlockMN; - protected: SharedStorage* const work_info_smem; + using Super = TileSchedulerCommon; + public: // Device side kernel params - struct Params { - int num_head, num_batch; - cutlass::FastDivmod qhead_per_khead_divmod; - int const seqlen; // Max seqlen + struct Params: public Super::Params { cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; // Static num splits divisor - int const num_splits; // Static num splits int* const tile_count_semaphore; - // Sequence length info (needed for tile_idx_to_work_tile and SeqlenInfo) - int const* const cu_seqlens_q = nullptr; - int const* const cu_seqlens_k = nullptr; - int const* const cu_seqlens_k_new = nullptr; - int const* const seqused_q = nullptr; - int const* const seqused_k = nullptr; - int const* const leftpad_k = nullptr; - int const* const seqlens_rotary = nullptr; - // Shape info (needed for SeqlenInfo) - ShapeQKV const shape_Q; - ShapeQKV const shape_K; - ShapeQKV const shape_K_new; - // Paged attention table (needed for SeqlenInfo) - int const* const ptr_pagetable = nullptr; - ShapePageTable const shape_pagetable; - // Dynamic splits for Varlen - int const* const num_splits_dynamic_ptr = nullptr; - // Window sizes for local/causal attention (needed for BlockMN_t) - int const window_size_left; - int const window_size_right; + }; static Params @@ -674,20 +569,10 @@ class VarlenDynamicPersistentTileScheduler { assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits int const num_splits_val = !Split ? 1 : args.num_splits; - return {args.num_head, args.num_batch, - cutlass::FastDivmod(args.qhead_per_khead), - args.seqlen, + return {Super::to_underlying_arguments(args), cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(num_splits_val), - num_splits_val, - args.tile_count_semaphore, - args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary, - args.shape_Q, args.shape_K, args.shape_K_new, - args.ptr_pagetable, - args.shape_pagetable, - args.num_splits_dynamic_ptr, - args.window_size_left, args.window_size_right + args.tile_count_semaphore }; } @@ -733,46 +618,14 @@ class VarlenDynamicPersistentTileScheduler { num_peers = reinterpret_cast(num_peers_u); } - SeqlenInfo_t seqlen_info { - bidb, - get<0>(params.shape_Q), - !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), - get<0>(params.shape_K_new), - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.seqlens_rotary - }; - - auto n_block_min_max = BlockMN_t::get_n_block_min_max( - seqlen_info, m_block, bidb, peer_id & 0xFFFF /* actual peer id*/, num_peers, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - - if constexpr (AppendKV) { - auto n_block_min_max_new = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, peer_id & 0xFFFF /* actual peer id*/, num_peers, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); - return {m_block, bidh_actual, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers, - get<0>(n_block_min_max_new), get<1>(n_block_min_max_new)}; - } else { - return {m_block, bidh_actual, bidb, get<0>(n_block_min_max), get<1>(n_block_min_max), peer_id, num_peers}; - } + SeqlenInfo_t seqlen_info = create_seqlen_info(bidb); + return Super::create_block_coord(seqlen_info, m_block, bidh_actual, bidb, peer_id, num_peers); } CUTLASS_DEVICE SeqlenInfo_t get_seqlen_info(WorkTileInfo const& work_tile) const { - // Extract bidb needed for SeqlenInfo - int bidb = work_tile.bidb; - - return SeqlenInfo_t { - bidb, - get<0>(params.shape_Q), - !params.ptr_pagetable ? size<0>(params.shape_K) : size<0>(params.shape_K) * size<1>(params.shape_pagetable), - get<0>(params.shape_K_new), - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_k_new, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.seqlens_rotary - }; + return create_seqlen_info(work_tile.bidb); } From 5e026f6a2df71cb2ebb342929f50448c094a58af Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 15 Apr 2025 03:33:23 +0000 Subject: [PATCH 100/102] build Signed-off-by: Lucas Wilkinson --- hopper/streamk.h | 22 ++++++ hopper/tile_scheduler.hpp | 152 +++++++++++++++++++++++++++++++++++--- 2 files changed, 165 insertions(+), 9 deletions(-) create mode 100644 hopper/streamk.h diff --git a/hopper/streamk.h b/hopper/streamk.h new file mode 100644 index 000000000..9fd75410b --- /dev/null +++ b/hopper/streamk.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +#include "cute/container/alignment.hpp" + +struct CUTE_ALIGNAS(16) StreamKWorkTile { + int const m_block = -1; + int const n_block_start = 0; + uint16_t const n_blocks = 0; // Max n_blocks per tile is 65535 + uint16_t const bidb = 0; // Max batch size is 65535 + uint16_t const bidh = 0; // Max num heads is 65535 + uint8_t const peer_id = 0; // Max 255 peers + uint8_t const num_peers = 0; // Max 255 peers +}; + +struct CUTE_ALIGNAS(8) StreamKCombineTile { + int const m_block; + uint16_t const bidh; + uint16_t const bidb; + uint16_t const num_peers; // Number of peers / num_splits +}; diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index ad15c4a7f..3de3e4b5e 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -11,6 +11,7 @@ #include "block.h" #include "utils.h" #include "seqlen.h" +#include "streamk.h" namespace flash { @@ -42,6 +43,8 @@ struct TileSchedulerArguments { int const* const num_splits_dynamic_ptr = nullptr; int const window_size_left = -1; int const window_size_right = 0; + int const* const sm_work_tile_ind_ptr = nullptr; + StreamKWorkTile const* const work_tiles_ptr = nullptr; }; template @@ -74,7 +77,7 @@ struct TileSchedulerCommon { struct Params { int const num_blocks, num_head, num_batch, num_splits; - cutlass::FastDivmod qhead_per_khead; + cutlass::FastDivmod qhead_per_khead_divmod; int const seqlen; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; @@ -115,7 +118,7 @@ struct TileSchedulerCommon { } CUTLASS_DEVICE SeqlenInfo_t - create_seqlen_info(int bidb) { + create_seqlen_info(int bidb) const { return { bidb, get<0>(params.shape_Q), @@ -129,7 +132,7 @@ struct TileSchedulerCommon { CUTLASS_DEVICE BlockCoord - create_block_coord_split(SeqlenInfo_t const& seqlen_info, int m_block, int bidh, int bidb, int peer_id, int num_peers) { + create_block_coord_split(SeqlenInfo_t const& seqlen_info, int m_block, int bidh, int bidb, int peer_id, int num_peers) const { // Calculate n_block_min/max based on causality and local window auto n_block_min_max = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, peer_id, num_peers, @@ -325,7 +328,8 @@ class StaticPersistentTileScheduler: public TileSchedulerCommon; protected: @@ -463,6 +468,7 @@ class DynamicPersistentTileScheduler: public TileSchedulerCommon; + using SharedStorage = int4; + using Super::create_seqlen_info; + using Super::create_block_coord_split; -public: // Device side kernel params struct Params: public Super::Params { @@ -585,8 +592,11 @@ class VarlenDynamicPersistentTileScheduler: public TileSchedulerCommon +class StreamKPersistentTileScheduler: public TileSchedulerCommon { + + static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); + static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; + +public: + using SharedStorage = int; + using Super = TileSchedulerCommon; + + using Super::create_seqlen_info; + using Super::create_block_coord_split; + + // Device side kernel params + struct Params: public Super::Params { + StreamKWorkTile const* const work_tiles_ptr; + int const* const sm_work_tile_ind_ptr; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + // If Split, for the purpose of scheduling, we pretend that instead there are + // (args.num_splits * args.num_head) number of heads. + assert(args.tile_count_semaphore != nullptr); + assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx + assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + int const num_splits_val = !Split ? 1 : args.num_splits; + return {Super::to_underlying_arguments(args), + args.work_tiles_ptr, + args.sm_work_tile_ind_ptr + }; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + StreamKWorkTile work_tile; + }; + + Params const& params; + int const sm_idx; + int const work_tiles_start_idx; + int const work_tiles_end_idx; + + CUTLASS_DEVICE + bool + is_valid(WorkTileInfo const& work_tile) const { + // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, work_tile.bidb, params.num_batch); } + return work_tile.tile_idx < work_tiles_end_idx; + } + + CUTLASS_DEVICE + BlockCoord + get_block_coord(WorkTileInfo const& work_tile) const { + SeqlenInfo_t seqlen_info = create_seqlen_info(work_tile.bidb); + if constexpr(AppendKV) { + return { + seqlen_info, + work_tile.work_tile.m_block, + work_tile.work_tile.bidh, + work_tile.work_tile.bidb, + work_tile.work_tile.peer_id, + work_tile.work_tile.num_peers, + 0, 0 // TODO support appendKV + }; + } else { + return { + seqlen_info, + work_tile.work_tile.m_block, + work_tile.work_tile.bidh, + work_tile.work_tile.bidb, + work_tile.work_tile.peer_id, + work_tile.work_tile.num_peers + }; + } + } + + CUTLASS_DEVICE + SeqlenInfo_t + get_seqlen_info(WorkTileInfo const& work_tile) const { + return create_seqlen_info(work_tile.bidb); + } + + + CUTLASS_DEVICE + StreamKPersistentTileScheduler(SharedStorage* const smem_scheduler, Params const& params) : Super(params), params(params), + sm_idx(blockIdx.x), work_tiles_start_idx(params.sm_work_tile_ind_ptr[blockIdx.x]), work_tiles_end_idx(params.sm_work_tile_ind_ptr[blockIdx.x + 1]) {}; + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {work_tiles_start_idx, params.work_tiles_ptr[work_tiles_start_idx]}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(WorkTileInfo const& current_work) const { + if (current_work.tile_idx + 1 < work_tiles_end_idx) { + return {current_work.tile_idx + 1, params.work_tiles_ptr[current_work.tile_idx + 1]}; + } else { + return {work_tiles_start_idx + 1, StreamKWorkTile{}}; + } + } + +}; + } // flash From 626a8718afb3e2006962a2e10cf7b66a48d5a111 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 18 Apr 2025 13:38:03 +0000 Subject: [PATCH 101/102] building Signed-off-by: Lucas Wilkinson --- hopper/flash.h | 10 + hopper/flash_api.cpp | 89 ++++-- hopper/flash_api_torch_lib.cpp | 4 + hopper/flash_fwd_combine_kernel.h | 40 ++- hopper/flash_fwd_combine_launch_template.h | 43 +-- hopper/flash_fwd_launch_template.h | 21 +- hopper/flash_prepare_scheduler.cu | 1 + hopper/streamk.h | 322 ++++++++++++++++++++- hopper/tile_scheduler.hpp | 11 +- tests/test_vllm_flash_attn.py | 9 +- vllm_flash_attn/flash_attn_interface.py | 6 +- 11 files changed, 484 insertions(+), 72 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index 91fb5c812..e88a6a7bf 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -7,6 +7,8 @@ #include #include +#include "streamk.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { @@ -158,6 +160,14 @@ struct Flash_fwd_params : public Qkv_params { int arch; int num_sm; + + // Streamk stuff + int * __restrict__ sm_work_tile_ind_ptr = nullptr; + StreamKWorkTile * __restrict__ work_tiles_ptr = nullptr; + StreamKCombineTile * __restrict__ combine_tiles_ptr = nullptr; + int num_combine_blocks = 0; + int streamk_m_block_size = 0; + bool use_one_mma_wg = false; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index acd84f48c..1a1c33a5d 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -15,6 +15,7 @@ #include "tile_size.h" #include "heuristics.h" #include "cuda_check.h" +#include "streamk.h" // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909 // This is so that we can pass in torch.dtype as a parameter to the function. @@ -503,7 +504,7 @@ inline int round_up_headdimv(int head_size) { } // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available -at::Tensor +std::tuple mha_fwd_get_scheduler_metadata( int batch_size, int max_seqlen_q, @@ -582,11 +583,15 @@ mha_fwd_get_scheduler_metadata( params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - bool const use_dynamic_split = params.b <= 992; - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + bool const use_stream_k = params.b <= 992; + params.num_splits_dynamic_ptr = !use_stream_k ? nullptr : reinterpret_cast(1); + + assert(use_stream_k && (num_splits <= 0 || num_splits == 2)); + num_splits = use_stream_k ? 2 : num_splits; params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); @@ -598,29 +603,43 @@ mha_fwd_get_scheduler_metadata( auto opts = seqused_k.options(); // This needs to be set after get_num_splits - at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic - bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; - if (scheduler_needs_semaphore || use_dynamic_split) { - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); - if (scheduler_needs_semaphore) { - if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); - } else { - params.tile_count_semaphore = nullptr; - } - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; - } + at::Tensor device_metadata; // Contains the semaphore and optionally num_splits_dynamic + at::Tensor host_metadata; - if (params.num_splits_dynamic_ptr) { - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); - auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); - int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); - int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); - CHECK_CUDA_KERNEL_LAUNCH(); + bool const scheduler_needs_semaphore = (params.arch >= 90 || params.num_splits > 1) && !use_stream_k; + + if (scheduler_needs_semaphore) { + device_metadata = torch::empty({1}, opts.dtype(torch::kInt32)); + } else { + std::tie(device_metadata, host_metadata) = streamk_schedule( + params.arch, params.num_sm, params.b, cu_seqlens_q_, seqused_k, params.seqlen_q, params.seqlen_k, + params.h, params.h_k, params.d, params.dv, params.is_causal, params.is_local, + params.is_e4m3 ? 1 : 2, false /*v_colmajor*/, true /*pagedkv*/, params.pagedkv_tma, + params.softcap, params.seqlen_knew > 0 + ); } - return tile_count_semaphore; + + // if (scheduler_needs_semaphore || use_stream_k) { + // tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_stream_k) * params.b}, opts.dtype(torch::kInt32)); + // if (scheduler_needs_semaphore) { + // if (!use_stream_k) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + // } else { + // params.tile_count_semaphore = nullptr; + // } + // params.num_splits_dynamic_ptr = use_stream_k ? tile_count_semaphore.data_ptr() + 1 : nullptr; + // } + + // if (params.num_splits_dynamic_ptr) { + // auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + // auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); + // int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + // int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + // auto stream = at::cuda::getCurrentCUDAStream().stream(); + // prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + // CHECK_CUDA_KERNEL_LAUNCH(); + // } + return {device_metadata, host_metadata}; } // b: batch_size @@ -663,6 +682,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional &scheduler_metadata_, // (b + 1) + std::optional &device_scheduler_metadata_, + std::optional &host_scheduler_metadata_, int num_splits, std::optional pack_gqa_, int const sm_margin @@ -943,6 +964,26 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + // Assume streamk scheduling + if (host_scheduler_metadata_) { + auto host_metadata_ptr = reinterpret_cast(host_scheduler_metadata_->data_ptr()); + params.num_splits = host_metadata_ptr->max_num_peers; + params.num_combine_blocks = host_metadata_ptr->num_combine_blocks; + params.pack_gqa = host_metadata_ptr->pack_gqa; + params.use_one_mma_wg = host_metadata_ptr->use_one_mma_wg; + + int num_work_tiles = host_metadata_ptr->num_work_tiles; + assert(device_scheduler_metadata_.has_value()); + auto device_metadata_ptr = device_scheduler_metadata_->data_ptr(); + + auto [offsets, total_size] = get_device_metadata_offsets_and_size( + params.num_sm, num_work_tiles, params.num_combine_blocks); + + params.work_tiles_ptr = reinterpret_cast(device_metadata_ptr); + params.sm_work_tile_ind_ptr = reinterpret_cast(device_metadata_ptr + offsets.work_tiles_ind_ptr_offset); + params.combine_tiles_ptr = reinterpret_cast(device_metadata_ptr + offsets.combine_tiles_offset); + } + // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic // We don't use the persistent scheduler if Split and not Varlen diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index a2006f3c4..775d05608 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -49,6 +49,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional &scheduler_metadata_, // (b + 1) + std::optional &device_scheduler_metadata_, + std::optional &host_scheduler_metadata_, int num_splits, std::optional pack_gqa_, int const sm_margin @@ -116,6 +118,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " float softcap," " bool is_rotary_interleaved," " Tensor? scheduler_metadata," + " Tensor? device_scheduler_metadata," + " Tensor? host_scheduler_metadata," " int num_splits," " bool? pack_gqa," " int sm_margin) -> Tensor[]"); diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index a22e05969..8fa677727 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -15,6 +15,7 @@ #include "cutlass/arch/grid_dependency_control.h" #include "seqlen.h" +#include "streamk.h" #include "utils.h" namespace flash { @@ -22,7 +23,7 @@ namespace flash { using namespace cute; template + bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_, bool StreamK = false> class FlashAttnFwdCombine { public: @@ -146,6 +147,9 @@ class FlashAttnFwdCombine { int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; int* const semaphore_to_reset = nullptr; + + StreamKCombineTile const* const streamk_combine_tiles_ptr = nullptr; + int streamk_m_block_size = 0; }; // Kernel entry point API @@ -165,6 +169,9 @@ class FlashAttnFwdCombine { int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; int* const semaphore_to_reset = nullptr; + + StreamKCombineTile const* const streamk_combine_tiles_ptr = nullptr; + int streamk_m_block_size = 0; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -187,7 +194,9 @@ class FlashAttnFwdCombine { args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, - args.semaphore_to_reset + args.semaphore_to_reset, + args.streamk_combine_tiles_ptr, + args.streamk_m_block_size }; } @@ -201,10 +210,16 @@ class FlashAttnFwdCombine { Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); int const thread_idx = threadIdx.x; - int const m_block = blockIdx.x; + int m_block = blockIdx.x; int const k_block = blockIdx.y; - int const batch = blockIdx.z; - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + int batch = blockIdx.z; + int num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + + if constexpr(StreamK) { + auto tile = params.streamk_combine_tiles_ptr[blockIdx.x]; + batch = tile.bidb; + num_splits = tile.num_peers; + } if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { cutlass::arch::wait_on_dependent_grids(); @@ -214,11 +229,26 @@ class FlashAttnFwdCombine { flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; + int max_idx = seqlen * get<2>(params.shape_LSE_partial); if constexpr (Varlen) { if (m_block * kBlockM >= max_idx) { return; } } + // TODO(Lucas) This is very hacky, we should really write this kernel in a way that that more naturally + // supports StreamK + if constexpr(StreamK) { + auto tile = params.streamk_combine_tiles_ptr[blockIdx.x]; + // m_block is encdoded as (seqlen, num_head): (1, seqlen), + // i.e. ind2crd(m_block, (seqlen, num_head)) -> (mi, head) + int mi = tile.m_block * params.streamk_m_block_size + blockIdx.z * kBlockM; + int bidh = tile.bidh; + m_block = mi + bidh * seqlen; + // dont allow bleeding across heads since in streamk different heads can have a different number + // of peers + max_idx = seqlen + bidh * seqlen; + } + cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); // Step 1: load LSE_partial from gmem -> smem diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 11d422924..f95ee54c9 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -17,12 +17,12 @@ using namespace cute; -template +template void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; using TileShape_MK = cute::Shape, Int>; using CombineKernel = flash::FlashAttnFwdCombine; + IsEvenK, Varlen, Element, ElementPartial, ArchTag, StreamK>; typename CombineKernel::Arguments args { static_cast(params.oaccum_ptr), @@ -35,13 +35,20 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore, + params.combine_tiles_ptr, params.streamk_m_block_size }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); int num_blocks_k = cute::ceil_div(params.dv, kBlockK); int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM); dim3 grid_m(num_blocks_m, num_blocks_k, params.b); + + if constexpr(StreamK) { + grid_m.x = params.num_combine_blocks; + grid_m.z = cute::ceil_div(params.streamk_m_block_size, kBlockM); + } + auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { @@ -60,21 +67,23 @@ void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool en static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); ARCH_SWITCH(params.arch, Arch, [&] { BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { - if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. - if (params.num_splits <= 16) { - run_flash_fwd_combine(params, stream, enable_pdl); - return; + BOOL_SWITCH(params.combine_tiles_ptr, StreamK, [&] { + if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. + if (params.num_splits <= 16) { + run_flash_fwd_combine(params, stream, enable_pdl); + return; + } + } + if (params.num_splits <= 32) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 64) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 128) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else { + run_flash_fwd_combine(params, stream, enable_pdl); } - } - if (params.num_splits <= 32) { - run_flash_fwd_combine(params, stream, enable_pdl); - } else if (params.num_splits <= 64) { - run_flash_fwd_combine(params, stream, enable_pdl); - } else if (params.num_splits <= 128) { - run_flash_fwd_combine(params, stream, enable_pdl); - } else { - run_flash_fwd_combine(params, stream, enable_pdl); - } + }); }); }); } diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index f78b37143..c5ad405eb 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -26,7 +26,7 @@ using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor, bool Use_one_mma_wg, bool StreamK> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -58,11 +58,14 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; - using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, - std::conditional_t, - flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> + using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, + std::conditional_t= 90 /*WarpSpecialized*/>, + std::conditional_t, + flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> + > > >; using SchedulerSingleTile = flash::SingleTileScheduler; @@ -210,7 +213,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { - BOOL_SWITCH(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k) <= 64, Use_one_mma_wg, [&] { + BOOL_SWITCH(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k) <= 64 || params.use_one_mma_wg, Use_one_mma_wg, [&] { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128; @@ -221,7 +224,9 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + BOOL_SWITCH(params.sm_work_tile_ind_ptr, StreamK, [&] { + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 7093fff32..2725ce230 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -9,6 +9,7 @@ #include "cutlass/arch/grid_dependency_control.h" #include "flash.h" +#include "streamk.h" namespace flash { diff --git a/hopper/streamk.h b/hopper/streamk.h index 9fd75410b..be8ca8a2a 100644 --- a/hopper/streamk.h +++ b/hopper/streamk.h @@ -1,17 +1,28 @@ #pragma once #include +#include #include "cute/container/alignment.hpp" +#include "flash.h" +#include "tile_size.h" + +struct CUTE_ALIGNAS(16) StreamKSchedulerDescisions { + int num_combine_blocks; + int num_work_tiles; + int max_num_peers; + bool pack_gqa; + bool use_one_mma_wg; +}; struct CUTE_ALIGNAS(16) StreamKWorkTile { - int const m_block = -1; - int const n_block_start = 0; - uint16_t const n_blocks = 0; // Max n_blocks per tile is 65535 - uint16_t const bidb = 0; // Max batch size is 65535 - uint16_t const bidh = 0; // Max num heads is 65535 - uint8_t const peer_id = 0; // Max 255 peers - uint8_t const num_peers = 0; // Max 255 peers + int m_block = -1; + int n_block_start = 0; + uint16_t n_blocks = 0; // Max n_blocks per tile is 65535 + uint16_t bidb = 0; // Max batch size is 65535 + uint16_t bidh = 0; // Max num heads is 65535 + uint8_t peer_id = 0; // Max 255 peers + uint8_t num_peers = 0; // Max 255 peers }; struct CUTE_ALIGNAS(8) StreamKCombineTile { @@ -20,3 +31,300 @@ struct CUTE_ALIGNAS(8) StreamKCombineTile { uint16_t const bidb; uint16_t const num_peers; // Number of peers / num_splits }; + +struct StreamKMetadataByteOffsets { int const work_tiles_offset; + int const combine_tiles_offset; + int const work_tiles_ind_ptr_offset; +}; + +inline std::tuple get_device_metadata_offsets_and_size( + int num_sms, + int num_work_tiles, + int num_combine_tiles +) { + auto round_up_to_16 = [](int size) { + return cutlass::round_up(size, 16); + }; + + int work_tiles_offset = 0; + int combine_tiles_offset = work_tiles_offset + round_up_to_16(num_work_tiles * sizeof(StreamKWorkTile)); + int work_tiles_ind_ptr_offset = combine_tiles_offset + round_up_to_16(num_combine_tiles * sizeof(StreamKCombineTile)); + int total_size = work_tiles_ind_ptr_offset + round_up_to_16(num_sms * sizeof(int)); + + StreamKMetadataByteOffsets metadata_offsets{ + work_tiles_offset, + combine_tiles_offset, + work_tiles_ind_ptr_offset + }; + + return std::make_tuple(metadata_offsets, total_size); +} + +inline std::tuple streamk_schedule( + int arch, + int num_sms, + int batch_size, + std::optional &cu_seqlens_q, + const at::Tensor &seqused_k, + int seqlen_q, + int seqlen_k, + int num_heads, + int num_heads_k, + int headdim, + int headdim_v, + bool is_causal, + bool is_local, + int element_size, + bool v_colmajor, + bool paged_kv, + bool paged_kv_non_TMA, + bool softcap, + bool append_kv +) { + assert (is_local == false && "StreamK + Local attention not supported yet"); + + std::optional cu_seqlens_q_cpu; + if (cu_seqlens_q) { + cu_seqlens_q_cpu.emplace(cu_seqlens_q->cpu()); + } + auto seqused_k_cpu = seqused_k.cpu(); + + auto get_tile_sizes = [&](bool use_one_mma_wg) -> std::tuple { + if (arch == 90) { + auto ts = tile_size_fwd_sm90(headdim, headdim_v, is_causal, is_local, element_size, v_colmajor, paged_kv_non_TMA, softcap); + return std::make_tuple(std::get<0>(ts), std::get<1>(ts)); + } else if (arch < 90) { + auto ts = tile_size_fwd_sm8x(headdim, headdim_v, is_causal, is_local, element_size, paged_kv, /* varlen_and_split */ true, softcap, append_kv); + return std::make_tuple(std::get<0>(ts), std::get<1>(ts)); + } else { + assert(false && "Unsupported architecture"); + return std::make_tuple(0, 0); + } + }; + + auto get_seqlen_k = [&](int bidb) { + return seqused_k_cpu.accessor()[bidb]; + }; + + + auto get_seqlen_q = [&](int bidb) { + if (cu_seqlens_q_cpu.has_value()) { + auto cu_seqlens_q = cu_seqlens_q_cpu.value().accessor(); + return cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb]; + } else { + return seqlen_q; + } + }; + + auto get_num_n_tiles = [&](int m_tile, int seqlen_k, int seqlen_q, int block_m, int block_n) { + int m_tile_k_end = std::max( + (seqlen_k - seqlen_q) + (m_tile + 1) * block_m, + seqlen_k + ); + return cutlass::ceil_div(m_tile_k_end, block_n); + }; + + auto tile_sizes = get_tile_sizes(false); + auto tile_sizes_one_mma_wg = get_tile_sizes(true); + + auto compute_tiles = [&, get_num_n_tiles = get_num_n_tiles]( + int num_heads, + int num_heads_k, + int seqlen_q, + int seqlen_k, + bool causal, + bool pack_gqa, + bool one_mma_wg + ) { + int block_m = one_mma_wg ? std::get<0>(tile_sizes_one_mma_wg) : std::get<0>(tile_sizes); + int block_n = one_mma_wg ? std::get<1>(tile_sizes_one_mma_wg) : std::get<1>(tile_sizes); + + seqlen_q *= pack_gqa ? num_heads / num_heads_k : 1; + num_heads = pack_gqa ? num_heads_k : num_heads; + int m_tiles = cutlass::ceil_div(seqlen_q, block_m); + int tiles_total = 0; + if (causal) { + tiles_total += m_tiles * cutlass::ceil_div(seqlen_k, block_n); + } else { + int block_m = one_mma_wg ? std::get<0>(tile_sizes_one_mma_wg) : std::get<0>(tile_sizes); + int block_n = one_mma_wg ? std::get<1>(tile_sizes_one_mma_wg) : std::get<1>(tile_sizes); + + for (int m_tile = 0; m_tile < m_tiles; m_tile++) { + tiles_total += get_num_n_tiles( + m_tile, seqlen_k, seqlen_q, block_m, block_n); + } + } + + return tiles_total * num_heads; + }; + + bool pack_gqa = false; + + // Determine if we should pack GQA by determining the the amount of + // available work that would benefit from packing GQA + // Assume not `use_one_mma_wg` for now, we determine this later + if (num_heads > num_heads_k) { + assert (num_heads % num_heads_k == 0); + + int total_tiles_pack_gqa = 0; + int total_tiles_no_pack_gqa = 0; + + for (int bidb = 0; bidb < batch_size; ++bidb) { + int seqlen_k = get_seqlen_k(bidb); + int seqlen_q = get_seqlen_q(bidb); + + total_tiles_pack_gqa += compute_tiles( + num_heads, num_heads_k, seqlen_q, seqlen_k, is_causal, true, false); + total_tiles_no_pack_gqa += compute_tiles( + num_heads, num_heads_k, seqlen_q, seqlen_k, is_causal, false, false); + } + + if (total_tiles_pack_gqa < (total_tiles_no_pack_gqa * 1.1f)) { + pack_gqa = true; + } + } + + bool use_one_mma_wg = false; + // Determine the amount of work that would benefit from using one MMA + // workgroup + + int total_tiles = 0; + int total_tiles_one_mma_wg = 0; + + for (int bidb = 0; bidb < batch_size; ++bidb) { + int seqlen_k = get_seqlen_k(bidb); + int seqlen_q = get_seqlen_q(bidb); + + total_tiles += compute_tiles( + num_heads, num_heads_k, seqlen_q, seqlen_k, is_causal, pack_gqa, false); + total_tiles_one_mma_wg += compute_tiles( + num_heads, num_heads_k, seqlen_q, seqlen_k, is_causal, pack_gqa, true); + } + + // if using one_mma_wg only increases the number of tiles by 50% or less + // then we should use it (since it performs each tile computes 1/2 as much) + // TODO(lucas): 50% is a guesstimate, we should do a more thorough analysis + if (total_tiles_one_mma_wg < (total_tiles * 1.5f)) { + use_one_mma_wg = true; + } + + tile_sizes = use_one_mma_wg ? tile_sizes_one_mma_wg : tile_sizes; + total_tiles = use_one_mma_wg ? total_tiles_one_mma_wg : total_tiles; + + int block_m = std::get<0>(tile_sizes); + int block_n = std::get<1>(tile_sizes); + + int target_tiles_per_sm = cutlass::ceil_div(total_tiles, num_sms); + + std::vector work_tiles; + work_tiles.reserve(1024); + std::vector work_tiles_ind_ptr; + work_tiles_ind_ptr.reserve(num_sms + 1); + work_tiles_ind_ptr.push_back(0); + std::vector combine_tiles; + combine_tiles.reserve(1024); + + int min_tiles = 2; + int max_num_peers = 0; + + int current_tile = 0; + int sm_target_tiles_remaining = target_tiles_per_sm; + int num_combine_tiles = 0; + + for (int bidb = 0; bidb < batch_size; ++bidb) { + for (int bidh = 0; bidh < num_heads; ++bidh) { + int seqlen_k = get_seqlen_k(bidb); + int seqlen_q = get_seqlen_q(bidb); + int m_tiles = cutlass::ceil_div(seqlen_q, block_m); + + for (int m_tile = 0; m_tile < m_tiles; m_tile++) { + int n_tiles = get_num_n_tiles( + m_tile, seqlen_k, seqlen_q, block_m, block_n); + + int m_tile_start_idx = current_tile; + int curr_n_tile_start = 0; + int curr_n_tiles_remaining = n_tiles; + int num_peers = 0; + + while (curr_n_tiles_remaining > 0) { + int n_tile = std::min(curr_n_tiles_remaining, sm_target_tiles_remaining); + + // if we would leave a residual tile that is less than the minimum tiles + // then we should just take the rest of the tiles + if (curr_n_tiles_remaining - n_tile < min_tiles) { + n_tile = curr_n_tiles_remaining; + } + + curr_n_tiles_remaining -= n_tile; + sm_target_tiles_remaining -= n_tile; + + work_tiles.emplace_back(StreamKWorkTile{ + /* m_block: */ m_tile, + /* n_block_start: */ curr_n_tile_start, + /* n_blocks: */ uint16_t(n_tile), + /* bidb: */ uint16_t(bidb), + /* bidh: */ uint16_t(bidh), + /* peer_id: */ uint8_t(num_peers), + /* num_peers: */ 0 + }); + + current_tile += 1; + num_peers += 1; + curr_n_tile_start += n_tile; + + if (sm_target_tiles_remaining <= 0) { + work_tiles_ind_ptr.push_back(current_tile); + sm_target_tiles_remaining = target_tiles_per_sm; + } + } + + if (num_peers > 1) { + combine_tiles.emplace_back(StreamKCombineTile{ + /* m_block: */ m_tile, + /* bidh: */ uint16_t(bidh), + /* bidb: */ uint16_t(bidb), + /* num_peers: */ uint16_t(num_peers) + }); + } + + if (num_peers > max_num_peers) { + max_num_peers = num_peers; + } + + for (int i = m_tile_start_idx; i < current_tile; ++i) { + work_tiles[i].num_peers = num_peers; + } + } + } + } + + auto [metadata_offsets, metadata_size] = get_device_metadata_offsets_and_size( + num_sms, + work_tiles.size(), + num_combine_tiles + ); + + auto device_metadata = torch::empty( + {int(work_tiles.size())}, + torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU) + ); + + uint8_t *device_metadata_ptr = device_metadata.data_ptr(); + std::memcpy(device_metadata_ptr + metadata_offsets.work_tiles_offset, work_tiles.data(), work_tiles.size() * sizeof(StreamKWorkTile)); + std::memcpy(device_metadata_ptr + metadata_offsets.work_tiles_ind_ptr_offset, work_tiles_ind_ptr.data(), work_tiles_ind_ptr.size() * sizeof(int)); + std::memcpy(device_metadata_ptr + metadata_offsets.combine_tiles_offset, combine_tiles.data(), combine_tiles.size() * sizeof(StreamKCombineTile)); + + auto host_metadata = torch::empty( + {sizeof(StreamKSchedulerDescisions)}, + torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU) + ); + + auto host_metadata_ptr = reinterpret_cast(host_metadata.data_ptr()); + host_metadata_ptr->num_combine_blocks = num_combine_tiles; + host_metadata_ptr->max_num_peers = max_num_peers; + host_metadata_ptr->pack_gqa = pack_gqa; + host_metadata_ptr->use_one_mma_wg = use_one_mma_wg; + host_metadata_ptr->num_work_tiles = work_tiles.size(); + + return {device_metadata, host_metadata}; +} diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 3de3e4b5e..751ac6955 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -828,7 +828,6 @@ class StreamKPersistentTileScheduler: public TileSchedulerCommon get_block_coord(WorkTileInfo const& work_tile) const { - SeqlenInfo_t seqlen_info = create_seqlen_info(work_tile.bidb); + SeqlenInfo_t seqlen_info = create_seqlen_info(work_tile.work_tile.bidb); if constexpr(AppendKV) { return { - seqlen_info, work_tile.work_tile.m_block, work_tile.work_tile.bidh, work_tile.work_tile.bidb, + work_tile.work_tile.n_block_start, + work_tile.work_tile.n_block_start + work_tile.work_tile.n_blocks, work_tile.work_tile.peer_id, work_tile.work_tile.num_peers, 0, 0 // TODO support appendKV }; } else { return { - seqlen_info, work_tile.work_tile.m_block, work_tile.work_tile.bidh, work_tile.work_tile.bidb, + work_tile.work_tile.n_block_start, + work_tile.work_tile.n_block_start + work_tile.work_tile.n_blocks, work_tile.work_tile.peer_id, work_tile.work_tile.num_peers }; @@ -886,7 +887,7 @@ class StreamKPersistentTileScheduler: public TileSchedulerCommon Date: Sat, 3 May 2025 22:41:14 +0000 Subject: [PATCH 102/102] streamk working Signed-off-by: Lucas Wilkinson --- CMakeLists.txt | 7 +- hopper/flash.h | 5 +- hopper/flash_api.cpp | 83 +----- hopper/flash_api_torch_lib.cpp | 4 +- hopper/flash_fwd_combine_kernel.h | 22 +- hopper/flash_fwd_combine_launch_template.h | 6 +- hopper/flash_fwd_launch_template.h | 7 +- hopper/heuristics.h | 66 +++++ hopper/streamk.cu | 271 +++++++++++++++++++ hopper/streamk.h | 288 +++------------------ hopper/tile_scheduler.hpp | 8 +- tests/test_vllm_flash_attn.py | 24 +- vllm_flash_attn/flash_attn_interface.py | 10 +- 13 files changed, 449 insertions(+), 352 deletions(-) create mode 100644 hopper/streamk.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index e4423efcc..1d51094a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,6 +226,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) SOURCES hopper/flash_fwd_combine.cu hopper/flash_prepare_scheduler.cu + hopper/streamk.cu hopper/flash_api.cpp hopper/flash_api_torch_lib.cpp ${FA3_GEN_SRCS} @@ -244,11 +245,15 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) FLASHATTENTION_DISABLE_BACKWARD FLASHATTENTION_DISABLE_DROPOUT # FLASHATTENTION_DISABLE_ALIBI - # FLASHATTENTION_DISABLE_SOFTCAP + FLASHATTENTION_DISABLE_SOFTCAP FLASHATTENTION_DISABLE_UNEVEN_K # FLASHATTENTION_DISABLE_LOCAL FLASHATTENTION_DISABLE_PYBIND FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size + FLASHATTENTION_DISABLE_HDIM64 + FLASHATTENTION_DISABLE_HDIM96 + FLASHATTENTION_DISABLE_HDIM192 + FLASHATTENTION_DISABLE_HDIM256 ) elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0) message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.") diff --git a/hopper/flash.h b/hopper/flash.h index e88a6a7bf..c0be32291 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -151,6 +151,7 @@ struct Flash_fwd_params : public Qkv_params { int num_splits; // For split-KV version bool pack_gqa; + bool use_one_mma_wg; int * __restrict__ tile_count_semaphore; // int * __restrict__ num_m_blocks_ptr; @@ -165,9 +166,7 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ sm_work_tile_ind_ptr = nullptr; StreamKWorkTile * __restrict__ work_tiles_ptr = nullptr; StreamKCombineTile * __restrict__ combine_tiles_ptr = nullptr; - int num_combine_blocks = 0; - int streamk_m_block_size = 0; - bool use_one_mma_wg = false; + StreamKSchedulerDescisions const* host_scheduler_metadata_ptr = nullptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 14d9a67ad..131028b2a 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -408,53 +408,6 @@ inline bool get_pagedkv_tma(Flash_fwd_params const& params) { return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; } -inline bool get_pack_gqa(Flash_fwd_params const& params) { - // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. - // Has little effect on speed. - if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } - #ifdef FLASHATTENTION_DISABLE_PACKGQA - return false; - #else - // params.page_table must already be set - if (params.h == params.h_k) { return false; } - // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); - int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); - return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); - #endif -} - -inline int get_num_splits(Flash_fwd_params const& params) { - #ifdef FLASHATTENTION_DISABLE_SPLIT - return 1; - #else - // Always enable PackGQA for Split - // params.page_table must already be set - // This needs to match the kernel configs - bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params)); - // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits - // has not been set here. It's OK though because we might just underestimate kBlockN a bit - auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); - int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); - int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); - int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); - // If is_local, we're not going to load all of seqlen_k - int const seqlen_k_loaded = !params.is_local - ? params.seqlen_k - : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); - int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; - int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; - int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); - // Always enable PackGQA for Split - // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. - // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending - // that batch = 1. - int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; - return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); - #endif -} - inline int get_max_headdim() { #ifndef FLASHATTENTION_DISABLE_HDIM256 return 256; @@ -532,7 +485,8 @@ mha_fwd_get_scheduler_metadata( TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads % num_heads_k == 0 && num_heads_k > 1, + "Number of heads in key/value must divide number of heads in query"); // Reset the parameters Flash_fwd_params params{}; @@ -585,15 +539,8 @@ mha_fwd_get_scheduler_metadata( bool const use_stream_k = params.b <= 992; params.num_splits_dynamic_ptr = !use_stream_k ? nullptr : reinterpret_cast(1); - assert(use_stream_k && (num_splits <= 0 || num_splits == 2)); - num_splits = use_stream_k ? 2 : num_splits; - params.pagedkv_tma = get_pagedkv_tma(params); - // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split - params.pack_gqa = params.num_splits > 1; + determine_pack_gqa_splits_and_mma_wgs(params, num_splits, pack_gqa_, use_stream_k); bool is_varlen = true; @@ -617,6 +564,7 @@ mha_fwd_get_scheduler_metadata( params.is_e4m3 ? 1 : 2, false /*v_colmajor*/, true /*pagedkv*/, params.pagedkv_tma, params.softcap, params.seqlen_knew > 0 ); + return {host_metadata, device_metadata}; } if (params.num_splits_dynamic_ptr) { @@ -628,7 +576,7 @@ mha_fwd_get_scheduler_metadata( prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } - return tile_count_semaphore; + return {host_metadata, device_metadata}; } // b: batch_size @@ -944,33 +892,28 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= 992; + bool const use_stream_k = device_scheduler_metadata_ && host_scheduler_metadata_; + bool const use_dynamic_split = is_varlen && params.b <= 992 && !use_stream_k; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); - - params.pagedkv_tma = get_pagedkv_tma(params); - // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - // Always enable PackGQA for Split - params.pack_gqa = params.num_splits > 1; + determine_pack_gqa_splits_and_mma_wgs(params, num_splits, pack_gqa_, use_stream_k); // Assume streamk scheduling - if (host_scheduler_metadata_) { + if (use_stream_k) { auto host_metadata_ptr = reinterpret_cast(host_scheduler_metadata_->data_ptr()); + params.host_scheduler_metadata_ptr = host_metadata_ptr; params.num_splits = host_metadata_ptr->max_num_peers; - params.num_combine_blocks = host_metadata_ptr->num_combine_blocks; params.pack_gqa = host_metadata_ptr->pack_gqa; params.use_one_mma_wg = host_metadata_ptr->use_one_mma_wg; - + int num_work_tiles = host_metadata_ptr->num_work_tiles; assert(device_scheduler_metadata_.has_value()); auto device_metadata_ptr = device_scheduler_metadata_->data_ptr(); auto [offsets, total_size] = get_device_metadata_offsets_and_size( - params.num_sm, num_work_tiles, params.num_combine_blocks); + params.num_sm, num_work_tiles, host_metadata_ptr->num_combine_blocks); - params.work_tiles_ptr = reinterpret_cast(device_metadata_ptr); + params.work_tiles_ptr = reinterpret_cast(device_metadata_ptr + offsets.work_tiles_offset); params.sm_work_tile_ind_ptr = reinterpret_cast(device_metadata_ptr + offsets.work_tiles_ind_ptr_offset); params.combine_tiles_ptr = reinterpret_cast(device_metadata_ptr + offsets.combine_tiles_offset); } diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index 775d05608..b01ab864d 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -57,7 +57,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq ); // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available -at::Tensor +std::tuple mha_fwd_get_scheduler_metadata( int batch_size, int max_seqlen_q, @@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " bool has_softcap," " int num_splits," " bool? pack_gqa," - " int sm_margin) -> Tensor"); + " int sm_margin) -> (Tensor, Tensor)"); ops.impl("get_scheduler_metadata", torch::kCUDA, make_pytorch_shim(&mha_fwd_get_scheduler_metadata)); } diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 8fa677727..b7ecfd230 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -231,16 +231,13 @@ class FlashAttnFwdCombine { int const seqlen = seqlen_info.seqlen; int max_idx = seqlen * get<2>(params.shape_LSE_partial); - if constexpr (Varlen) { - if (m_block * kBlockM >= max_idx) { return; } - } // TODO(Lucas) This is very hacky, we should really write this kernel in a way that that more naturally // supports StreamK if constexpr(StreamK) { auto tile = params.streamk_combine_tiles_ptr[blockIdx.x]; // m_block is encdoded as (seqlen, num_head): (1, seqlen), - // i.e. ind2crd(m_block, (seqlen, num_head)) -> (mi, head) + // i.e. ind2crd(m_block, (seqlen, num_head)) -> (mi, head) int mi = tile.m_block * params.streamk_m_block_size + blockIdx.z * kBlockM; int bidh = tile.bidh; m_block = mi + bidh * seqlen; @@ -249,6 +246,17 @@ class FlashAttnFwdCombine { max_idx = seqlen + bidh * seqlen; } + // For StreamK we don't always start at kBlockM boundaries, so m_block + // encodes the offset instead of the block number + auto get_m_block_offset = [&](int m_block) { + if constexpr (StreamK) return m_block; + else return m_block * kBlockM; + }; + + if constexpr (Varlen) { + if (get_m_block_offset(m_block) >= max_idx) { return; } + } + cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); // Step 1: load LSE_partial from gmem -> smem @@ -270,7 +278,7 @@ class FlashAttnFwdCombine { #pragma unroll for (int m = 0; m < size<2>(tLSEcLSE); ++m) { int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); - int idx = m_block * kBlockM + mi; + int idx = get_m_block_offset(m_block) + mi; if (idx < max_idx) { int m_idx, bidh; if constexpr (!Varlen) { @@ -314,7 +322,7 @@ class FlashAttnFwdCombine { #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { int mi = get<0>(tOcO(_0{}, m, _0{})); - int idx = m_block * kBlockM + mi; + int idx = get_m_block_offset(m_block) + mi; if constexpr (!Varlen) { tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx); } else { @@ -424,7 +432,7 @@ class FlashAttnFwdCombine { for (int m = 0; m < size<2>(ts2rrLSE); ++m) { if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); - int idx = m_block * kBlockM + mi; + int idx = get_m_block_offset(m_block) + mi; if (idx < max_idx) { int m_idx, bidh; if constexpr (!Varlen) { diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index f95ee54c9..d083a5c1d 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -36,7 +36,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore, - params.combine_tiles_ptr, params.streamk_m_block_size + params.combine_tiles_ptr, params.host_scheduler_metadata_ptr->m_block_size }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); @@ -45,8 +45,8 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e dim3 grid_m(num_blocks_m, num_blocks_k, params.b); if constexpr(StreamK) { - grid_m.x = params.num_combine_blocks; - grid_m.z = cute::ceil_div(params.streamk_m_block_size, kBlockM); + grid_m.x = params.host_scheduler_metadata_ptr->num_combine_blocks; + grid_m.z = cute::ceil_div(params.host_scheduler_metadata_ptr->m_block_size, kBlockM); } auto kernel = cutlass::device_kernel; diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index dcfc39e4d..d71d191e8 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -163,7 +163,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.page_table, {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table params.num_splits_dynamic_ptr, - params.window_size_left, params.window_size_right + params.window_size_left, params.window_size_right, + params.sm_work_tile_ind_ptr, + params.work_tiles_ptr, + params.host_scheduler_metadata_ptr->grid_size }; if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { @@ -214,7 +217,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { - BOOL_SWITCH(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k) <= 64 || params.use_one_mma_wg, Use_one_mma_wg, [&] { + BOOL_SWITCH(params.use_one_mma_wg, Use_one_mma_wg, [&] { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128; diff --git a/hopper/heuristics.h b/hopper/heuristics.h index 43d06f548..5e36e4792 100644 --- a/hopper/heuristics.h +++ b/hopper/heuristics.h @@ -63,3 +63,69 @@ inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks } return 1; } + +inline bool get_pack_gqa(Flash_fwd_params const& params) { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. + // Has little effect on speed. + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } + #ifdef FLASHATTENTION_DISABLE_PACKGQA + return false; + #else + // params.page_table must already be set + if (params.h == params.h_k) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); + #endif +} + +inline int get_num_splits(Flash_fwd_params const& params) { + #ifdef FLASHATTENTION_DISABLE_SPLIT + return 1; + #else + // Always enable PackGQA for Split + // params.page_table must already be set + // This needs to match the kernel configs + bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; + + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, params.use_one_mma_wg); + // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits + // has not been set here. It's OK though because we might just underestimate kBlockN a bit + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); + // If is_local, we're not going to load all of seqlen_k + int const seqlen_k_loaded = !params.is_local + ? params.seqlen_k + : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); + int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; + int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; + int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); + // Always enable PackGQA for Split + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); + #endif +} + + +inline void determine_pack_gqa_splits_and_mma_wgs( + Flash_fwd_params ¶ms, + int num_splits, + std::optional const& pack_gqa, + bool use_stream_k = false +) { + assert(use_stream_k && (num_splits <= 0 || num_splits == 2)); + num_splits = use_stream_k ? 2 : num_splits; + + // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) + params.use_one_mma_wg = use_one_mma_wg(params); + params.pack_gqa = pack_gqa.has_value() ? pack_gqa.value() : get_pack_gqa(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split + params.pack_gqa = params.num_splits > 1; +} \ No newline at end of file diff --git a/hopper/streamk.cu b/hopper/streamk.cu new file mode 100644 index 000000000..e8ffc28b1 --- /dev/null +++ b/hopper/streamk.cu @@ -0,0 +1,271 @@ +#include + +std::tuple streamk_schedule( + int arch, + int num_sms, + int batch_size, + std::optional &cu_seqlens_q, + const at::Tensor &seqused_k, + int seqlen_q, + int seqlen_k, + int num_heads_q, + int num_heads_k, + int headdim, + int headdim_v, + bool is_causal, + bool is_local, + int element_size, + bool v_colmajor, + bool paged_kv, + bool paged_kv_non_TMA, + bool softcap, + bool append_kv +) { + assert (is_local == false && "StreamK + Local attention not supported yet"); + + std::optional cu_seqlens_q_cpu; + if (cu_seqlens_q) { + cu_seqlens_q_cpu.emplace(cu_seqlens_q->cpu()); + } + auto seqused_k_cpu = seqused_k.cpu(); + + auto get_tile_sizes = [&](bool use_one_mma_wg) -> std::tuple { + if (arch == 90) { + auto ts = tile_size_fwd_sm90(headdim, headdim_v, is_causal, is_local, element_size, v_colmajor, paged_kv_non_TMA, softcap, use_one_mma_wg); + return std::make_tuple(std::get<0>(ts), std::get<1>(ts)); + } else if (arch < 90) { + auto ts = tile_size_fwd_sm8x(headdim, headdim_v, is_causal, is_local, element_size, paged_kv, /* varlen_and_split */ true, softcap, append_kv); + return std::make_tuple(std::get<0>(ts), std::get<1>(ts)); + } else { + assert(false && "Unsupported architecture"); + return std::make_tuple(0, 0); + } + }; + + auto get_seqlen_k = [&](int bidb) { + return seqused_k_cpu.accessor()[bidb]; + }; + + + auto get_seqlen_q = [&](int bidb) { + if (cu_seqlens_q_cpu.has_value()) { + auto cu_seqlens_q = cu_seqlens_q_cpu.value().accessor(); + return cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb]; + } else { + return seqlen_q; + } + }; + + auto get_num_n_tiles = [&](int m_tile, int seqlen_k, int seqlen_q, int block_m, int block_n) { + int m_tile_k_end = std::min( + (seqlen_k - seqlen_q) + (m_tile + 1) * block_m, + seqlen_k + ); + return cutlass::ceil_div(m_tile_k_end, block_n); + }; + + auto tile_sizes = get_tile_sizes(false); + auto tile_sizes_one_mma_wg = get_tile_sizes(true); + + auto compute_tiles = [&, get_num_n_tiles = get_num_n_tiles]( + int num_heads_q, + int num_heads_k, + int seqlen_q, + int seqlen_k, + bool causal, + bool pack_gqa, + bool one_mma_wg + ) { + int block_m = one_mma_wg ? std::get<0>(tile_sizes_one_mma_wg) : std::get<0>(tile_sizes); + int block_n = one_mma_wg ? std::get<1>(tile_sizes_one_mma_wg) : std::get<1>(tile_sizes); + + seqlen_q *= pack_gqa ? num_heads_q / num_heads_k : 1; + int num_heads = pack_gqa ? num_heads_k : num_heads_q; + int m_tiles = cutlass::ceil_div(seqlen_q, block_m); + int tiles_total = 0; + if (causal) { + tiles_total += m_tiles * cutlass::ceil_div(seqlen_k, block_n); + } else { + for (int m_tile = 0; m_tile < m_tiles; m_tile++) { + tiles_total += get_num_n_tiles( + m_tile, seqlen_k, seqlen_q, block_m, block_n); + } + } + + return tiles_total * num_heads; + }; + + bool pack_gqa = false; + + // Determine if we should pack GQA by determining the the amount of + // available work that would benefit from packing GQA + // Assume not `use_one_mma_wg` for now, we determine this later + if (num_heads_q > num_heads_k) { + assert (num_heads_q % num_heads_k == 0); + + int total_tiles_pack_gqa = 0; + int total_tiles_no_pack_gqa = 0; + + for (int bidb = 0; bidb < batch_size; ++bidb) { + int seqlen_k = get_seqlen_k(bidb); + int seqlen_q = get_seqlen_q(bidb); + + total_tiles_pack_gqa += compute_tiles( + num_heads_q, num_heads_k, seqlen_q, seqlen_k, is_causal, true, false); + total_tiles_no_pack_gqa += compute_tiles( + num_heads_q, num_heads_k, seqlen_q, seqlen_k, is_causal, false, false); + } + + if (total_tiles_pack_gqa < (total_tiles_no_pack_gqa * 1.1f)) { + pack_gqa = true; + } + } + + bool use_one_mma_wg = false; + // Determine the amount of work that would benefit from using one MMA + // workgroup + + int total_tiles = 0; + int total_tiles_one_mma_wg = 0; + + for (int bidb = 0; bidb < batch_size; ++bidb) { + int seqlen_k = get_seqlen_k(bidb); + int seqlen_q = get_seqlen_q(bidb); + + total_tiles += compute_tiles( + num_heads_q, num_heads_k, seqlen_q, seqlen_k, is_causal, pack_gqa, false); + total_tiles_one_mma_wg += compute_tiles( + num_heads_q, num_heads_k, seqlen_q, seqlen_k, is_causal, pack_gqa, true); + } + + // if using one_mma_wg only increases the number of tiles by 50% or less + // then we should use it (since it performs each tile computes 1/2 as much) + // TODO(lucas): 50% is a guesstimate, we should do a more thorough analysis + if (total_tiles_one_mma_wg < (total_tiles * 1.5f)) { + use_one_mma_wg = true; + } + + tile_sizes = use_one_mma_wg ? tile_sizes_one_mma_wg : tile_sizes; + total_tiles = use_one_mma_wg ? total_tiles_one_mma_wg : total_tiles; + + int block_m = use_one_mma_wg ? std::get<0>(tile_sizes_one_mma_wg) : std::get<0>(tile_sizes); + int block_n = use_one_mma_wg ? std::get<1>(tile_sizes_one_mma_wg) : std::get<1>(tile_sizes); + + int target_tiles_per_sm = cutlass::ceil_div(total_tiles, num_sms); + + std::vector work_tiles; + work_tiles.reserve(1024); + std::vector work_tiles_ind_ptr; + work_tiles_ind_ptr.reserve(num_sms + 1); + work_tiles_ind_ptr.push_back(0); + std::vector combine_tiles; + combine_tiles.reserve(1024); + + int min_tiles = 1; + int max_num_peers = 0; + + int current_tile = 0; + int sm_target_tiles_remaining = target_tiles_per_sm; + + int tiles_total_allocated = 0; + int num_heads = (pack_gqa) ? num_heads_k : num_heads_q; + + for (int bidb = 0; bidb < batch_size; ++bidb) { + int seqlen_k = get_seqlen_k(bidb); + int seqlen_q = get_seqlen_q(bidb); + int m_tiles = cutlass::ceil_div(seqlen_q, block_m); + for (int bidh = 0; bidh < num_heads; ++bidh) { + for (int m_tile = 0; m_tile < m_tiles; m_tile++) { + int n_tiles = get_num_n_tiles( + m_tile, seqlen_k, seqlen_q, block_m, block_n); + + int m_tile_start_idx = current_tile; + int curr_n_tile_start = 0; + int curr_n_tiles_remaining = n_tiles; + int num_peers = 0; + + while (curr_n_tiles_remaining > 0) { + int n_tiles = std::min(curr_n_tiles_remaining, sm_target_tiles_remaining); + + // if we would leave a residual tile that is less than the minimum tiles + // then we should just take the rest of the tiles + if (curr_n_tiles_remaining - n_tiles < min_tiles) { + n_tiles = curr_n_tiles_remaining; + } + + curr_n_tiles_remaining -= n_tiles; + sm_target_tiles_remaining -= n_tiles; + + work_tiles.emplace_back(StreamKWorkTile{ + /* m_block: */ m_tile, + /* n_block_start: */ curr_n_tile_start, + /* n_blocks: */ uint16_t(n_tiles), + /* bidb: */ uint16_t(bidb), + /* bidh: */ uint16_t(bidh), + /* peer_id: */ uint8_t(num_peers), + /* num_peers: */ 0 + }); + + current_tile += 1; + num_peers += 1; + curr_n_tile_start += n_tiles; + tiles_total_allocated += n_tiles; + + if (sm_target_tiles_remaining <= 0) { + work_tiles_ind_ptr.push_back(current_tile); + sm_target_tiles_remaining = target_tiles_per_sm; + } + } + + if (num_peers > 1) { + combine_tiles.emplace_back(StreamKCombineTile{ + /* m_block: */ m_tile, + /* bidh: */ uint16_t(bidh), + /* bidb: */ uint16_t(bidb), + /* num_peers: */ uint16_t(num_peers) + }); + } + + if (num_peers > max_num_peers) { + max_num_peers = num_peers; + } + + for (int i = m_tile_start_idx; i < current_tile; ++i) { + work_tiles[i].num_peers = num_peers; + } + } + } + } + + auto [metadata_offsets, metadata_size] = get_device_metadata_offsets_and_size( + num_sms, + work_tiles.size(), + combine_tiles.size() + ); + + auto device_metadata = torch::empty( + {int(metadata_size)}, + torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU) + ); + + uint8_t *device_metadata_ptr = device_metadata.data_ptr(); + std::memcpy(device_metadata_ptr + metadata_offsets.work_tiles_offset, work_tiles.data(), work_tiles.size() * sizeof(StreamKWorkTile)); + std::memcpy(device_metadata_ptr + metadata_offsets.work_tiles_ind_ptr_offset, work_tiles_ind_ptr.data(), work_tiles_ind_ptr.size() * sizeof(int)); + std::memcpy(device_metadata_ptr + metadata_offsets.combine_tiles_offset, combine_tiles.data(), combine_tiles.size() * sizeof(StreamKCombineTile)); + + auto host_metadata = torch::empty( + {sizeof(StreamKSchedulerDescisions)}, + torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU) + ); + + auto host_metadata_ptr = reinterpret_cast(host_metadata.data_ptr()); + host_metadata_ptr->num_combine_blocks = combine_tiles.size(); + host_metadata_ptr->max_num_peers = max_num_peers; + host_metadata_ptr->pack_gqa = pack_gqa; + host_metadata_ptr->use_one_mma_wg = use_one_mma_wg; + host_metadata_ptr->num_work_tiles = work_tiles.size(); + host_metadata_ptr->grid_size = work_tiles_ind_ptr.size() - 1; + host_metadata_ptr->m_block_size = block_m; + + return {device_metadata, host_metadata}; +} \ No newline at end of file diff --git a/hopper/streamk.h b/hopper/streamk.h index be8ca8a2a..acb8a487a 100644 --- a/hopper/streamk.h +++ b/hopper/streamk.h @@ -4,17 +4,28 @@ #include #include "cute/container/alignment.hpp" -#include "flash.h" #include "tile_size.h" struct CUTE_ALIGNAS(16) StreamKSchedulerDescisions { int num_combine_blocks; int num_work_tiles; int max_num_peers; + int grid_size; + int m_block_size; bool pack_gqa; bool use_one_mma_wg; + + void print() const { + printf("StreamKSchedulerDescisions:\n"); + printf(" num_combine_blocks: %d\n", num_combine_blocks); + printf(" num_work_tiles: %d\n", num_work_tiles); + printf(" max_num_peers: %d\n", max_num_peers); + printf(" pack_gqa: %d\n", pack_gqa); + printf(" use_one_mma_wg: %d\n", use_one_mma_wg); + } }; +// Make sure this fits into 128bits struct CUTE_ALIGNAS(16) StreamKWorkTile { int m_block = -1; int n_block_start = 0; @@ -23,6 +34,17 @@ struct CUTE_ALIGNAS(16) StreamKWorkTile { uint16_t bidh = 0; // Max num heads is 65535 uint8_t peer_id = 0; // Max 255 peers uint8_t num_peers = 0; // Max 255 peers + + void print() const { + printf("StreamKWorkTile:\n"); + printf(" m_block: %d\n", m_block); + printf(" n_block_start: %d\n", n_block_start); + printf(" n_blocks: %d\n", n_blocks); + printf(" bidb: %d\n", bidb); + printf(" bidh: %d\n", bidh); + printf(" peer_id: %d\n", peer_id); + printf(" num_peers: %d\n", num_peers); + } }; struct CUTE_ALIGNAS(8) StreamKCombineTile { @@ -30,9 +52,18 @@ struct CUTE_ALIGNAS(8) StreamKCombineTile { uint16_t const bidh; uint16_t const bidb; uint16_t const num_peers; // Number of peers / num_splits + + void print() const { + printf("StreamKCombineTile:\n"); + printf(" m_block: %d\n", m_block); + printf(" bidh: %d\n", bidh); + printf(" bidb: %d\n", bidb); + printf(" num_peers: %d\n", num_peers); + } }; -struct StreamKMetadataByteOffsets { int const work_tiles_offset; +struct StreamKMetadataByteOffsets { + int const work_tiles_offset; int const combine_tiles_offset; int const work_tiles_ind_ptr_offset; }; @@ -49,7 +80,7 @@ inline std::tuple get_device_metadata_offsets_a int work_tiles_offset = 0; int combine_tiles_offset = work_tiles_offset + round_up_to_16(num_work_tiles * sizeof(StreamKWorkTile)); int work_tiles_ind_ptr_offset = combine_tiles_offset + round_up_to_16(num_combine_tiles * sizeof(StreamKCombineTile)); - int total_size = work_tiles_ind_ptr_offset + round_up_to_16(num_sms * sizeof(int)); + int total_size = work_tiles_ind_ptr_offset + round_up_to_16((num_sms + 1) * sizeof(int)); StreamKMetadataByteOffsets metadata_offsets{ work_tiles_offset, @@ -60,7 +91,7 @@ inline std::tuple get_device_metadata_offsets_a return std::make_tuple(metadata_offsets, total_size); } -inline std::tuple streamk_schedule( +std::tuple streamk_schedule( int arch, int num_sms, int batch_size, @@ -80,251 +111,4 @@ inline std::tuple streamk_schedule( bool paged_kv_non_TMA, bool softcap, bool append_kv -) { - assert (is_local == false && "StreamK + Local attention not supported yet"); - - std::optional cu_seqlens_q_cpu; - if (cu_seqlens_q) { - cu_seqlens_q_cpu.emplace(cu_seqlens_q->cpu()); - } - auto seqused_k_cpu = seqused_k.cpu(); - - auto get_tile_sizes = [&](bool use_one_mma_wg) -> std::tuple { - if (arch == 90) { - auto ts = tile_size_fwd_sm90(headdim, headdim_v, is_causal, is_local, element_size, v_colmajor, paged_kv_non_TMA, softcap); - return std::make_tuple(std::get<0>(ts), std::get<1>(ts)); - } else if (arch < 90) { - auto ts = tile_size_fwd_sm8x(headdim, headdim_v, is_causal, is_local, element_size, paged_kv, /* varlen_and_split */ true, softcap, append_kv); - return std::make_tuple(std::get<0>(ts), std::get<1>(ts)); - } else { - assert(false && "Unsupported architecture"); - return std::make_tuple(0, 0); - } - }; - - auto get_seqlen_k = [&](int bidb) { - return seqused_k_cpu.accessor()[bidb]; - }; - - - auto get_seqlen_q = [&](int bidb) { - if (cu_seqlens_q_cpu.has_value()) { - auto cu_seqlens_q = cu_seqlens_q_cpu.value().accessor(); - return cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb]; - } else { - return seqlen_q; - } - }; - - auto get_num_n_tiles = [&](int m_tile, int seqlen_k, int seqlen_q, int block_m, int block_n) { - int m_tile_k_end = std::max( - (seqlen_k - seqlen_q) + (m_tile + 1) * block_m, - seqlen_k - ); - return cutlass::ceil_div(m_tile_k_end, block_n); - }; - - auto tile_sizes = get_tile_sizes(false); - auto tile_sizes_one_mma_wg = get_tile_sizes(true); - - auto compute_tiles = [&, get_num_n_tiles = get_num_n_tiles]( - int num_heads, - int num_heads_k, - int seqlen_q, - int seqlen_k, - bool causal, - bool pack_gqa, - bool one_mma_wg - ) { - int block_m = one_mma_wg ? std::get<0>(tile_sizes_one_mma_wg) : std::get<0>(tile_sizes); - int block_n = one_mma_wg ? std::get<1>(tile_sizes_one_mma_wg) : std::get<1>(tile_sizes); - - seqlen_q *= pack_gqa ? num_heads / num_heads_k : 1; - num_heads = pack_gqa ? num_heads_k : num_heads; - int m_tiles = cutlass::ceil_div(seqlen_q, block_m); - int tiles_total = 0; - if (causal) { - tiles_total += m_tiles * cutlass::ceil_div(seqlen_k, block_n); - } else { - int block_m = one_mma_wg ? std::get<0>(tile_sizes_one_mma_wg) : std::get<0>(tile_sizes); - int block_n = one_mma_wg ? std::get<1>(tile_sizes_one_mma_wg) : std::get<1>(tile_sizes); - - for (int m_tile = 0; m_tile < m_tiles; m_tile++) { - tiles_total += get_num_n_tiles( - m_tile, seqlen_k, seqlen_q, block_m, block_n); - } - } - - return tiles_total * num_heads; - }; - - bool pack_gqa = false; - - // Determine if we should pack GQA by determining the the amount of - // available work that would benefit from packing GQA - // Assume not `use_one_mma_wg` for now, we determine this later - if (num_heads > num_heads_k) { - assert (num_heads % num_heads_k == 0); - - int total_tiles_pack_gqa = 0; - int total_tiles_no_pack_gqa = 0; - - for (int bidb = 0; bidb < batch_size; ++bidb) { - int seqlen_k = get_seqlen_k(bidb); - int seqlen_q = get_seqlen_q(bidb); - - total_tiles_pack_gqa += compute_tiles( - num_heads, num_heads_k, seqlen_q, seqlen_k, is_causal, true, false); - total_tiles_no_pack_gqa += compute_tiles( - num_heads, num_heads_k, seqlen_q, seqlen_k, is_causal, false, false); - } - - if (total_tiles_pack_gqa < (total_tiles_no_pack_gqa * 1.1f)) { - pack_gqa = true; - } - } - - bool use_one_mma_wg = false; - // Determine the amount of work that would benefit from using one MMA - // workgroup - - int total_tiles = 0; - int total_tiles_one_mma_wg = 0; - - for (int bidb = 0; bidb < batch_size; ++bidb) { - int seqlen_k = get_seqlen_k(bidb); - int seqlen_q = get_seqlen_q(bidb); - - total_tiles += compute_tiles( - num_heads, num_heads_k, seqlen_q, seqlen_k, is_causal, pack_gqa, false); - total_tiles_one_mma_wg += compute_tiles( - num_heads, num_heads_k, seqlen_q, seqlen_k, is_causal, pack_gqa, true); - } - - // if using one_mma_wg only increases the number of tiles by 50% or less - // then we should use it (since it performs each tile computes 1/2 as much) - // TODO(lucas): 50% is a guesstimate, we should do a more thorough analysis - if (total_tiles_one_mma_wg < (total_tiles * 1.5f)) { - use_one_mma_wg = true; - } - - tile_sizes = use_one_mma_wg ? tile_sizes_one_mma_wg : tile_sizes; - total_tiles = use_one_mma_wg ? total_tiles_one_mma_wg : total_tiles; - - int block_m = std::get<0>(tile_sizes); - int block_n = std::get<1>(tile_sizes); - - int target_tiles_per_sm = cutlass::ceil_div(total_tiles, num_sms); - - std::vector work_tiles; - work_tiles.reserve(1024); - std::vector work_tiles_ind_ptr; - work_tiles_ind_ptr.reserve(num_sms + 1); - work_tiles_ind_ptr.push_back(0); - std::vector combine_tiles; - combine_tiles.reserve(1024); - - int min_tiles = 2; - int max_num_peers = 0; - - int current_tile = 0; - int sm_target_tiles_remaining = target_tiles_per_sm; - int num_combine_tiles = 0; - - for (int bidb = 0; bidb < batch_size; ++bidb) { - for (int bidh = 0; bidh < num_heads; ++bidh) { - int seqlen_k = get_seqlen_k(bidb); - int seqlen_q = get_seqlen_q(bidb); - int m_tiles = cutlass::ceil_div(seqlen_q, block_m); - - for (int m_tile = 0; m_tile < m_tiles; m_tile++) { - int n_tiles = get_num_n_tiles( - m_tile, seqlen_k, seqlen_q, block_m, block_n); - - int m_tile_start_idx = current_tile; - int curr_n_tile_start = 0; - int curr_n_tiles_remaining = n_tiles; - int num_peers = 0; - - while (curr_n_tiles_remaining > 0) { - int n_tile = std::min(curr_n_tiles_remaining, sm_target_tiles_remaining); - - // if we would leave a residual tile that is less than the minimum tiles - // then we should just take the rest of the tiles - if (curr_n_tiles_remaining - n_tile < min_tiles) { - n_tile = curr_n_tiles_remaining; - } - - curr_n_tiles_remaining -= n_tile; - sm_target_tiles_remaining -= n_tile; - - work_tiles.emplace_back(StreamKWorkTile{ - /* m_block: */ m_tile, - /* n_block_start: */ curr_n_tile_start, - /* n_blocks: */ uint16_t(n_tile), - /* bidb: */ uint16_t(bidb), - /* bidh: */ uint16_t(bidh), - /* peer_id: */ uint8_t(num_peers), - /* num_peers: */ 0 - }); - - current_tile += 1; - num_peers += 1; - curr_n_tile_start += n_tile; - - if (sm_target_tiles_remaining <= 0) { - work_tiles_ind_ptr.push_back(current_tile); - sm_target_tiles_remaining = target_tiles_per_sm; - } - } - - if (num_peers > 1) { - combine_tiles.emplace_back(StreamKCombineTile{ - /* m_block: */ m_tile, - /* bidh: */ uint16_t(bidh), - /* bidb: */ uint16_t(bidb), - /* num_peers: */ uint16_t(num_peers) - }); - } - - if (num_peers > max_num_peers) { - max_num_peers = num_peers; - } - - for (int i = m_tile_start_idx; i < current_tile; ++i) { - work_tiles[i].num_peers = num_peers; - } - } - } - } - - auto [metadata_offsets, metadata_size] = get_device_metadata_offsets_and_size( - num_sms, - work_tiles.size(), - num_combine_tiles - ); - - auto device_metadata = torch::empty( - {int(work_tiles.size())}, - torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU) - ); - - uint8_t *device_metadata_ptr = device_metadata.data_ptr(); - std::memcpy(device_metadata_ptr + metadata_offsets.work_tiles_offset, work_tiles.data(), work_tiles.size() * sizeof(StreamKWorkTile)); - std::memcpy(device_metadata_ptr + metadata_offsets.work_tiles_ind_ptr_offset, work_tiles_ind_ptr.data(), work_tiles_ind_ptr.size() * sizeof(int)); - std::memcpy(device_metadata_ptr + metadata_offsets.combine_tiles_offset, combine_tiles.data(), combine_tiles.size() * sizeof(StreamKCombineTile)); - - auto host_metadata = torch::empty( - {sizeof(StreamKSchedulerDescisions)}, - torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU) - ); - - auto host_metadata_ptr = reinterpret_cast(host_metadata.data_ptr()); - host_metadata_ptr->num_combine_blocks = num_combine_tiles; - host_metadata_ptr->max_num_peers = max_num_peers; - host_metadata_ptr->pack_gqa = pack_gqa; - host_metadata_ptr->use_one_mma_wg = use_one_mma_wg; - host_metadata_ptr->num_work_tiles = work_tiles.size(); - - return {device_metadata, host_metadata}; -} +); diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 751ac6955..c15a2aadc 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -43,8 +43,10 @@ struct TileSchedulerArguments { int const* const num_splits_dynamic_ptr = nullptr; int const window_size_left = -1; int const window_size_right = 0; + // StreamK specific int const* const sm_work_tile_ind_ptr = nullptr; StreamKWorkTile const* const work_tiles_ptr = nullptr; + int grid_size = 0; }; template @@ -819,6 +821,7 @@ class StreamKPersistentTileScheduler: public TileSchedulerCommon