diff --git a/driver/bn_driver.hpp b/driver/bn_driver.hpp index bfa1b91aef..e56680b806 100644 --- a/driver/bn_driver.hpp +++ b/driver/bn_driver.hpp @@ -195,10 +195,10 @@ int BatchNormDriver::GetandSetData() SetBNParametersFromCmdLineArgs(); in.AllocOnHost(tensor{bn_layout, in_len}); - for(size_t i = 0; i < in.GetVector().size(); i++) - { - in.GetVector()[i] = prng::gen_canonical(); - } + // 0.0 to 2.0 (since unsigned) + in.GetTensor().generate([](auto...) { + return prng::gen_descreet_unsigned(2e-3 /*scale*/, 1000 /*range*/); + }); auto derivedBnDesc = miopen::TensorDescriptor{}; miopen::DeriveBNTensorDescriptor(derivedBnDesc, in.GetTensor().desc, bn_mode); @@ -208,20 +208,25 @@ int BatchNormDriver::GetandSetData() out.AllocOnHost(tensor{bn_layout, in_len}); scale.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); bias.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); - - for(int i = 0; i < scale.GetVector().size(); i++) - { - scale.GetVector()[i] = prng::gen_canonical(); - bias.GetVector()[i] = prng::gen_canonical(); - } + // -2.0 to 2.0 + scale.GetTensor().generate([](auto...) { + return prng::gen_descreet_uniform_sign(2e-3 /*scale*/, 1000 /*range*/); + }); + bias.GetTensor().generate([](auto...) { + return prng::gen_descreet_uniform_sign(2e-3 /*scale*/, 1000 /*range*/); + }); } if(isFwdInfer) { estMean.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); estVariance.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); - auto gen_value_emean = [](auto...) { return prng::gen_descreet_unsigned(1e-2, 100); }; - estMean.InitHostData(estMean.GetTensor().desc.GetElementSize(), true, gen_value_emean); + // 0.0 to 1.0 + estMean.InitHostData(estMean.GetTensor().desc.GetElementSize(), true, [](auto...) { + return prng::gen_descreet_uniform_sign(2e-3 /*scale*/, 1000 /*range*/); + }); + estVariance.GetTensor().generate( + [](auto...) { return static_cast(2e-3 * (prng::gen_0_to_B(1000) + 1)); }); } else if(isFwdTrain) { @@ -230,11 +235,13 @@ int BatchNormDriver::GetandSetData() runMean.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); runVariance.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); - for(int i = 0; i < runVariance.GetVector().size(); i++) - { - runMean.GetVector()[i] = prng::gen_canonical(); - runVariance.GetVector()[i] = prng::gen_canonical(); - } + // -2.0 to 2.0 + runMean.GetTensor().generate([](auto...) { + return prng::gen_descreet_uniform_sign(2e-3 /*scale*/, 1000 /*range*/); + }); + runVariance.GetTensor().generate([](auto...) { + return prng::gen_descreet_uniform_sign(2e-3 /*scale*/, 1000 /*range*/); + }); } else if(isBwd) { @@ -242,33 +249,33 @@ int BatchNormDriver::GetandSetData() bnScale.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); dy.AllocOnHost(tensor{bn_layout, in_len}); - - auto gen_var_bwd = [](auto...) { - return static_cast(1e-2 * (prng::gen_0_to_B(100) + 1)); - }; - - dy.InitHostData(dy.GetTensor().desc.GetElementSize(), true, gen_var_bwd); + // -2.0 to 2.0 + dy.InitHostData(dy.GetTensor().desc.GetElementSize(), true, [](auto...) { + return prng::gen_descreet_uniform_sign(2e-3, 1000); + }); dScale.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); dBias.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); savedMean.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); savedInvVar.AllocOnHost(tensor{bn_layout, derivedBnDesc.GetLengths()}); - auto gen_value = [](auto...) { return prng::gen_descreet_unsigned(1e-2, 100); }; - bnScale.InitHostData(bnScale.GetTensor().desc.GetElementSize(), true, gen_value); - - auto gen_in_var = [](auto...) { - return static_cast(1e-2 * (prng::gen_0_to_B(100) + 1)); + auto gen_value_bnScale = [](auto...) { + return prng::gen_descreet_uniform_sign(2e-3, 1000); }; - savedMean.InitHostData(savedMean.GetTensor().desc.GetElementSize(), true, gen_in_var); - savedInvVar.InitHostData(savedInvVar.GetTensor().desc.GetElementSize(), true, gen_in_var); + bnScale.InitHostData(bnScale.GetTensor().desc.GetElementSize(), true, gen_value_bnScale); + // -2.0 to 2.0 + savedMean.InitHostData(savedMean.GetTensor().desc.GetElementSize(), true, [](auto...) { + return prng::gen_descreet_uniform_sign(2e-3, 1000); + }); + savedInvVar.InitHostData(savedInvVar.GetTensor().desc.GetElementSize(), true, [](auto...) { + return prng::gen_descreet_uniform_sign(2e-3, 1000); + }); } else { std::cout << "\nUnknown batch norm state!\n"; exit(EXIT_FAILURE); } - return miopenStatusSuccess; } diff --git a/src/kernels/MIOpenBatchNormFwdInferPerAct.cl b/src/kernels/MIOpenBatchNormFwdInferPerAct.cl index f516a076ea..a505fd6b2a 100644 --- a/src/kernels/MIOpenBatchNormFwdInferPerAct.cl +++ b/src/kernels/MIOpenBatchNormFwdInferPerAct.cl @@ -43,6 +43,8 @@ MIOpenBatchNormFwdInferPerActivationEst(const __global _FLOAT* in, const __global _FLOAT_PREC* __restrict bias, double epsilon, unsigned int batchSize, + unsigned int cLen, + unsigned int cStride, unsigned int imageDims, unsigned int batchStride) { @@ -58,7 +60,7 @@ MIOpenBatchNormFwdInferPerActivationEst(const __global _FLOAT* in, for(int img_offset = ygid; img_offset < imageDims; img_offset += yglb_sz) { - adjIndex = (grpid * imageDims) + img_offset; + adjIndex = (grpid * cStride) + img_offset * cLen; mean = estimatedMean[adjIndex]; variance = estimatedVariance[adjIndex]; invVariance = rsqrt(fabs(variance + epsilon)); diff --git a/src/kernels/MIOpenBatchNormFwdInferSpatial.cl b/src/kernels/MIOpenBatchNormFwdInferSpatial.cl index a81db2a03b..24e268fd0d 100644 --- a/src/kernels/MIOpenBatchNormFwdInferSpatial.cl +++ b/src/kernels/MIOpenBatchNormFwdInferSpatial.cl @@ -43,6 +43,8 @@ MIOpenBatchNormFwdInferSpatialEst(const __global _FLOAT* __restrict in, /* x inp const __global _FLOAT_PREC* __restrict bias, double epsilon, unsigned int batchSize, + unsigned int cLen, + unsigned int cStride, unsigned int imageDims, unsigned int batchStride) { @@ -66,7 +68,7 @@ MIOpenBatchNormFwdInferSpatialEst(const __global _FLOAT* __restrict in, /* x inp { for(int n = 0; n < batchSize; n++) { - index = (n * batchStride) + (xgid * imageDims) + idx; + index = (n * batchStride) + (xgid * cStride) + idx * cLen; inhat = (FLOAT2FLOATPREC(*(in + index)) - mean) * invVariance; out[index] = FLOATPREC2FLOAT(mad(pscale, inhat, pbias)); } diff --git a/src/ocl/batchnormocl.cpp b/src/ocl/batchnormocl.cpp index 6232759b0b..33a8243789 100644 --- a/src/ocl/batchnormocl.cpp +++ b/src/ocl/batchnormocl.cpp @@ -152,9 +152,9 @@ void BatchNormForwardTraining(Handle& handle, }(); const auto solvers = solver::SolverContainer{}; + // solver::batchnorm::BnCKFwdTraining>{}; solvers.ExecutePrimitive(handle, problem, algo, invoke_params); @@ -250,9 +250,8 @@ void BatchNormForwardInference(Handle& handle, }(); const auto algo = AlgorithmName{"miopenBatchNormalizationForwardInference"}; - const auto solvers = solver::SolverContainer{}; + const auto solvers = solver::SolverContainer{}; + // solver::batchnorm::BnCKFwdInference>{}; solvers.ExecutePrimitive(handle, problem, algo, invoke_params); } @@ -395,9 +394,9 @@ void BatchNormBackward(Handle& handle, }(); const auto solvers = solver::SolverContainer{}; + // solver::batchnorm::BnCKBwdBackward>{}; solvers.ExecutePrimitive(handle, problem, algo, invoke_params); diff --git a/src/solver/batchnorm/backward_spatial_multiple.cpp b/src/solver/batchnorm/backward_spatial_multiple.cpp index e26922f478..2fa80fe145 100644 --- a/src/solver/batchnorm/backward_spatial_multiple.cpp +++ b/src/solver/batchnorm/backward_spatial_multiple.cpp @@ -38,9 +38,33 @@ namespace solver { namespace batchnorm { +bool BNBwdIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem) +{ + size_t n, c, h, w; + std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths()); + + size_t in_cstride = h * w; + size_t in_nhw = n * in_cstride; + + if((in_nhw >= static_cast(32 * 1024 * 1024) || in_cstride <= 1024) && in_cstride > 512) + { + return true; + } + else + return false; +} + bool BnBwdTrainingSpatialMultiple::IsApplicable( const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const { + if(!problem.IsLayoutNCHW()) + return false; + // NCHW is Applicable for variant = 2 only + if(!BNBwdIsCaseVariant2(problem)) + { + return false; + } + if(problem.GetDirection() != miopen::batchnorm::Direction::Backward || problem.GetMode() != miopenBNSpatial) return false; diff --git a/src/solver/batchnorm/forward_inference.cpp b/src/solver/batchnorm/forward_inference.cpp index a05fce5105..7aab643be7 100644 --- a/src/solver/batchnorm/forward_inference.cpp +++ b/src/solver/batchnorm/forward_inference.cpp @@ -41,8 +41,6 @@ namespace batchnorm { bool BnFwdInference::IsApplicable(const ExecutionContext&, const miopen::batchnorm::ProblemDescription& bn_problem) const { - if(bn_problem.IsLayoutNHWC()) - return false; if(bn_problem.GetDirection() != miopen::batchnorm::Direction::ForwardInference) return false; if(!(bn_problem.IsFp32() or bn_problem.IsFp16() or bn_problem.IsBFp16())) @@ -149,16 +147,36 @@ ConvSolution BnFwdInference::GetSolution(const ExecutionContext& context, unsigned int in_nstride_ = c_ * h_ * w_; unsigned int in_cstride_ = h_ * w_; - kernel(params.x, - params.y, - params.estimatedMean, - params.estimatedVariance, - params.bnScale, - params.bnBias, - params.epsilon, - n_, - in_cstride_, - in_nstride_); + if(params.xDesc->GetLayout_t() == miopenTensorNHWC) + { + kernel(params.x, + params.y, + params.estimatedMean, + params.estimatedVariance, + params.bnScale, + params.bnBias, + params.epsilon, + n_, + c_, // nhwc = c + 1, + in_cstride_, + in_nstride_); + } + else + { + kernel(params.x, + params.y, + params.estimatedMean, + params.estimatedVariance, + params.bnScale, + params.bnBias, + params.epsilon, + n_, + 1, // nchw 1 + h_ * w_, + in_cstride_, + in_nstride_); + } }; }; diff --git a/src/solver/batchnorm/forward_spatial_multiple.cpp b/src/solver/batchnorm/forward_spatial_multiple.cpp index 6a2c42743b..a7a0f871ac 100644 --- a/src/solver/batchnorm/forward_spatial_multiple.cpp +++ b/src/solver/batchnorm/forward_spatial_multiple.cpp @@ -40,9 +40,40 @@ namespace solver { namespace batchnorm { +bool BNFwdTrainIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem) +{ + const auto& xDesc = problem.GetXDesc(); + size_t n, c, h, w; + std::tie(n, c, h, w) = tien<4>(xDesc.GetLengths()); + size_t in_cstride = h * w; + size_t in_nhw = n * in_cstride; + bool bfp32parm = xDesc.GetType() == miopenFloat; + bool bfpmixparm = (xDesc.GetType() == miopenHalf || xDesc.GetType() == miopenBFloat16) && + problem.GetBnScale().GetType() == miopenFloat; + + // NCHW is Applicable for variant = 2 only + // these number comes from BnFwdTrainingSpatialMultiple::GetSolution of + // forward_spatial_multiple.cpp + if((n >= 3 && in_cstride > 512 && (in_nhw >= 33554432 || in_cstride <= 1024) && + ((n < 256) || (in_cstride <= 60) || !bfpmixparm) && (!bfpmixparm || in_cstride <= 512)) || + (n <= 768 || in_cstride <= 150 || !bfp32parm)) + { + return true; + } + else + return false; +} + bool BnFwdTrainingSpatialMultiple::IsApplicable( const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const { + // if NCHW check if variant is 2 else false (for all data type) + // update get solution to not change variant + if(!BNFwdTrainIsCaseVariant2(problem)) + { + return false; + } + if(problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining || problem.GetMode() != miopenBNSpatial) return false; diff --git a/test/gtest/bn_bwd.cpp b/test/gtest/bn_bwd.cpp index a84a8a8feb..9857c31f6f 100644 --- a/test/gtest/bn_bwd.cpp +++ b/test/gtest/bn_bwd.cpp @@ -95,7 +95,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_OCL_BWD_Large_FP16, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV2})), TestNameGenerator()); @@ -110,7 +110,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_OCL_BWD_Large_BFP16, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV2})), TestNameGenerator()); @@ -118,14 +118,14 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_BWD_Small_FP32, testing::Combine(testing::ValuesIn(NetworkSmall()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV1})), TestNameGenerator()); INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_BWD_Large_FP32, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV2})), TestNameGenerator()); // // fp64 diff --git a/test/gtest/bn_fwd_train.cpp b/test/gtest/bn_fwd_train.cpp index 9b4722aaf8..c9db49e8a3 100644 --- a/test/gtest/bn_fwd_train.cpp +++ b/test/gtest/bn_fwd_train.cpp @@ -101,7 +101,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_OCL_FWD_Train_Large_FP16, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV1, testBNAPIV2})), TestNameGenerator()); @@ -116,7 +116,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_OCL_FWD_Train_Large_BFP16, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV1, testBNAPIV2})), TestNameGenerator()); @@ -124,14 +124,14 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_FWD_Train_Small_FP32, testing::Combine(testing::ValuesIn(NetworkSmall()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV1})), TestNameGenerator()); INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_FWD_Train_Large_FP32, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV2})), TestNameGenerator()); // // fp64 diff --git a/test/gtest/bn_infer.cpp b/test/gtest/bn_infer.cpp index 591cbd0b1a..b1027e819c 100644 --- a/test/gtest/bn_infer.cpp +++ b/test/gtest/bn_infer.cpp @@ -102,7 +102,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_OCL_Infer_Large_FP16, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV1, testBNAPIV2})), TestNameGenerator()); // bfp16 @@ -116,7 +116,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_OCL_Infer_Large_BFP16, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV1, testBNAPIV2})), TestNameGenerator()); @@ -124,14 +124,14 @@ INSTANTIATE_TEST_SUITE_P(Smoke, INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_Infer_Small_FP32, testing::Combine(testing::ValuesIn(NetworkLarge()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV1})), TestNameGenerator()); INSTANTIATE_TEST_SUITE_P(Smoke, GPU_BN_Infer_Large_FP32, testing::Combine(testing::ValuesIn(NetworkSmall()), - testing::ValuesIn({miopenTensorNCHW}), + testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}), testing::ValuesIn({testBNAPIV2})), TestNameGenerator()); // fp64 diff --git a/test/gtest/bn_test_data.hpp b/test/gtest/bn_test_data.hpp index d3b1c6b073..c22acd7f0e 100644 --- a/test/gtest/bn_test_data.hpp +++ b/test/gtest/bn_test_data.hpp @@ -67,6 +67,7 @@ inline std::vector NetworkLarge() // pyt_mlperf_resnet50v1.5 return { {192, 1, 8, 8, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 1, 0}, + {12, 40, 122, 122, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 1, 0}, {64, 2048, 7, 7, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 0, 1}, {64, 2048, 7, 7, miopenBNSpatial, miopen::batchnorm::Direction::ForwardTraining, 1, 1}, {64, 2048, 7, 7, miopenBNSpatial, miopen::batchnorm::Direction::ForwardInference, 1, 0}, @@ -101,7 +102,7 @@ inline std::vector NetworkSmall() { // pyt_mlperf_resnet50v1.5 return { - {192, 2, 8, 8, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 1, 0}, + {12, 40, 122, 122, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 1, 0}, {16, 8, 132, 28, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 1, 0}, {16, 8, 128, 256, miopenBNSpatial, miopen::batchnorm::Direction::ForwardTraining, 1, 0}, {64, 2048, 17, 17, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 0, 1}, @@ -148,8 +149,10 @@ struct BNTestData void InitTensorsWithRandValue() { - input.generate( - [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 100); }); + // 0.0 to 2.0 (since unsigned) + input.generate([](auto...) { + return prng::gen_descreet_unsigned(2e-3 /*scale*/, 1000 /*range*/); + }); } void SetDirection() { direction = bn_config.Direction; } @@ -212,15 +215,17 @@ struct BNInferTestData : public BNTestData void InitTensorsWithRandValue() { - auto gen_value = [](auto...) { - return prng::gen_descreet_uniform_sign(1e-2, 100); - }; - scale.generate(gen_value); - shift.generate(gen_value); - estMean.generate(gen_value); - + // -2.0 to 2.0 + scale.generate( + [](auto...) { return prng::gen_descreet_uniform_sign(2e-3, 1000); }); + shift.generate( + [](auto...) { return prng::gen_descreet_uniform_sign(2e-3, 1000); }); + estMean.generate( + [](auto...) { return prng::gen_descreet_uniform_sign(2e-3, 1000); }); + + // 0.0 to 2.0 auto gen_var = [](auto...) { - return static_cast(1e-2 * (prng::gen_0_to_B(100) + 1)); + return static_cast(2e-3 * (prng::gen_0_to_B(1000) + 1)); }; estVariance.generate(gen_var); } @@ -303,17 +308,17 @@ struct BNBwdTestData : public BNTestData void InitTensorsWithRandValue() { - auto gen_value = [](auto...) { - return prng::gen_descreet_uniform_sign(1e-2, 100); - }; - dy.generate(gen_value); - bnScale.generate(gen_value); - savedMean.generate(gen_value); - - auto gen_var = [](auto...) { - return static_cast(1e-2 * (prng::gen_0_to_B(100) + 1)); - }; - savedInvVar.generate(gen_var); + // -2.0 to 2.0 + dy.generate( + [](auto...) { return prng::gen_descreet_uniform_sign(2e-3, 1000); }); + bnScale.generate( + [](auto...) { return prng::gen_descreet_uniform_sign(2e-3, 1000); }); + savedMean.generate( + [](auto...) { return prng::gen_descreet_uniform_sign(2e-3, 1000); }); + // 0.0 to 2.0 + savedInvVar.generate([](auto...) { + return static_cast(2e-3 * (prng::gen_0_to_B(1000) + 1)); + }); std::fill(dScale.begin(), dScale.end(), 0.); std::fill(dBias.begin(), dBias.end(), 0.); @@ -400,14 +405,14 @@ struct BNFwdTrainTestData : public BNTestData void InitTensorsWithRandValue() { - auto gen_value = [](auto...) { - return prng::gen_descreet_uniform_sign(1e-2, 100); - }; - scale.generate(gen_value); - shift.generate(gen_value); - + // -2.0 to 2.0 + scale.generate( + [](auto...) { return prng::gen_descreet_uniform_sign(2e-3, 1000); }); + shift.generate( + [](auto...) { return prng::gen_descreet_uniform_sign(2e-3, 1000); }); + // 0.0 to 2.0 auto gen_var = [](auto...) { - return static_cast(1e-2 * (prng::gen_0_to_B(100) + 1)); + return static_cast(2e-3 * (prng::gen_0_to_B(1000) + 1)); }; runMean.generate(gen_var); runVariance.generate(gen_var);