-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BN] Enable NHWC in OCL #3399
base: develop
Are you sure you want to change the base?
[BN] Enable NHWC in OCL #3399
Changes from 2 commits
d0e7b78
9b62286
54ef272
38486ec
7e21c04
dcc58f8
281b230
f6cfb74
12a8920
4f08f31
aa999e0
bcc0ae7
c853342
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -38,9 +38,34 @@ namespace solver { | |||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
namespace batchnorm { | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
bool BNBwdIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem) | ||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||
int n, c, h, w; | ||||||||||||||||||||||||||||||||||||||||||||
std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths()); | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
unsigned int in_cstride = h * w; | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But actually you can avoid all the prevois
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||
unsigned int in_nhw = n * in_cstride; | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
if(!(in_nhw < (32 * 1024 * 1024) && in_cstride > 1024) && | ||||||||||||||||||||||||||||||||||||||||||||
!(in_nhw < (32 * 1024 * 1024) && in_cstride > 512) && !(in_cstride <= 512)) | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not that complex as the condition from src/solver/batchnorm/forward_spatial_multiple.cpp, but still can be simplified, since it contains redundant statements. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One example of the transformations:
So it basically means If there are any doubts about those transformations, here is a proof (tested in excel, lol):
|
||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
bool BnBwdTrainingSpatialMultiple::IsApplicable( | ||||||||||||||||||||||||||||||||||||||||||||
const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const | ||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||
if(!problem.IsLayoutNCHW()) | ||||||||||||||||||||||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||||||||||||||||||||||
// NCHW is Applicable for variant = 2 only | ||||||||||||||||||||||||||||||||||||||||||||
if(!BNBwdIsCaseVariant2(problem)) | ||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
if(problem.GetDirection() != miopen::batchnorm::Direction::Backward || | ||||||||||||||||||||||||||||||||||||||||||||
problem.GetMode() != miopenBNSpatial) | ||||||||||||||||||||||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,9 +40,42 @@ namespace solver { | |
|
||
namespace batchnorm { | ||
|
||
bool BNFwdTrainIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem) | ||
{ | ||
const auto& xDesc = problem.GetXDesc(); | ||
int n, c, h, w; | ||
std::tie(n, c, h, w) = tien<4>(xDesc.GetLengths()); | ||
unsigned int in_cstride = h * w; | ||
unsigned int in_nhw = n * in_cstride; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See the comments from src/solver/batchnorm/backward_spatial_multiple.cpp |
||
bool bfp32parm = xDesc.GetType() == miopenFloat; | ||
bool bfpmixparm = (xDesc.GetType() == miopenHalf || xDesc.GetType() == miopenBFloat16) && | ||
problem.GetBnScale().GetType() == miopenFloat; | ||
|
||
// NCHW is Applicable for variant = 2 only | ||
if((!(n < 3) && | ||
!((in_nhw < 33554432 && in_cstride > 1024) || | ||
((n >= 256) && (in_cstride > 60) && bfpmixparm) || ((in_cstride > 512) && bfpmixparm)) && | ||
!(in_cstride <= 512)) || | ||
!((n > 768) && (in_cstride > 150) && bfp32parm)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's barely readable and probably redundant condition. |
||
{ | ||
return true; | ||
} | ||
else | ||
return false; | ||
} | ||
|
||
bool BnFwdTrainingSpatialMultiple::IsApplicable( | ||
const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const | ||
{ | ||
if(!problem.IsLayoutNCHW()) | ||
return false; | ||
|
||
if(!BNFwdTrainIsCaseVariant2(problem)) | ||
{ | ||
return false; | ||
} | ||
// if NCHW check if variant is 2 else false (for all data type) | ||
// update get solution to not change variant | ||
if(problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining || | ||
problem.GetMode() != miopenBNSpatial) | ||
return false; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.