Skip to content

Commit

Permalink
Fix CI (OpenNMT#1747)
Browse files Browse the repository at this point in the history
Fix CI + disable Flash Attention 2
  • Loading branch information
minhthuc2502 authored Aug 7, 2024
1 parent 39f48f2 commit e6a8f94
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 105 deletions.
38 changes: 4 additions & 34 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
CT2_VERBOSE: 1
strategy:
matrix:
os: [ubuntu-20.04, macos-11]
os: [ubuntu-20.04]
backend: [mkl, dnnl]

steps:
Expand All @@ -33,17 +33,6 @@ jobs:
sudo sh -c 'echo "deb https://apt.repos.intel.com/oneapi all main" > /etc/apt/sources.list.d/oneAPI.list'
sudo apt-get update
- name: Store URL for downloading Intel oneAPI to environment variable
if: startsWith(matrix.os, 'macos')
run: |
echo 'ONEAPI_INSTALLER_URL=https://registrationcenter-download.intel.com/akdlm/irc_nas/19080/m_BaseKit_p_2023.0.0.25441_offline.dmg' >> $GITHUB_ENV
- name: Install Intel oneAPI
if: startsWith(matrix.os, 'macos')
run: |
wget -q $ONEAPI_INSTALLER_URL
hdiutil attach -noverify -noautofsck $(basename $ONEAPI_INSTALLER_URL)
- name: Configure with MKL
if: startsWith(matrix.os, 'ubuntu') && matrix.backend == 'mkl'
env:
Expand All @@ -61,20 +50,6 @@ jobs:
sudo apt-get install -y intel-oneapi-dnnl-devel=$DNNL_VERSION intel-oneapi-dnnl=$DNNL_VERSION
cmake -DCMAKE_INSTALL_PREFIX=$PWD/install -DBUILD_TESTS=ON -DWITH_MKL=OFF -DOPENMP_RUNTIME=COMP -DWITH_DNNL=ON .
- name: Configure with MKL
if: startsWith(matrix.os, 'macos') && matrix.backend == 'mkl'
env:
CT2_USE_MKL: 1
run: |
sudo /Volumes/$(basename $ONEAPI_INSTALLER_URL .dmg)/bootstrapper.app/Contents/MacOS/bootstrapper --silent --eula accept --components intel.oneapi.mac.mkl.devel
cmake -DCMAKE_INSTALL_PREFIX=$PWD/install -DBUILD_TESTS=ON .
- name: Configure with DNNL
if: startsWith(matrix.os, 'macos') && matrix.backend == 'dnnl'
run: |
sudo /Volumes/$(basename $ONEAPI_INSTALLER_URL .dmg)/bootstrapper.app/Contents/MacOS/bootstrapper --silent --eula accept --components intel.oneapi.mac.dnnl
cmake -DCMAKE_INSTALL_PREFIX=$PWD/install -DBUILD_TESTS=ON -DWITH_MKL=OFF -DWITH_DNNL=ON .
- name: Build
run: |
make install
Expand Down Expand Up @@ -157,12 +132,12 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04, macos-11, windows-2019]
os: [ubuntu-20.04, windows-2019]
arch: [auto64]
include:
- os: ubuntu-20.04
arch: aarch64
- os: macos-11
- os: macos-12
arch: arm64

steps:
Expand Down Expand Up @@ -206,7 +181,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04, macos-11, windows-2019]
os: [ubuntu-20.04, windows-2019]

steps:
- name: Set up Python 3.8
Expand All @@ -231,11 +206,6 @@ jobs:
run: |
pip install *cp38*manylinux*x86_64.whl
- name: Install wheel
if: startsWith(matrix.os, 'macos')
run: |
pip install *cp38*macosx*x86_64.whl
- name: Install wheel
if: startsWith(matrix.os, 'windows')
shell: bash
Expand Down
146 changes: 78 additions & 68 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ option(BUILD_CLI "Compile the clients" ON)
option(BUILD_TESTS "Compile the tests" OFF)
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF)

if(ENABLE_PROFILING)
message(STATUS "Enable profiling support")
Expand Down Expand Up @@ -485,8 +486,10 @@ if (WITH_CUDA)
set(CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER})

# flags for flash attention
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda")
if (WITH_FLASH_ATTN)
list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr")
list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda")
endif()

message(STATUS "NVCC host compiler: ${CUDA_HOST_COMPILER}")
message(STATUS "NVCC compilation flags: ${CUDA_NVCC_FLAGS}")
Expand Down Expand Up @@ -570,77 +573,84 @@ if (WITH_CUDA)
src/ops/topp_mask_gpu.cu
src/ops/quantize_gpu.cu
src/ops/nccl_ops_gpu.cu
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
src/ops/awq/gemm_gpu.cu
src/ops/awq/gemv_gpu.cu
src/ops/awq/dequantize_gpu.cu
)
if (WITH_FLASH_ATTN)
add_definitions(-DCT2_WITH_FLASH_ATTN)
cuda_add_library(${PROJECT_NAME}
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
)

set_source_files_properties(
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
PROPERTIES COMPILE_FLAGS "--use_fast_math")
endif()


set_source_files_properties(
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
PROPERTIES COMPILE_FLAGS "--use_fast_math")
elseif(WITH_CUDNN)
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
else()
Expand Down
5 changes: 4 additions & 1 deletion python/tools/prepare_build_environment_linux.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ if [ "$CIBW_ARCHS" == "aarch64" ]; then
rm -r OpenBLAS-*

else

# Install CUDA 12.2:
yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo
# error mirrorlist.centos.org doesn't exists anymore.
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
yum install --setopt=obsoletes=0 -y \
cuda-nvcc-12-2-12.2.140-1 \
cuda-cudart-devel-12-2-12.2.140-1 \
Expand Down
3 changes: 3 additions & 0 deletions python/tools/prepare_test_environment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
set -e
set -x

# force use pip < 24.1
python -m pip install 'pip<24.1'

# Install test rquirements
pip cache purge
pip --no-cache-dir install -r python/tests/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down
10 changes: 8 additions & 2 deletions src/ops/flash_attention_gpu.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "ctranslate2/ops/flash_attention.h"
#ifdef CT2_WITH_FLASH_ATTN
#include "ctranslate2/ops/flash-attention/flash.h"
#include "ctranslate2/ops/flash-attention/static_switch.h"
#endif
#include "ctranslate2/ops/transpose.h"
#include "ctranslate2/ops/slide.h"
#include "cuda/utils.h"

#include "dispatch.h"
Expand All @@ -13,6 +14,7 @@

namespace ctranslate2 {
namespace ops {
#ifdef CT2_WITH_FLASH_ATTN
static void set_params_fprop(Flash_fwd_params &params,
// sizes
const size_t b,
Expand Down Expand Up @@ -188,7 +190,7 @@ namespace ctranslate2 {
}

static const ops::Transpose transpose_op({0, 2, 1, 3});

#endif
template<>
void FlashAttention::compute<Device::CUDA>(StorageView& queries,
StorageView& keys,
Expand All @@ -203,6 +205,7 @@ namespace ctranslate2 {
const bool rotary_interleave,
StorageView* alibi,
dim_t offset) const {
#ifdef CT2_WITH_FLASH_ATTN
const Device device = queries.device();
const DataType dtype = queries.dtype();

Expand Down Expand Up @@ -357,6 +360,9 @@ namespace ctranslate2 {
output.reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
#else
throw std::runtime_error("Flash attention 2 is not supported");
#endif
}
}
}

0 comments on commit e6a8f94

Please sign in to comment.