-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Add RotaryEmbeddings(23) - CUDA #25178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
25b2315
09d60d1
3dcb960
6f8d35f
b54e965
f806a93
f5031ea
ce1029f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/cuda/cuda_common.h" | ||
#include "core/providers/cpu/llm/rotary_embedding_helper.h" | ||
#include "core/providers/cuda/llm/rotary_embedding.h" | ||
#include "core/providers/cuda/llm/rotary_embedding_impl.h" | ||
|
||
using namespace onnxruntime::cuda; | ||
Check warning on line 9 in onnxruntime/core/providers/cuda/llm/rotary_embedding.cc
|
||
using namespace ::onnxruntime::common; | ||
Check warning on line 10 in onnxruntime/core/providers/cuda/llm/rotary_embedding.cc
|
||
using namespace ONNX_NAMESPACE; | ||
Check warning on line 11 in onnxruntime/core/providers/cuda/llm/rotary_embedding.cc
|
||
using namespace onnxruntime::rotary_embedding_helper; | ||
Check warning on line 12 in onnxruntime/core/providers/cuda/llm/rotary_embedding.cc
|
||
|
||
namespace onnxruntime { | ||
namespace cuda { | ||
|
||
#define REGISTER_KERNEL_TYPED(T) \ | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( \ | ||
RotaryEmbedding, \ | ||
kOnnxDomain, \ | ||
23, \ | ||
T, \ | ||
kCudaExecutionProvider, \ | ||
(*KernelDefBuilder::Create()) \ | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \ | ||
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), \ | ||
RotaryEmbedding<T>); | ||
|
||
REGISTER_KERNEL_TYPED(float) | ||
REGISTER_KERNEL_TYPED(MLFloat16) | ||
REGISTER_KERNEL_TYPED(BFloat16) | ||
|
||
template <typename T> | ||
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { | ||
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0)); | ||
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0)); | ||
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); | ||
} | ||
|
||
template <typename T> | ||
Status RotaryEmbedding<T>::ComputeInternal(OpKernelContext* context) const { | ||
const Tensor* input = context->Input<Tensor>(0); | ||
const Tensor* cos_cache = context->Input<Tensor>(1); | ||
const Tensor* sin_cache = context->Input<Tensor>(2); | ||
const Tensor* position_ids = context->Input<Tensor>(3); // Optional, can be nullptr | ||
|
||
RotaryParameters parameters = {}; | ||
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input, | ||
position_ids, | ||
cos_cache, | ||
sin_cache, | ||
num_heads, | ||
rotary_embedding_dim, | ||
¶meters)); | ||
|
||
Tensor* output = context->Output(0, input->Shape()); | ||
|
||
// Launch rotary embedding kernel | ||
typedef typename ToCudaType<T>::MappedType CudaT; | ||
auto& device_prop = GetDeviceProp(); | ||
|
||
// Handle optional position_ids - pass nullptr if position_ids is null | ||
const int64_t* position_ids_data = (position_ids != nullptr) ? position_ids->Data<int64_t>() : nullptr; | ||
|
||
return LaunchRotaryEmbeddingKernel<CudaT>( | ||
Stream(context), | ||
reinterpret_cast<CudaT*>(output->template MutableData<T>()), | ||
reinterpret_cast<const CudaT*>(input->template Data<T>()), | ||
position_ids_data, | ||
reinterpret_cast<const CudaT*>(cos_cache->template Data<T>()), | ||
reinterpret_cast<const CudaT*>(sin_cache->template Data<T>()), | ||
parameters.batch_size, | ||
parameters.sequence_length, | ||
parameters.num_heads, | ||
parameters.head_size, | ||
parameters.rotary_embedding_dim, | ||
parameters.max_sequence_length, | ||
parameters.position_ids_format, | ||
interleaved, | ||
device_prop.maxThreadsPerBlock, | ||
parameters.transposed); | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/common/common.h" | ||
#include "core/providers/cuda/cuda_kernel.h" | ||
|
||
namespace onnxruntime { | ||
namespace cuda { | ||
|
||
using namespace onnxruntime::cuda; | ||
Check warning on line 11 in onnxruntime/core/providers/cuda/llm/rotary_embedding.h
|
||
|
||
template <typename T> | ||
class RotaryEmbedding final : public CudaKernel { | ||
public: | ||
RotaryEmbedding(const OpKernelInfo& info); | ||
Status ComputeInternal(OpKernelContext* context) const override; | ||
|
||
protected: | ||
int num_heads; | ||
int rotary_embedding_dim; | ||
int interleaved; | ||
}; | ||
|
||
} // namespace cuda | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
/* | ||
Copyright (c) Microsoft Corporation. | ||
Licensed under the MIT License. | ||
*/ | ||
|
||
/* | ||
Kernel implementation for rotary embeddings. | ||
*/ | ||
|
||
#include "core/providers/cuda/llm/rotary_embedding_impl.h" | ||
#include "core/providers/cuda/cu_inc/common.cuh" | ||
#include <cuda_fp16.h> | ||
Check warning on line 12 in onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu
|
||
|
||
using namespace onnxruntime::cuda; | ||
Check warning on line 14 in onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu
|
||
|
||
namespace onnxruntime { | ||
namespace cuda { | ||
|
||
template <typename T> | ||
__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH | ||
const T* input, // BxSxNxH | ||
const T* cos_cache, // BxSx(H/2) or Mx(H/2) | ||
const T* sin_cache, // BxSx(H/2) or Mx(H/2) | ||
const int64_t* position_ids, // (0) or BxS | ||
const int sequence_length, const int num_heads, const int head_size, | ||
const int rotary_embedding_dim, const int position_ids_format, | ||
const bool interleaved, | ||
int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous | ||
) { | ||
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length | ||
// Use .x in innermost loop to access global memory efficiently | ||
|
||
const int b = blockIdx.y; | ||
const int s = blockIdx.x; | ||
const int n = blockIdx.z; | ||
|
||
const int i = threadIdx.x; | ||
|
||
if (i >= head_size) { | ||
return; | ||
} | ||
|
||
const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; | ||
T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; | ||
|
||
if (i >= rotary_embedding_dim) { | ||
output_data[i] = input_data[i]; | ||
return; | ||
} | ||
|
||
// Cache is (M, H/2) | ||
const int half_rotary_embedding_dim = rotary_embedding_dim / 2; | ||
int cache_offset; | ||
|
||
// position_ids_format == 0 means position_ids is nullptr | ||
// position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length) | ||
if (position_ids_format == 0) { | ||
cache_offset = (b * sequence_length + s) * half_rotary_embedding_dim; | ||
} else { | ||
// Cache is (M, H/2) or (M, rotary_embedding_dim/2) | ||
const int position_id = static_cast<int>(position_ids[b * sequence_length + s]); | ||
cache_offset = position_id * half_rotary_embedding_dim; | ||
} | ||
const T* cos_data = cos_cache + cache_offset; | ||
const T* sin_data = sin_cache + cache_offset; | ||
|
||
int cache_idx = 0; | ||
T sign = 0; | ||
int j = 0; | ||
if (interleaved) { | ||
cache_idx = (i / 2) % half_rotary_embedding_dim; | ||
sign = (i % 2 == 0) ? -1 : 1; | ||
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign | ||
} else { | ||
cache_idx = i % half_rotary_embedding_dim; | ||
sign = (i < half_rotary_embedding_dim) ? -1 : 1; | ||
j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; | ||
} | ||
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; | ||
} | ||
|
||
template <typename T> | ||
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, | ||
const T* cos_cache, const T* sin_cache, const int batch_size, | ||
const int sequence_length, const int num_heads, const int head_size, | ||
const int rotary_embedding_dim, const int max_sequence_length, | ||
const int position_ids_format, const bool interleaved, | ||
const int max_threads_per_block, const bool is_input_bnsh_format) { | ||
int4 in_strides; | ||
int4 out_strides; | ||
if (is_input_bnsh_format) { | ||
int in_head_stride = sequence_length * head_size; | ||
int out_head_stride = sequence_length * head_size; | ||
in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1}; | ||
out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1}; | ||
} else { | ||
int in_head_stride = head_size; | ||
int out_head_stride = head_size; | ||
in_strides = int4{sequence_length * num_heads * in_head_stride, in_head_stride, num_heads * in_head_stride, 1}; | ||
out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1}; | ||
} | ||
return LaunchRotaryEmbeddingKernel<T>( | ||
stream, output, input, position_ids, | ||
cos_cache, sin_cache, batch_size, | ||
sequence_length, num_heads, head_size, | ||
rotary_embedding_dim, max_sequence_length, | ||
position_ids_format, interleaved, | ||
max_threads_per_block, | ||
in_strides, out_strides); | ||
} | ||
|
||
template <typename T> | ||
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, | ||
const T* cos_cache, const T* sin_cache, const int batch_size, | ||
const int sequence_length, const int num_heads, const int head_size, | ||
const int rotary_embedding_dim, const int /*max_sequence_length*/, | ||
const int position_ids_format, const bool interleaved, | ||
const int max_threads_per_block, | ||
int4 in_strides, int4 out_strides // strides in bnsh coord | ||
) { | ||
// Note: Current implementation assumes head_size <= max_threads_per_block | ||
// because head_size is currently large for LLaMA-2. For smaller head_size | ||
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)` | ||
// instead. This will require kernel changes to support. | ||
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); | ||
// strides in canonical bnsh coord, h is always contiguous (dim_stride == 1) | ||
ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); | ||
|
||
int tpb = (head_size + 31) / 32 * 32; | ||
|
||
const dim3 block(tpb); | ||
const dim3 grid(sequence_length, batch_size, num_heads); | ||
|
||
assert(head_size <= max_threads_per_block); | ||
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, | ||
num_heads, head_size, rotary_embedding_dim, position_ids_format, | ||
interleaved, in_strides, out_strides); | ||
return CUDA_CALL(cudaGetLastError()); | ||
} | ||
|
||
template Status LaunchRotaryEmbeddingKernel<float>(cudaStream_t stream, float* output, const float* input, | ||
const int64_t* position_ids, const float* cos_cache, | ||
const float* sin_cache, const int batch_size, | ||
const int sequence_length, const int num_heads, const int head_size, | ||
const int rotary_embedding_dim, const int max_sequence_length, | ||
const int position_ids_format, const bool interleaved, | ||
const int max_threads_per_block, const bool is_input_bnsh_format); | ||
|
||
template Status LaunchRotaryEmbeddingKernel<half>(cudaStream_t stream, half* output, const half* input, | ||
const int64_t* position_ids, const half* cos_cache, | ||
const half* sin_cache, const int batch_size, | ||
const int sequence_length, const int num_heads, const int head_size, | ||
const int rotary_embedding_dim, const int max_sequence_length, | ||
const int position_ids_format, const bool interleaved, | ||
const int max_threads_per_block, const bool is_input_bnsh_format); | ||
|
||
template Status LaunchRotaryEmbeddingKernel<BFloat16>( | ||
cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids, | ||
const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length, | ||
const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length, | ||
const int position_ids_format, const bool interleaved, const int max_threads_per_block, | ||
const bool is_input_bnsh_format); | ||
|
||
} // namespace cuda | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/common/common.h" | ||
#include "core/providers/cuda/shared_inc/cuda_utils.h" | ||
|
||
namespace onnxruntime { | ||
namespace cuda { | ||
|
||
template <typename T> | ||
Status LaunchRotaryEmbeddingKernel( | ||
cudaStream_t stream, | ||
T* output, | ||
const T* input, | ||
const int64_t* position_ids, | ||
const T* cos_cache, | ||
const T* sin_cache, | ||
const int batch_size, | ||
const int sequence_length, | ||
const int num_heads, | ||
const int head_size, | ||
const int rotary_embedding_dim, | ||
const int max_sequence_length, | ||
const int position_ids_format, | ||
const bool interleaved, | ||
const int max_threads_per_block, | ||
const bool is_input_bnsh_format); | ||
|
||
template <typename T> | ||
Status LaunchRotaryEmbeddingKernel( | ||
cudaStream_t stream, | ||
T* output, | ||
const T* input, | ||
const int64_t* position_ids, | ||
const T* cos_cache, | ||
const T* sin_cache, | ||
const int batch_size, | ||
const int sequence_length, | ||
const int num_heads, | ||
const int head_size, | ||
const int rotary_embedding_dim, | ||
const int max_sequence_length, | ||
const int position_ids_format, | ||
const bool interleaved, | ||
const int max_threads_per_block, | ||
int4 in_strides, | ||
int4 out_strides); | ||
|
||
} // namespace cuda | ||
} // namespace onnxruntime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid copying this and re-use the same existing kernel
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to separate them because
position_ids
is optional to standard ONNX op: https://github.com/onnx/onnx/blob/main/docs/Operators.md#RotaryEmbedding.onnxruntime/onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu
Lines 51 to 65 in ce1029f
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my education, is the benefit of having fewer cuda kernels that it would be more efficient to onnxruntime? Or it's for cleaner coding? If it's latter, I think coupling contrib op and onnx op piles up complexity because the inputs (shape/value) of the function are different. I also give
position_ids_format
different definition where 0 means nullptr, and 1 means 2D now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we update how
position_ids_format
is defined below to have an option for whenposition_ids = nullptr
?onnxruntime/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
Lines 13 to 26 in 47ddaaa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have gone another direction to create the new CheckInputs though:
onnxruntime/onnxruntime/core/providers/cpu/llm/rotary_embedding_helper.h
Line 23 in 47ddaaa
These are the lines we need to add into helper if we want to merge them.
onnxruntime/onnxruntime/core/providers/cpu/llm/rotary_embedding_helper.h
Lines 82 to 115 in 47ddaaa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Size 1 is banned in onnx rotary embedding to respect the spec:
onnxruntime/onnxruntime/core/providers/cpu/llm/rotary_embedding_helper.h
Lines 134 to 137 in 47ddaaa