Skip to content

[ILUVATAR_GPU] Add logic to apply patches to python files in install script && Fix the segment fault that occurred after linking with the NCCL library. #1762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions backends/iluvatar_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ cmake_minimum_required(VERSION 3.10)
set(PROJ_NAME "paddle-iluvatar-gpu")
project(${PROJ_NAME} CXX C CUDA)

set(PLUGIN_VERSION "0.0.1")
set(TARGET_NAME ${PROJ_NAME})

set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake")
Expand All @@ -35,6 +34,7 @@ include(external/xxhash)
include(external/zlib)
include(external/protobuf)

set(PLUGIN_VERSION ${PADDLE_VERSION})
set(PROTO_FILE "${PADDLE_SOURCE_DIR}/paddle/phi/core/external_error.proto")
get_filename_component(PROTO_WE "${PROTO_FILE}" NAME_WE)

Expand Down Expand Up @@ -253,7 +253,6 @@ target_link_libraries(
protobuf
external_error_proto
cuinfer
# May cause a segment fault when the program exits
nccl)

include_directories(BEFORE ${PADDLE_SOURCE_DIR})
Expand Down
1 change: 0 additions & 1 deletion backends/iluvatar_gpu/build_paddle.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ fi
BUILD_TEST=${BUILD_TEST:-1}
COREX_ARCH=${COREX_ARCH:-ivcore11}
export CMAKE_CUDA_ARCHITECTURES=${COREX_ARCH}
export PADDLE_VERSION=${PADDLE_VERSION:-3.0.0}

CURRENT_DIR=$(pwd)
PADDLE_SOURCE_DIR="${CURRENT_DIR}/../../Paddle"
Expand Down
3 changes: 3 additions & 0 deletions backends/iluvatar_gpu/install_paddle.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ PYTHON_PATH=$(which python3)
PYTHON_DIST_PATH=${TARGET_DIR}/lib/python3/dist-packages

PKG_DIR="build_pip"
PKGCPU_NAME="paddlepaddle"
PKG_NAME="paddle_iluvatar_gpu"

if [[ ! -d ${PKG_DIR} ]]; then
Expand All @@ -43,6 +44,8 @@ if [[ "${TARGET_DIR}" != "" ]]; then
rm -rf ./tmp
echo "Paddle installed in ${PYTHON_DIST_PATH}; please add it to your PYTHONPATH."
else
${PYTHON_PATH} -m pip uninstall ${PKGCPU_NAME} -y
${PYTHON_PATH} -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
${PYTHON_PATH} -m pip uninstall ${PKG_NAME} -y
${PYTHON_PATH} -m pip install ${PKG_DIR}/${latest_pkg} || exit
fi
Expand Down
148 changes: 148 additions & 0 deletions backends/iluvatar_gpu/kernels/cuda_kernels/c_embedding_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "glog/logging.h"
#include "paddle/common/flags.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/c_embedding_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_grad.h"

COMMON_DECLARE_int64(embedding_deterministic);

namespace phi {

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}

template <typename T, typename IndexT>
__global__ void CEmbeddingGrad(T* table,
const T* output,
const IndexT* ids,
const int rows,
const int columns,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
phi::CudaAtomicAdd(&table[real_idx * columns + col], output[i]);
}
}
}

template <typename T, typename Context>
void CEmbeddingGradKernel(const Context& dev_ctx,
const DenseTensor& w,
const DenseTensor& ids,
const DenseTensor& out_grad,
int64_t start_index,
DenseTensor* w_grad) {
int N = w_grad->dims()[0];
int D = w_grad->dims()[1];
int K = ids.numel();

auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;

const T* d_output = out_grad.data<T>();
T* d_table = dev_ctx.template Alloc<T>(w_grad);

auto t = EigenVector<T>::Flatten(*w_grad);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

const auto& index_type = ids.dtype();
if (FLAGS_embedding_deterministic == 1) {
if (index_type == phi::DataType::INT32) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int32_t>(
dev_ctx,
ids.data<int32_t>(),
d_output,
d_table,
N,
D,
K,
start_index);
return;
} else if (index_type == phi::DataType::INT64) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int64_t>(
dev_ctx,
ids.data<int64_t>(),
d_output,
d_table,
N,
D,
K,
start_index);
return;
}
} else {
if (FLAGS_embedding_deterministic > 1) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
blocks = 1;
}
const int64_t end_idx = start_index + N;
if (index_type == phi::DataType::INT32) {
CEmbeddingGrad<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids.data<int32_t>(),
K,
D,
N,
start_index,
end_idx,
limit);
return;
} else if (index_type == phi::DataType::INT64) {
CEmbeddingGrad<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids.data<int64_t>(),
K,
D,
N,
start_index,
end_idx,
limit);
return;
}
}
PADDLE_THROW(common::errors::InvalidArgument(
"The data type of Input(Ids) must be int32 or int64."));
}

} // namespace phi

PD_REGISTER_PLUGIN_KERNEL(c_embedding_grad,
iluvatar_gpu,
ALL_LAYOUT,
phi::CEmbeddingGradKernel,
float,
phi::dtype::bfloat16,
phi::dtype::float16,
phi::dtype::complex<float>) {}

This file was deleted.

122 changes: 122 additions & 0 deletions backends/iluvatar_gpu/kernels/cuda_kernels/c_embedding_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/c_embedding_kernel.h"

namespace phi {

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
}

template <typename T, typename IndexT>
__global__ void CEmbedding(T* out,
const T* table,
const IndexT* ids,
const int rows,
const int columns,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit,
const int64_t vocab_size) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];

PADDLE_ENFORCE(
id >= 0 && (vocab_size < 0 || id < vocab_size),
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]",
vocab_size,
id);
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
out[i] = table[real_idx * columns + col];
} else {
out[i] = static_cast<T>(0);
}
}
}

template <typename T, typename Context>
void CEmbeddingKernel(const Context& dev_ctx,
const DenseTensor& w,
const DenseTensor& ids,
int64_t start_index,
int64_t vocab_size,
DenseTensor* out) {
size_t N = w.dims()[0];
size_t D = w.dims()[1];
size_t K = ids.numel();

const int64_t end_idx = start_index + N;

auto* table = w.data<T>();
auto* output = dev_ctx.template Alloc<T>(out);

auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;

const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
CEmbedding<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(output,
table,
ids.data<int32_t>(),
K,
D,
N,
start_index,
end_idx,
limit,
vocab_size);

} else if (index_type == phi::DataType::INT64) {
CEmbedding<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(output,
table,
ids.data<int64_t>(),
K,
D,
N,
start_index,
end_idx,
limit,
vocab_size);
} else {
PADDLE_THROW(common::errors::Unavailable(
"GPU c_embedding ids only support int32 or int64."));
}
}
} // namespace phi

PD_REGISTER_PLUGIN_KERNEL(c_embedding,
iluvatar_gpu,
ALL_LAYOUT,
phi::CEmbeddingKernel,
float,
phi::dtype::bfloat16,
phi::dtype::float16,
phi::dtype::complex<float>) {}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ void FlashAttnUnpaddedBaseKernel(
flashAttnInfo.softmax_scale = std::sqrt(1.f / head_size);
flashAttnInfo.dropout_prob = is_test ? 0.0f : dropout;
flashAttnInfo.is_causal = causal;
flashAttnInfo.causal_mode = 1;
// flashAttnInfo.is_alibi = use_alibi;
// flashAttnInfo.alibi_mode = alibi_mode;
flashAttnInfo.return_softmax_lse = true;
Expand Down
Loading