Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BN] Enable NHWC in OCL #3399

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
13 changes: 6 additions & 7 deletions src/ocl/batchnormocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ void BatchNormForwardTraining(Handle& handle,
}();

const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdTrainingSpatialSingle,
// solver::batchnorm::BnCKFwdTraining,
solver::batchnorm::BnFwdTrainingSpatialMultiple,
solver::batchnorm::BnFwdTrainingPerActivation>{};
solver::batchnorm::BnFwdTrainingPerActivation,
solver::batchnorm::BnCKFwdTraining>{};

solvers.ExecutePrimitive(handle, problem, algo, invoke_params);

Expand Down Expand Up @@ -250,9 +250,8 @@ void BatchNormForwardInference(Handle& handle,
}();

const auto algo = AlgorithmName{"miopenBatchNormalizationForwardInference"};
const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdInference
// solver::batchnorm::BnCKFwdInference
>{};
const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdInference,
solver::batchnorm::BnCKFwdInference>{};

solvers.ExecutePrimitive(handle, problem, algo, invoke_params);
}
Expand Down Expand Up @@ -395,9 +394,9 @@ void BatchNormBackward(Handle& handle,
}();

const auto solvers = solver::SolverContainer<solver::batchnorm::BnBwdTrainingSpatialSingle,
// solver::batchnorm::BnCKBwdBackward,
solver::batchnorm::BnBwdTrainingSpatialMultiple,
solver::batchnorm::BnBwdTrainingPerActivation>{};
solver::batchnorm::BnBwdTrainingPerActivation,
solver::batchnorm::BnCKBwdBackward>{};

solvers.ExecutePrimitive(handle, problem, algo, invoke_params);

Expand Down
25 changes: 25 additions & 0 deletions src/solver/batchnorm/backward_spatial_multiple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,34 @@ namespace solver {

namespace batchnorm {

bool BNBwdIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem)
{
int n, c, h, w;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int n, c, h, w;
size_t n, c, h, w;

std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths());

unsigned int in_cstride = h * w;
Copy link
Contributor

@CAHEK7 CAHEK7 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But actually you can avoid all the prevois int-related comment by the following code:

Suggested change
int n, c, h, w;
std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths());
unsigned int in_cstride = h * w;
auto [n, c, h, w] = tien<4>(problem.GetXDesc().GetLengths());
auto in_cstride = problem.GetXDesc().GetStrides()[1];

unsigned int in_nhw = n * in_cstride;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unsigned int in_cstride = h * w;
unsigned int in_nhw = n * in_cstride;
size_t in_cstride = h * w;
size_t in_nhw = n * in_cstride;


if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) &&
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not that complex as the condition from src/solver/batchnorm/forward_spatial_multiple.cpp, but still can be simplified, since it contains redundant statements.

Copy link
Contributor

@CAHEK7 CAHEK7 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example of the transformations:

X == in_nhw < (32 * 1024 * 1024)
Y == in_cstride > 1024
Z == in_cstride > 512

!(X & Y) & !(X & Z) & !(!Z)
(!X | !Y) & (!X | !Z) & Z
(!X | !Y) & (!X & Z | !Z & Z) // !Z & Z -> false
(!X | !Y) & !X & Z
!X & Z | !Y & !X & Z
!X & Z & (1 | !Y) // (1 | !Y) -> true
!X & Z

So it basically means return (in_nhw >= (32 * 1024 * 1024)) && (in_cstride > 512);
Which probably can be simplified more, since in_nhw is n * in_cstride and we know for sure that in_cstride must be greater than 512.

If there are any doubts about those transformations, here is a proof (tested in excel, lol):
Using that fact that if Y is true, then Z must always be true we can even exclude few cases:

X Y Z old new result
0 0 0 FALSE FALSE TRUE
1 0 0 FALSE FALSE TRUE
0 0 1 TRUE TRUE TRUE
1 0 1 FALSE FALSE TRUE
0 1 1 TRUE TRUE TRUE
1 1 1 FALSE FALSE TRUE

{
return true;
}
else
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) &&
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512))
{
return true;
}
else
return false;
}
return !(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) &&
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512);


bool BnBwdTrainingSpatialMultiple::IsApplicable(
const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const
{
if(!problem.IsLayoutNCHW())
return false;
// NCHW is Applicable for variant = 2 only
if(!BNBwdIsCaseVariant2(problem))
{
return false;
}

if(problem.GetDirection() != miopen::batchnorm::Direction::Backward ||
problem.GetMode() != miopenBNSpatial)
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/batchnorm/forward_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace batchnorm {
bool BnFwdInference::IsApplicable(const ExecutionContext&,
const miopen::batchnorm::ProblemDescription& bn_problem) const
{
if(bn_problem.IsLayoutNHWC())
if(!problem.IsLayoutNCHW())
return false;
if(bn_problem.GetDirection() != miopen::batchnorm::Direction::ForwardInference)
return false;
Expand Down
33 changes: 33 additions & 0 deletions src/solver/batchnorm/forward_spatial_multiple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,42 @@ namespace solver {

namespace batchnorm {

bool BNFwdTrainIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem)
{
const auto& xDesc = problem.GetXDesc();
int n, c, h, w;
std::tie(n, c, h, w) = tien<4>(xDesc.GetLengths());
unsigned int in_cstride = h * w;
unsigned int in_nhw = n * in_cstride;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments from src/solver/batchnorm/backward_spatial_multiple.cpp

bool bfp32parm = xDesc.GetType() == miopenFloat;
bool bfpmixparm = (xDesc.GetType() == miopenHalf || xDesc.GetType() == miopenBFloat16) &&
problem.GetBnScale().GetType() == miopenFloat;

// NCHW is Applicable for variant = 2 only
if((!(n < 3) &&
!((in_nhw < 33554432 && in_cstride > 1024) ||
((n >= 256) && (in_cstride > 60) && bfpmixparm) || ((in_cstride > 512) && bfpmixparm)) &&
!(in_cstride <= 512)) ||
!((n > 768) && (in_cstride > 150) && bfp32parm))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's barely readable and probably redundant condition.
May I ask you to do the math and simplify it? (at least replace !(n < 3) with (n >= 3), but there are more simplifications possible)

{
return true;
}
else
return false;
}

bool BnFwdTrainingSpatialMultiple::IsApplicable(
const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const
{
if(!problem.IsLayoutNCHW())
return false;

if(!BNFwdTrainIsCaseVariant2(problem))
{
return false;
}
// if NCHW check if variant is 2 else false (for all data type)
// update get solution to not change variant
if(problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining ||
problem.GetMode() != miopenBNSpatial)
return false;
Expand Down
8 changes: 4 additions & 4 deletions test/gtest/bn_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_BWD_Large_FP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());

Expand All @@ -110,22 +110,22 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_BWD_Large_BFP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());

// fp32
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_BWD_Small_FP32,
testing::Combine(testing::ValuesIn(NetworkSmall<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1})),
TestNameGenerator());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_BWD_Large_FP32,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
// // fp64
Expand Down
8 changes: 4 additions & 4 deletions test/gtest/bn_fwd_train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_FWD_Train_Large_FP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1, testBNAPIV2})),
TestNameGenerator());

Expand All @@ -116,22 +116,22 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_FWD_Train_Large_BFP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1, testBNAPIV2})),
TestNameGenerator());

// // fp32
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_FWD_Train_Small_FP32,
testing::Combine(testing::ValuesIn(NetworkSmall<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1})),
TestNameGenerator());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_FWD_Train_Large_FP32,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
// // fp64
Expand Down
8 changes: 4 additions & 4 deletions test/gtest/bn_infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_Infer_Large_FP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1, testBNAPIV2})),
TestNameGenerator());
// bfp16
Expand All @@ -116,22 +116,22 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_Infer_Large_BFP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1, testBNAPIV2})),
TestNameGenerator());

// fp32
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_Infer_Small_FP32,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1})),
TestNameGenerator());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_Infer_Large_FP32,
testing::Combine(testing::ValuesIn(NetworkSmall<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
// fp64
Expand Down
Loading