Skip to content

Commit

Permalink
Merge pull request #3390 from ROCm/release/rocm-rel-6.3-staging
Browse files Browse the repository at this point in the history
Staging fixes for BN and find2.0 issues
  • Loading branch information
vamovsik authored Nov 18, 2024
2 parents 7a3c295 + cd1a3f1 commit a85ca8a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 96 deletions.
19 changes: 13 additions & 6 deletions src/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,34 @@ void DeriveBNTensorDescriptor(TensorDescriptor& derivedBnDesc,

TensorDescriptor BuildReshaped4DTensorDescriptor(const miopen::TensorDescriptor& tDesc)
{
std::vector<size_t> dims(tDesc.GetLengths());

auto dataType = tDesc.GetType();
auto layout = tDesc.GetLayout_t();
if(layout == miopenTensorNCDHW)
{
layout = miopenTensorNCHW;

// NxCxDxHxW -> NxCx(D*H)xW
dims[2] *= dims[3];
dims[3] = dims[4];
dims.pop_back();
}
else if(layout == miopenTensorNDHWC)
{
layout = miopenTensorNHWC;

// NxDxHxWxC -> Nx(D*H)xWxC
dims[1] *= dims[2];
dims[2] = dims[3];
dims[3] = dims[4];
dims.pop_back();
}
else
{
std::cout << "Cannot handle layout : " << layout << "\n";
exit(EXIT_FAILURE); // NOLINT (concurrency-mt-unsafe)
}
std::vector<size_t> dims(tDesc.GetLengths());

// NxCxDxHxW -> NxCx(D*H)xW
dims[2] *= dims[3];
dims[3] = dims[4];
dims.pop_back();

return {dataType, layout, dims};
}
Expand Down
160 changes: 81 additions & 79 deletions src/batch_norm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t yDesc,
void* y,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t estMeanDesc,
const miopenTensorDescriptor_t estVarianceDesc,
void* bnScale,
Expand All @@ -222,7 +222,7 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
yDesc,
y,
scaleDesc,
BiasDesc,
biasDesc,
estMeanDesc,
estVarianceDesc,
bnScale,
Expand All @@ -239,31 +239,31 @@ miopenBatchNormalizationForwardInference_V2(miopenHandle_t handle,
nullptr,
nullptr,
miopen::debug::BatchNormDirection_t::ForwardInference);

// In case of NxCxDxHxW
int size{0};
miopenGetTensorDescriptorSize(xDesc, &size);
// In case of NxCxDxHxW
auto ReshapeIfNeeded = [size](const auto desc) {
return (size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(desc))
: miopen::deref(desc);
};
return miopen::try_([&] {
miopen::BatchNormForwardInference(
miopen::deref(handle),
bn_mode,
alpha,
beta,
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(xDesc))
: miopen::deref(xDesc),
DataCast(x),
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(yDesc))
: miopen::deref(yDesc),
DataCast(y),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(estMeanDesc),
miopen::deref(estVarianceDesc),
DataCast(bnScale),
DataCast(bnBias),
DataCast(estimatedMean),
DataCast(estimatedVariance),
epsilon);
miopen::BatchNormForwardInference(miopen::deref(handle),
bn_mode,
alpha,
beta,
ReshapeIfNeeded(xDesc),
DataCast(x),
ReshapeIfNeeded(yDesc),
DataCast(y),
ReshapeIfNeeded(scaleDesc),
ReshapeIfNeeded(biasDesc),
ReshapeIfNeeded(estMeanDesc),
ReshapeIfNeeded(estVarianceDesc),
DataCast(bnScale),
DataCast(bnBias),
DataCast(estimatedMean),
DataCast(estimatedVariance),
epsilon);
});
}

Expand All @@ -277,7 +277,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t yDesc,
void* y,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t savedMeanDesc,
const miopenTensorDescriptor_t savedVarianceDesc,
void* bnScale,
Expand All @@ -296,7 +296,7 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
yDesc,
y,
scaleDesc,
BiasDesc,
biasDesc,
savedMeanDesc,
savedVarianceDesc,
bnScale,
Expand All @@ -316,33 +316,35 @@ miopenBatchNormalizationForwardTraining_V2(miopenHandle_t handle,
resultSaveMean,
resultSaveInvVariance,
miopen::debug::BatchNormDirection_t::ForwardTraining);
// In case of NxCxDxHxW

int size{0};
miopenGetTensorDescriptorSize(xDesc, &size);
// In case of NxCxDxHxW
auto ReshapeIfNeeded = [size](const auto desc) {
return (size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(desc))
: miopen::deref(desc);
};
return miopen::try_([&] {
miopen::BatchNormForwardTraining(
miopen::deref(handle),
bn_mode,
alpha,
beta,
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(xDesc))
: miopen::deref(xDesc),
DataCast(x),
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(yDesc))
: miopen::deref(yDesc),
DataCast(y),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(savedMeanDesc),
miopen::deref(savedVarianceDesc),
DataCast(bnScale),
DataCast(bnBias),
expAvgFactor,
DataCast(resultRunningMean),
DataCast(resultRunningVariance),
epsilon,
DataCast(resultSaveMean),
DataCast(resultSaveInvVariance));
miopen::BatchNormForwardTraining(miopen::deref(handle),
bn_mode,
alpha,
beta,
ReshapeIfNeeded(xDesc),
DataCast(x),
ReshapeIfNeeded(yDesc),
DataCast(y),
ReshapeIfNeeded(scaleDesc),
ReshapeIfNeeded(biasDesc),
ReshapeIfNeeded(savedMeanDesc),
ReshapeIfNeeded(savedVarianceDesc),
DataCast(bnScale),
DataCast(bnBias),
expAvgFactor,
DataCast(resultRunningMean),
DataCast(resultRunningVariance),
epsilon,
DataCast(resultSaveMean),
DataCast(resultSaveInvVariance));
});
}

Expand All @@ -360,7 +362,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
const miopenTensorDescriptor_t dxDesc,
void* dx,
const miopenTensorDescriptor_t scaleDesc,
const miopenTensorDescriptor_t BiasDesc,
const miopenTensorDescriptor_t biasDesc,
const miopenTensorDescriptor_t savedMeanDesc,
const miopenTensorDescriptor_t savedVarianceDesc,
const void* bnScale,
Expand All @@ -379,7 +381,7 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
dxDesc,
dx,
scaleDesc,
BiasDesc,
biasDesc,
savedMeanDesc,
savedVarianceDesc,
bnScale,
Expand All @@ -396,35 +398,35 @@ miopenBatchNormalizationBackward_V2(miopenHandle_t handle,
savedMean,
savedInvVariance,
miopen::debug::BatchNormDirection_t::Backward);
// In case of NxCxDxHxW
int size{0};
miopenGetTensorDescriptorSize(xDesc, &size);
// In case of NxCxDxHxW
auto ReshapeIfNeeded = [size](const auto desc) {
return (size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(desc))
: miopen::deref(desc);
};
return miopen::try_([&] {
miopen::BatchNormBackward(
miopen::deref(handle),
bn_mode,
alphaDataDiff,
betaDataDiff,
alphaParamDiff,
betaParamDiff,
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(xDesc))
: miopen::deref(xDesc),
DataCast(x),
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(dyDesc))
: miopen::deref(dyDesc),
DataCast(dy),
(size == 5) ? miopen::BuildReshaped4DTensorDescriptor(miopen::deref(dxDesc))
: miopen::deref(dxDesc),
DataCast(dx),
miopen::deref(scaleDesc),
miopen::deref(BiasDesc),
miopen::deref(savedMeanDesc),
miopen::deref(savedVarianceDesc),
DataCast(bnScale),
DataCast(resultBnScaleDiff),
DataCast(resultBnBiasDiff),
epsilon,
DataCast(savedMean),
DataCast(savedInvVariance));
miopen::BatchNormBackward(miopen::deref(handle),
bn_mode,
alphaDataDiff,
betaDataDiff,
alphaParamDiff,
betaParamDiff,
ReshapeIfNeeded(xDesc),
DataCast(x),
ReshapeIfNeeded(dyDesc),
DataCast(dy),
ReshapeIfNeeded(dxDesc),
DataCast(dx),
ReshapeIfNeeded(scaleDesc),
ReshapeIfNeeded(biasDesc),
ReshapeIfNeeded(savedMeanDesc),
ReshapeIfNeeded(savedVarianceDesc),
DataCast(bnScale),
DataCast(resultBnScaleDiff),
DataCast(resultBnBiasDiff),
epsilon,
DataCast(savedMean),
DataCast(savedInvVariance));
});
}
22 changes: 11 additions & 11 deletions src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,7 @@ Problem::FindSolutions(Handle& handle, const FindOptions& options, std::size_t m
auto ret = std::visit(
boost::hof::match(
[&](const ConvolutionDescriptor& op_desc) {
if(op_desc.mode == miopenTranspose)
return MakeTransposed().FindSolutionsImpl(
handle, options, max_solutions, buffers, op_desc);
else
return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc);
return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc);
},
[&](const SoftmaxDescriptor& op_desc) {
return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc);
Expand Down Expand Up @@ -481,17 +477,21 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,
const auto& w = buffers.at(miopenTensorConvolutionW);
auto y = buffers.at(miopenTensorConvolutionY);

if(conv_desc.mode == miopenTranspose)
std::swap(x, y);

const auto conv_problem = AsConvolution();

ValidateGroupCount(x_desc, w_desc, conv_desc);
const auto conv_problem =
conv_desc.mode == miopenTranspose ? MakeTransposed().AsConvolution() : AsConvolution();

std::size_t workspace_size;
Allocator::ManageDataPtr owned_workspace;
Data_t workspace;

if(conv_desc.mode == miopenTranspose)
{
std::swap(x, y);
std::swap(x_desc, y_desc);
}

ValidateGroupCount(x_desc, w_desc, conv_desc);

if(options.preallocated_workspace)
{
workspace = options.preallocated_workspace->buffer;
Expand Down

0 comments on commit a85ca8a

Please sign in to comment.