Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding 3d bn gtest #3385

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions test/gtest/bn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ static std::string LayoutToString(int tensor_format)
switch(tensor_format)
{
case miopenTensorNCHW: return "NCHW";
case miopenTensorNCDHW: return "NCDHW";
case miopenTensorNHWC: return "NHWC";
case miopenTensorNDHWC: return "NDHWC";
default: return "UnknownTensorFormat";
}
}
Expand All @@ -61,34 +63,40 @@ static std::string ApiVerisonToString(int api_version)
}

// Custom test name generator to handle enums
template <typename TestCase>
struct TestNameGenerator
{
std::string operator()(
const testing::TestParamInfo<std::tuple<BNTestCase, miopenTensorLayout_t, BNApiType>>& info)
const testing::TestParamInfo<std::tuple<TestCase, miopenTensorLayout_t, BNApiType>>& info)
const
{
std::string dimension = std::is_same<TestCase, BN2DTestCase>::value ? "2D"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should start enforcing invalid types with compile-time failures (not just here, but for all of our tests). You could use std::enable_if, but it's so ugly...here's a cleaner idea:

constexpr int dimension = std::is_same<TestCase, BN2DTestCase>::value ? 2
            : std::is_same<TestCase, BN3DTestCase>::value ? 3
            : -1;
static_assert(dimension > 0);
<snip>
std::ostringstream oss;
oss << tensor_name + "_" + api_name + dimension + "D_" + info.index;
return oss.str();

: std::is_same<TestCase, BN3DTestCase>::value ? "3D"
: "Unknown";

const auto& layout_type = std::get<1>(info.param);
const auto& api_type = std::get<2>(info.param);

std::string tensor_name = LayoutToString(layout_type);
std::string api_name = ApiVerisonToString(api_type);

return tensor_name + "_" + api_name + "_" + std::to_string(info.index);
return tensor_name + "_" + api_name + dimension + "_" + std::to_string(info.index);
}
};

template <typename XDataType,
typename YDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
typename MeanVarDataType,
typename TestCase>
struct BNInferTest
: public ::testing::TestWithParam<std::tuple<BNTestCase, miopenTensorLayout_t, BNApiType>>
: public ::testing::TestWithParam<std::tuple<TestCase, miopenTensorLayout_t, BNApiType>>
{
protected:
void SetUp() override
{
std::tie(bn_config, tensor_layout, api_type) = GetParam();
std::tie(bn_config, tensor_layout, api_type) = this->GetParam();
bn_infer_test_data.SetUpImpl(bn_config, tensor_layout);

auto&& handle = get_handle();
Expand Down Expand Up @@ -150,7 +158,7 @@ struct BNInferTest

void TearDown() override
{
if(test_skipped || Test::HasFailure())
if(test_skipped || ::testing::Test::HasFailure())
{
return;
}
Expand All @@ -163,9 +171,9 @@ struct BNInferTest
test::CompareTensor<YDataType>(bn_infer_test_data.output, bn_infer_test_data.ref_out, 4e-3);
}

BNTestCase bn_config;
TestCase bn_config;
bool test_skipped = false;
BNInferTestData<XDataType, YDataType, ScaleDataType, BiasDataType, MeanVarDataType, BNTestCase>
BNInferTestData<XDataType, YDataType, ScaleDataType, BiasDataType, MeanVarDataType, TestCase>
bn_infer_test_data;
miopenTensorLayout_t tensor_layout;
BNApiType api_type;
Expand All @@ -177,14 +185,15 @@ template <typename XDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType>
typename MeanVarDataType,
typename TestCase>
struct BNBwdTest
: public ::testing::TestWithParam<std::tuple<BNTestCase, miopenTensorLayout_t, BNApiType>>
: public ::testing::TestWithParam<std::tuple<TestCase, miopenTensorLayout_t, BNApiType>>
{
protected:
void SetUp() override
{
std::tie(bn_config, tensor_layout, api_type) = GetParam();
std::tie(bn_config, tensor_layout, api_type) = this->GetParam();
bn_bwd_test_data.SetUpImpl(bn_config, tensor_layout);

auto&& handle = get_handle();
Expand Down Expand Up @@ -255,7 +264,7 @@ struct BNBwdTest

void TearDown() override
{
if(test_skipped || Test::HasFailure())
if(test_skipped || ::testing::Test::HasFailure())
{
return;
}
Expand All @@ -277,7 +286,7 @@ struct BNBwdTest
bn_bwd_test_data.dBias, bn_bwd_test_data.dBias_ref, bwd_tol);
}

BNTestCase bn_config;
TestCase bn_config;
bool test_skipped = false;
BNBwdTestData<XDataType,
DxDataType,
Expand All @@ -286,7 +295,7 @@ struct BNBwdTest
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
BNTestCase>
TestCase>
bn_bwd_test_data;
miopenTensorLayout_t tensor_layout;
BNApiType api_type;
Expand All @@ -297,14 +306,15 @@ template <typename XDataType,
typename YDataType,
typename ScaleDataType,
typename BiasDataType,
typename AccDataType>
typename AccDataType,
typename TestCase>
struct BNFwdTrainTest
: public ::testing::TestWithParam<std::tuple<BNTestCase, miopenTensorLayout_t, BNApiType>>
: public ::testing::TestWithParam<std::tuple<TestCase, miopenTensorLayout_t, BNApiType>>
{
protected:
void SetUp() override
{
std::tie(bn_config, tensor_layout, api_type) = GetParam();
std::tie(bn_config, tensor_layout, api_type) = this->GetParam();
bn_fwd_train_test_data.SetUpImpl(bn_config, tensor_layout);

auto&& handle = get_handle();
Expand Down Expand Up @@ -379,7 +389,7 @@ struct BNFwdTrainTest

void TearDown() override
{
if(test_skipped || Test::HasFailure())
if(test_skipped || ::testing::Test::HasFailure())
{
return;
}
Expand Down Expand Up @@ -413,9 +423,9 @@ struct BNFwdTrainTest
bn_fwd_train_test_data.runVariance, bn_fwd_train_test_data.runVariance_ref, 4e-3);
}

BNTestCase bn_config;
TestCase bn_config;
bool test_skipped = false;
BNFwdTrainTestData<XDataType, YDataType, ScaleDataType, BiasDataType, AccDataType, BNTestCase>
BNFwdTrainTestData<XDataType, YDataType, ScaleDataType, BiasDataType, AccDataType, TestCase>
bn_fwd_train_test_data;
miopenTensorLayout_t tensor_layout;
BNApiType api_type;
Expand Down
153 changes: 104 additions & 49 deletions test/gtest/bn_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,112 +33,167 @@
typename DscaleDbiasDataType,
typename MeanVarDataType> */

struct GPU_BN_CK_BWD_Large_FP16
: BNBwdTest<half_float::half, float, float, float, half_float::half, float, float>
struct GPU_BNBWDSmall_FP32
: BNBwdTest<half_float::half, float, float, float, half_float::half, float, float, BN2DTestCase>
{
};

struct GPU_BN_OCL_BWD_Large_FP16
: BNBwdTest<half_float::half, half_float::half, half_float::half, float, float, float, float>
struct GPU_BNOCLBWDLarge2D_FP16 : BNBwdTest<half_float::half,
half_float::half,
half_float::half,
float,
float,
float,
float,
BN2DTestCase>
{
};

struct GPU_BN_CK_BWD_Large_BFP16 : BNBwdTest<bfloat16, float, float, float, bfloat16, float, float>
struct GPU_BNOCLBWDLarge3D_FP16 : BNBwdTest<half_float::half,
half_float::half,
half_float::half,
float,
float,
float,
float,
BN3DTestCase>
{
};

struct GPU_BN_OCL_BWD_Large_BFP16
: BNBwdTest<bfloat16, bfloat16, bfloat16, float, float, float, float>
struct GPU_BNCKBWDLarge2D_BFP16
: BNBwdTest<bfloat16, float, float, float, bfloat16, float, float, BN2DTestCase>
{
};

struct GPU_BN_BWD_Small_FP32 : BNBwdTest<float, float, float, float, float, float, float>
struct GPU_BNOCLBWDLarge2D_BFP16
: BNBwdTest<bfloat16, bfloat16, bfloat16, float, float, float, float, BN2DTestCase>
{
};

struct GPU_BN_BWD_Large_FP32 : BNBwdTest<float, float, float, float, float, float, float>
struct GPU_BNOCLBWDLarge3D_BFP16
: BNBwdTest<bfloat16, bfloat16, bfloat16, float, float, float, float, BN3DTestCase>
{
};

struct GPU_BN_BWD_Small_FP64 : BNBwdTest<double, double, double, double, double, double, double>
struct GPU_BNBWDSmall2D_FP32
: BNBwdTest<float, float, float, float, float, float, float, BN2DTestCase>
{
};

struct GPU_BN_BWD_Large_FP64 : BNBwdTest<double, double, double, double, double, double, double>
struct GPU_BNBWDLarge2D_FP32
: BNBwdTest<float, float, float, float, float, float, float, BN2DTestCase>
{
};

struct GPU_BNBWDLarge3D_FP32
: BNBwdTest<float, float, float, float, float, float, float, BN3DTestCase>
{
};

struct GPU_BNBWDSmall2D_FP64
: BNBwdTest<double, double, double, double, double, double, double, BN2DTestCase>
{
};

struct GPU_BNBWDLarge2D_FP64
: BNBwdTest<double, double, double, double, double, double, double, BN2DTestCase>
{
};

// fp16
TEST_P(GPU_BN_CK_BWD_Large_FP16, DISABLED_BnV2LargeBWDCKfp16) {}
TEST_P(GPU_BN_OCL_BWD_Large_FP16, BnV2LargeBWDOCLfp16) {}
TEST_P(GPU_BNBWDSmall_FP32, DISABLED_BnV2LargeBWDCK2D_fp16) {}
TEST_P(GPU_BNOCLBWDLarge2D_FP16, BnV2LargeBWDOCL2D_fp16) {}
TEST_P(GPU_BNOCLBWDLarge3D_FP16, BnV2LargeBWDOCL3D_fp16) {}

// bfp16
TEST_P(GPU_BN_CK_BWD_Large_BFP16, DISABLED_BnV2LargeBWDCKbfp16) {}
TEST_P(GPU_BN_OCL_BWD_Large_BFP16, BnV2LargeBWDOCLbfp16) {}
// // // bfp16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra //'s

TEST_P(GPU_BNCKBWDLarge2D_BFP16, DISABLED_BnV2LargeBWDCKbfp16_2D) {}
TEST_P(GPU_BNOCLBWDLarge2D_BFP16, BnV2LargeBWDOCLbfp16_2D) {}
TEST_P(GPU_BNOCLBWDLarge3D_BFP16, BnV2LargeBWDOCLbfp16_3D) {}

// fp32 (float)
TEST_P(GPU_BN_BWD_Small_FP32, BnV1SmallBWDCKfp32) {}
TEST_P(GPU_BN_BWD_Large_FP32, BnV2LargeBWDCKfp32) {}
TEST_P(GPU_BNBWDSmall2D_FP32, BnV1SmallBWDCKfp32_2D) {}
TEST_P(GPU_BNBWDLarge2D_FP32, BnV2LargeBWDCKfp32_2D) {}
TEST_P(GPU_BNBWDLarge3D_FP32, BnV2LargeBWDCKfp32_3D) {}

// fp64
TEST_P(GPU_BN_BWD_Small_FP64, DISABLED_BnV1SmallBWDCKfp64) {}
TEST_P(GPU_BN_BWD_Large_FP64, DISABLED_BnV2LargeBWDCKfp64) {}
TEST_P(GPU_BNBWDSmall2D_FP64, DISABLED_BnV1SmallBWDCKfp64_2D) {}
TEST_P(GPU_BNBWDLarge2D_FP64, DISABLED_BnV2LargeBWDCKfp64_2D) {}

// fp16
// // fp16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra //'s

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_CK_BWD_Large_FP16,
testing::Combine(testing::ValuesIn(NetworkSmall<BNTestCase>()),
GPU_BNBWDSmall_FP32,
testing::Combine(testing::ValuesIn(Network2DSmall<BN2DTestCase>()),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
TestNameGenerator<BN2DTestCase>());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_BWD_Large_FP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
GPU_BNOCLBWDLarge2D_FP16,
testing::Combine(testing::ValuesIn(Network2DLarge<BN2DTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
TestNameGenerator<BN2DTestCase>());

// bfp16
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_CK_BWD_Large_BFP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
GPU_BNOCLBWDLarge3D_FP16,
testing::Combine(testing::ValuesIn(Network3DBN<BN3DTestCase>()),
testing::ValuesIn({miopenTensorNCDHW}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator<BN3DTestCase>());

// // bfp16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra //'s

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BNCKBWDLarge2D_BFP16,
testing::Combine(testing::ValuesIn(Network2DLarge<BN2DTestCase>()),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
TestNameGenerator<BN2DTestCase>());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_BWD_Large_BFP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
GPU_BNOCLBWDLarge2D_BFP16,
testing::Combine(testing::ValuesIn(Network2DLarge<BN2DTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
TestNameGenerator<BN2DTestCase>());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BNOCLBWDLarge3D_BFP16,
testing::Combine(testing::ValuesIn(Network3DBN<BN3DTestCase>()),
testing::ValuesIn({miopenTensorNCDHW}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator<BN3DTestCase>());

// fp32
// // fp32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra //'s

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_BWD_Small_FP32,
testing::Combine(testing::ValuesIn(NetworkSmall<BNTestCase>()),
GPU_BNBWDSmall2D_FP32,
testing::Combine(testing::ValuesIn(Network2DSmall<BN2DTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({testBNAPIV1})),
TestNameGenerator());
TestNameGenerator<BN2DTestCase>());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_BWD_Large_FP32,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
GPU_BNBWDLarge2D_FP32,
testing::Combine(testing::ValuesIn(Network2DLarge<BN2DTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
// // fp64
TestNameGenerator<BN2DTestCase>());
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_BWD_Small_FP64,
testing::Combine(testing::ValuesIn(NetworkSmall<BNTestCase>()),
GPU_BNBWDLarge3D_FP32,
testing::Combine(testing::ValuesIn(Network3DBN<BN3DTestCase>()),
testing::ValuesIn({miopenTensorNCDHW}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator<BN3DTestCase>());
// fp64
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BNBWDSmall2D_FP64,
testing::Combine(testing::ValuesIn(Network2DSmall<BN2DTestCase>()),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1})),
TestNameGenerator());

TestNameGenerator<BN2DTestCase>());
//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra //'s

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_BWD_Large_FP64,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
GPU_BNBWDLarge2D_FP64,
testing::Combine(testing::ValuesIn(Network2DLarge<BN2DTestCase>()),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
TestNameGenerator<BN2DTestCase>());
Loading
Loading