Skip to content

Add RotaryEmbeddings(23) - CUDA #25178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ Do not modify directly.*
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
|RotaryEmbedding|*in* X:**T**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**M**<br> *out* Y:**T**|23+|**M** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(float), tensor(float16)|
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Scan|*in* initial_state_and_scan_inputs:**V**<br> *out* final_state_and_scan_outputs:**V**<br><br>or<br><br>*in* sequence_lens:**I**<br> *in* initial_state_and_scan_inputs:**V**<br> *out* final_state_and_scan_outputs:**V**|19+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16_BFloat16, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_MLFloat16, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_float, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding);

#endif

Expand Down Expand Up @@ -2449,6 +2452,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16_BFloat16, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_MLFloat16, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_float, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding)>,
#endif
};

Expand Down
85 changes: 85 additions & 0 deletions onnxruntime/core/providers/cuda/llm/rotary_embedding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cpu/llm/rotary_embedding_helper.h"
#include "core/providers/cuda/llm/rotary_embedding.h"
#include "core/providers/cuda/llm/rotary_embedding_impl.h"

using namespace onnxruntime::cuda;

Check warning on line 9 in onnxruntime/core/providers/cuda/llm/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/cuda/llm/rotary_embedding.cc:9: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using namespace ::onnxruntime::common;

Check warning on line 10 in onnxruntime/core/providers/cuda/llm/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/cuda/llm/rotary_embedding.cc:10: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using namespace ONNX_NAMESPACE;

Check warning on line 11 in onnxruntime/core/providers/cuda/llm/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/cuda/llm/rotary_embedding.cc:11: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using namespace onnxruntime::rotary_embedding_helper;

Check warning on line 12 in onnxruntime/core/providers/cuda/llm/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/cuda/llm/rotary_embedding.cc:12: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RotaryEmbedding, \
kOnnxDomain, \
23, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), \
RotaryEmbedding<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

template <typename T>
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
}

template <typename T>
Status RotaryEmbedding<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* cos_cache = context->Input<Tensor>(1);
const Tensor* sin_cache = context->Input<Tensor>(2);
const Tensor* position_ids = context->Input<Tensor>(3); // Optional, can be nullptr

RotaryParameters parameters = {};
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input,
position_ids,
cos_cache,
sin_cache,
num_heads,
rotary_embedding_dim,
&parameters));

Tensor* output = context->Output(0, input->Shape());

// Launch rotary embedding kernel
typedef typename ToCudaType<T>::MappedType CudaT;
auto& device_prop = GetDeviceProp();

// Handle optional position_ids - pass nullptr if position_ids is null
const int64_t* position_ids_data = (position_ids != nullptr) ? position_ids->Data<int64_t>() : nullptr;

return LaunchRotaryEmbeddingKernel<CudaT>(
Stream(context),
reinterpret_cast<CudaT*>(output->template MutableData<T>()),
reinterpret_cast<const CudaT*>(input->template Data<T>()),
position_ids_data,
reinterpret_cast<const CudaT*>(cos_cache->template Data<T>()),
reinterpret_cast<const CudaT*>(sin_cache->template Data<T>()),
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.head_size,
parameters.rotary_embedding_dim,
parameters.max_sequence_length,
parameters.position_ids_format,
interleaved,
device_prop.maxThreadsPerBlock,
parameters.transposed);
}

} // namespace cuda
} // namespace onnxruntime
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/cuda/llm/rotary_embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"

namespace onnxruntime {
namespace cuda {

using namespace onnxruntime::cuda;

Check warning on line 11 in onnxruntime/core/providers/cuda/llm/rotary_embedding.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/cuda/llm/rotary_embedding.h:11: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

template <typename T>
class RotaryEmbedding final : public CudaKernel {
public:
RotaryEmbedding(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;

protected:
int num_heads;
int rotary_embedding_dim;
int interleaved;
};

} // namespace cuda
} // namespace onnxruntime
165 changes: 165 additions & 0 deletions onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/

/*
Kernel implementation for rotary embeddings.
*/

#include "core/providers/cuda/llm/rotary_embedding_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include <cuda_fp16.h>

Check warning on line 12 in onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: rotary_embedding_impl.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu:12: Found C system header after other header. Should be: rotary_embedding_impl.h, c system, c++ system, other. [build/include_order] [4]

using namespace onnxruntime::cuda;

Check warning on line 14 in onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu:14: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace cuda {

template <typename T>
__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid copying this and re-use the same existing kernel

Copy link
Contributor Author

@titaiwangms titaiwangms Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to separate them because position_ids is optional to standard ONNX op: https://github.com/onnx/onnx/blob/main/docs/Operators.md#RotaryEmbedding.

// Cache is (M, H/2)
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
int cache_offset;
// position_ids_format == 0 means position_ids is nullptr
// position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length)
if (position_ids_format == 0) {
cache_offset = (b * sequence_length + s) * half_rotary_embedding_dim;
} else {
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = static_cast<int>(position_ids[b * sequence_length + s]);
cache_offset = position_id * half_rotary_embedding_dim;
}
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;

Copy link
Contributor Author

@titaiwangms titaiwangms Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my education, is the benefit of having fewer cuda kernels that it would be more efficient to onnxruntime? Or it's for cleaner coding? If it's latter, I think coupling contrib op and onnx op piles up complexity because the inputs (shape/value) of the function are different. I also give position_ids_format different definition where 0 means nullptr, and 1 means 2D now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update how position_ids_format is defined below to have an option for when position_ids = nullptr?

struct RotaryParameters {
int batch_size; // Batch size used by input
int sequence_length; // Sequence length used by input
int hidden_size; // Hidden size used by input
int head_size; // Head size
int rotary_embedding_dim; // Rotary embedding dimension.
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int head_stride; // Head stride
int seq_stride; // Sequence stride
int batch_stride; // Batch stride
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have gone another direction to create the new CheckInputs though:

int position_ids_format; // Format of position ids - 0 is (0), 1 is (batch_size, sequence_length)

These are the lines we need to add into helper if we want to merge them.

if (nullptr == position_ids) {
// Check cos_cache and sin_cache
const auto& cos_cache_dims = cos_cache->Shape().GetDims();
if (cos_cache_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 3 dimensions, got ",
cos_cache_dims.size());
}
const auto& sin_cache_dims = sin_cache->Shape().GetDims();
if (sin_cache_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 3 dimensions, got ",
sin_cache_dims.size());
}
if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1] || cos_cache_dims[2] != sin_cache_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ",
"the same shape");
}
// Make sure cos_cache and sin_cache have the same batch size and sequence length as input x
// when position_ids is not provided.
if (cos_cache_dims[0] != batch_size || cos_cache_dims[1] != sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ",
"the same shape as input 'x', got ", cos_cache_dims[0], " and ", cos_cache_dims[1]);
}
max_sequence_length = static_cast<int>(cos_cache_dims[1]);
if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
"head_size");
}
// Check cos_cache input shapes
if (cos_cache_dims[2] != (rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size) / 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 2 should be same as ",
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[2]);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Size 1 is banned in onnx rotary embedding to respect the spec:

if (position_ids_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 2 ",
"dimensions, got ", position_ids_dims.size());
}

const T* input, // BxSxNxH
const T* cos_cache, // BxSx(H/2) or Mx(H/2)
const T* sin_cache, // BxSx(H/2) or Mx(H/2)
const int64_t* position_ids, // (0) or BxS
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int position_ids_format,
const bool interleaved,
int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous
) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently

const int b = blockIdx.y;
const int s = blockIdx.x;
const int n = blockIdx.z;

const int i = threadIdx.x;

if (i >= head_size) {
return;
}

const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y;
T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y;

if (i >= rotary_embedding_dim) {
output_data[i] = input_data[i];
return;
}

// Cache is (M, H/2)
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
int cache_offset;

// position_ids_format == 0 means position_ids is nullptr
// position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length)
if (position_ids_format == 0) {
cache_offset = (b * sequence_length + s) * half_rotary_embedding_dim;
} else {
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = static_cast<int>(position_ids[b * sequence_length + s]);
cache_offset = position_id * half_rotary_embedding_dim;
}
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;

int cache_idx = 0;
T sign = 0;
int j = 0;
if (interleaved) {
cache_idx = (i / 2) % half_rotary_embedding_dim;
sign = (i % 2 == 0) ? -1 : 1;
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_rotary_embedding_dim;
sign = (i < half_rotary_embedding_dim) ? -1 : 1;
j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}

template <typename T>
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block, const bool is_input_bnsh_format) {
int4 in_strides;
int4 out_strides;
if (is_input_bnsh_format) {
int in_head_stride = sequence_length * head_size;
int out_head_stride = sequence_length * head_size;
in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1};
out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1};
} else {
int in_head_stride = head_size;
int out_head_stride = head_size;
in_strides = int4{sequence_length * num_heads * in_head_stride, in_head_stride, num_heads * in_head_stride, 1};
out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1};
}
return LaunchRotaryEmbeddingKernel<T>(
stream, output, input, position_ids,
cos_cache, sin_cache, batch_size,
sequence_length, num_heads, head_size,
rotary_embedding_dim, max_sequence_length,
position_ids_format, interleaved,
max_threads_per_block,
in_strides, out_strides);
}

template <typename T>
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int /*max_sequence_length*/,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block,
int4 in_strides, int4 out_strides // strides in bnsh coord
) {
// Note: Current implementation assumes head_size <= max_threads_per_block
// because head_size is currently large for LLaMA-2. For smaller head_size
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
// strides in canonical bnsh coord, h is always contiguous (dim_stride == 1)
ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous");

int tpb = (head_size + 31) / 32 * 32;

const dim3 block(tpb);
const dim3 grid(sequence_length, batch_size, num_heads);

assert(head_size <= max_threads_per_block);
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
num_heads, head_size, rotary_embedding_dim, position_ids_format,
interleaved, in_strides, out_strides);
return CUDA_CALL(cudaGetLastError());
}

template Status LaunchRotaryEmbeddingKernel<float>(cudaStream_t stream, float* output, const float* input,
const int64_t* position_ids, const float* cos_cache,
const float* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block, const bool is_input_bnsh_format);

template Status LaunchRotaryEmbeddingKernel<half>(cudaStream_t stream, half* output, const half* input,
const int64_t* position_ids, const half* cos_cache,
const half* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block, const bool is_input_bnsh_format);

template Status LaunchRotaryEmbeddingKernel<BFloat16>(
cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids,
const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length,
const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved, const int max_threads_per_block,
const bool is_input_bnsh_format);

} // namespace cuda
} // namespace onnxruntime
51 changes: 51 additions & 0 deletions onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
Status LaunchRotaryEmbeddingKernel(
cudaStream_t stream,
T* output,
const T* input,
const int64_t* position_ids,
const T* cos_cache,
const T* sin_cache,
const int batch_size,
const int sequence_length,
const int num_heads,
const int head_size,
const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
const bool is_input_bnsh_format);

template <typename T>
Status LaunchRotaryEmbeddingKernel(
cudaStream_t stream,
T* output,
const T* input,
const int64_t* position_ids,
const T* cos_cache,
const T* sin_cache,
const int batch_size,
const int sequence_length,
const int num_heads,
const int head_size,
const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
int4 in_strides,
int4 out_strides);

} // namespace cuda
} // namespace onnxruntime
Loading