Skip to content

Commit

Permalink
Merge pull request #563 from AMDComputeLibraries/bnDSfix
Browse files Browse the repository at this point in the history
Bn d sfix
  • Loading branch information
pfultz2 authored Oct 16, 2017
2 parents 5058fb1 + 73ccbb2 commit e1566a4
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 73 deletions.
6 changes: 3 additions & 3 deletions driver/bn_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
#define MIO_BN_DEBUG 1
#define MIO_BN_MAX_DEBUGLOOP 65536

#define EPSILON 1e-6
#define EPSILON 1e-4

#define ERRTOL 1e-6
#define RMSTOL 1e-6
Expand Down Expand Up @@ -1286,8 +1286,8 @@ int BatchNormDriver<T>::VerifyBackward()
if(!back)
return miopenStatusSuccess;

const double tolerance = ERRTOL;
const double maxrms = RMSTOL;
const double tolerance = ERRTOL * 1000;
const double maxrms = RMSTOL * 1000;
double diff = 0.;
bool anError = false;

Expand Down
15 changes: 6 additions & 9 deletions driver/miopen_BatchNormHost.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,17 @@
#include <cmath>
#include <iomanip>

#define MIO_HEIRARCH_SEL 0
#define MIO_HEIRARCH_SEL 1

#if(MIO_HEIRARCH_SEL == 1)
#define MIO_BN_DIST 32
#endif

template <typename T>
int miopenBNFwdTrainPerActivationRunHost(
/* T alpha,
T beta,
/*
T alpha,
T beta,
*/
int n_batchs,
int channels,
Expand Down Expand Up @@ -743,7 +744,6 @@ int miopenBNBwdPerActivationRunHost(
dxhat += tmp1;
dxhathat += tmp1 * xhat[xhat_index];
} // end for(n_batchs)
dscale_ptr[adjIndex] /= double(n_batchs);

for(int bidx = 0; bidx < n_batchs; bidx++)
{ // via mini_batch
Expand Down Expand Up @@ -812,7 +812,6 @@ int miopenBNBwdPerActivationRunHost(
dxhat += tmp1;
dxhathat += tmp1 * xhat[xhat_index];
} // end for(n_batchs)
dscale_ptr[adjIndex] /= double(n_batchs);

for(int bidx = 0; bidx < n_batchs; bidx++)
{ // via mini_batch
Expand Down Expand Up @@ -891,7 +890,7 @@ int miopenBNBwdSpatialRunHost(
} // end for(n_batchs)
} // for (column)
} // for (row)
dscale_ptr[cidx] /= NHW;

// process the batch per channel
for(int row = 0; row < height; row++)
{ // via rows
Expand Down Expand Up @@ -1087,17 +1086,15 @@ int miopenBNBwdSpatialRunHost(
}
#endif

dscale_ptr[cidx] /= NHW;
// printf("dscale: %f\n",dscale_ptr[cidx]);
// printf("dbias: %f\n",dbias_ptr[cidx]);
// printf("HELLO BASTARDS!!!");

#if(MIO_HEIRARCH_SEL == 0)
for(int row = 0; row < height; row++)
{ // via rows
for(int column = 0; column < width; column++)
{ // via columns
adjIndex = Csubindex + width * row + column;

for(int bidx = 0; bidx < n_batchs; bidx++)
{ // via mini_batch
index = in_nstride * bidx + adjIndex;
Expand Down
4 changes: 2 additions & 2 deletions include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ MIOPEN_EXPORT miopenStatus_t miopenGet4dTensorDescriptor(miopenTensorDescriptor_
* @param tensorDesc Tensor descriptor type (input)
* @param dataType Currently only miopenFloat is implemented (input)
* @param nbDims Number of dimensions in the dimsA array (input)
* @param dimsA Array containing the size of dimensions (output)
* @param stridesA Array containing the size of stride (output)
* @param dimsA Array containing the size of dimensions (input)
* @param stridesA Array containing the size of stride (input)
* @return miopenStatus_t
*/
MIOPEN_EXPORT miopenStatus_t miopenSetTensorDescriptor(miopenTensorDescriptor_t tensorDesc,
Expand Down
10 changes: 8 additions & 2 deletions src/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <miopen/errors.hpp>
#include <miopen/batch_norm.hpp>

#define MIOPEN_BN_SYNCH 0

namespace miopen {

void DeriveBNTensorDescriptor(TensorDescriptor& derivedBnDesc,
Expand Down Expand Up @@ -73,7 +75,9 @@ inline void profileSequence(Handle& handle, unsigned char select)
}
else
{
#if(MIOPEN_BN_SYNCH)
handle.Finish();
#endif
}
break;
case 1:
Expand All @@ -89,7 +93,9 @@ inline void profileSequence(Handle& handle, unsigned char select)
}
else
{
#if(MIOPEN_BN_SYNCH)
handle.Finish();
#endif
}
break;

Expand Down Expand Up @@ -382,7 +388,7 @@ void bnBwdTrainSelectMulti(Handle& handle,

kernel_subname = kernel_name + "FinalDScale";
handle.GetKernel(algo_name, network_config, program_name, kernel_subname, vld, vgd, parms)(
dx, dScale, inhw);
dx, dScale);
profileSequence(handle, 1);

kernel_subname = kernel_name + "DX";
Expand Down Expand Up @@ -433,7 +439,7 @@ void bnBwdTrainSelectMulti(Handle& handle,

kernel_subname = kernel_name + "FinalDScale";
handle.GetKernel(algo_name, network_config, program_name, kernel_subname, vld, vgd, parms)(
dx, dScale, inhw);
dx, dScale);
profileSequence(handle, 1);

kernel_subname = kernel_name + "DX";
Expand Down
2 changes: 0 additions & 2 deletions src/kernels/MIOpenBatchNormBwdPerAct.cl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ __kernel void BatchNormBwdPerActivationSaved(const __global _FLOAT* x_in,
dxhat += tmp1;
dxhathat = mad(tmp1, xhat, dxhathat);
} // end for(n)
pvt_dscale /= (_FLOAT)N;

for(int n = 0; n < N; n++)
{
Expand Down Expand Up @@ -246,7 +245,6 @@ __kernel void BatchNormBwdPerActivation(const __global _FLOAT* x_in,
dxhat += tmp1;
dxhathat = mad(tmp1, xhat, dxhathat);
} // end for(n)
pvt_dscale /= (_FLOAT)N;

for(int n = 0; n < N; n++)
{
Expand Down
62 changes: 26 additions & 36 deletions src/kernels/MIOpenBatchNormBwdSpatial.cl
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@
#undef __AMDGCN__
#endif

//#ifdef __AMDGCN__
//#undef __AMDGCN__
//#endif
/*
#ifdef __AMDGCN__
#undef __AMDGCN__
#endif
*/

// Disable specific warnings
#ifdef __clang__
Expand Down Expand Up @@ -409,15 +411,15 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
lcl_data[lid] += lcl_data[lid + red];
barrier(CLK_LOCAL_MEM_FENCE);
}
dppLDSReduce64(&ds, lcl_data, lid, INHW);
dppLDSReduce64(&ds, lcl_data, lid, 1);
#else
for(unsigned int red = (MIO_BN_GRP0 >> 1); red > 256; red >>= 1)
{
if(lid < red)
lcl_data[lid] += lcl_data[lid + red];
barrier(CLK_LOCAL_MEM_FENCE);
}
regLDSreduce(&ds, lcl_data, lid, INHW);
regLDSreduce(&ds, lcl_data, lid, 1);
#endif

if(lid < MIO_BN_SEGMENT)
Expand Down Expand Up @@ -649,16 +651,14 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
#ifdef __AMDGCN__

#if(MIO_BN_HW > 16)
dppRegReduce64(&ds, INHW);
dppRegReduce64(&ds, 1);
#elif(MIO_BN_HW > 1)
dppRegReduce16(&ds, INHW);
#else
ds *= INHW;
dppRegReduce16(&ds, 1);
#endif // HW
#else // if not GCN

#if(MIO_BN_HW > 16)
regLDSreduce(&ds, lcl_data, ylid, INHW);
regLDSreduce(&ds, lcl_data, ylid, 1);
#elif(MIO_BN_HW > 1)
lcl_data[ylid] = ds;
barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -668,9 +668,6 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
{
ds += lcl_data[i];
}
ds *= INHW;
#else
ds *= INHW;
#endif // HW
#endif // GCN
//===========================================
Expand All @@ -686,7 +683,7 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
{
index = n * MIO_BN_CHW + cidx + ylid;
tmp1 = mad(NHW, dyvalues[n], -db);
tmp2 = -batchvalues[n] * ds;
tmp2 = -(batchvalues[n]) * ds;
tmp3 = (pscale * invVar) * INHW;
dx_out[index] = tmp3 * (tmp2 + tmp1);
}
Expand Down Expand Up @@ -901,7 +898,7 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
#else // GCN

#if(MIO_BN_N > 16)
regLDSreduce(&db, lcl_data, ylid, INHW);
regLDSreduce(&db, lcl_data, ylid, 1);
#elif(MIO_BN_N > 1)
lcl_data[ylid] = db;
barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -917,7 +914,7 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
#ifdef __AMDGCN__

#if(MIO_BN_N > 16)
dppRegReduce64(&ds, INHW);
dppRegReduce64(&ds, 1);
#elif(MIO_BN_N > 1)
lcl_data[ylid] = ds;
barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -927,14 +924,11 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
{
ds += lcl_data[i];
}
ds *= INHW;
#else
ds *= INHW;
#endif // N
#else // if not GCN

#if(MIO_BN_N > 16)
regLDSreduce(&ds, lcl_data, ylid, INHW);
regLDSreduce(&ds, lcl_data, ylid, 1);
#elif(MIO_BN_N > 1)
lcl_data[ylid] = ds;
barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -944,9 +938,6 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
{
ds += lcl_data[i];
}
ds *= INHW;
#else
ds *= INHW;
#endif // HW
#endif // GCN
//===========================================
Expand All @@ -962,7 +953,7 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
{
index = ylid * MIO_BN_CHW + cidx + hw;
tmp1 = mad(NHW, dyvalues[hw], -db);
tmp2 = -batchvalues[hw] * ds;
tmp2 = -(batchvalues[hw]) * ds;
tmp3 = (pscale * invVar) * INHW;
dx_out[index] = tmp3 * (tmp2 + tmp1);
}
Expand Down Expand Up @@ -1231,7 +1222,7 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
barrier(CLK_LOCAL_MEM_FENCE);
lcl_data[ylid] = ds;
barrier(CLK_LOCAL_MEM_FENCE);
dppLDSReduce64(&ds, lcl_data, ylid, INHW);
dppLDSReduce64(&ds, lcl_data, ylid, 1);

#else
for(unsigned int red = (MIO_BN_GRP1 >> 1); red > 256; red >>= 1)
Expand All @@ -1240,7 +1231,7 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
lcl_data[ylid] += lcl_data[ylid + red];
barrier(CLK_LOCAL_MEM_FENCE);
}
regLDSreduce(&ds, lcl_data, ylid, INHW);
regLDSreduce(&ds, lcl_data, ylid, 1);
#endif
//===========================================

Expand All @@ -1256,7 +1247,7 @@ BatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
index = n * MIO_BN_CHW + cidx + ylid;
#if(MIO_BN_N < MIO_BN_MAXN)
tmp1 = mad(NHW, dyvalues[n], -db);
tmp2 = -batchvalues[n] * ds;
tmp2 = -(batchvalues[n]) * ds;
#else
tmp1 = mad(NHW, dy_in[index], -db);
tmp2 = -(x_in[index] - mean) * invVar * ds;
Expand Down Expand Up @@ -1761,7 +1752,7 @@ BatchNormBwdSpatialDScale(const __global _FLOAT* x_in,
unsigned int varstashindex = cidx + ygrp_sz * ygrp_id + 3;
lmean = buff[meanstashindex]; // load stashed mean
livar = buff[varstashindex];
#else // SAVED
#else // NO SAVED
lmean = savedMean[xgid];
livar = savedInvVariance[xgid];
#endif // SAVED
Expand All @@ -1781,6 +1772,7 @@ BatchNormBwdSpatialDScale(const __global _FLOAT* x_in,
elemStd = x_in[index] - mean; // (x_i - mean)
xhat = elemStd * invVar;
dscale = mad(xhat, dy_in[index], dscale);
// dscale += 1.;
} // end for
} // end if

Expand Down Expand Up @@ -1809,7 +1801,6 @@ BatchNormBwdSpatialDScale(const __global _FLOAT* x_in,
regLDSreduce(&dscale, lcl_data, ylid, 1);

#endif // GCN

if(ylid == 0)
{
unsigned int gammaindex = cidx + ygrp_sz * ygrp_id + 4;
Expand All @@ -1818,7 +1809,7 @@ BatchNormBwdSpatialDScale(const __global _FLOAT* x_in,
}

__attribute__((reqd_work_group_size(MIO_BN_GRP0, MIO_BN_GRP1, MIO_BN_GRP2))) __kernel void
BatchNormBwdSpatialFinalDScale(__global _FLOAT* buff, __global _FLOAT* delta_scale, _FLOAT INHW)
BatchNormBwdSpatialFinalDScale(__global _FLOAT* buff, __global _FLOAT* delta_scale)
{

__private _FLOAT ds = 0.;
Expand Down Expand Up @@ -1852,29 +1843,29 @@ BatchNormBwdSpatialFinalDScale(__global _FLOAT* buff, __global _FLOAT* delta_sca
lcl_data[ylid] += lcl_data[ylid + red];
barrier(CLK_LOCAL_MEM_FENCE);
}
dppLDSReduce64(&ds, lcl_data, ylid, INHW);
dppLDSReduce64(&ds, lcl_data, ylid, 1);
#else // GCN
for(unsigned int red = (MIO_BN_GRP1 >> 1); red > 256; red >>= 1)
{
if(ylid < red)
lcl_data[ylid] += lcl_data[ylid + red];
barrier(CLK_LOCAL_MEM_FENCE);
}
regLDSreduce(&ds, lcl_data, ylid, INHW);
regLDSreduce(&ds, lcl_data, ylid, 1);
#endif // GCN

#elif(MIO_BN_NGRPS <= 64)

#ifdef __AMDGCN__
dppRegReduce64(&ds, INHW);
dppRegReduce64(&ds, 1);
#else // GCN
__local _FLOAT lcl_data[MIO_BN_LDS_SIZE];
regLDSreduce(&ds, lcl_data, ylid, INHW);
regLDSreduce(&ds, lcl_data, ylid, 1);
#endif // GCN
#else // else < 16

#ifdef __AMDGCN__
dppRegReduce16(&ds, INHW);
dppRegReduce16(&ds, 1);
#else // GCN
__local _FLOAT lcl_data[MIO_BN_LDS_SIZE];
lcl_data[ylid] = ds;
Expand All @@ -1885,7 +1876,6 @@ BatchNormBwdSpatialFinalDScale(__global _FLOAT* buff, __global _FLOAT* delta_sca
{
ds += lcl_data[i];
}
ds *= INHW;
#endif // end AMDGCN
#endif // NGRPS

Expand Down
Loading

0 comments on commit e1566a4

Please sign in to comment.