Skip to content

Commit

Permalink
Op2dTensorGeneric kernel upgrade (#3305)
Browse files Browse the repository at this point in the history
  • Loading branch information
novakovicdj authored Oct 14, 2024
1 parent 16773ac commit 31ba13d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 610 deletions.
80 changes: 0 additions & 80 deletions src/kernels/MIOpenTensorKernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -1105,86 +1105,6 @@ __kernel void Op2dTensorSquash(const global MIOPEN_TYPE* a,
}
#endif

#ifdef USE_2D_TENSOR_GENERIC
// NC
__kernel void Op2dTensorGeneric(global MIOPEN_TYPE* a,
const int a_nstride,
global MIOPEN_TYPE* b,
const int b_c,
const int b_nstride,
global MIOPEN_TYPE* c,
const int c_c,
const int c_nstride,
const MIOPEN_TYPE alpha0,
const MIOPEN_TYPE alpha1,
const MIOPEN_TYPE beta,
const unsigned int bitmap,
const int work_per_wg,
const long Aoffset,
const long Boffset,
const long Coffset,
const int num_wg)
{
int gid = get_group_id(0);

global MIOPEN_TYPE* a_off = a + Aoffset;
global MIOPEN_TYPE* b_off = b + Boffset;
global MIOPEN_TYPE* c_off = c + Coffset;

int o_n_div = bitmap & (1 << 0) ? 1 : c_c;

// num_wg: the number of workgroups should be launched
// MAX_NUM_WG: the maximum number of workgroups actually launched
if(beta == (MIOPEN_TYPE)0)
{
for(; gid < num_wg; gid += MAX_NUM_WG)
{

int lid = get_local_id(0);
int o_c_gid_off = gid % b_c;
int o_n_gid_off = gid / b_c;

int bindex = o_n_gid_off * b_nstride + o_c_gid_off;
MIOPEN_TYPE operand = b_off[bindex] * alpha1;

while(lid < work_per_wg)
{
int o_c = (bitmap & (1 << 0)) ? o_c_gid_off : lid % c_c;
int o_n = (bitmap & (1 << 1)) ? o_n_gid_off : lid / o_n_div;
int aindex = o_n * a_nstride + o_c;
int cindex = o_n * c_nstride + o_c;
c_off[cindex] = MIOPEN_TENSOR_OP(a_off[aindex] * alpha0, operand);
lid += get_local_size(0);
}
}
}
else
{
for(; gid < num_wg; gid += MAX_NUM_WG)
{
int lid = get_local_id(0);
int o_c_gid_off = gid % b_c;
int o_n_gid_off = gid / b_c;

int bindex = o_n_gid_off * b_nstride + o_c_gid_off;
MIOPEN_TYPE operand = b_off[bindex] * alpha1;

while(lid < work_per_wg)
{
int o_c = (bitmap & (1 << 0)) ? o_c_gid_off : lid % c_c;
int o_n = (bitmap & (1 << 1)) ? o_n_gid_off : lid / o_n_div;
int aindex = o_n * a_nstride + o_c;
int cindex = o_n * c_nstride + o_c;
c_off[cindex] =
MIOPEN_TENSOR_OP(a_off[aindex] * alpha0, operand) + beta * c_off[cindex];
lid += get_local_size(0);
}
}
}
}

#endif

#ifdef USE_4D_TENSOR_LITE
// N - batch size
// C - # of maps
Expand Down
109 changes: 41 additions & 68 deletions src/kernels/MIOpenTensorKernelsHip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
* SOFTWARE.
*
*******************************************************************************/

#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
Expand Down Expand Up @@ -101,82 +100,56 @@ extern "C" __global__ void Op1dTensorGeneric(const MIOPEN_TYPE* a,

#ifdef USE_2D_TENSOR_GENERIC
// NC
extern "C" __global__ void Op2dTensorGeneric(MIOPEN_TYPE* a,
const int a_nstride,
MIOPEN_TYPE* b,
const int b_c,
const int b_nstride,
extern "C" __global__ void Op2dTensorGeneric(const MIOPEN_TYPE* a,
const MIOPEN_TYPE* b,
MIOPEN_TYPE* c,
const int c_c,
const int c_nstride,
const MIOPEN_TYPE alpha0,
const MIOPEN_TYPE alpha1,
const MIOPEN_TYPE beta,
const unsigned int bitmap,
const int work_per_wg,
const long Aoffset,
const long Boffset,
const long Coffset,
const int num_wg)
const uint32_t b_c,
const uint32_t c_c,
const uint32_t a_nstride,
const uint32_t a_cstride,
const uint32_t b_nstride,
const uint32_t b_cstride,
const uint32_t c_nstride,
const uint32_t c_cstride,
const MIOPEN_TYPE alpha0,
const MIOPEN_TYPE alpha1,
const MIOPEN_TYPE beta,
const uint32_t total_work,
const bool use_beta)
{
int gid = blockIdx.x;
const MIOPEN_TYPE* a_off = a + Aoffset;
const MIOPEN_TYPE* b_off = b + Boffset;
MIOPEN_TYPE* c_off = c + Coffset;

MIOPEN_TYPE* a_off = a + Aoffset;
MIOPEN_TYPE* b_off = b + Boffset;
MIOPEN_TYPE* c_off = c + Coffset;
auto gid = blockIdx.x * blockDim.x + threadIdx.x;
const auto* a_ptr = a_off + (gid / c_c) * a_nstride + (gid % c_c) * a_cstride;
auto* c_ptr = c_off + (gid / c_c) * c_nstride + (gid % c_c) * c_cstride;

int o_n_div = (bitmap & (1 << 0)) ? 1 : c_c;
const auto step = gridDim.x * blockDim.x;
const auto a_step = (step / c_c) * a_nstride + (step % c_c) * a_cstride;
const auto c_step = (step / c_c) * c_nstride + (step % c_c) * c_cstride;

// num_wg: the number of workgroups should be launched
// MAX_NUM_WG: the maximum number of workgroups actually launched
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
if(beta == static_cast<MIOPEN_TYPE>(0))
#pragma clang diagnostic pop
{
for(; gid < num_wg; gid += MAX_NUM_WG)
{

int lid = threadIdx.x;
int o_c_gid_off = gid % b_c;
int o_n_gid_off = gid / b_c;

int bindex = o_n_gid_off * b_nstride + o_c_gid_off;
MIOPEN_TYPE operand = b_off[bindex] * alpha1;

while(lid < work_per_wg)
{
int o_c = (bitmap & (1 << 0)) ? o_c_gid_off : lid % c_c;
int o_n = (bitmap & (1 << 1)) ? o_n_gid_off : lid / o_n_div;
int aindex = o_n * a_nstride + o_c;
int cindex = o_n * c_nstride + o_c;
c_off[cindex] = MIOPEN_TENSOR_OP(a_off[aindex] * alpha0, operand);
lid += blockDim.x;
}
}
}
else
const auto c_end = c_off + total_work * c_nstride;
while(c_ptr < c_end)
{
for(; gid < num_wg; gid += MAX_NUM_WG)
{
int lid = threadIdx.x;
int o_c_gid_off = gid % b_c;
int o_n_gid_off = gid / b_c;

int bindex = o_n_gid_off * b_nstride + o_c_gid_off;
MIOPEN_TYPE operand = b_off[bindex] * alpha1;

while(lid < work_per_wg)
{
int o_c = (bitmap & (1 << 0)) ? o_c_gid_off : lid % c_c;
int o_n = (bitmap & (1 << 1)) ? o_n_gid_off : lid / o_n_div;
int aindex = o_n * a_nstride + o_c;
int cindex = o_n * c_nstride + o_c;
c_off[cindex] =
MIOPEN_TENSOR_OP(a_off[aindex] * alpha0, operand) + beta * c_off[cindex];
lid += blockDim.x;
}
}
const auto* b_ptr = b_off;
if(b_nstride != 0)
b_ptr += (gid / b_c) * b_nstride;

if(b_cstride != 0)
b_ptr += (gid % b_c) * b_cstride;

auto b_val = *b_ptr;
auto a_val = *a_ptr;
auto c_val = use_beta ? *c_ptr : static_cast<MIOPEN_TYPE>(0);
*c_ptr = MIOPEN_TENSOR_OP(b_val * alpha1, a_val * alpha0) + c_val * beta;

a_ptr += a_step;
c_ptr += c_step;
gid += step;
}
}

Expand Down
66 changes: 38 additions & 28 deletions src/ocl/tensorocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ void OpTensorOther(const Handle& handle,
const bool case_2d = bsize == 2;
const bool case_5d = bsize == 5;

const bool use_hip = case_1d;
const bool use_hip = case_1d || case_2d;

// first_not_one is incorrect if btensor size equal to 1
auto first_not_one = std::find_if(blens.rbegin(), blens.rend(), [](int i) { return i != 1; });
Expand Down Expand Up @@ -1005,22 +1005,28 @@ void OpTensorOther(const Handle& handle,

int num_wg_orig = num_wg;
int max_num_wg = 4096;
num_wg = num_wg > max_num_wg ? max_num_wg : num_wg;

size_t local_threads = 256;

std::string program_name = use_hip ? "MIOpenTensorKernelsHip.cpp" : "MIOpenTensorKernels.cl";
if(case_2d)
local_threads = 32;

if(case_1d)
num_wg = std::clamp(clens[0] / local_threads, size_t(1), size_t(max_num_wg));
if(case_2d)
num_wg = std::clamp((clens[0] * clens[1]) / local_threads, size_t(1), size_t(max_num_wg));
num_wg = num_wg > max_num_wg ? max_num_wg : num_wg;

const std::vector<size_t> vld{local_threads, 1, 1};

// Special case for adding tensors in place
size_t global_threads;
global_threads =
(case_1d ? std::clamp(clens[0] / local_threads, size_t(1), size_t(max_num_wg)) : num_wg) *
local_threads;
global_threads = num_wg * local_threads;

const std::vector<size_t> vgd{global_threads, 1, 1};

std::string program_name = use_hip ? "MIOpenTensorKernelsHip.cpp" : "MIOpenTensorKernels.cl";

std::string network_config{};
network_config += std::to_string(bTensorDesc.GetType()) + "-" +
std::to_string(aTensorDesc.GetType()) + "-" + std::to_string(tensorOp) + "-" +
Expand Down Expand Up @@ -1081,22 +1087,24 @@ void OpTensorOther(const Handle& handle,
{
auto kernel = kernels.front();
kernel(ATensor,
static_cast<int>(astrides[0]),
BTensor,
static_cast<int>(blens[1]),
static_cast<int>(bstrides[0]),
CTensor,
static_cast<int>(clens[1]),
static_cast<int>(cstrides[0]),
static_cast<long>(Aoffset),
static_cast<long>(Boffset),
static_cast<long>(Coffset),
static_cast<uint32_t>(blens[1] == 1 ? clens[1] : blens[1]),
static_cast<uint32_t>(clens[1]),
static_cast<uint32_t>(astrides[0]),
static_cast<uint32_t>(astrides[1]),
static_cast<uint32_t>(blens[0] == 1 ? 0 : bstrides[0]),
static_cast<uint32_t>(blens[1] == 1 ? 0 : bstrides[1]),
static_cast<uint32_t>(cstrides[0]),
static_cast<uint32_t>(cstrides[1]),
miopen_alpha0,
miopen_alpha1,
miopen_beta,
bitmap,
work_per_wg,
static_cast<int64_t>(Aoffset),
static_cast<int64_t>(Boffset),
static_cast<int64_t>(Coffset),
static_cast<int>(num_wg_orig));
static_cast<uint32_t>(clens[0]),
!float_equal(miopen_beta, 0.0));
return;
}
}
Expand Down Expand Up @@ -1194,22 +1202,24 @@ void OpTensorOther(const Handle& handle,
vld,
vgd,
parms)(ATensor,
static_cast<int>(astrides[0]),
BTensor,
static_cast<int>(blens[1]),
static_cast<int>(bstrides[0]),
CTensor,
static_cast<int>(clens[1]),
static_cast<int>(cstrides[0]),
static_cast<long>(Aoffset),
static_cast<long>(Boffset),
static_cast<long>(Coffset),
static_cast<uint32_t>(blens[1] == 1 ? clens[1] : blens[1]),
static_cast<uint32_t>(clens[1]),
static_cast<uint32_t>(astrides[0]),
static_cast<uint32_t>(astrides[1]),
static_cast<uint32_t>(blens[0] == 1 ? 0 : bstrides[0]),
static_cast<uint32_t>(blens[1] == 1 ? 0 : bstrides[1]),
static_cast<uint32_t>(cstrides[0]),
static_cast<uint32_t>(cstrides[1]),
miopen_alpha0,
miopen_alpha1,
miopen_beta,
bitmap,
work_per_wg,
static_cast<int64_t>(Aoffset),
static_cast<int64_t>(Boffset),
static_cast<int64_t>(Coffset),
static_cast<int>(num_wg_orig));
static_cast<uint32_t>(clens[0]),
!float_equal(miopen_beta, 0.0));
}
else if(case_1d)
{
Expand Down
Loading

0 comments on commit 31ba13d

Please sign in to comment.