From 3419b32ab8f539a28a911137db84a94668205fed Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Wed, 30 Oct 2024 16:25:13 +0100 Subject: [PATCH 01/21] fwd dynamic lstm --- src/include/miopen/rnn.hpp | 6 + src/include/miopen/rnn/solvers.hpp | 319 ++++++++++++++++---- src/include/miopen/rnn/tmp_buffer_utils.hpp | 12 + src/ocl/rnnocl.cpp | 10 +- src/rnn.cpp | 28 +- src/rnn/Solutions/Base/fw_data_modular.cpp | 92 ++++++ src/rnn/Solutions/fwd_s_stream.cpp | 62 ++++ src/rnn/selector.cpp | 24 +- 8 files changed, 485 insertions(+), 68 deletions(-) diff --git a/src/include/miopen/rnn.hpp b/src/include/miopen/rnn.hpp index eba00e2a29..aa17890cee 100644 --- a/src/include/miopen/rnn.hpp +++ b/src/include/miopen/rnn.hpp @@ -155,6 +155,10 @@ struct MIOPEN_INTERNALS_EXPORT RNNDescriptor : miopenRNNDescriptor miopenRNNFWDMode_t fwdMode) const; size_t GetMaxReserveSize(Handle& handle, const SeqTensorDescriptor& xDesc) const; + std::tuple GetTmpSpaceSizeDynamicAlgo(Handle& handle, + const SeqTensorDescriptor& xDesc, + miopenRNNFWDMode_t fwdMode) const; + size_t GetParamsSize(Handle& handle, const TensorDescriptor& xDesc, miopenDataType_t dtype) const; size_t GetParamsSize(size_t inputVector) const; @@ -534,6 +538,8 @@ struct MIOPEN_INTERNALS_EXPORT RNNDescriptor : miopenRNNDescriptor Data_t hy, const TensorDescriptor& cyDesc, Data_t cy, + Data_t workSpace, + size_t workSpaceSize, Data_t reserveSpace, size_t reserveSpaceSize) const; diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index 429bcee752..3718e284dc 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -49,6 +49,28 @@ struct runtimeArgsFwd class RNNModuleAlgoBase { +protected: + static GeneralLstmTempBuffer backwardInterimInfoBuilder(const RNNDescriptor& rnnDesc, + const SeqTensorDescriptor& xDesc) + { + auto layers_cnt = static_cast(rnnDesc.nLayers); + const size_t seq_directions = rnnDesc.dirMode == miopenRNNbidirection ? 2 : 1; + auto hidden_vec_sz = rnnDesc.hsize; + + return GeneralLstmTempBuffer::build( + layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); + } + + static GeneralLstmRedBuffer forwardInterimInfoBuilder(const RNNDescriptor& rnnDesc, + const SeqTensorDescriptor& xDesc) + { + auto layers_cnt = static_cast(rnnDesc.nLayers); + const size_t seq_directions = rnnDesc.dirMode == miopenRNNbidirection ? 2 : 1; + auto hidden_vec_sz = rnnDesc.hsize; + + return GeneralLstmRedBuffer::build( + layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); + } public: static RNNModuleAlgoBase create(const RNNDescriptor& rnnDesc, @@ -71,14 +93,11 @@ class RNNModuleAlgoBase // class update req assert(!is_seq_bidir); - const size_t seq_directions = is_seq_bidir ? 2 : 1; // TODO all size_t - GeneralLstmRedBuffer rb_layout = GeneralLstmRedBuffer::build( - layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); + GeneralLstmRedBuffer rb_layout = forwardInterimInfoBuilder(rnnDesc, xDesc); - GeneralLstmTempBuffer workspace_info = GeneralLstmTempBuffer::build( - layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); + GeneralLstmTempBuffer workspace_info = backwardInterimInfoBuilder(rnnDesc, xDesc); WeightsBufferDescriptor weights_layout = WeightsBufferDescriptor::create(static_cast(input_vec_sz), @@ -162,9 +181,88 @@ class RNNModuleAlgoBase return layer_id * (isBidirectSeq ? 2 : 1) + (direction == SequenceDirection::Forward ? 0 : 1); } + + template + inline miopen::TensorDescriptor BuildLstmTmpBlockDesc2D(const BufType& buf_info, + const size_t batch_size) const + { + const std::array& tmp_block_stride = buf_info.getGateBlockStride(); + const std::array& tmp_block_size = buf_info.getGateBlockSize(); + + // batch, gateBlock_elements + return miopen::TensorDescriptor{rnnDesc.dataType, + {batch_size, tmp_block_size[3]}, + {tmp_block_stride[1], tmp_block_stride[3]}}; + } + + inline miopen::TensorDescriptor BuildLstmFilterXDesc2D(int layer_id) const + { + assert(rnnDesc.inputMode == 0 || layer_id != 0); + // TODO replace by stride + auto x_vec = layer_id != 0 ? weightsLayout.xInVec : weightsLayout.inVec; + + // gateBlock_elements, ht_vec + return miopen::TensorDescriptor{ + rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, x_vec}, {x_vec, 1}}; + } + + inline miopen::TensorDescriptor BuildLstmFilterHidDesc2D() const + { + // TODO replace by stride + auto h_vec = weightsLayout.hVec; + + // gateBlock_elements, ht_vec + return miopen::TensorDescriptor{ + rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, h_vec}, {h_vec, 1}}; + } + + inline miopen::TensorDescriptor BuildWsHtDesc2D(size_t batch_size) const + { + auto& ht_stride = workspaceInfo.getHiddenStateStride(); + auto& ht_size = workspaceInfo.hStateSizes; + + // batch, gateBlock_elements + return miopen::TensorDescriptor{ + rnnDesc.dataType, {batch_size, ht_size[3]}, {ht_stride[1], ht_stride[3]}}; + } + + // 2 dims batch, vec + inline miopen::TensorDescriptor BuildHxCxDesc2D(size_t batch_size) const + { + const std::vector hx_size{batch_size, hiddenHxCxInfo.getHiddenSize()}; + const std::vector hx_stride{hiddenHxCxInfo.getStrides()[1], + hiddenHxCxInfo.getStrides()[2]}; + + return miopen::TensorDescriptor{rnnDesc.dataType, hx_size, hx_stride}; + } + + // 3 dims layer, batch, vec + inline miopen::TensorDescriptor BuildHxCxDesc3D(size_t layer_size, size_t batch_size) const + { + const std::vector hx_accum_size{ + layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; + + return miopen::TensorDescriptor{ + rnnDesc.dataType, hx_accum_size, hiddenHxCxInfo.getStrides()}; + } + + // 3 dims layer, batch, vec + inline miopen::TensorDescriptor BuildTempDhtDesc3D(size_t layer_size, size_t batch_size) const + { + const std::vector dy_dhy_accum_size{ + layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; + + const auto ws_dy_stride = [](const auto& ws_4dim_strides) -> std::vector { + // convert 4dim stride to 3 dim without direction + // TODO change hiddenBufferDesc + return std::vector{ws_4dim_strides[0], ws_4dim_strides[1], ws_4dim_strides[3]}; + }(workspaceInfo.getHiddenStateStride()); + + return miopen::TensorDescriptor{rnnDesc.dataType, dy_dhy_accum_size, ws_dy_stride}; + } }; -class RNNForwardDataModularAlgo : RNNModuleAlgoBase +class RNNForwardDataModularAlgo : protected RNNModuleAlgoBase { public: // Compute API @@ -252,87 +350,150 @@ class RNNForwardDataModularAlgo : RNNModuleAlgoBase #endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP } - RNNForwardDataModularAlgo(RNNModuleAlgoBase base) : RNNModuleAlgoBase(std::move(base)) {} - -private: - template - inline miopen::TensorDescriptor BuildLstmTmpBlockDesc2D(const BufType& buf_info, - const size_t batch_size) const + std::tuple getTempBuffersSize() const { - const std::array& tmp_block_stride = buf_info.getGateBlockStride(); - const std::array& tmp_block_size = buf_info.getGateBlockSize(); - // batch, gateBlock_elements - return miopen::TensorDescriptor{rnnDesc.dataType, - {batch_size, tmp_block_size[3]}, - {tmp_block_stride[1], tmp_block_stride[3]}}; + return std::make_tuple(workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType), + reservLayout.getBufferSize() * GetTypeSize(rnnDesc.dataType)); } - inline miopen::TensorDescriptor BuildLstmFilterXDesc2D(int layer_id) const + static std::tuple getTempBuffersSize(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xDesc) { - assert(rnnDesc.inputMode == 0 || layer_id != 0); - // TODO replace by stride - auto x_vec = layer_id != 0 ? weightsLayout.xInVec : weightsLayout.inVec; + auto wsInfo = backwardInterimInfoBuilder(rnnD, xDesc); + auto reservInfo = forwardInterimInfoBuilder(rnnD, xDesc); - // gateBlock_elements, ht_vec - return miopen::TensorDescriptor{ - rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, x_vec}, {x_vec, 1}}; + return std::make_tuple(wsInfo.getBufferSize() * GetTypeSize(rnnD.dataType), + reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); } - inline miopen::TensorDescriptor BuildLstmFilterHidDesc2D() const - { - // TODO replace by stride - auto h_vec = weightsLayout.hVec; + RNNForwardDataModularAlgo(RNNModuleAlgoBase base) : RNNModuleAlgoBase(std::move(base)) {} - // gateBlock_elements, ht_vec - return miopen::TensorDescriptor{ - rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, h_vec}, {h_vec, 1}}; +private: +}; + +class RNNModuleAlgoDynamic : public RNNForwardDataModularAlgo +{ + static SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) + { + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), def_layout, desc.GetLengths(), false}; } - inline miopen::TensorDescriptor BuildWsHtDesc2D(size_t batch_size) const + static SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) { - auto& ht_stride = workspaceInfo.getHiddenStateStride(); - auto& ht_size = workspaceInfo.hStateSizes; + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), + def_layout, + desc.GetLengths(), + desc.GetSequenceLengthsVector(), + std::vector{}, + true, + true}; + } - // batch, gateBlock_elements - return miopen::TensorDescriptor{ - rnnDesc.dataType, {batch_size, ht_size[3]}, {ht_stride[1], ht_stride[3]}}; +public: + RNNModuleAlgoDynamic(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xTDesc, + const SeqTensorDescriptor& yTDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + : RNNForwardDataModularAlgo(RNNModuleAlgoBase::create( + rnnD, buildDynamicVirtual(xTDesc), buildDynamicVirtual(yTDesc), hDesc, mode)), + realBatchController(BatchController::Create(xTDesc)), + realXDesc(xTDesc), + realYDesc(yTDesc), + tmpMapXDesc(buildRealToDynamicMapTmp(xTDesc)), + tmpMapYDesc(buildRealToDynamicMapTmp(yTDesc)) + { } - // 2 dims batch, vec - inline miopen::TensorDescriptor BuildHxCxDesc2D(size_t batch_size) const + struct runtimeArgsFwdDynamicExt { - const std::vector hx_size{batch_size, hiddenHxCxInfo.getHiddenSize()}; - const std::vector hx_stride{hiddenHxCxInfo.getStrides()[1], - hiddenHxCxInfo.getStrides()[2]}; + const ConstData_t realX; + const Data_t tempX; + const ConstData_t hx; + const ConstData_t cx; + const Data_t realY; + const Data_t tempY; + const Data_t hy; + const Data_t cy; + const ConstData_t w; + const Data_t workSpace; + const Data_t reserveSpace; + }; - return miopen::TensorDescriptor{rnnDesc.dataType, hx_size, hx_stride}; + runtimeArgsFwdDynamicExt createRuntimeArgsExt(const runtimeArgsFwd& runtimeArgs) const + { + const Data_t temp_x = + moveDataPtr(runtimeArgs.workSpace, workspaceInfo.getBufferSizeImpl(), rnnDesc.dataType); + + const Data_t temp_y = moveDataPtrByte(temp_x, tmpMapXDesc.GetTensorMaxByteSpace()); + + return { + runtimeArgs.x, + temp_x, + runtimeArgs.hx, + runtimeArgs.cx, + runtimeArgs.y, + temp_y, + runtimeArgs.hy, + runtimeArgs.cy, + runtimeArgs.w, + runtimeArgs.workSpace, + runtimeArgs.reserveSpace, + }; } - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildHxCxDesc3D(size_t layer_size, size_t batch_size) const + auto getTempBuffersSize() const { - const std::vector hx_accum_size{ - layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; + auto [ws_size, reserve_size] = RNNForwardDataModularAlgo::getTempBuffersSize(); - return miopen::TensorDescriptor{ - rnnDesc.dataType, hx_accum_size, hiddenHxCxInfo.getStrides()}; + return std::make_tuple(ws_size + tmpMapXDesc.GetTensorMaxByteSpace() + + tmpMapYDesc.GetTensorMaxByteSpace(), + reserve_size); } - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildTempDhtDesc3D(size_t layer_size, size_t batch_size) const + static auto getTempBuffersSize(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { - const std::vector dy_dhy_accum_size{ - layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; + auto y_desc = [](const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { + std::vector y_lenghts{xDesc.GetLengths()}; + y_lenghts[2] = rnnD.hsize * (rnnD.dirMode == miopenRNNbidirection ? 2 : 1); + return SeqTensorDescriptor{xDesc.GetType(), y_lenghts}; + }(rnnD, xDesc); - const auto ws_dy_stride = [](const auto& ws_4dim_strides) -> std::vector { - // convert 4dim stride to 3 dim without direction - // TODO change hiddenBufferDesc - return std::vector{ws_4dim_strides[0], ws_4dim_strides[1], ws_4dim_strides[3]}; - }(workspaceInfo.getHiddenStateStride()); + auto temp_x_desc = buildDynamicVirtual(xDesc); + auto temp_y_desc = buildDynamicVirtual(y_desc); - return miopen::TensorDescriptor{rnnDesc.dataType, dy_dhy_accum_size, ws_dy_stride}; + auto [ws_size, reserve_size] = + RNNForwardDataModularAlgo::getTempBuffersSize(rnnD, temp_x_desc); + + return std::make_tuple(ws_size + temp_x_desc.GetTensorMaxByteSpace() + + temp_y_desc.GetTensorMaxByteSpace(), + reserve_size); } + + void realXProp(const Handle& handle, const runtimeArgsFwdDynamicExt& runtimeArgsExt) const; + + void realYProp(const Handle& handle, const runtimeArgsFwdDynamicExt& runtimeArgsExt) const; + + void PrepareWriteBuffers(const Handle& handle, + const runtimeArgsFwdDynamicExt& runtimeArgsExt, + const runtimeArgsFwd& runtimeArgs) const; + + void PropHyCy(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + +private: + BatchController realBatchController; + + SeqTensorDescriptor realXDesc; + SeqTensorDescriptor realYDesc; + SeqTensorDescriptor tmpMapXDesc; + SeqTensorDescriptor tmpMapYDesc; }; class RNNBackwardDataModularAlgo : RNNModuleAlgoBase @@ -554,6 +715,44 @@ class RNNModularSingleStreamFWD const size_t max_seq_len; }; +class RNNDynamicModularSingleStreamFWD +{ +private: +public: + RNNDynamicModularSingleStreamFWD(const RNNDescriptor& rnn, + const SeqTensorDescriptor& xDesc, + const SeqTensorDescriptor& yDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), + rnnDesc(rnn), + max_seq_len(xDesc.GetMaxSequenceLength()) + { + } + + static bool IsApplicable() + { +#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP + return true; +#else + return false; +#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP + } + + auto getTempBuffersSize() const { return rnnAlgoModules.getTempBuffersSize(); } + + static auto getTempBuffersSize(const RNNDescriptor& rnn, const SeqTensorDescriptor& xDesc) + { + return rnn_base::RNNModuleAlgoDynamic::getTempBuffersSize(rnn, xDesc); + } + + void ComputeFWD(Handle& handle, const runtimeArgsFwd& runtimeArgs) const; + + const rnn_base::RNNModuleAlgoDynamic rnnAlgoModules; + const RNNDescriptor& rnnDesc; + const size_t max_seq_len; +}; + class RNNModularSingleStreamBWD { public: diff --git a/src/include/miopen/rnn/tmp_buffer_utils.hpp b/src/include/miopen/rnn/tmp_buffer_utils.hpp index 39806aa73d..9a499b9215 100644 --- a/src/include/miopen/rnn/tmp_buffer_utils.hpp +++ b/src/include/miopen/rnn/tmp_buffer_utils.hpp @@ -64,6 +64,16 @@ OutputIt exclusive_scan_wa(InputIt first, InputIt last, OutputIt d_first, T init } // namespace WA_RHEL +inline Data_t moveDataPtrByte(Data_t ptr, size_t byteOffset) +{ + return static_cast(reinterpret_cast(ptr) + byteOffset); +} + +inline Data_t moveDataPtr(Data_t ptr, size_t elementOffset, miopenDataType_t elementType) +{ + return moveDataPtrByte(ptr, elementOffset * GetTypeSize(elementType)); +} + namespace rnn_base { enum class SequenceDirection @@ -914,6 +924,8 @@ class IOBufferDescriptor return {lengths[1] * lengths[2], lengths[2], 1}; } + inline size_t getBufferSizeImpl() const { return packedLens[0]; } + // private: // local caching diff --git a/src/ocl/rnnocl.cpp b/src/ocl/rnnocl.cpp index 2700034f93..cee195819b 100644 --- a/src/ocl/rnnocl.cpp +++ b/src/ocl/rnnocl.cpp @@ -2617,6 +2617,8 @@ void RNNDescriptor::RNNForwardTraining(Handle& handle, hy, cyDesc, cy, + workSpace, + workSpaceSize, reserveSpace, reserveSpaceSize); } @@ -2673,6 +2675,8 @@ void RNNDescriptor::RNNForwardTraining(Handle& handle, hy, cyDesc, cy, + workSpace, + workSpaceSize, reserveSpace, reserveSpaceSize); @@ -2708,6 +2712,8 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( Data_t hy, const TensorDescriptor& cyDesc, Data_t cy, + Data_t workSpace, + size_t workSpaceSize, Data_t reserveSpace, size_t reserveSpaceSize) const { @@ -2824,8 +2830,8 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( cy, y_seq, y, - nullptr, - 0, + workSpace, + workSpaceSize, reserveSpace, reserveSpaceSize); } diff --git a/src/rnn.cpp b/src/rnn.cpp index 5d23c79018..a37a9ff065 100644 --- a/src/rnn.cpp +++ b/src/rnn.cpp @@ -33,12 +33,16 @@ #include #include +#include "miopen/env.hpp" + // Disable specific warnings #define MIO_RNN_DEBUG 0 #define MIOPEN_RNN_SYNCH 0 #define MIO_RNN_CPP_PROF 0 +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNN_DYNAMIC_EXP) + namespace miopen { void profileRNNkernels(const Handle& handle, unsigned char select, float& ctime) @@ -483,6 +487,15 @@ size_t RNNDescriptor::GetWorkspaceSize(Handle& handle, dirMode == miopenRNNbidirection, dataType); + if(env::enabled(MIOPEN_RNN_DYNAMIC_EXP)) + { + auto [ws, rs] = GetTmpSpaceSizeDynamicAlgo(handle, xDesc, miopenRNNTraining); + return std::max(ws, + GetMainSolWorkspaceSize( + total_sequence_len, fwdMode, miopenRNNDataSeqMajorNotPadded)) + + transformer_tmp_space + reduction_ws; + } + return transformer_tmp_space + reduction_ws + GetMainSolWorkspaceSize(total_sequence_len, fwdMode, miopenRNNDataSeqMajorNotPadded); } @@ -564,18 +577,25 @@ size_t RNNDescriptor::GetReserveSize(size_t batchLenSum) const // with tensor with maximum sequence length and maximum count of non empty sequences. // The previous version of this function returned a size sufficient only for the current tensor // size. -size_t RNNDescriptor::GetMaxReserveSize(Handle& /* handle */, - const SeqTensorDescriptor& xDesc) const +size_t RNNDescriptor::GetMaxReserveSize(Handle& handle, const SeqTensorDescriptor& xDesc) const { if(xDesc.GetType() != dataType) { MIOPEN_THROW(miopenStatusBadParm, "Data type mismatch between descriptors"); } + + if(env::enabled(MIOPEN_RNN_DYNAMIC_EXP)) + { + auto [ws, rs] = GetTmpSpaceSizeDynamicAlgo(handle, xDesc, miopenRNNTraining); + return std::max( + rs, GetReserveSize(xDesc.GetMaxSequenceLength() * xDesc.GetMaxCountOfSequences())); + } + return GetReserveSize(xDesc.GetMaxSequenceLength() * xDesc.GetMaxCountOfSequences()); } // Legacy. -size_t RNNDescriptor::GetReserveSize(Handle& /* handle */, +size_t RNNDescriptor::GetReserveSize(Handle& handle, const int seqLength, c_array_view xDesc) const { @@ -1211,6 +1231,8 @@ void RNNDescriptor::RNNVanillaForward(Handle& handle, hy, cDesc, cy, + workSpace, + workSpaceSize, reserveSpace, reserveSpaceSize); } diff --git a/src/rnn/Solutions/Base/fw_data_modular.cpp b/src/rnn/Solutions/Base/fw_data_modular.cpp index ca6d18d294..458092c999 100644 --- a/src/rnn/Solutions/Base/fw_data_modular.cpp +++ b/src/rnn/Solutions/Base/fw_data_modular.cpp @@ -527,6 +527,98 @@ void RNNForwardDataModularAlgo::PropY(const Handle& handle, const runtimeArgsFwd } } +void RNNModuleAlgoDynamic::realXProp(const Handle& handle, + const runtimeArgsFwdDynamicExt& runtimeArgsExt) const +{ + + RNNTensorBaseLayoutConverter::ConvertInputTensorGPUData( + handle, realXDesc, runtimeArgsExt.realX, tmpMapXDesc, runtimeArgsExt.tempX, nullptr, false); +} + +void RNNModuleAlgoDynamic::realYProp(const Handle& handle, + const runtimeArgsFwdDynamicExt& runtimeArgsExt) const +{ + RNNTensorBaseLayoutConverter::ConvertInputTensorGPUData( + handle, tmpMapYDesc, runtimeArgsExt.tempY, realYDesc, runtimeArgsExt.realY, nullptr, false); +} + +void RNNModuleAlgoDynamic::PrepareWriteBuffers(const Handle& handle, + const runtimeArgsFwdDynamicExt& runtimeArgsExt, + const runtimeArgsFwd& runtimeArgs) const +{ + RNNForwardDataModularAlgo::PrepareWriteBuffers(handle, runtimeArgs); + realXProp(handle, runtimeArgsExt); +} + +void RNNModuleAlgoDynamic::PropHyCy(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const +{ + if(runtimeArgs.hy != nullptr || (runtimeArgs.cy != nullptr)) + { + const auto gap_batch_size = [&]() { + if(currentSeq.isLast()) + { + return realBatchController.getBatchSize(currentSeq.getPhisVal()); + } + else + { + if(direction == SequenceDirection::Forward) + { + return realBatchController.getBatchSize(currentSeq.getPhisVal()) - + realBatchController.getBatchSize(currentSeq.getNext().getPhisVal()); + } + else + return static_cast(0); + } + }(); + + const auto gap_batch_offset = [&]() { + if(currentSeq.isLast()) + return static_cast(0); + else + return realBatchController.getBatchSize(currentSeq.getPhisVal()) - gap_batch_size; + }(); + + if(gap_batch_size > 0) + { + + auto src_desc = BuildTempDhtDesc3D(1, gap_batch_size); + + auto dst_desc = BuildHxCxDesc3D(1, gap_batch_size); + + size_t tmp_batch_offset = + batchController.getBatchSum(currentSeq.getPhisVal()) + gap_batch_offset; + + if(runtimeArgs.hy != nullptr) + { + CopyTensor(handle, + src_desc, + runtimeArgs.reserveSpace, + dst_desc, + runtimeArgs.hy, + reservLayout.getGasOffset( + layer, tmp_batch_offset, direction, LstmGateAndState::Ht), + hiddenHxCxInfo.getOffset(layer, gap_batch_offset)); + } + + if(runtimeArgs.cy != nullptr) + { + CopyTensor(handle, + src_desc, + runtimeArgs.reserveSpace, + dst_desc, + runtimeArgs.cy, + reservLayout.getGasOffset( + layer, tmp_batch_offset, direction, LstmGateAndState::St), + hiddenHxCxInfo.getOffset(layer, gap_batch_offset)); + } + } + } +} + // // // diff --git a/src/rnn/Solutions/fwd_s_stream.cpp b/src/rnn/Solutions/fwd_s_stream.cpp index a9c1ed9751..74a2c0f632 100644 --- a/src/rnn/Solutions/fwd_s_stream.cpp +++ b/src/rnn/Solutions/fwd_s_stream.cpp @@ -77,5 +77,67 @@ void RNNModularSingleStreamFWD::ComputeFWD(Handle& handle, const runtimeArgsFwd& rnnAlgoModules.PropY(handle, runtimeArgs); } +void RNNDynamicModularSingleStreamFWD::ComputeFWD(Handle& handle, + const runtimeArgsFwd& realRuntimeArgs) const +{ + + if(rnnDesc.nLayers == 0 || max_seq_len == 0) + return; + + auto sequence_directions = + rnnDesc.dirMode == miopenRNNDirectionMode_t::miopenRNNbidirection ? 2 : 1; + + const auto runtimeArgsExt = rnnAlgoModules.createRuntimeArgsExt(realRuntimeArgs); + const auto runtimeArgs = runtimeArgsFwd{runtimeArgsExt.tempX, + runtimeArgsExt.hx, + runtimeArgsExt.cx, + runtimeArgsExt.tempY, + runtimeArgsExt.hy, + runtimeArgsExt.cy, + runtimeArgsExt.w, + runtimeArgsExt.workSpace, + runtimeArgsExt.reserveSpace}; + + rnnAlgoModules.PrepareWriteBuffers(handle, runtimeArgsExt, runtimeArgs); + + // skip or linear + // copy or gemm + rnnAlgoModules.PropX(handle, runtimeArgs); + + rnnAlgoModules.AddBias(handle, runtimeArgs); + + for(auto layer_i = 0; layer_i < rnnDesc.nLayers; ++layer_i) + { + + for(int dir = 0; dir < sequence_directions; dir++) + { + const auto seq_dir = dir == 0 ? rnn_base::SequenceDirection::Forward + : rnn_base::SequenceDirection::Reverse; + + if(layer_i != 0) + rnnAlgoModules.PropHiddenY(handle, runtimeArgs, layer_i, seq_dir); + + for(int ti = 0; ti < max_seq_len; ti++) + { + const rnn_base::SequenceIterator cur_seq(ti, seq_dir, max_seq_len, true); + + if(ti == 0) + rnnAlgoModules.PropHxCx(handle, runtimeArgs, layer_i, cur_seq, seq_dir); + else + rnnAlgoModules.PropHiddenHt(handle, runtimeArgs, layer_i, cur_seq, seq_dir); + + rnnAlgoModules.UpdateHStatePerTimeSeq( + handle, runtimeArgs, layer_i, cur_seq, seq_dir); + + rnnAlgoModules.PropHyCy(handle, runtimeArgs, layer_i, cur_seq, seq_dir); + } + } + } + + rnnAlgoModules.PropY(handle, runtimeArgs); + + rnnAlgoModules.realYProp(handle, runtimeArgsExt); +} + } // namespace rnn_base } // namespace miopen diff --git a/src/rnn/selector.cpp b/src/rnn/selector.cpp index e7966eaffa..c0c448d6de 100644 --- a/src/rnn/selector.cpp +++ b/src/rnn/selector.cpp @@ -43,6 +43,8 @@ MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNNBWDMS_EXP) MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNNBWMS_EXP) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNN_DYNAMIC_EXP) + namespace miopen { bool RNNBwdMSIsFast(const int seqLen) @@ -65,6 +67,12 @@ bool RNNBwWeightMSIsFast(const int seqLen) return false; } +std::tuple RNNDescriptor::GetTmpSpaceSizeDynamicAlgo( + Handle& /*handle*/, const SeqTensorDescriptor& xDesc, miopenRNNFWDMode_t fwdMode) const +{ + return rnn_base::RNNDynamicModularSingleStreamFWD::getTempBuffersSize(*this, xDesc); +} + void RNNDescriptor::ModularForward(Handle& handle, miopenRNNFWDMode_t fwdMode, ConstData_t w, @@ -83,9 +91,19 @@ void RNNDescriptor::ModularForward(Handle& handle, Data_t reserveSpace, size_t /*reserveSpaceSize*/) const { - rnn_base::RNNModularSingleStreamFWD single_stream{*this, xDesc, yDesc, hDesc, fwdMode}; - single_stream.ComputeFWD( - handle, rnn_base::runtimeArgsFwd{x, hx, cx, y, hy, cy, w, workSpace, reserveSpace}); + if(env::enabled(MIOPEN_RNN_DYNAMIC_EXP)) + { + rnn_base::RNNDynamicModularSingleStreamFWD single_stream{ + *this, xDesc, yDesc, hDesc, fwdMode}; + single_stream.ComputeFWD( + handle, rnn_base::runtimeArgsFwd{x, hx, cx, y, hy, cy, w, workSpace, reserveSpace}); + } + else + { + rnn_base::RNNModularSingleStreamFWD single_stream{*this, xDesc, yDesc, hDesc, fwdMode}; + single_stream.ComputeFWD( + handle, rnn_base::runtimeArgsFwd{x, hx, cx, y, hy, cy, w, workSpace, reserveSpace}); + } } void RNNDescriptor::ModularBackward(Handle& handle, From 467b85d50c31a961edd2284d41a73feed0d33f96 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Wed, 30 Oct 2024 17:24:42 +0100 Subject: [PATCH 02/21] tidy --- src/rnn/selector.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rnn/selector.cpp b/src/rnn/selector.cpp index c0c448d6de..d7f60c7ea4 100644 --- a/src/rnn/selector.cpp +++ b/src/rnn/selector.cpp @@ -68,7 +68,7 @@ bool RNNBwWeightMSIsFast(const int seqLen) } std::tuple RNNDescriptor::GetTmpSpaceSizeDynamicAlgo( - Handle& /*handle*/, const SeqTensorDescriptor& xDesc, miopenRNNFWDMode_t fwdMode) const + Handle& /*handle*/, const SeqTensorDescriptor& xDesc, miopenRNNFWDMode_t /*fwdMode*/) const { return rnn_base::RNNDynamicModularSingleStreamFWD::getTempBuffersSize(*this, xDesc); } From fdccaf72518e5c83aa164d8e114c3e0d07375e63 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Thu, 31 Oct 2024 21:13:14 +0100 Subject: [PATCH 03/21] test update --- src/include/miopen/rnn.hpp | 3 ++ src/rnn.cpp | 89 ++++++++++++++++++++++++-------------- src/rnn/selector.cpp | 16 ++++++- test/rnn_seq_api.cpp | 3 ++ test/rnn_seq_api.hpp | 55 ++++++++++++++++++++--- 5 files changed, 125 insertions(+), 41 deletions(-) diff --git a/src/include/miopen/rnn.hpp b/src/include/miopen/rnn.hpp index aa17890cee..d7a7a94700 100644 --- a/src/include/miopen/rnn.hpp +++ b/src/include/miopen/rnn.hpp @@ -158,6 +158,9 @@ struct MIOPEN_INTERNALS_EXPORT RNNDescriptor : miopenRNNDescriptor std::tuple GetTmpSpaceSizeDynamicAlgo(Handle& handle, const SeqTensorDescriptor& xDesc, miopenRNNFWDMode_t fwdMode) const; + bool CheckDynamicAlgoSelection(Handle& handle, + const SeqTensorDescriptor& xDesc, + miopenRNNFWDMode_t fwdMode) const; size_t GetParamsSize(Handle& handle, const TensorDescriptor& xDesc, miopenDataType_t dtype) const; diff --git a/src/rnn.cpp b/src/rnn.cpp index a37a9ff065..1d3f95eeac 100644 --- a/src/rnn.cpp +++ b/src/rnn.cpp @@ -41,8 +41,6 @@ #define MIOPEN_RNN_SYNCH 0 #define MIO_RNN_CPP_PROF 0 -MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNN_DYNAMIC_EXP) - namespace miopen { void profileRNNkernels(const Handle& handle, unsigned char select, float& ctime) @@ -487,17 +485,20 @@ size_t RNNDescriptor::GetWorkspaceSize(Handle& handle, dirMode == miopenRNNbidirection, dataType); - if(env::enabled(MIOPEN_RNN_DYNAMIC_EXP)) + size_t solution_ws = 0; + + if(CheckDynamicAlgoSelection(handle, xDesc, fwdMode)) { auto [ws, rs] = GetTmpSpaceSizeDynamicAlgo(handle, xDesc, miopenRNNTraining); - return std::max(ws, - GetMainSolWorkspaceSize( - total_sequence_len, fwdMode, miopenRNNDataSeqMajorNotPadded)) + - transformer_tmp_space + reduction_ws; + solution_ws = ws; + } + else + { + solution_ws = + GetMainSolWorkspaceSize(total_sequence_len, fwdMode, miopenRNNDataSeqMajorNotPadded); } - return transformer_tmp_space + reduction_ws + - GetMainSolWorkspaceSize(total_sequence_len, fwdMode, miopenRNNDataSeqMajorNotPadded); + return transformer_tmp_space + reduction_ws + solution_ws; } size_t RNNDescriptor::GetMaxWorkspaceSize(Handle& handle, @@ -535,23 +536,35 @@ size_t RNNDescriptor::GetWorkspaceSize(Handle& handle, padding_converter_tmp_space = packedXInSpace + packedYOutSpace; } - std::size_t total_sequence_len = 0; - total_sequence_len = std::accumulate( - xDesc.data, xDesc.data + seqLength, 0ULL, [](size_t x, miopenTensorDescriptor_t y) { - return x + deref(y).GetLengths()[0]; - }); + SeqTensorDescriptor xSeqTDesc = + makeSeqTensorDescriptor(xDesc, seqLength, miopenRNNDataSeqMajorNotPadded); - size_t reduction_ws = ReductionWorkspaceSize(handle, - total_sequence_len, - nHiddenTensorsPerLayer, - workspaceScale, - hsize, - dirMode == miopenRNNbidirection, - dataType); + if(CheckDynamicAlgoSelection(handle, xSeqTDesc, miopenRNNTraining)) + { + auto [ws, rs] = GetTmpSpaceSizeDynamicAlgo(handle, xSeqTDesc, miopenRNNTraining); + return ws + padding_converter_tmp_space; + } + else + { - return padding_converter_tmp_space + reduction_ws + - GetMainSolWorkspaceSize( - total_sequence_len, miopenRNNInference, miopenRNNDataSeqMajorNotPadded); + std::size_t total_sequence_len = 0; + total_sequence_len = std::accumulate( + xDesc.data, xDesc.data + seqLength, 0ULL, [](size_t x, miopenTensorDescriptor_t y) { + return x + deref(y).GetLengths()[0]; + }); + + size_t reduction_ws = ReductionWorkspaceSize(handle, + total_sequence_len, + nHiddenTensorsPerLayer, + workspaceScale, + hsize, + dirMode == miopenRNNbidirection, + dataType); + + return padding_converter_tmp_space + reduction_ws + + GetMainSolWorkspaceSize( + total_sequence_len, miopenRNNInference, miopenRNNDataSeqMajorNotPadded); + } } ///////////////////////////////// @@ -584,11 +597,10 @@ size_t RNNDescriptor::GetMaxReserveSize(Handle& handle, const SeqTensorDescripto MIOPEN_THROW(miopenStatusBadParm, "Data type mismatch between descriptors"); } - if(env::enabled(MIOPEN_RNN_DYNAMIC_EXP)) + if(CheckDynamicAlgoSelection(handle, xDesc, miopenRNNTraining)) { auto [ws, rs] = GetTmpSpaceSizeDynamicAlgo(handle, xDesc, miopenRNNTraining); - return std::max( - rs, GetReserveSize(xDesc.GetMaxSequenceLength() * xDesc.GetMaxCountOfSequences())); + return rs; } return GetReserveSize(xDesc.GetMaxSequenceLength() * xDesc.GetMaxCountOfSequences()); @@ -604,12 +616,23 @@ size_t RNNDescriptor::GetReserveSize(Handle& handle, { MIOPEN_THROW(miopenStatusBadParm, "Data type mismatch between descriptors"); } - std::size_t inputBatchLenSum = 0; - inputBatchLenSum = std::accumulate( - xDesc.data, xDesc.data + seqLength, 0ULL, [](size_t x, miopenTensorDescriptor_t y) { - return x + deref(y).GetLengths()[0]; - }); - return GetReserveSize(inputBatchLenSum); + SeqTensorDescriptor xSeqTDesc = + makeSeqTensorDescriptor(xDesc, seqLength, miopenRNNDataSeqMajorNotPadded); + + if(CheckDynamicAlgoSelection(handle, xSeqTDesc, miopenRNNTraining)) + { + auto [ws, rs] = GetTmpSpaceSizeDynamicAlgo(handle, xSeqTDesc, miopenRNNTraining); + return rs; + } + else + { + std::size_t inputBatchLenSum = 0; + inputBatchLenSum = std::accumulate( + xDesc.data, xDesc.data + seqLength, 0ULL, [](size_t x, miopenTensorDescriptor_t y) { + return x + deref(y).GetLengths()[0]; + }); + return GetReserveSize(inputBatchLenSum); + } } size_t RNNDescriptor::GetParamsSize(size_t inputVector) const diff --git a/src/rnn/selector.cpp b/src/rnn/selector.cpp index d7f60c7ea4..1cc1909004 100644 --- a/src/rnn/selector.cpp +++ b/src/rnn/selector.cpp @@ -73,6 +73,20 @@ std::tuple RNNDescriptor::GetTmpSpaceSizeDynamicAlgo( return rnn_base::RNNDynamicModularSingleStreamFWD::getTempBuffersSize(*this, xDesc); } +bool RNNDescriptor::CheckDynamicAlgoSelection(Handle& /*handle*/, + const SeqTensorDescriptor& xDesc, + miopenRNNFWDMode_t /*fwdMode*/) const +{ + bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); + bool rnn_config_match = (dirMode == 0 && inputMode == miopenRNNlinear && + rnnMode == miopenLSTM && !use_dropout && algoMode == miopenRNNdefault); + if(rnn_config_match && env::enabled(MIOPEN_RNN_DYNAMIC_EXP)) + { + return true; + } + return false; +} + void RNNDescriptor::ModularForward(Handle& handle, miopenRNNFWDMode_t fwdMode, ConstData_t w, @@ -91,7 +105,7 @@ void RNNDescriptor::ModularForward(Handle& handle, Data_t reserveSpace, size_t /*reserveSpaceSize*/) const { - if(env::enabled(MIOPEN_RNN_DYNAMIC_EXP)) + if(CheckDynamicAlgoSelection(handle, xDesc, fwdMode)) { rnn_base::RNNDynamicModularSingleStreamFWD single_stream{ *this, xDesc, yDesc, hDesc, fwdMode}; diff --git a/test/rnn_seq_api.cpp b/test/rnn_seq_api.cpp index 7b69ac92aa..7fec633f48 100644 --- a/test/rnn_seq_api.cpp +++ b/test/rnn_seq_api.cpp @@ -63,6 +63,9 @@ struct rnn_seq_driver : rnn_seq_api_test_driver this->add(this->nocy, "nocy", this->generate_data({false, true})); this->add(this->pytorchTensorDescriptorFormat, "pyDescFormat", this->generate_data(modes)); + this->add(this->skip_backward_data, "disable-backward-data", this->generate_data({false})); + this->add( + this->skip_backward_weights, "disable-backward-weights", this->generate_data({false})); } rnn_seq_driver(bool) : rnn_seq_api_test_driver() {} diff --git a/test/rnn_seq_api.hpp b/test/rnn_seq_api.hpp index 58cf9170d6..177fa4d6ff 100644 --- a/test/rnn_seq_api.hpp +++ b/test/rnn_seq_api.hpp @@ -1002,6 +1002,9 @@ struct verify_train_rnn : verify_rnn_api_base bool nodhy{}; bool nodcy{}; + bool skip_backward_data{}; + bool skip_backward_weights{}; + using verify_rnn_api_base::is_padded_verification; using verify_rnn_api_base::padding_symbol; @@ -1044,12 +1047,16 @@ struct verify_train_rnn : verify_rnn_api_base tensor& dhy, tensor& dcy, std::vector& w, - const bool pnohx = false, - const bool pnocx = false, - const bool pnohy = false, - const bool pnocy = false, - T* paddingSymbol = nullptr) + const bool pnohx = false, + const bool pnocx = false, + const bool pnohy = false, + const bool pnocy = false, + const bool skip_bw_data = false, + const bool skip_bw_wei = false, + T* paddingSymbol = nullptr) : verify_rnn_api_base(pRD, x, y, hx, cx, w, pnohx, pnocx, pnohy, pnocy, paddingSymbol), + skip_backward_data(skip_bw_data), + skip_backward_weights(skip_bw_wei), dyHiddenState(dhy), dyCellState(dcy), dOutput(dy) @@ -1080,6 +1087,9 @@ struct verify_train_rnn : verify_rnn_api_base nocx, nohy, nocy); + if(skip_backward_data) + return result_tuple( + std::move(fwd_y), std::move(fwd_hy), std::move(fwd_cy), {}, {}, {}, {}); auto [bwd_din, bwd_dhx, bwd_dcx] = refMethod.bwd(input.desc, output.desc, @@ -1098,6 +1108,15 @@ struct verify_train_rnn : verify_rnn_api_base nohx, nocx); + if(skip_backward_weights) + return result_tuple(std::move(fwd_y), + std::move(fwd_hy), + std::move(fwd_cy), + std::move(bwd_din), + std::move(bwd_dhx), + std::move(bwd_dcx), + {}); + auto wrw_res = refMethod.wrw(input.desc, output.desc, input.data, @@ -1176,6 +1195,9 @@ struct verify_train_rnn : verify_rnn_api_base const auto fwd_hy = readTFromGPUOrEmpty(handle, hy_dev, xHiddenState, nohy); const auto fwd_cy = readTFromGPUOrEmpty(handle, cy_dev, xCellState, nocy); + if(skip_backward_data) + return result_tuple(fwd_y, fwd_hy, fwd_cy, {}, {}, {}, {}); + const auto dy_dev = transferTensorToGPUOrNullptr(handle, dOutput, false); const auto dhy_dev = transferTensorToGPUOrNullptr(handle, dyHiddenState, nodhy); const auto dcy_dev = transferTensorToGPUOrNullptr(handle, dyCellState, nodcy); @@ -1211,6 +1233,9 @@ struct verify_train_rnn : verify_rnn_api_base const auto bwd_dhx = readTFromGPUOrEmpty(handle, dhx_dev, xHiddenState, nodhx); const auto bwd_dcx = readTFromGPUOrEmpty(handle, dcx_dev, xCellState, nodcx); + if(skip_backward_data) + return result_tuple(fwd_y, fwd_hy, fwd_cy, bwd_din, bwd_dhx, bwd_dcx, {}); + std::vector workSpace_bwd_out(workSpace_TCnt); handle.ReadTo(workSpace_bwd_out.data(), workSpace_dev, workSpaceByteSize); @@ -1402,6 +1427,9 @@ struct rnn_seq_api_test_driver : test_driver bool pytorchTensorDescriptorFormat{}; + bool skip_backward_data{false}; + bool skip_backward_weights{false}; + rnn_seq_api_test_driver() {} bool check_GPU_mem_limit(miopen::Handle& handle, @@ -1658,7 +1686,20 @@ struct rnn_seq_api_test_driver : test_driver tolerance = 80; } - auto fwdTrain = verify(verify_train_rnn{ - rnnDesc, input, output, dy, hx, cx, dhy, dcy, weights, nohx, nocx, nohy, nocy}); + auto fwdTrain = verify(verify_train_rnn{rnnDesc, + input, + output, + dy, + hx, + cx, + dhy, + dcy, + weights, + nohx, + nocx, + nohy, + nocy, + skip_backward_data, + skip_backward_weights}); } }; From 6f2d903341ab7563f07d8677cd8dbad1b7fc21ca Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Thu, 31 Oct 2024 21:14:40 +0100 Subject: [PATCH 04/21] tidy --- src/rnn/selector.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rnn/selector.cpp b/src/rnn/selector.cpp index 1cc1909004..6241500e91 100644 --- a/src/rnn/selector.cpp +++ b/src/rnn/selector.cpp @@ -74,7 +74,7 @@ std::tuple RNNDescriptor::GetTmpSpaceSizeDynamicAlgo( } bool RNNDescriptor::CheckDynamicAlgoSelection(Handle& /*handle*/, - const SeqTensorDescriptor& xDesc, + const SeqTensorDescriptor& /*xDesc*/, miopenRNNFWDMode_t /*fwdMode*/) const { bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); From 2df3712e8140d6b0de3526e2b51740ffa10fe662 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Fri, 8 Nov 2024 17:45:36 +0100 Subject: [PATCH 05/21] buffer fix --- src/include/miopen/rnn/tmp_buffer_utils.hpp | 60 +++++++++++++-------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/src/include/miopen/rnn/tmp_buffer_utils.hpp b/src/include/miopen/rnn/tmp_buffer_utils.hpp index 9a499b9215..3c54ab8cb8 100644 --- a/src/include/miopen/rnn/tmp_buffer_utils.hpp +++ b/src/include/miopen/rnn/tmp_buffer_utils.hpp @@ -188,10 +188,11 @@ class BaseRnnWsBufferPacked */ ////// -class GeneralRNNTempBuffer : public BaseRnnWsBufferPacked +template +class GeneralRNNTempBufferTemplate : public BaseRnnWsBufferPacked { protected: - GeneralRNNTempBuffer(const std::array& hstate_strides, + GeneralRNNTempBufferTemplate(const std::array& hstate_strides, const std::array& hstate_sizes, size_t total_element_cnt) : hStateStrides{hstate_strides}, @@ -201,7 +202,7 @@ class GeneralRNNTempBuffer : public BaseRnnWsBufferPacked } public: - static GeneralRNNTempBuffer + static GeneralRNNTempBufferTemplate build(size_t layers_cnt, size_t vectors_per_layer, size_t directions, size_t hidden_vec_sz) { const std::array h_state_sizes{ @@ -213,7 +214,7 @@ class GeneralRNNTempBuffer : public BaseRnnWsBufferPacked std::multiplies{}); auto total_element_cnt = h_state_strides[0] * h_state_sizes[0]; - return GeneralRNNTempBuffer{h_state_strides, h_state_sizes, total_element_cnt}; + return {h_state_strides, h_state_sizes, total_element_cnt}; } size_t getHiddenStateOffsetImpl(const size_t layer_id, @@ -260,18 +261,21 @@ class GeneralRNNTempBuffer : public BaseRnnWsBufferPacked *} */ -class GeneralLstmTempBuffer : public GeneralRNNTempBuffer, - public GeneralLstmWsExt, - public LstmWsGateBlockExt + +template +class GeneralLstmInternalBuffTemplate : public GeneralRNNTempBufferTemplate, + public GeneralLstmWsExt, + public LstmWsGateBlockExt { + using RNNBufferTemplate = GeneralRNNTempBufferTemplate; protected: - GeneralLstmTempBuffer(const std::array& h_state_strides, + GeneralLstmInternalBuffTemplate(const std::array& h_state_strides, const std::array& h_state_sizes, const std::array& lstm_gate_sizes, const std::array& lstm_gate_strides, const std::array& lstm_gates_block_sizes, size_t total_element_cnt) - : GeneralRNNTempBuffer{h_state_strides, h_state_sizes, total_element_cnt}, + : RNNBufferTemplate{h_state_strides, h_state_sizes, total_element_cnt}, gateSizes{lstm_gate_sizes}, gateStride{lstm_gate_strides}, gateBlockSizes{lstm_gates_block_sizes} @@ -279,7 +283,7 @@ class GeneralLstmTempBuffer : public GeneralRNNTempBuffer, } public: - static GeneralLstmTempBuffer + static GeneralLstmInternalBuffTemplate build(size_t layers_cnt, size_t comp_dim_per_layer, size_t directions, size_t hidden_vec_sz) { @@ -355,7 +359,8 @@ class GeneralLstmTempBuffer : public GeneralRNNTempBuffer, const size_t vector_id, const SequenceDirection direction) const { - return getGasOffset(layer_id, vector_id, direction, LstmGateAndState::Ht); + return GeneralLstmWsExt::getGasOffset( + layer_id, vector_id, direction, LstmGateAndState::Ht); } size_t getGasOffsetImpl(const size_t layer_id, @@ -367,20 +372,22 @@ class GeneralLstmTempBuffer : public GeneralRNNTempBuffer, if(gas == LstmGateAndState::Ht || gas == LstmGateAndState::St) return start_ident + - GeneralRNNTempBuffer::getHiddenStateOffset(layer_id, vector_id, direction); + RNNBufferTemplate::getHiddenStateOffset(layer_id, vector_id, direction); const std::array pos{layer_id, vector_id, static_cast(direction)}; return start_ident + - std::inner_product( - pos.cbegin(), pos.cend(), hStateStrides.cbegin(), static_cast(0)); + std::inner_product(pos.cbegin(), + pos.cend(), + RNNBufferTemplate::hStateStrides.cbegin(), + static_cast(0)); } // layer, minor dim(seq or sample), directions, element const std::array& getGateAndStateStrideImpl(LstmGateAndState gas) const { if(gas == LstmGateAndState::Ht || gas == LstmGateAndState::St) - return getHiddenStateStride(); + return RNNBufferTemplate::getHiddenStateStride(); return gateStride; } @@ -402,12 +409,23 @@ class GeneralLstmTempBuffer : public GeneralRNNTempBuffer, case LstmGateAndState::O: return static_cast(gas) * gateStride[3] * gateSizes[3]; case LstmGateAndState::St: return gateStride[2] * gateSizes[2]; // direction DIM case LstmGateAndState::Ht: - return (gateStride[2] + getHiddenStateStride()[2]) * gateSizes[2]; + return (gateStride[2] + RNNBufferTemplate::getHiddenStateStride()[2]) * gateSizes[2]; } return 0; } }; + +// final +class GeneralLstmTempBuffer : public GeneralLstmInternalBuffTemplate +{ +public: + GeneralLstmTempBuffer(const GeneralLstmInternalBuffTemplate base ) + : GeneralLstmInternalBuffTemplate{base} + { + } +}; + /* *struct ReserveSpace_LSTM{ //packed * struct layer{ @@ -436,16 +454,16 @@ class GeneralLstmTempBuffer : public GeneralRNNTempBuffer, *} */ -class GeneralLstmRedBuffer : public GeneralLstmTempBuffer, +class GeneralLstmRedBuffer : public GeneralLstmInternalBuffTemplate, public LstmActiveCellExt { protected: - GeneralLstmRedBuffer(const GeneralLstmTempBuffer& base, + GeneralLstmRedBuffer(const GeneralLstmInternalBuffTemplate& base, const std::array& active_cells_sizes, const std::array& active_cells_strides, size_t active_cells_ident, size_t active_cell_elements) - : GeneralLstmTempBuffer{base}, + : GeneralLstmInternalBuffTemplate{base}, activeCellSize{active_cells_sizes}, activeCellStride{active_cells_strides}, activeCellsIdent{active_cells_ident}, @@ -458,8 +476,8 @@ class GeneralLstmRedBuffer : public GeneralLstmTempBuffer, build(size_t layers_cnt, size_t comp_dim_per_layer, size_t directions, size_t hidden_vec_sz) { - auto base = - GeneralLstmTempBuffer::build(layers_cnt, comp_dim_per_layer, directions, hidden_vec_sz); + auto base = GeneralLstmInternalBuffTemplate::build( + layers_cnt, comp_dim_per_layer, directions, hidden_vec_sz); auto active_cells_ident = base.gateStride[0] * base.gateSizes[0]; From c4838e0b1dcf5a8b625a455ed60ade5464ace972 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Fri, 8 Nov 2024 17:55:04 +0100 Subject: [PATCH 06/21] dynamic BWD --- src/include/miopen/rnn/solvers.hpp | 285 +++++++++---- src/rnn/Solutions/Base/bw_data_modular.cpp | 446 +++++++++++++++------ src/rnn/Solutions/bwd_s_stream.cpp | 72 ++++ src/rnn/selector.cpp | 20 +- 4 files changed, 622 insertions(+), 201 deletions(-) diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index 3718e284dc..9e745f4cef 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -47,6 +47,21 @@ struct runtimeArgsFwd const Data_t reserveSpace; }; +struct runtimeArgsBwd +{ + const Handle* handle; + const ConstData_t dy; + const ConstData_t dhy; + const Data_t dhx; + const ConstData_t cx; + const ConstData_t dcy; + const Data_t dcx; + const Data_t dx; + const ConstData_t w; + const Data_t workSpace; + const Data_t reserveSpace; +}; + class RNNModuleAlgoBase { protected: @@ -176,6 +191,23 @@ class RNNModuleAlgoBase const bool isBidirectSeq; + std::tuple getTempBuffersSize() const + { + + return std::make_tuple(workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType), + reservLayout.getBufferSize() * GetTypeSize(rnnDesc.dataType)); + } + + static std::tuple getTempBuffersSize(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xDesc) + { + auto wsInfo = backwardInterimInfoBuilder(rnnD, xDesc); + auto reservInfo = forwardInterimInfoBuilder(rnnD, xDesc); + + return std::make_tuple(wsInfo.getBufferSize() * GetTypeSize(rnnD.dataType), + reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); + } + inline size_t getVirtualLayer(const size_t layer_id, SequenceDirection direction) const { return layer_id * (isBidirectSeq ? 2 : 1) + @@ -496,7 +528,7 @@ class RNNModuleAlgoDynamic : public RNNForwardDataModularAlgo SeqTensorDescriptor tmpMapYDesc; }; -class RNNBackwardDataModularAlgo : RNNModuleAlgoBase +class RNNBackwardDataModularAlgo : protected RNNModuleAlgoBase { public: void PrepareWriteBuffers(const Handle& handle, Data_t dhx, Data_t dcx, Data_t workSpace) const; @@ -525,6 +557,19 @@ class RNNBackwardDataModularAlgo : RNNModuleAlgoBase const SequenceIterator& seq, SequenceDirection direction) const; + void UpdateHStatePerTimeSeq(const Handle& handle, + ConstData_t dcy, + ConstData_t cx, + Data_t, + Data_t workSpace, + Data_t reserveSpace, + size_t batchSizeUpdate, + size_t useDcyIfGtBatch, + size_t useCxIfGTBatch, + int layer, + const SequenceIterator& seq, + SequenceDirection direction) const; + void PropDhxDcx(const Handle& handle, ConstData_t w, Data_t dhx, @@ -593,92 +638,155 @@ class RNNBackwardDataModularAlgo : RNNModuleAlgoBase } RNNBackwardDataModularAlgo(RNNModuleAlgoBase&& base) : RNNModuleAlgoBase(std::move(base)) {} +}; -private: - template - inline miopen::TensorDescriptor BuildLstmTmpBlockDesc2D(const BufType& buf_info, - const size_t batch_size) const +class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo +{ + using BaseBWDModuleT = rnn_base::RNNBackwardDataModularAlgo; + static SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) { - const std::array& tmp_block_stride = buf_info.getGateBlockStride(); - const std::array& tmp_block_size = buf_info.getGateBlockSize(); - - // batch, gateBlock_elements - return miopen::TensorDescriptor{rnnDesc.dataType, - {batch_size, tmp_block_size[3]}, - {tmp_block_stride[1], tmp_block_stride[3]}}; + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), def_layout, desc.GetLengths(), false}; } - inline miopen::TensorDescriptor BuildLstmFilterXDesc2D(int layer_id) const + static SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) { - assert(rnnDesc.inputMode == 0 || layer_id != 0); - // TODO replace by stride - auto x_vec = layer_id != 0 ? weightsLayout.xInVec : weightsLayout.inVec; - - // gateBlock_elements, ht_vec - return miopen::TensorDescriptor{ - rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, x_vec}, {x_vec, 1}}; + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), + def_layout, + desc.GetLengths(), + desc.GetSequenceLengthsVector(), + std::vector{}, + true, + true}; } - inline miopen::TensorDescriptor BuildLstmFilterHidDesc2D() const +public: + RNNBackwardModuleAlgoDynamic(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xTDesc, + const SeqTensorDescriptor& yTDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + : BaseBWDModuleT(RNNModuleAlgoBase::create( + rnnD, buildDynamicVirtual(xTDesc), buildDynamicVirtual(yTDesc), hDesc, mode)), + realBatchController(BatchController::Create(xTDesc)), + realDxDesc(xTDesc), + realDyDesc(yTDesc), + tmpMapDxDesc(buildRealToDynamicMapTmp(xTDesc)), + tmpMapDyDesc(buildRealToDynamicMapTmp(yTDesc)) { - // TODO replace by stride - auto h_vec = weightsLayout.hVec; - - // gateBlock_elements, ht_vec - return miopen::TensorDescriptor{ - rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, h_vec}, {h_vec, 1}}; } - inline miopen::TensorDescriptor BuildWsHtDesc2D(size_t batch_size) const + struct runtimeArgsBwdDynamicExt { - auto& ht_stride = workspaceInfo.getHiddenStateStride(); - auto& ht_size = workspaceInfo.hStateSizes; - - // batch, gateBlock_elements - return miopen::TensorDescriptor{ - rnnDesc.dataType, {batch_size, ht_size[3]}, {ht_stride[1], ht_stride[3]}}; - } + const ConstData_t realDy; + const Data_t tempDy; + const ConstData_t dhy; + const Data_t dhx; + const ConstData_t cx; + const ConstData_t dcy; + const Data_t dcx; + const Data_t realDx; + const Data_t tempDx; + const ConstData_t w; + const Data_t workSpace; + const Data_t reserveSpace; + }; - // 2 dims batch, vec - inline miopen::TensorDescriptor BuildHxCxDesc2D(size_t batch_size) const + runtimeArgsBwdDynamicExt createRuntimeArgsExt(const runtimeArgsBwd& runtimeArgs) const { - const std::vector hx_size{batch_size, hiddenHxCxInfo.getHiddenSize()}; - const std::vector hx_stride{hiddenHxCxInfo.getStrides()[1], - hiddenHxCxInfo.getStrides()[2]}; + const Data_t temp_dx = + moveDataPtr(runtimeArgs.workSpace, workspaceInfo.getBufferSizeImpl(), rnnDesc.dataType); - return miopen::TensorDescriptor{rnnDesc.dataType, hx_size, hx_stride}; + const Data_t temp_dy = moveDataPtrByte(temp_dx, tmpMapDxDesc.GetTensorMaxByteSpace()); + + return { + runtimeArgs.dy, + temp_dy, + runtimeArgs.dhy, + runtimeArgs.dhx, + runtimeArgs.cx, + runtimeArgs.dcy, + runtimeArgs.dcx, + runtimeArgs.dx, + temp_dx, + runtimeArgs.w, + runtimeArgs.workSpace, + runtimeArgs.reserveSpace, + }; } - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildHxCxDesc3D(size_t layer_size, size_t batch_size) const + auto getTempBuffersSize() const { - const std::vector hx_accum_size{ - layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; + auto [ws_size, reserve_size] = BaseBWDModuleT::getTempBuffersSize(); - return miopen::TensorDescriptor{ - rnnDesc.dataType, hx_accum_size, hiddenHxCxInfo.getStrides()}; + return std::make_tuple(ws_size + tmpMapDxDesc.GetTensorMaxByteSpace() + + tmpMapDyDesc.GetTensorMaxByteSpace(), + reserve_size); } - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildTempDhtDesc3D(size_t layer_size, size_t batch_size) const + static auto getTempBuffersSize(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { - const std::vector dy_dhy_accum_size{ - layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; + auto y_desc = [](const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { + std::vector y_lenghts{xDesc.GetLengths()}; + y_lenghts[2] = rnnD.hsize * (rnnD.dirMode == miopenRNNbidirection ? 2 : 1); + return SeqTensorDescriptor{xDesc.GetType(), y_lenghts}; + }(rnnD, xDesc); - const auto ws_dy_stride = [](const auto& ws_4dim_strides) -> std::vector { - // convert 4dim stride to 3 dim without direction - // TODO change hiddenBufferDesc - return std::vector{ws_4dim_strides[0], ws_4dim_strides[1], ws_4dim_strides[3]}; - }(workspaceInfo.getHiddenStateStride()); + auto temp_x_desc = buildDynamicVirtual(xDesc); + auto temp_y_desc = buildDynamicVirtual(y_desc); - return miopen::TensorDescriptor{rnnDesc.dataType, dy_dhy_accum_size, ws_dy_stride}; - } + auto [ws_size, reserve_size] = + RNNForwardDataModularAlgo::getTempBuffersSize(rnnD, temp_x_desc); - inline size_t getVirtualLayer(const size_t layer_id, SequenceDirection direction) const - { - return layer_id * (isBidirectSeq ? 2 : 1) + - (direction == SequenceDirection::Forward ? 0 : 1); + return std::make_tuple(ws_size + temp_x_desc.GetTensorMaxByteSpace() + + temp_y_desc.GetTensorMaxByteSpace(), + reserve_size); } + + void realDxProp(const Handle& handle, const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; + + void realDyProp(const Handle& handle, const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; + + void realPropDhy(const Handle& handle, + ConstData_t dhy, + Data_t workSpace, + unsigned int layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + + void realUpdateHStatePerTimeSeq(const Handle& handle, + ConstData_t dcy, + ConstData_t cx, + Data_t, + Data_t workSpace, + Data_t reserveSpace, + int layer, + const SequenceIterator& seq, + SequenceDirection direction) const; + + void PrepareWriteBuffers(const Handle& handle, + const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; + + void HtHiddenDataZeroing() const; + + // void PrepareWriteBuffers(const Handle& handle, + // const runtimeArgsBwdDynamicExt& runtimeArgsExt, + // const runtimeArgsFwd& runtimeArgs) const; + + void PropHyCy(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + +private: + BatchController realBatchController; + + SeqTensorDescriptor realDxDesc; + SeqTensorDescriptor realDyDesc; + SeqTensorDescriptor tmpMapDxDesc; + SeqTensorDescriptor tmpMapDyDesc; }; class RNNModularSingleStreamFWD @@ -796,6 +904,44 @@ class RNNModularSingleStreamBWD const size_t max_seq_len; }; +class RNNDynamicModularSingleStreamBWD +{ +private: +public: + RNNDynamicModularSingleStreamBWD(const RNNDescriptor& rnn, + const SeqTensorDescriptor& xDesc, + const SeqTensorDescriptor& yDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), + rnnDesc(rnn), + max_seq_len(xDesc.GetMaxSequenceLength()) + { + } + + static bool IsApplicable() + { +#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP + return true; +#else + return false; +#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP + } + + auto getTempBuffersSize() const { return rnnAlgoModules.getTempBuffersSize(); } + + static auto getTempBuffersSize(const RNNDescriptor& rnn, const SeqTensorDescriptor& xDesc) + { + return decltype(rnnAlgoModules)::getTempBuffersSize(rnn, xDesc); + } + + void ComputeBWD(Handle& handle, const runtimeArgsBwd& runtimeArgs) const; + + const rnn_base::RNNBackwardModuleAlgoDynamic rnnAlgoModules; + const RNNDescriptor& rnnDesc; + const size_t max_seq_len; +}; + class RNNModularMultiStreamBWD { public: @@ -822,21 +968,6 @@ class RNNModularMultiStreamBWD // TODO static size_t GetWsSize() { return 0; }; - struct runtimeArgsBwd - { - const Handle* handle; - ConstData_t dy; - ConstData_t dhy; - Data_t dhx; - ConstData_t cx; - ConstData_t dcy; - Data_t dcx; - Data_t dx; - ConstData_t w; - Data_t workSpace; - Data_t reserveSpace; - }; - void ComputeBWD(Handle& handle, ConstData_t dy, ConstData_t dhy, diff --git a/src/rnn/Solutions/Base/bw_data_modular.cpp b/src/rnn/Solutions/Base/bw_data_modular.cpp index 04bbfd780e..a8984482a2 100644 --- a/src/rnn/Solutions/Base/bw_data_modular.cpp +++ b/src/rnn/Solutions/Base/bw_data_modular.cpp @@ -172,127 +172,39 @@ void RNNBackwardDataModularAlgo::UpdateHStatePerTimeSeq(const Handle& handle, const SequenceIterator& seq, SequenceDirection direction) const { - // Inited - const size_t hidden_vec = rnnDesc.hsize; - auto rnn_data_type = rnnDesc.dataType; - auto rnn_mode = rnnDesc.rnnMode; - auto rnn_algo_mode = rnnDesc.algoMode; - - if(rnn_mode == miopenRNNRELU || rnn_mode == miopenRNNTANH) - { - // float alpha = 1; - // float beta = 0; - // - //// activation - // auto& activDesc = rnn_mode == miopenRNNRELU ? reluDesc : tanhDesc; - - /* - activDesc.Backward(handle, - &alpha, - dht_desc, - reserveSpace, - dht_desc, - workSpace, - dht_desc, - reserveSpace, - &beta, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - offset + static_cast(ri) * wei_len, - offset + static_cast(ri) * wei_len, - offset + static_cast(ri) * wei_len); - */ - } - else if(rnn_mode == miopenLSTM) - { - if(rnn_algo_mode == miopenRNNdefault) - { - - size_t cur_batch = batchController.getBatchSize(seq.getPhisVal()); - - const auto [dcy_use_batch, cx_use_batch] = [](const auto& seq, - const BatchController& batch_c, - const SequenceDirection dir) { - auto current_batch = batch_c.getBatchSize(seq.getPhisVal()); - if(dir == SequenceDirection::Forward) - { - const auto dcy_batch = seq.isFirst() - ? current_batch - : batch_c.getBatchSize(seq.getPrev().getPhisVal()); - const auto cx_batch = current_batch; - return std::make_tuple(dcy_batch, cx_batch); - } - else - { - const auto dcy_batch = current_batch; - const auto cx_batch = seq.isLast() - ? current_batch - : batch_c.getBatchSize(seq.getNext().getPhisVal()); - return std::make_tuple(dcy_batch, cx_batch); - } - }(seq, batchController, direction); - size_t cur_comb_dim = batchController.getBatchSum(seq.getPhisVal()); - size_t prev_comb_dim = !seq.isFirst() - ? batchController.getBatchSum(seq.getPrev().getPhisVal()) - : batchController.getBatchSum(seq.getPhisVal()); - size_t next_comb_dim = !seq.isLast() - ? batchController.getBatchSum(seq.getNext().getPhisVal()) - : batchController.getBatchSum(seq.getPhisVal()); - - LSTMBackwardHiddenStateUpdate( - handle, - rnn_data_type, - seq.isLast(), // ti == 0, - seq.isFirst(), // ti == seqLen - 1, - static_cast(direction), - batchController.getBatchSize(0), - cur_batch, - dcy_use_batch, - cx_use_batch, - hidden_vec, - reservLayout.gateStride[1], - -666, // unused - -666, // unused - cx, - hiddenHxCxInfo.getOffset(getVirtualLayer(layer, direction), 0), - reserveSpace, - reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::I), - reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::F), - reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::O), - reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::G), - reservLayout.getActiveCellOffset(layer, cur_comb_dim, direction), - reservLayout.getGasOffset( // TODO - layer, - next_comb_dim, - direction, - LstmGateAndState::St), - dcy, - hiddenHxCxInfo.getOffset(getVirtualLayer(layer, direction), 0), - workSpace, - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::I), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::F), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::O), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::G), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::St), - workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::St), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::Ht), - workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::F)); - } - else - { - MIOPEN_THROW(miopenStatusInternalError, - "TODO implementation algoMode != miopenRNNdefault"); - // TODO implementation - } - } - else if(rnn_mode == miopenGRU) - { - MIOPEN_THROW(miopenStatusInternalError, "TODO implementation miopenGRU"); - // TODO implementation - } + const auto [cur_batch, dcy_use_batch, cx_use_batch] = + [](const auto& seq, const BatchController& batch_c, const SequenceDirection dir) { + auto current_batch = batch_c.getBatchSize(seq.getPhisVal()); + if(dir == SequenceDirection::Forward) + { + const auto dcy_batch = seq.isFirst() + ? current_batch + : batch_c.getBatchSize(seq.getPrev().getPhisVal()); + const auto cx_batch = current_batch; + return std::make_tuple(current_batch, dcy_batch, cx_batch); + } + else + { + const auto dcy_batch = current_batch; + const auto cx_batch = + seq.isLast() ? current_batch : batch_c.getBatchSize(seq.getNext().getPhisVal()); + return std::make_tuple(current_batch, dcy_batch, cx_batch); + } + }(seq, batchController, direction); + + return UpdateHStatePerTimeSeq(handle, + dcy, + cx, + nullptr, + workSpace, + reserveSpace, + cur_batch, + dcy_use_batch, + cx_use_batch, + layer, + seq, + direction); } void RNNBackwardDataModularAlgo::PropDhxDcx(const Handle& handle, @@ -672,5 +584,299 @@ void RNNBackwardDataModularAlgo::PropDx(const Handle& handle, false); } +void RNNBackwardDataModularAlgo::UpdateHStatePerTimeSeq(const Handle& handle, + ConstData_t dcy, + ConstData_t cx, + Data_t, + Data_t workSpace, + Data_t reserveSpace, + size_t batchSizeUpdate, + size_t useDcyIfGtBatch, + size_t useCxIfGTBatch, + int layer, + const SequenceIterator& seq, + SequenceDirection direction) const +{ + // Inited + const size_t hidden_vec = rnnDesc.hsize; + auto rnn_data_type = rnnDesc.dataType; + auto rnn_mode = rnnDesc.rnnMode; + auto rnn_algo_mode = rnnDesc.algoMode; + + if(rnn_mode == miopenRNNRELU || rnn_mode == miopenRNNTANH) {} + else if(rnn_mode == miopenLSTM) + { + if(rnn_algo_mode == miopenRNNdefault) + { + size_t cur_comb_dim = batchController.getBatchSum(seq.getPhisVal()); + size_t prev_comb_dim = !seq.isFirst() + ? batchController.getBatchSum(seq.getPrev().getPhisVal()) + : batchController.getBatchSum(seq.getPhisVal()); + size_t next_comb_dim = !seq.isLast() + ? batchController.getBatchSum(seq.getNext().getPhisVal()) + : batchController.getBatchSum(seq.getPhisVal()); + + LSTMBackwardHiddenStateUpdate( + handle, + rnn_data_type, + seq.isLast(), // ti == 0, + seq.isFirst(), // ti == seqLen - 1, + static_cast(direction), + batchController.getBatchSize(0), + batchSizeUpdate, + useDcyIfGtBatch, + useCxIfGTBatch, + hidden_vec, + reservLayout.gateStride[1], + -666, // unused + -666, // unused + cx, + hiddenHxCxInfo.getOffset(getVirtualLayer(layer, direction), 0), + reserveSpace, + reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::I), + reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::F), + reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::O), + reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::G), + reservLayout.getActiveCellOffset(layer, cur_comb_dim, direction), + reservLayout.getGasOffset( // TODO + layer, + next_comb_dim, + direction, + LstmGateAndState::St), + dcy, + hiddenHxCxInfo.getOffset(getVirtualLayer(layer, direction), 0), + workSpace, + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::I), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::F), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::O), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::G), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::St), + workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::St), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::Ht), + workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::F)); + } + else + { + MIOPEN_THROW(miopenStatusInternalError, + "TODO implementation algoMode != miopenRNNdefault"); + // TODO implementation + } + } + else if(rnn_mode == miopenGRU) + { + MIOPEN_THROW(miopenStatusInternalError, "TODO implementation miopenGRU"); + // TODO implementation + } +} + +void RNNBackwardModuleAlgoDynamic::realDyProp(const Handle& handle, + const runtimeArgsBwdDynamicExt& runtimeArgsExt) const +{ + RNNTensorBaseLayoutConverter::ConvertInputTensorGPUData(handle, + realDyDesc, + runtimeArgsExt.realDy, + tmpMapDyDesc, + runtimeArgsExt.tempDy, + nullptr, + false); +} + +void RNNBackwardModuleAlgoDynamic::realDxProp(const Handle& handle, + const runtimeArgsBwdDynamicExt& runtimeArgsExt) const +{ + RNNTensorBaseLayoutConverter::ConvertInputTensorGPUData(handle, + tmpMapDxDesc, + runtimeArgsExt.tempDx, + realDxDesc, + runtimeArgsExt.realDx, + nullptr, + false); +} + +void RNNBackwardModuleAlgoDynamic::realPropDhy(const Handle& handle, + ConstData_t dhy, + Data_t workSpace, + unsigned int layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const +{ + if(dhy == nullptr) + return; + + if(direction == SequenceDirection::Reverse && !currentSeq.isFirst()) + return; + + const auto [copy_batch_size, copy_batch_offset_id] = [](const SequenceIterator& current_seq, + const BatchController& b_c) { + const auto cur_time_batch = b_c.getBatchSize(current_seq.getPhisVal()); + const auto prev_time_batch = + current_seq.isFirst() ? 0 : b_c.getBatchSize(current_seq.getPrev().getPhisVal()); + + size_t dst_batch_offset_id_ = prev_time_batch; + size_t dst_batch_size_ = cur_time_batch - prev_time_batch; + return std::make_tuple(dst_batch_size_, dst_batch_offset_id_); + }(currentSeq, realBatchController); + + // no data so return + if(copy_batch_size <= 0) + return; + + // ws_dy + dhy + const float alpha0 = 1; + const float alpha1 = 1; + const float beta_t = 0; + + // TODO remove virtual in implementation change getOffset + auto virtual_layer = getVirtualLayer(layer, direction); + size_t dhy_layer_offset = hiddenHxCxInfo.getOffset(virtual_layer, copy_batch_offset_id); + + size_t time_batch_offset_id = batchController.getBatchSum(currentSeq.getPhisVal()); + size_t workspace_dy_offset = workspaceInfo.getHiddenStateOffset( + layer, time_batch_offset_id + copy_batch_offset_id, direction); + + const auto dhy_desc = BuildHxCxDesc3D(1, copy_batch_size); + + const auto workspace_dy_desc = BuildTempDhtDesc3D(1, copy_batch_size); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + dhy_desc, + dhy, + &alpha1, + workspace_dy_desc, + workSpace, + &beta_t, + workspace_dy_desc, + workSpace, + dhy_layer_offset, + workspace_dy_offset, + workspace_dy_offset); +} + +void RNNBackwardModuleAlgoDynamic::realUpdateHStatePerTimeSeq(const Handle& handle, + ConstData_t dcy, + ConstData_t cx, + Data_t, + Data_t workSpace, + Data_t reserveSpace, + int layer, + const SequenceIterator& seq, + SequenceDirection direction) const +{ + // Inited + + const auto [cur_batch, dcy_use_batch, cx_use_batch] = + [](const auto& seq, const BatchController& batch_c, const SequenceDirection dir) { + auto current_batch = batch_c.getBatchSize(seq.getPhisVal()); + if(dir == SequenceDirection::Forward) + { + const auto dcy_batch = seq.isFirst() + ? current_batch + : batch_c.getBatchSize(seq.getPrev().getPhisVal()); + const auto cx_batch = current_batch; + return std::make_tuple(current_batch, dcy_batch, cx_batch); + } + else + { + const auto dcy_batch = current_batch; + const auto cx_batch = + seq.isLast() ? current_batch : batch_c.getBatchSize(seq.getNext().getPhisVal()); + return std::make_tuple(current_batch, dcy_batch, cx_batch); + } + }(seq, realBatchController, direction); + + return UpdateHStatePerTimeSeq(handle, + dcy, + cx, + nullptr, + workSpace, + reserveSpace, + cur_batch, + dcy_use_batch, + cx_use_batch, + layer, + seq, + direction); +} + +void RNNBackwardModuleAlgoDynamic::PrepareWriteBuffers( + const Handle& handle, const runtimeArgsBwdDynamicExt& runtimeArgsExt) const +{ + RNNBackwardDataModularAlgo::PrepareWriteBuffers( + handle, runtimeArgsExt.dhx, runtimeArgsExt.dcx, runtimeArgsExt.workSpace); + + // realDxProp(handle, runtimeArgsExt); +} + +void RNNBackwardModuleAlgoDynamic::PropHyCy(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const +{ + if(runtimeArgs.hy != nullptr || (runtimeArgs.cy != nullptr)) + { + const auto gap_batch_size = [&]() { + if(currentSeq.isLast()) + { + return realBatchController.getBatchSize(currentSeq.getPhisVal()); + } + else + { + if(direction == SequenceDirection::Forward) + { + return realBatchController.getBatchSize(currentSeq.getPhisVal()) - + realBatchController.getBatchSize(currentSeq.getNext().getPhisVal()); + } + else + return static_cast(0); + } + }(); + + const auto gap_batch_offset = [&]() { + if(currentSeq.isLast()) + return static_cast(0); + else + return realBatchController.getBatchSize(currentSeq.getPhisVal()) - gap_batch_size; + }(); + + if(gap_batch_size > 0) + { + + auto src_desc = BuildTempDhtDesc3D(1, gap_batch_size); + + auto dst_desc = BuildHxCxDesc3D(1, gap_batch_size); + + size_t tmp_batch_offset = + batchController.getBatchSum(currentSeq.getPhisVal()) + gap_batch_offset; + + if(runtimeArgs.hy != nullptr) + { + CopyTensor(handle, + src_desc, + runtimeArgs.reserveSpace, + dst_desc, + runtimeArgs.hy, + reservLayout.getGasOffset( + layer, tmp_batch_offset, direction, LstmGateAndState::Ht), + hiddenHxCxInfo.getOffset(layer, gap_batch_offset)); + } + + if(runtimeArgs.cy != nullptr) + { + CopyTensor(handle, + src_desc, + runtimeArgs.reserveSpace, + dst_desc, + runtimeArgs.cy, + reservLayout.getGasOffset( + layer, tmp_batch_offset, direction, LstmGateAndState::St), + hiddenHxCxInfo.getOffset(layer, gap_batch_offset)); + } + } + } +} + } // namespace rnn_base } // namespace miopen diff --git a/src/rnn/Solutions/bwd_s_stream.cpp b/src/rnn/Solutions/bwd_s_stream.cpp index 5244aff4bd..d84caa6e83 100644 --- a/src/rnn/Solutions/bwd_s_stream.cpp +++ b/src/rnn/Solutions/bwd_s_stream.cpp @@ -134,5 +134,77 @@ void RNNModularSingleStreamBWD::ComputeBWD(Handle& handle, #endif } +void RNNDynamicModularSingleStreamBWD::ComputeBWD(Handle& handle, + const runtimeArgsBwd& realRuntimeArgs) const +{ + auto layer_i = rnnDesc.nLayers; + + if(layer_i == 0 || max_seq_len == 0) + return; + + auto sequence_directions = + rnnDesc.dirMode == miopenRNNDirectionMode_t::miopenRNNbidirection ? 2 : 1; + + const auto runtimeArgsExt = rnnAlgoModules.createRuntimeArgsExt(realRuntimeArgs); + const auto [real_dy, + temp_dy, + dhy, + dhx, + cx, + dcy, + dcx, + real_dx, + temp_dx, + w, + workSpace, + reserveSpace] = runtimeArgsExt; + + rnnAlgoModules.PrepareWriteBuffers(handle, runtimeArgsExt); + + rnnAlgoModules.realDyProp(handle, runtimeArgsExt); + + rnnAlgoModules.PropDy(handle, temp_dy, workSpace); + + do + { + layer_i--; + + for(int dir = 0; dir < sequence_directions; dir++) + { + const auto seq_dir = dir == 0 ? rnn_base::SequenceDirection::Forward + : rnn_base::SequenceDirection::Reverse; + + auto ti = max_seq_len; + do + { + const rnn_base::SequenceIterator cur_seq(--ti, seq_dir, max_seq_len, false); + + rnnAlgoModules.realPropDhy(handle, dhy, workSpace, layer_i, cur_seq, seq_dir); + + // rnnAlgoModules.HtHiddenDataZeroing(); + + rnnAlgoModules.realUpdateHStatePerTimeSeq( + handle, dcy, cx, dcx, workSpace, reserveSpace, layer_i, cur_seq, seq_dir); + + // GEMM + if(ti != 0) + rnnAlgoModules.PropHiddenDht(handle, w, workSpace, layer_i, cur_seq, seq_dir); + else + rnnAlgoModules.PropDhxDcx( + handle, w, dhx, dcx, workSpace, reserveSpace, layer_i, cur_seq, seq_dir); + + } while(ti != 0); + + if(layer_i != 0) + rnnAlgoModules.PropHiddenDy(handle, w, workSpace, reserveSpace, layer_i, seq_dir); + else + rnnAlgoModules.PropDx(handle, w, workSpace, temp_dx, seq_dir); + } + + } while(layer_i != 0); + + rnnAlgoModules.realDxProp(handle, runtimeArgsExt); +} + } // namespace rnn_base } // namespace miopen diff --git a/src/rnn/selector.cpp b/src/rnn/selector.cpp index 6241500e91..aeab61ce93 100644 --- a/src/rnn/selector.cpp +++ b/src/rnn/selector.cpp @@ -147,10 +147,22 @@ void RNNDescriptor::ModularBackward(Handle& handle, } else { - rnn_base::RNNModularSingleStreamBWD single_stream{ - *this, xDesc, yDesc, hDesc, miopenRNNFWDMode_t::miopenRNNTraining}; - single_stream.ComputeBWD( - handle, dy, dhy, dhx, cx, dcy, dcx, dx, w, workSpace, reserveSpace); + if(CheckDynamicAlgoSelection(handle, xDesc, miopenRNNFWDMode_t::miopenRNNTraining)) + { + rnn_base::RNNDynamicModularSingleStreamBWD single_stream{ + *this, xDesc, yDesc, hDesc, miopenRNNFWDMode_t::miopenRNNTraining}; + single_stream.ComputeBWD( + handle, + rnn_base::runtimeArgsBwd{ + &handle, dy, dhy, dhx, cx, dcy, dcx, dx, w, workSpace, reserveSpace}); + } + else + { + rnn_base::RNNModularSingleStreamBWD single_stream{ + *this, xDesc, yDesc, hDesc, miopenRNNFWDMode_t::miopenRNNTraining}; + single_stream.ComputeBWD( + handle, dy, dhy, dhx, cx, dcy, dcx, dx, w, workSpace, reserveSpace); + } } } From bc6e9a6d761d71b734c1882e22eac8571febf220 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Wed, 13 Nov 2024 16:50:34 +0100 Subject: [PATCH 07/21] file separation --- .../rnn/algorithms/default_algo_utils.hpp | 710 +++++++++++ .../rnn/algorithms/dynamic_algo_utils.hpp | 312 +++++ src/include/miopen/rnn/solvers.hpp | 1113 +---------------- src/rnn/Solutions/Base/bw_data_modular.cpp | 4 +- src/rnn/Solutions/Base/fw_data_modular.cpp | 4 +- src/rnn/Solutions/bww_multi_stream.cpp | 4 +- 6 files changed, 1061 insertions(+), 1086 deletions(-) create mode 100644 src/include/miopen/rnn/algorithms/default_algo_utils.hpp create mode 100644 src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp diff --git a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp new file mode 100644 index 0000000000..4b0839431d --- /dev/null +++ b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp @@ -0,0 +1,710 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once +#include + +namespace miopen { + +namespace rnn_base { + +struct runtimeArgsFwd +{ + const ConstData_t x; + const ConstData_t hx; + const ConstData_t cx; + const Data_t y; + const Data_t hy; + const Data_t cy; + const ConstData_t w; + const Data_t workSpace; + const Data_t reserveSpace; +}; + +struct runtimeArgsBwd +{ + const Handle* handle; + const ConstData_t dy; + const ConstData_t dhy; + const Data_t dhx; + const ConstData_t cx; + const ConstData_t dcy; + const Data_t dcx; + const Data_t dx; + const ConstData_t w; + const Data_t workSpace; + const Data_t reserveSpace; +}; + +struct runtimeArgsBWWeights +{ + const Handle* handle; + const ConstData_t x; + const ConstData_t hx; + const Data_t dw; + const Data_t workSpace; + const ConstData_t reserveSpace; +}; + + +class RNNModuleAlgoBase +{ +protected: + static GeneralLstmTempBuffer backwardInterimInfoBuilder(const RNNDescriptor& rnnDesc, + const SeqTensorDescriptor& xDesc) + { + auto layers_cnt = static_cast(rnnDesc.nLayers); + const size_t seq_directions = rnnDesc.dirMode == miopenRNNbidirection ? 2 : 1; + auto hidden_vec_sz = rnnDesc.hsize; + + return GeneralLstmTempBuffer::build( + layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); + } + + static GeneralLstmRedBuffer forwardInterimInfoBuilder(const RNNDescriptor& rnnDesc, + const SeqTensorDescriptor& xDesc) + { + auto layers_cnt = static_cast(rnnDesc.nLayers); + const size_t seq_directions = rnnDesc.dirMode == miopenRNNbidirection ? 2 : 1; + auto hidden_vec_sz = rnnDesc.hsize; + + return GeneralLstmRedBuffer::build( + layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); + } + +public: + static RNNModuleAlgoBase create(const RNNDescriptor& rnnDesc, + const SeqTensorDescriptor& xDesc, + const SeqTensorDescriptor& yDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + { + auto [max_layers_hid, max_batch_hid, hidden_vec_sz] = miopen::tien<3>(hDesc.GetLengths()); + auto [max_batch_in, max_seq, input_vec_sz] = miopen::tien<3>(xDesc.GetLengths()); + + assert(max_batch_in <= max_batch_hid); + + auto layers_cnt = static_cast(rnnDesc.nLayers); + const bool is_seq_bidir = rnnDesc.dirMode == miopenRNNbidirection; + + assert(static_cast(layers_cnt) * (is_seq_bidir ? 2 : 1) <= max_layers_hid); + + auto gates_cnt = static_cast(rnnDesc.nHiddenTensorsPerLayer); + + // class update req + assert(!is_seq_bidir); + + // TODO all size_t + GeneralLstmRedBuffer rb_layout = forwardInterimInfoBuilder(rnnDesc, xDesc); + + GeneralLstmTempBuffer workspace_info = backwardInterimInfoBuilder(rnnDesc, xDesc); + + WeightsBufferDescriptor weights_layout = + WeightsBufferDescriptor::create(static_cast(input_vec_sz), + static_cast(hidden_vec_sz), + layers_cnt, + rnnDesc.biasMode, + rnnDesc.inputMode, + gates_cnt, + is_seq_bidir); + + BatchController batch_controller = BatchController::Create(xDesc); + + HiddenBuffersDescriptor hidden_hxcx_info{hDesc}; + + IOBufferDescriptor x_info{IOBufferDescriptor::build(xDesc)}; + IOBufferDescriptor y_info{IOBufferDescriptor::build(yDesc)}; + + return {std::move(rb_layout), + workspace_info, + weights_layout, + hidden_hxcx_info, + x_info, + y_info, + rnnDesc, + batch_controller, + mode}; + } + + RNNModuleAlgoBase(RNNModuleAlgoBase&&) = default; + // RNNModuleAlgoBase(RNNModuleAlgoBase const&) = default; + + RNNModuleAlgoBase(GeneralLstmRedBuffer rb_layout, + GeneralLstmTempBuffer workspace_info, + WeightsBufferDescriptor weights_layout, + HiddenBuffersDescriptor hidden_hxcx_info, + IOBufferDescriptor x_info, + IOBufferDescriptor y_info, + const RNNDescriptor& rnn_desc, + BatchController batch_controller, + miopenRNNFWDMode_t fwd_mode) + : reservLayout(std::move(rb_layout)), + workspaceInfo(std::move(workspace_info)), + weightsLayout(std::move(weights_layout)), + hiddenHxCxInfo(std::move(hidden_hxcx_info)), + xInfo(std::move(x_info)), + yInfo(std::move(y_info)), + rnnDesc(rnn_desc), + tanhDesc{miopenActivationTANH, 1, 1, 1}, + sigDesc{miopenActivationLOGISTIC, 1, 0, 1}, + reluDesc{miopenActivationRELU, 1, 0, 1}, + batchController(std::move(batch_controller)), + fwdMode(fwd_mode), + isBidirectSeq(false) + { + } + + const GeneralLstmRedBuffer reservLayout; + // const WorkspaceBufferDescriptor workspaceInfo; + const GeneralLstmTempBuffer workspaceInfo; + + const WeightsBufferDescriptor weightsLayout; + const HiddenBuffersDescriptor hiddenHxCxInfo; + + const IOBufferDescriptor xInfo; + const IOBufferDescriptor yInfo; + + const RNNDescriptor& rnnDesc; + + const ActivationDescriptor tanhDesc; + const ActivationDescriptor sigDesc; + const ActivationDescriptor reluDesc; + + const BatchController batchController; + + const miopenRNNFWDMode_t fwdMode; + + const bool isBidirectSeq; + + std::tuple getTempBuffersSize() const + { + + return std::make_tuple(workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType), + reservLayout.getBufferSize() * GetTypeSize(rnnDesc.dataType)); + } + + static std::tuple getTempBuffersSize(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xDesc) + { + auto wsInfo = backwardInterimInfoBuilder(rnnD, xDesc); + auto reservInfo = forwardInterimInfoBuilder(rnnD, xDesc); + + return std::make_tuple(wsInfo.getBufferSize() * GetTypeSize(rnnD.dataType), + reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); + } + + inline size_t getVirtualLayer(const size_t layer_id, SequenceDirection direction) const + { + return layer_id * (isBidirectSeq ? 2 : 1) + + (direction == SequenceDirection::Forward ? 0 : 1); + } + + template + inline miopen::TensorDescriptor BuildLstmTmpBlockDesc2D(const BufType& buf_info, + const size_t batch_size) const + { + const std::array& tmp_block_stride = buf_info.getGateBlockStride(); + const std::array& tmp_block_size = buf_info.getGateBlockSize(); + + // batch, gateBlock_elements + return miopen::TensorDescriptor{rnnDesc.dataType, + {batch_size, tmp_block_size[3]}, + {tmp_block_stride[1], tmp_block_stride[3]}}; + } + + inline miopen::TensorDescriptor BuildLstmFilterXDesc2D(int layer_id) const + { + assert(rnnDesc.inputMode == 0 || layer_id != 0); + // TODO replace by stride + auto x_vec = layer_id != 0 ? weightsLayout.xInVec : weightsLayout.inVec; + + // gateBlock_elements, ht_vec + return miopen::TensorDescriptor{ + rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, x_vec}, {x_vec, 1}}; + } + + inline miopen::TensorDescriptor BuildLstmFilterHidDesc2D() const + { + // TODO replace by stride + auto h_vec = weightsLayout.hVec; + + // gateBlock_elements, ht_vec + return miopen::TensorDescriptor{ + rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, h_vec}, {h_vec, 1}}; + } + + template + inline miopen::TensorDescriptor BuildTmpHtDesc2D(const BufType& tmpSpace, size_t batch_size) const + { + auto& ht_stride = tmpSpace.getHiddenStateStride(); + auto& ht_size = tmpSpace.hStateSizes; + + // batch, gateBlock_elements + return miopen::TensorDescriptor{ + rnnDesc.dataType, {batch_size, ht_size[3]}, {ht_stride[1], ht_stride[3]}}; + } + + // 2 dims batch, vec + inline miopen::TensorDescriptor BuildHxCxDesc2D(size_t batch_size) const + { + const std::vector hx_size{batch_size, hiddenHxCxInfo.getHiddenSize()}; + const std::vector hx_stride{hiddenHxCxInfo.getStrides()[1], + hiddenHxCxInfo.getStrides()[2]}; + + return miopen::TensorDescriptor{rnnDesc.dataType, hx_size, hx_stride}; + } + + // 3 dims layer, batch, vec + inline miopen::TensorDescriptor BuildHxCxDesc3D(size_t layer_size, size_t batch_size) const + { + const std::vector hx_accum_size{ + layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; + + return miopen::TensorDescriptor{ + rnnDesc.dataType, hx_accum_size, hiddenHxCxInfo.getStrides()}; + } + + // 3 dims layer, batch, vec + inline miopen::TensorDescriptor BuildTempDhtDesc3D(size_t layer_size, size_t batch_size) const + { + const std::vector dy_dhy_accum_size{ + layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; + + const auto ws_dy_stride = [](const auto& ws_4dim_strides) -> std::vector { + // convert 4dim stride to 3 dim without direction + // TODO change hiddenBufferDesc + return std::vector{ws_4dim_strides[0], ws_4dim_strides[1], ws_4dim_strides[3]}; + }(workspaceInfo.getHiddenStateStride()); + + return miopen::TensorDescriptor{rnnDesc.dataType, dy_dhy_accum_size, ws_dy_stride}; + } + + // 3 dims layer, batch, vec + inline miopen::TensorDescriptor BuildWeiBiasDesc2D() const + { + const std::vector bias_size = [](const auto& wei_4dim_size) -> std::vector { + // wei_4dim_size{layer, dir, gate, vec} + return {1, wei_4dim_size[1] * wei_4dim_size[2] * wei_4dim_size[3]}; + }(weightsLayout.getBiasSize()); + + const auto bias_stride = [](const auto& wei_4dim_strides) -> std::vector { + // convert 4dim stride to 2 dim without direction + return std::vector{wei_4dim_strides[0], wei_4dim_strides[3]}; + }(weightsLayout.getBiasStride()); + + return miopen::TensorDescriptor{rnnDesc.dataType, bias_size, bias_stride}; + } +}; + +class RNNForwardDataModularAlgo : protected RNNModuleAlgoBase +{ +public: + // Compute API + // base API + void PrepareWriteBuffers(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const; + + void PropX(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const; + + void AddBias(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const; + void PropHxCx(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + unsigned int layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + + void PropHiddenHt(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + int layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + + void UpdateHStatePerTimeSeq(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + int layer, + const SequenceIterator& seq, + SequenceDirection direction) const; + + void PropHyCy(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + + void PropHiddenY(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + SequenceDirection direction) const; + + void PropY(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const; + + // ext API + void PropX(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t gemm_batch_offset, + size_t gemm_batch_size) const; + + void PropHiddenY(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + SequenceDirection direction, + const SequenceIterator& firstSeq, + const SequenceIterator& lastSeq) const; + + void PropHiddenY(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + SequenceDirection direction, + size_t gemm_batch_size, + size_t gemm_batch_offset) const; + + void PropX(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + SequenceDirection direction, + const SequenceIterator& firstSeq, + const SequenceIterator& lastSeq) const; + + void PropX(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + SequenceDirection direction) const; + + void PropX(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + SequenceDirection direction, + size_t gemm_batch_offset, + size_t gemm_batch_size) const; + + /// end compute API + + static bool IsApplicable() + { +#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP + return true; +#else + return false; +#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP + } + + std::tuple getTempBuffersSize() const + { + + return std::make_tuple(workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType), + reservLayout.getBufferSize() * GetTypeSize(rnnDesc.dataType)); + } + + static std::tuple getTempBuffersSize(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xDesc) + { + auto wsInfo = backwardInterimInfoBuilder(rnnD, xDesc); + auto reservInfo = forwardInterimInfoBuilder(rnnD, xDesc); + + return std::make_tuple(wsInfo.getBufferSize() * GetTypeSize(rnnD.dataType), + reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); + } + + RNNForwardDataModularAlgo(RNNModuleAlgoBase base) : RNNModuleAlgoBase(std::move(base)) {} + +private: +}; + +class RNNBackwardDataModularAlgo : protected RNNModuleAlgoBase +{ +public: + void PrepareWriteBuffers(const Handle& handle, Data_t dhx, Data_t dcx, Data_t workSpace) const; + + void PropDhy(const Handle& handle, + ConstData_t dhy, + Data_t workSpace, + unsigned int layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + + void PropHiddenDht(const Handle& handle, + ConstData_t w, + Data_t workSpace, + int layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + + void UpdateHStatePerTimeSeq(const Handle& handle, + ConstData_t dcy, + ConstData_t cx, + Data_t, + Data_t workSpace, + Data_t reserveSpace, + int layer, + const SequenceIterator& seq, + SequenceDirection direction) const; + + void UpdateHStatePerTimeSeq(const Handle& handle, + ConstData_t dcy, + ConstData_t cx, + Data_t, + Data_t workSpace, + Data_t reserveSpace, + size_t batchSizeUpdate, + size_t useDcyIfGtBatch, + size_t useCxIfGTBatch, + int layer, + const SequenceIterator& seq, + SequenceDirection direction) const; + + void PropDhxDcx(const Handle& handle, + ConstData_t w, + Data_t dhx, + Data_t dcx, + Data_t workSpace, + Data_t reserveSpace, + size_t layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + + void PropDy(const Handle& handle, ConstData_t dy, Data_t workSpace) const; + + void PropHiddenDy(const Handle& handle, + ConstData_t w, + Data_t workSpace, + Data_t reserveSpace, + size_t layer, + SequenceDirection direction) const; + + void PropHiddenDy(const Handle& handle, + ConstData_t w, + Data_t workSpace, + Data_t reserveSpace, + size_t layer, + SequenceDirection direction, + const SequenceIterator& firstSeq, + const SequenceIterator& lastSeq) const; + + void PropHiddenDy(const Handle& handle, + ConstData_t w, + Data_t workSpace, + Data_t reserveSpace, + size_t layer, + SequenceDirection direction, + size_t gemm_batch_size, + size_t gemm_batch_offset) const; + + void PropDx(const Handle& handle, + ConstData_t w, + ConstData_t workSpace, + Data_t dx, + SequenceDirection direction, + const SequenceIterator& firstSeq, + const SequenceIterator& lastSeq) const; + + void PropDx(const Handle& handle, + ConstData_t w, + ConstData_t workSpace, + Data_t dx, + SequenceDirection direction) const; + + void PropDx(const Handle& handle, + ConstData_t w, + ConstData_t workSpace, + Data_t dx, + SequenceDirection direction, + size_t gemm_batch_offset, + size_t gemm_batch_size) const; + static bool IsApplicable() + { +#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP + return true; +#else + return false; +#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP + } + + RNNBackwardDataModularAlgo(RNNModuleAlgoBase&& base) : RNNModuleAlgoBase(std::move(base)) {} +}; + +class RNNBackwardWeightsModularAlgo : public RNNModuleAlgoBase +{ +public: + void PrepareWriteBuffers(const Handle& handle, Data_t w) const; + + void PhisXInputWeights(const Handle& handle, Data_t dw, Data_t workSpace, ConstData_t x) const; + + void HiddenXInputWeights(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t reserveSpace, + size_t layer) const; + + void BiasUpdate(const Handle& handle, + Data_t dw, + Data_t workSpace, + size_t layer, + size_t workSpaceSize) const; + + void HiddenHStateWeights(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t reserveSpace, + const SequenceIterator& seq, + size_t layer, + SequenceDirection direction) const + { + const size_t gemm_batch_size = [&]() -> size_t { + if(seq.isFirst()) + return 0; + + if(direction == SequenceDirection::Reverse) + return batchController.getBatchSize(seq.getPhisVal()); + else + return batchController.getBatchSize(seq.getPrev().getPhisVal()); + }(); + + if(gemm_batch_size != 0) + return HiddenHStateWeights_Unchecked( + handle, dw, workSpace, reserveSpace, seq, layer, direction, gemm_batch_size); + } + + void HiddenHStateWeights(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t reserveSpace, + size_t layer, + size_t max_seq_len, + const SequenceDirection direction) const + { + size_t start_seq_id = 0; + const size_t last_seq = max_seq_len - 1; + for(auto i = start_seq_id + 1; i <= last_seq; i++) + { + + if(batchController.getBatchSize(i) != batchController.getBatchSize(start_seq_id) || + i == last_seq) + { + const size_t gemm_batch_size = (batchController.getBatchSum(i - 1) - + batchController.getBatchSum(start_seq_id)) + + batchController.getBatchSize(i); + + if(gemm_batch_size != 0) + { + const auto first_logical_val = direction == SequenceDirection::Forward + ? start_seq_id + : (max_seq_len - 1) - start_seq_id - 1; + const auto seq = + SequenceIterator(first_logical_val, direction, max_seq_len, false); + + HiddenHStateWeights_Unchecked(handle, + dw, + workSpace, + reserveSpace, + seq, + layer, + direction, + gemm_batch_size); + } + start_seq_id = i; + } + } + } + + void PhisHStateWeights(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t hx, + size_t layer, + size_t max_seq_len, + SequenceDirection direction) const + { + if(hx == nullptr) + return; + + for(auto i = max_seq_len; i > 0; i--) + { + const auto seq = SequenceIterator(i - 1, direction, max_seq_len, false); + + PhisHStateWeights(handle, dw, workSpace, hx, seq, layer, direction); + } + } + + static bool IsApplicable() + { +#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP + return true; +#else + return false; +#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP + } + + std::tuple getTempBuffersSize() const + { + + return std::make_tuple(workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType), + reservLayout.getBufferSize() * GetTypeSize(rnnDesc.dataType)); + } + + static std::tuple getTempBuffersSize(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xDesc) + { + auto wsInfo = backwardInterimInfoBuilder(rnnD, xDesc); + auto reservInfo = forwardInterimInfoBuilder(rnnD, xDesc); + + return std::make_tuple(wsInfo.getBufferSize() * GetTypeSize(rnnD.dataType), + reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); + } + + RNNBackwardWeightsModularAlgo(RNNModuleAlgoBase base) : RNNModuleAlgoBase(std::move(base)) {} + +protected: + void HiddenHStateWeights_Unchecked(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t reserveSpace, + const SequenceIterator& seq, + size_t layer, + SequenceDirection direction, + size_t gemm_batch_size) const; + + void PhisHStateWeights(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t hx, + const SequenceIterator& seq, + size_t layer, + SequenceDirection direction) const; + + static size_t getHxBatchSizeReadAtTime(const SequenceIterator& seq, + const BatchController& batchInfo, + SequenceDirection direction) + { + if(seq.isLast()) + return batchInfo.getBatchSize(seq.getPhisVal()); + + if(direction == SequenceDirection::Reverse) + { + return batchInfo.getBatchSize(seq.getPhisVal()) - + batchInfo.getBatchSize(seq.getPrev().getPhisVal()); + } + return 0; + } +}; + +} // namespace rnn_base +} // namespace miopen diff --git a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp new file mode 100644 index 0000000000..8bb99d4133 --- /dev/null +++ b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp @@ -0,0 +1,312 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once +#include + +#include "miopen/rnn/algorithms/default_algo_utils.hpp" + +namespace miopen { + +namespace rnn_base { + +inline SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) +{ + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), def_layout, desc.GetLengths(), false}; +} + +inline SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) +{ + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), + def_layout, + desc.GetLengths(), + desc.GetSequenceLengthsVector(), + std::vector{}, + true, + true}; +} + +class RNNModuleAlgoDynamic : public RNNForwardDataModularAlgo +{ + +public: + RNNModuleAlgoDynamic(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xTDesc, + const SeqTensorDescriptor& yTDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + : RNNForwardDataModularAlgo(RNNModuleAlgoBase::create( + rnnD, buildDynamicVirtual(xTDesc), buildDynamicVirtual(yTDesc), hDesc, mode)), + realBatchController(BatchController::Create(xTDesc)), + realXDesc(xTDesc), + realYDesc(yTDesc), + tmpMapXDesc(buildRealToDynamicMapTmp(xTDesc)), + tmpMapYDesc(buildRealToDynamicMapTmp(yTDesc)) + { + } + + struct runtimeArgsFwdDynamicExt + { + const ConstData_t realX; + const Data_t tempX; + const ConstData_t hx; + const ConstData_t cx; + const Data_t realY; + const Data_t tempY; + const Data_t hy; + const Data_t cy; + const ConstData_t w; + const Data_t workSpace; + const Data_t reserveSpace; + }; + + runtimeArgsFwdDynamicExt createRuntimeArgsExt(const runtimeArgsFwd& runtimeArgs) const + { + const Data_t temp_x = + moveDataPtr(runtimeArgs.workSpace, workspaceInfo.getBufferSizeImpl(), rnnDesc.dataType); + + const Data_t temp_y = moveDataPtrByte(temp_x, tmpMapXDesc.GetTensorMaxByteSpace()); + + return { + runtimeArgs.x, + temp_x, + runtimeArgs.hx, + runtimeArgs.cx, + runtimeArgs.y, + temp_y, + runtimeArgs.hy, + runtimeArgs.cy, + runtimeArgs.w, + runtimeArgs.workSpace, + runtimeArgs.reserveSpace, + }; + } + + auto getTempBuffersSize() const + { + auto [ws_size, reserve_size] = RNNForwardDataModularAlgo::getTempBuffersSize(); + + return std::make_tuple(ws_size + tmpMapXDesc.GetTensorMaxByteSpace() + + tmpMapYDesc.GetTensorMaxByteSpace(), + reserve_size); + } + + static auto getTempBuffersSize(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) + { + auto y_desc = [](const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { + std::vector y_lenghts{xDesc.GetLengths()}; + y_lenghts[2] = rnnD.hsize * (rnnD.dirMode == miopenRNNbidirection ? 2 : 1); + return SeqTensorDescriptor{xDesc.GetType(), y_lenghts}; + }(rnnD, xDesc); + + auto temp_x_desc = buildDynamicVirtual(xDesc); + auto temp_y_desc = buildDynamicVirtual(y_desc); + + auto [ws_size, reserve_size] = + RNNForwardDataModularAlgo::getTempBuffersSize(rnnD, temp_x_desc); + + return std::make_tuple(ws_size + temp_x_desc.GetTensorMaxByteSpace() + + temp_y_desc.GetTensorMaxByteSpace(), + reserve_size); + } + + void realXProp(const Handle& handle, const runtimeArgsFwdDynamicExt& runtimeArgsExt) const; + + void realYProp(const Handle& handle, const runtimeArgsFwdDynamicExt& runtimeArgsExt) const; + + void PrepareWriteBuffers(const Handle& handle, + const runtimeArgsFwdDynamicExt& runtimeArgsExt, + const runtimeArgsFwd& runtimeArgs) const; + + void PropHyCy(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + +private: + BatchController realBatchController; + + SeqTensorDescriptor realXDesc; + SeqTensorDescriptor realYDesc; + SeqTensorDescriptor tmpMapXDesc; + SeqTensorDescriptor tmpMapYDesc; +}; + +class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo +{ + using BaseBWDModuleT = rnn_base::RNNBackwardDataModularAlgo; + static SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) + { + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), def_layout, desc.GetLengths(), false}; + } + + static SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) + { + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), + def_layout, + desc.GetLengths(), + desc.GetSequenceLengthsVector(), + std::vector{}, + true, + true}; + } + +public: + RNNBackwardModuleAlgoDynamic(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xTDesc, + const SeqTensorDescriptor& yTDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + : BaseBWDModuleT(RNNModuleAlgoBase::create( + rnnD, buildDynamicVirtual(xTDesc), buildDynamicVirtual(yTDesc), hDesc, mode)), + realBatchController(BatchController::Create(xTDesc)), + realDxDesc(xTDesc), + realDyDesc(yTDesc), + tmpMapDxDesc(buildRealToDynamicMapTmp(xTDesc)), + tmpMapDyDesc(buildRealToDynamicMapTmp(yTDesc)) + { + } + + struct runtimeArgsBwdDynamicExt + { + const ConstData_t realDy; + const Data_t tempDy; + const ConstData_t dhy; + const Data_t dhx; + const ConstData_t cx; + const ConstData_t dcy; + const Data_t dcx; + const Data_t realDx; + const Data_t tempDx; + const ConstData_t w; + const Data_t workSpace; + const Data_t reserveSpace; + }; + + runtimeArgsBwdDynamicExt createRuntimeArgsExt(const runtimeArgsBwd& runtimeArgs) const + { + const Data_t temp_dx = + moveDataPtr(runtimeArgs.workSpace, workspaceInfo.getBufferSizeImpl(), rnnDesc.dataType); + + const Data_t temp_dy = moveDataPtrByte(temp_dx, tmpMapDxDesc.GetTensorMaxByteSpace()); + + return { + runtimeArgs.dy, + temp_dy, + runtimeArgs.dhy, + runtimeArgs.dhx, + runtimeArgs.cx, + runtimeArgs.dcy, + runtimeArgs.dcx, + runtimeArgs.dx, + temp_dx, + runtimeArgs.w, + runtimeArgs.workSpace, + runtimeArgs.reserveSpace, + }; + } + + auto getTempBuffersSize() const + { + auto [ws_size, reserve_size] = BaseBWDModuleT::getTempBuffersSize(); + + return std::make_tuple(ws_size + tmpMapDxDesc.GetTensorMaxByteSpace() + + tmpMapDyDesc.GetTensorMaxByteSpace(), + reserve_size); + } + + static auto getTempBuffersSize(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) + { + auto y_desc = [](const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { + std::vector y_lenghts{xDesc.GetLengths()}; + y_lenghts[2] = rnnD.hsize * (rnnD.dirMode == miopenRNNbidirection ? 2 : 1); + return SeqTensorDescriptor{xDesc.GetType(), y_lenghts}; + }(rnnD, xDesc); + + auto temp_x_desc = buildDynamicVirtual(xDesc); + auto temp_y_desc = buildDynamicVirtual(y_desc); + + auto [ws_size, reserve_size] = + RNNForwardDataModularAlgo::getTempBuffersSize(rnnD, temp_x_desc); + + return std::make_tuple(ws_size + temp_x_desc.GetTensorMaxByteSpace() + + temp_y_desc.GetTensorMaxByteSpace(), + reserve_size); + } + + void realDxProp(const Handle& handle, const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; + + void realDyProp(const Handle& handle, const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; + + void realPropDhy(const Handle& handle, + ConstData_t dhy, + Data_t workSpace, + unsigned int layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + + void realUpdateHStatePerTimeSeq(const Handle& handle, + ConstData_t dcy, + ConstData_t cx, + Data_t, + Data_t workSpace, + Data_t reserveSpace, + int layer, + const SequenceIterator& seq, + SequenceDirection direction) const; + + void PrepareWriteBuffers(const Handle& handle, + const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; + + void HtHiddenDataZeroing() const; + + // void PrepareWriteBuffers(const Handle& handle, + // const runtimeArgsBwdDynamicExt& runtimeArgsExt, + // const runtimeArgsFwd& runtimeArgs) const; + + void PropHyCy(const Handle& handle, + const runtimeArgsFwd& runtimeArgs, + size_t layer, + const SequenceIterator& currentSeq, + SequenceDirection direction) const; + +private: + BatchController realBatchController; + + SeqTensorDescriptor realDxDesc; + SeqTensorDescriptor realDyDesc; + SeqTensorDescriptor tmpMapDxDesc; + SeqTensorDescriptor tmpMapDyDesc; +}; + + +} // namespace rnn_base +} // namespace miopen diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index 9e745f4cef..f575bdbfaf 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -30,764 +30,17 @@ #include #include "miopen/rnn/tmp_buffer_utils.hpp" -namespace miopen { - -namespace rnn_base { - -struct runtimeArgsFwd -{ - const ConstData_t x; - const ConstData_t hx; - const ConstData_t cx; - const Data_t y; - const Data_t hy; - const Data_t cy; - const ConstData_t w; - const Data_t workSpace; - const Data_t reserveSpace; -}; - -struct runtimeArgsBwd -{ - const Handle* handle; - const ConstData_t dy; - const ConstData_t dhy; - const Data_t dhx; - const ConstData_t cx; - const ConstData_t dcy; - const Data_t dcx; - const Data_t dx; - const ConstData_t w; - const Data_t workSpace; - const Data_t reserveSpace; -}; - -class RNNModuleAlgoBase -{ -protected: - static GeneralLstmTempBuffer backwardInterimInfoBuilder(const RNNDescriptor& rnnDesc, - const SeqTensorDescriptor& xDesc) - { - auto layers_cnt = static_cast(rnnDesc.nLayers); - const size_t seq_directions = rnnDesc.dirMode == miopenRNNbidirection ? 2 : 1; - auto hidden_vec_sz = rnnDesc.hsize; - - return GeneralLstmTempBuffer::build( - layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); - } - - static GeneralLstmRedBuffer forwardInterimInfoBuilder(const RNNDescriptor& rnnDesc, - const SeqTensorDescriptor& xDesc) - { - auto layers_cnt = static_cast(rnnDesc.nLayers); - const size_t seq_directions = rnnDesc.dirMode == miopenRNNbidirection ? 2 : 1; - auto hidden_vec_sz = rnnDesc.hsize; - - return GeneralLstmRedBuffer::build( - layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); - } - -public: - static RNNModuleAlgoBase create(const RNNDescriptor& rnnDesc, - const SeqTensorDescriptor& xDesc, - const SeqTensorDescriptor& yDesc, - const TensorDescriptor& hDesc, - miopenRNNFWDMode_t mode) - { - auto [max_layers_hid, max_batch_hid, hidden_vec_sz] = miopen::tien<3>(hDesc.GetLengths()); - auto [max_batch_in, max_seq, input_vec_sz] = miopen::tien<3>(xDesc.GetLengths()); - - assert(max_batch_in <= max_batch_hid); - - auto layers_cnt = static_cast(rnnDesc.nLayers); - const bool is_seq_bidir = rnnDesc.dirMode == miopenRNNbidirection; - - assert(static_cast(layers_cnt) * (is_seq_bidir ? 2 : 1) <= max_layers_hid); - - auto gates_cnt = static_cast(rnnDesc.nHiddenTensorsPerLayer); - - // class update req - assert(!is_seq_bidir); - - // TODO all size_t - GeneralLstmRedBuffer rb_layout = forwardInterimInfoBuilder(rnnDesc, xDesc); - - GeneralLstmTempBuffer workspace_info = backwardInterimInfoBuilder(rnnDesc, xDesc); - - WeightsBufferDescriptor weights_layout = - WeightsBufferDescriptor::create(static_cast(input_vec_sz), - static_cast(hidden_vec_sz), - layers_cnt, - rnnDesc.biasMode, - rnnDesc.inputMode, - gates_cnt, - is_seq_bidir); - - BatchController batch_controller = BatchController::Create(xDesc); - - HiddenBuffersDescriptor hidden_hxcx_info{hDesc}; - - IOBufferDescriptor x_info{IOBufferDescriptor::build(xDesc)}; - IOBufferDescriptor y_info{IOBufferDescriptor::build(yDesc)}; - - return {std::move(rb_layout), - workspace_info, - weights_layout, - hidden_hxcx_info, - x_info, - y_info, - rnnDesc, - batch_controller, - mode}; - } - - RNNModuleAlgoBase(RNNModuleAlgoBase&&) = default; - // RNNModuleAlgoBase(RNNModuleAlgoBase const&) = default; - - RNNModuleAlgoBase(GeneralLstmRedBuffer rb_layout, - GeneralLstmTempBuffer workspace_info, - WeightsBufferDescriptor weights_layout, - HiddenBuffersDescriptor hidden_hxcx_info, - IOBufferDescriptor x_info, - IOBufferDescriptor y_info, - const RNNDescriptor& rnn_desc, - BatchController batch_controller, - miopenRNNFWDMode_t fwd_mode) - : reservLayout(std::move(rb_layout)), - workspaceInfo(std::move(workspace_info)), - weightsLayout(std::move(weights_layout)), - hiddenHxCxInfo(std::move(hidden_hxcx_info)), - xInfo(std::move(x_info)), - yInfo(std::move(y_info)), - rnnDesc(rnn_desc), - tanhDesc{miopenActivationTANH, 1, 1, 1}, - sigDesc{miopenActivationLOGISTIC, 1, 0, 1}, - reluDesc{miopenActivationRELU, 1, 0, 1}, - batchController(std::move(batch_controller)), - fwdMode(fwd_mode), - isBidirectSeq(false) - { - } - - const GeneralLstmRedBuffer reservLayout; - // const WorkspaceBufferDescriptor workspaceInfo; - const GeneralLstmTempBuffer workspaceInfo; - - const WeightsBufferDescriptor weightsLayout; - const HiddenBuffersDescriptor hiddenHxCxInfo; - - const IOBufferDescriptor xInfo; - const IOBufferDescriptor yInfo; - - const RNNDescriptor& rnnDesc; - - const ActivationDescriptor tanhDesc; - const ActivationDescriptor sigDesc; - const ActivationDescriptor reluDesc; - - const BatchController batchController; - - const miopenRNNFWDMode_t fwdMode; - - const bool isBidirectSeq; - - std::tuple getTempBuffersSize() const - { - - return std::make_tuple(workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType), - reservLayout.getBufferSize() * GetTypeSize(rnnDesc.dataType)); - } - - static std::tuple getTempBuffersSize(const RNNDescriptor& rnnD, - const SeqTensorDescriptor& xDesc) - { - auto wsInfo = backwardInterimInfoBuilder(rnnD, xDesc); - auto reservInfo = forwardInterimInfoBuilder(rnnD, xDesc); - - return std::make_tuple(wsInfo.getBufferSize() * GetTypeSize(rnnD.dataType), - reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); - } - - inline size_t getVirtualLayer(const size_t layer_id, SequenceDirection direction) const - { - return layer_id * (isBidirectSeq ? 2 : 1) + - (direction == SequenceDirection::Forward ? 0 : 1); - } - - template - inline miopen::TensorDescriptor BuildLstmTmpBlockDesc2D(const BufType& buf_info, - const size_t batch_size) const - { - const std::array& tmp_block_stride = buf_info.getGateBlockStride(); - const std::array& tmp_block_size = buf_info.getGateBlockSize(); - - // batch, gateBlock_elements - return miopen::TensorDescriptor{rnnDesc.dataType, - {batch_size, tmp_block_size[3]}, - {tmp_block_stride[1], tmp_block_stride[3]}}; - } - - inline miopen::TensorDescriptor BuildLstmFilterXDesc2D(int layer_id) const - { - assert(rnnDesc.inputMode == 0 || layer_id != 0); - // TODO replace by stride - auto x_vec = layer_id != 0 ? weightsLayout.xInVec : weightsLayout.inVec; - - // gateBlock_elements, ht_vec - return miopen::TensorDescriptor{ - rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, x_vec}, {x_vec, 1}}; - } - - inline miopen::TensorDescriptor BuildLstmFilterHidDesc2D() const - { - // TODO replace by stride - auto h_vec = weightsLayout.hVec; - - // gateBlock_elements, ht_vec - return miopen::TensorDescriptor{ - rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, h_vec}, {h_vec, 1}}; - } - - inline miopen::TensorDescriptor BuildWsHtDesc2D(size_t batch_size) const - { - auto& ht_stride = workspaceInfo.getHiddenStateStride(); - auto& ht_size = workspaceInfo.hStateSizes; - - // batch, gateBlock_elements - return miopen::TensorDescriptor{ - rnnDesc.dataType, {batch_size, ht_size[3]}, {ht_stride[1], ht_stride[3]}}; - } - - // 2 dims batch, vec - inline miopen::TensorDescriptor BuildHxCxDesc2D(size_t batch_size) const - { - const std::vector hx_size{batch_size, hiddenHxCxInfo.getHiddenSize()}; - const std::vector hx_stride{hiddenHxCxInfo.getStrides()[1], - hiddenHxCxInfo.getStrides()[2]}; - - return miopen::TensorDescriptor{rnnDesc.dataType, hx_size, hx_stride}; - } - - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildHxCxDesc3D(size_t layer_size, size_t batch_size) const - { - const std::vector hx_accum_size{ - layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; - - return miopen::TensorDescriptor{ - rnnDesc.dataType, hx_accum_size, hiddenHxCxInfo.getStrides()}; - } +#include "miopen/rnn/algorithms/default_algo_utils.hpp" +#include "miopen/rnn/algorithms/dynamic_algo_utils.hpp" - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildTempDhtDesc3D(size_t layer_size, size_t batch_size) const - { - const std::vector dy_dhy_accum_size{ - layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; - - const auto ws_dy_stride = [](const auto& ws_4dim_strides) -> std::vector { - // convert 4dim stride to 3 dim without direction - // TODO change hiddenBufferDesc - return std::vector{ws_4dim_strides[0], ws_4dim_strides[1], ws_4dim_strides[3]}; - }(workspaceInfo.getHiddenStateStride()); - - return miopen::TensorDescriptor{rnnDesc.dataType, dy_dhy_accum_size, ws_dy_stride}; - } -}; - -class RNNForwardDataModularAlgo : protected RNNModuleAlgoBase -{ -public: - // Compute API - // base API - void PrepareWriteBuffers(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const; - - void PropX(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const; - - void AddBias(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const; - void PropHxCx(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - unsigned int layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; - - void PropHiddenHt(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - int layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; - - void UpdateHStatePerTimeSeq(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - int layer, - const SequenceIterator& seq, - SequenceDirection direction) const; - - void PropHyCy(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; - - void PropHiddenY(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t layer, - SequenceDirection direction) const; - - void PropY(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const; - - // ext API - void PropX(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t gemm_batch_offset, - size_t gemm_batch_size) const; - - void PropHiddenY(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t layer, - SequenceDirection direction, - const SequenceIterator& firstSeq, - const SequenceIterator& lastSeq) const; - - void PropHiddenY(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t layer, - SequenceDirection direction, - size_t gemm_batch_size, - size_t gemm_batch_offset) const; - - void PropX(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - SequenceDirection direction, - const SequenceIterator& firstSeq, - const SequenceIterator& lastSeq) const; - - void PropX(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - SequenceDirection direction) const; - - void PropX(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - SequenceDirection direction, - size_t gemm_batch_offset, - size_t gemm_batch_size) const; - - /// end compute API - - static bool IsApplicable() - { -#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP - return true; -#else - return false; -#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP - } - - std::tuple getTempBuffersSize() const - { - - return std::make_tuple(workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType), - reservLayout.getBufferSize() * GetTypeSize(rnnDesc.dataType)); - } - - static std::tuple getTempBuffersSize(const RNNDescriptor& rnnD, - const SeqTensorDescriptor& xDesc) - { - auto wsInfo = backwardInterimInfoBuilder(rnnD, xDesc); - auto reservInfo = forwardInterimInfoBuilder(rnnD, xDesc); - - return std::make_tuple(wsInfo.getBufferSize() * GetTypeSize(rnnD.dataType), - reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); - } - - RNNForwardDataModularAlgo(RNNModuleAlgoBase base) : RNNModuleAlgoBase(std::move(base)) {} - -private: -}; - -class RNNModuleAlgoDynamic : public RNNForwardDataModularAlgo -{ - static SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) - { - std::vector def_layout{1, 0, 2}; - return {desc.GetType(), def_layout, desc.GetLengths(), false}; - } - - static SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) - { - std::vector def_layout{1, 0, 2}; - return {desc.GetType(), - def_layout, - desc.GetLengths(), - desc.GetSequenceLengthsVector(), - std::vector{}, - true, - true}; - } - -public: - RNNModuleAlgoDynamic(const RNNDescriptor& rnnD, - const SeqTensorDescriptor& xTDesc, - const SeqTensorDescriptor& yTDesc, - const TensorDescriptor& hDesc, - miopenRNNFWDMode_t mode) - : RNNForwardDataModularAlgo(RNNModuleAlgoBase::create( - rnnD, buildDynamicVirtual(xTDesc), buildDynamicVirtual(yTDesc), hDesc, mode)), - realBatchController(BatchController::Create(xTDesc)), - realXDesc(xTDesc), - realYDesc(yTDesc), - tmpMapXDesc(buildRealToDynamicMapTmp(xTDesc)), - tmpMapYDesc(buildRealToDynamicMapTmp(yTDesc)) - { - } - - struct runtimeArgsFwdDynamicExt - { - const ConstData_t realX; - const Data_t tempX; - const ConstData_t hx; - const ConstData_t cx; - const Data_t realY; - const Data_t tempY; - const Data_t hy; - const Data_t cy; - const ConstData_t w; - const Data_t workSpace; - const Data_t reserveSpace; - }; - - runtimeArgsFwdDynamicExt createRuntimeArgsExt(const runtimeArgsFwd& runtimeArgs) const - { - const Data_t temp_x = - moveDataPtr(runtimeArgs.workSpace, workspaceInfo.getBufferSizeImpl(), rnnDesc.dataType); - - const Data_t temp_y = moveDataPtrByte(temp_x, tmpMapXDesc.GetTensorMaxByteSpace()); - - return { - runtimeArgs.x, - temp_x, - runtimeArgs.hx, - runtimeArgs.cx, - runtimeArgs.y, - temp_y, - runtimeArgs.hy, - runtimeArgs.cy, - runtimeArgs.w, - runtimeArgs.workSpace, - runtimeArgs.reserveSpace, - }; - } - - auto getTempBuffersSize() const - { - auto [ws_size, reserve_size] = RNNForwardDataModularAlgo::getTempBuffersSize(); - - return std::make_tuple(ws_size + tmpMapXDesc.GetTensorMaxByteSpace() + - tmpMapYDesc.GetTensorMaxByteSpace(), - reserve_size); - } - - static auto getTempBuffersSize(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) - { - auto y_desc = [](const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { - std::vector y_lenghts{xDesc.GetLengths()}; - y_lenghts[2] = rnnD.hsize * (rnnD.dirMode == miopenRNNbidirection ? 2 : 1); - return SeqTensorDescriptor{xDesc.GetType(), y_lenghts}; - }(rnnD, xDesc); - - auto temp_x_desc = buildDynamicVirtual(xDesc); - auto temp_y_desc = buildDynamicVirtual(y_desc); - - auto [ws_size, reserve_size] = - RNNForwardDataModularAlgo::getTempBuffersSize(rnnD, temp_x_desc); - - return std::make_tuple(ws_size + temp_x_desc.GetTensorMaxByteSpace() + - temp_y_desc.GetTensorMaxByteSpace(), - reserve_size); - } - - void realXProp(const Handle& handle, const runtimeArgsFwdDynamicExt& runtimeArgsExt) const; - - void realYProp(const Handle& handle, const runtimeArgsFwdDynamicExt& runtimeArgsExt) const; - - void PrepareWriteBuffers(const Handle& handle, - const runtimeArgsFwdDynamicExt& runtimeArgsExt, - const runtimeArgsFwd& runtimeArgs) const; - - void PropHyCy(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; - -private: - BatchController realBatchController; - - SeqTensorDescriptor realXDesc; - SeqTensorDescriptor realYDesc; - SeqTensorDescriptor tmpMapXDesc; - SeqTensorDescriptor tmpMapYDesc; -}; - -class RNNBackwardDataModularAlgo : protected RNNModuleAlgoBase -{ -public: - void PrepareWriteBuffers(const Handle& handle, Data_t dhx, Data_t dcx, Data_t workSpace) const; - - void PropDhy(const Handle& handle, - ConstData_t dhy, - Data_t workSpace, - unsigned int layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; - void PropHiddenDht(const Handle& handle, - ConstData_t w, - Data_t workSpace, - int layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; - - void UpdateHStatePerTimeSeq(const Handle& handle, - ConstData_t dcy, - ConstData_t cx, - Data_t, - Data_t workSpace, - Data_t reserveSpace, - int layer, - const SequenceIterator& seq, - SequenceDirection direction) const; - - void UpdateHStatePerTimeSeq(const Handle& handle, - ConstData_t dcy, - ConstData_t cx, - Data_t, - Data_t workSpace, - Data_t reserveSpace, - size_t batchSizeUpdate, - size_t useDcyIfGtBatch, - size_t useCxIfGTBatch, - int layer, - const SequenceIterator& seq, - SequenceDirection direction) const; - - void PropDhxDcx(const Handle& handle, - ConstData_t w, - Data_t dhx, - Data_t dcx, - Data_t workSpace, - Data_t reserveSpace, - size_t layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; - - void PropDy(const Handle& handle, ConstData_t dy, Data_t workSpace) const; - - void PropHiddenDy(const Handle& handle, - ConstData_t w, - Data_t workSpace, - Data_t reserveSpace, - size_t layer, - SequenceDirection direction) const; - - void PropHiddenDy(const Handle& handle, - ConstData_t w, - Data_t workSpace, - Data_t reserveSpace, - size_t layer, - SequenceDirection direction, - const SequenceIterator& firstSeq, - const SequenceIterator& lastSeq) const; - - void PropHiddenDy(const Handle& handle, - ConstData_t w, - Data_t workSpace, - Data_t reserveSpace, - size_t layer, - SequenceDirection direction, - size_t gemm_batch_size, - size_t gemm_batch_offset) const; - - void PropDx(const Handle& handle, - ConstData_t w, - ConstData_t workSpace, - Data_t dx, - SequenceDirection direction, - const SequenceIterator& firstSeq, - const SequenceIterator& lastSeq) const; - - void PropDx(const Handle& handle, - ConstData_t w, - ConstData_t workSpace, - Data_t dx, - SequenceDirection direction) const; - - void PropDx(const Handle& handle, - ConstData_t w, - ConstData_t workSpace, - Data_t dx, - SequenceDirection direction, - size_t gemm_batch_offset, - size_t gemm_batch_size) const; - static bool IsApplicable() - { -#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP - return true; -#else - return false; -#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP - } - - RNNBackwardDataModularAlgo(RNNModuleAlgoBase&& base) : RNNModuleAlgoBase(std::move(base)) {} -}; - -class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo -{ - using BaseBWDModuleT = rnn_base::RNNBackwardDataModularAlgo; - static SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) - { - std::vector def_layout{1, 0, 2}; - return {desc.GetType(), def_layout, desc.GetLengths(), false}; - } - - static SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) - { - std::vector def_layout{1, 0, 2}; - return {desc.GetType(), - def_layout, - desc.GetLengths(), - desc.GetSequenceLengthsVector(), - std::vector{}, - true, - true}; - } - -public: - RNNBackwardModuleAlgoDynamic(const RNNDescriptor& rnnD, - const SeqTensorDescriptor& xTDesc, - const SeqTensorDescriptor& yTDesc, - const TensorDescriptor& hDesc, - miopenRNNFWDMode_t mode) - : BaseBWDModuleT(RNNModuleAlgoBase::create( - rnnD, buildDynamicVirtual(xTDesc), buildDynamicVirtual(yTDesc), hDesc, mode)), - realBatchController(BatchController::Create(xTDesc)), - realDxDesc(xTDesc), - realDyDesc(yTDesc), - tmpMapDxDesc(buildRealToDynamicMapTmp(xTDesc)), - tmpMapDyDesc(buildRealToDynamicMapTmp(yTDesc)) - { - } - - struct runtimeArgsBwdDynamicExt - { - const ConstData_t realDy; - const Data_t tempDy; - const ConstData_t dhy; - const Data_t dhx; - const ConstData_t cx; - const ConstData_t dcy; - const Data_t dcx; - const Data_t realDx; - const Data_t tempDx; - const ConstData_t w; - const Data_t workSpace; - const Data_t reserveSpace; - }; - - runtimeArgsBwdDynamicExt createRuntimeArgsExt(const runtimeArgsBwd& runtimeArgs) const - { - const Data_t temp_dx = - moveDataPtr(runtimeArgs.workSpace, workspaceInfo.getBufferSizeImpl(), rnnDesc.dataType); - - const Data_t temp_dy = moveDataPtrByte(temp_dx, tmpMapDxDesc.GetTensorMaxByteSpace()); - - return { - runtimeArgs.dy, - temp_dy, - runtimeArgs.dhy, - runtimeArgs.dhx, - runtimeArgs.cx, - runtimeArgs.dcy, - runtimeArgs.dcx, - runtimeArgs.dx, - temp_dx, - runtimeArgs.w, - runtimeArgs.workSpace, - runtimeArgs.reserveSpace, - }; - } - - auto getTempBuffersSize() const - { - auto [ws_size, reserve_size] = BaseBWDModuleT::getTempBuffersSize(); - - return std::make_tuple(ws_size + tmpMapDxDesc.GetTensorMaxByteSpace() + - tmpMapDyDesc.GetTensorMaxByteSpace(), - reserve_size); - } - - static auto getTempBuffersSize(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) - { - auto y_desc = [](const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { - std::vector y_lenghts{xDesc.GetLengths()}; - y_lenghts[2] = rnnD.hsize * (rnnD.dirMode == miopenRNNbidirection ? 2 : 1); - return SeqTensorDescriptor{xDesc.GetType(), y_lenghts}; - }(rnnD, xDesc); - - auto temp_x_desc = buildDynamicVirtual(xDesc); - auto temp_y_desc = buildDynamicVirtual(y_desc); - - auto [ws_size, reserve_size] = - RNNForwardDataModularAlgo::getTempBuffersSize(rnnD, temp_x_desc); - - return std::make_tuple(ws_size + temp_x_desc.GetTensorMaxByteSpace() + - temp_y_desc.GetTensorMaxByteSpace(), - reserve_size); - } - - void realDxProp(const Handle& handle, const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; - - void realDyProp(const Handle& handle, const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; - - void realPropDhy(const Handle& handle, - ConstData_t dhy, - Data_t workSpace, - unsigned int layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; - - void realUpdateHStatePerTimeSeq(const Handle& handle, - ConstData_t dcy, - ConstData_t cx, - Data_t, - Data_t workSpace, - Data_t reserveSpace, - int layer, - const SequenceIterator& seq, - SequenceDirection direction) const; - - void PrepareWriteBuffers(const Handle& handle, - const runtimeArgsBwdDynamicExt& runtimeArgsExt) const; - - void HtHiddenDataZeroing() const; - - // void PrepareWriteBuffers(const Handle& handle, - // const runtimeArgsBwdDynamicExt& runtimeArgsExt, - // const runtimeArgsFwd& runtimeArgs) const; - - void PropHyCy(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; +namespace miopen { -private: - BatchController realBatchController; +namespace rnn_base { - SeqTensorDescriptor realDxDesc; - SeqTensorDescriptor realDyDesc; - SeqTensorDescriptor tmpMapDxDesc; - SeqTensorDescriptor tmpMapDyDesc; -}; +// +// Forward data +// class RNNModularSingleStreamFWD { @@ -861,6 +114,10 @@ class RNNDynamicModularSingleStreamFWD const size_t max_seq_len; }; +// +// Backward Data +// + class RNNModularSingleStreamBWD { public: @@ -993,159 +250,21 @@ class RNNModularMultiStreamBWD const size_t max_seq_len; }; -class RNNBackwardWeightsModularAlgo +// +//Backward Weights +// + +class RNNModularSingleStreamBWWeights { public: - static RNNBackwardWeightsModularAlgo create(const RNNDescriptor& rnnDesc, + RNNModularSingleStreamBWWeights(const RNNDescriptor& rnn, const SeqTensorDescriptor& xDesc, const SeqTensorDescriptor& yDesc, const TensorDescriptor& hDesc) + : rnnAlgoModules(RNNModuleAlgoBase::create(rnn, xDesc, yDesc, hDesc, miopenRNNTraining)), + rnnDesc(rnn), + max_seq_len(xDesc.GetMaxSequenceLength()) { - auto [max_layers_hid, max_batch_hid, hidden_vec_sz] = miopen::tien<3>(hDesc.GetLengths()); - auto [max_batch_in, max_seq, input_vec_sz] = miopen::tien<3>(xDesc.GetLengths()); - - assert(max_batch_in <= max_batch_hid); - - size_t layers_cnt = rnnDesc.nLayers; - const bool is_seq_bidir = rnnDesc.dirMode == miopenRNNbidirection; - - assert(layers_cnt * (is_seq_bidir ? 2 : 1) <= max_layers_hid); - - auto gates_cnt = static_cast(rnnDesc.nHiddenTensorsPerLayer); - - // class update req - assert(!is_seq_bidir); - const size_t seq_directions = is_seq_bidir ? 2 : 1; - - GeneralLstmRedBuffer rb_layout = GeneralLstmRedBuffer::build( - layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); - - GeneralLstmTempBuffer workspace_info = GeneralLstmTempBuffer::build( - layers_cnt, xDesc.GetTotalSequenceLen(), seq_directions, hidden_vec_sz); - - WeightsBufferDescriptor weights_layout = WeightsBufferDescriptor::create(input_vec_sz, - hidden_vec_sz, - layers_cnt, - rnnDesc.biasMode, - rnnDesc.inputMode, - gates_cnt, - is_seq_bidir); - - BatchController batch_controller = BatchController::Create(xDesc); - - HiddenBuffersDescriptor hidden_hxcx_info{hDesc}; - - IOBufferDescriptor x_info{IOBufferDescriptor::build(xDesc)}; - IOBufferDescriptor y_info{IOBufferDescriptor::build(yDesc)}; - - return {std::move(rb_layout), - workspace_info, - weights_layout, - hidden_hxcx_info, - x_info, - y_info, - rnnDesc, - batch_controller}; - } - - void PrepareWriteBuffers(const Handle& handle, Data_t w) const; - - void PhisXInputWeights(const Handle& handle, Data_t dw, Data_t workSpace, ConstData_t x) const; - - void HiddenXInputWeights(const Handle& handle, - Data_t dw, - ConstData_t workSpace, - ConstData_t reserveSpace, - size_t layer) const; - - void BiasUpdate(const Handle& handle, - Data_t dw, - Data_t workSpace, - size_t layer, - size_t workSpaceSize) const; - - void HiddenHStateWeights(const Handle& handle, - Data_t dw, - ConstData_t workSpace, - ConstData_t reserveSpace, - const SequenceIterator& seq, - size_t layer, - SequenceDirection direction) const - { - const size_t gemm_batch_size = [&]() -> size_t { - if(seq.isFirst()) - return 0; - - if(direction == SequenceDirection::Reverse) - return batchController.getBatchSize(seq.getPhisVal()); - else - return batchController.getBatchSize(seq.getPrev().getPhisVal()); - }(); - - if(gemm_batch_size != 0) - return HiddenHStateWeights_Unchecked( - handle, dw, workSpace, reserveSpace, seq, layer, direction, gemm_batch_size); - } - - void HiddenHStateWeights(const Handle& handle, - Data_t dw, - ConstData_t workSpace, - ConstData_t reserveSpace, - size_t layer, - size_t max_seq_len, - const SequenceDirection direction) const - { - size_t start_seq_id = 0; - const size_t last_seq = max_seq_len - 1; - for(auto i = start_seq_id + 1; i <= last_seq; i++) - { - - if(batchController.getBatchSize(i) != batchController.getBatchSize(start_seq_id) || - i == last_seq) - { - const size_t gemm_batch_size = (batchController.getBatchSum(i - 1) - - batchController.getBatchSum(start_seq_id)) + - batchController.getBatchSize(i); - - if(gemm_batch_size != 0) - { - const auto first_logical_val = direction == SequenceDirection::Forward - ? start_seq_id - : (max_seq_len - 1) - start_seq_id - 1; - const auto seq = - SequenceIterator(first_logical_val, direction, max_seq_len, false); - - HiddenHStateWeights_Unchecked(handle, - dw, - workSpace, - reserveSpace, - seq, - layer, - direction, - gemm_batch_size); - } - start_seq_id = i; - } - } - } - - void PhisHStateWeights(const Handle& handle, - Data_t dw, - ConstData_t workSpace, - ConstData_t hx, - size_t layer, - size_t max_seq_len, - SequenceDirection direction) const - { - if(hx == nullptr) - return; - - for(auto i = max_seq_len; i > 0; i--) - { - const auto seq = SequenceIterator(i - 1, direction, max_seq_len, false); - - PhisHStateWeights(handle, dw, workSpace, hx, seq, layer, direction); - } } static bool IsApplicable() @@ -1157,179 +276,21 @@ class RNNBackwardWeightsModularAlgo #endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP } -private: - RNNBackwardWeightsModularAlgo(GeneralLstmRedBuffer rb_layout, - GeneralLstmTempBuffer workspace_info, - WeightsBufferDescriptor weights_layout, - HiddenBuffersDescriptor hidden_hxcx_info, - IOBufferDescriptor x_info, - IOBufferDescriptor y_info, - const RNNDescriptor& rnn_desc, - BatchController batch_controller) - : reservLayout(std::move(rb_layout)), - workspaceInfo(std::move(workspace_info)), - weightsLayout(std::move(weights_layout)), - hiddenHxCxInfo(std::move(hidden_hxcx_info)), - xInfo(std::move(x_info)), - yInfo(std::move(y_info)), - rnnDesc(rnn_desc), - batchController(std::move(batch_controller)) - { - } + // TODO + static size_t GetWsSize() { return 0; }; - void HiddenHStateWeights_Unchecked(const Handle& handle, + void Compute(const Handle& handle, + ConstData_t x, + ConstData_t hx, Data_t dw, - ConstData_t workSpace, + Data_t workSpace, + size_t /*workSpaceSize*/, ConstData_t reserveSpace, - const SequenceIterator& seq, - size_t layer, - SequenceDirection direction, - size_t gemm_batch_size) const; - - void PhisHStateWeights(const Handle& handle, - Data_t dw, - ConstData_t workSpace, - ConstData_t hx, - const SequenceIterator& seq, - size_t layer, - SequenceDirection direction) const; - - template - inline miopen::TensorDescriptor BuildLstmTmpBlockDesc2D(const BufType& buf_info, - const size_t batch_size) const - { - const std::array& tmp_block_stride = buf_info.getGateBlockStride(); - const std::array& tmp_block_size = buf_info.getGateBlockSize(); - - // batch, gateBlock_elements - return miopen::TensorDescriptor{rnnDesc.dataType, - {batch_size, tmp_block_size[3]}, - {tmp_block_stride[1], tmp_block_stride[3]}}; - } - - inline miopen::TensorDescriptor BuildLstmFilterXDesc2D(int layer_id) const - { - assert(rnnDesc.inputMode == 0 || layer_id != 0); - // TODO replace by stride - auto x_vec = layer_id != 0 ? weightsLayout.xInVec : weightsLayout.inVec; - - // gateBlock_elements, ht_vec - return miopen::TensorDescriptor{ - rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, x_vec}, {x_vec, 1}}; - } - - inline miopen::TensorDescriptor BuildLstmFilterHidDesc2D() const - { - // TODO replace by stride - auto h_vec = weightsLayout.hVec; - - // gateBlock_elements, ht_vec - return miopen::TensorDescriptor{ - rnnDesc.dataType, {weightsLayout.gatesCnt * weightsLayout.hVec, h_vec}, {h_vec, 1}}; - } - - template - inline miopen::TensorDescriptor BuildTmpHtDesc2D(const BufType& buf_info, - size_t batch_size) const - { - auto& ht_stride = buf_info.getHiddenStateStride(); - auto& ht_size = buf_info.hStateSizes; - - // batch, gateBlock_elements - return miopen::TensorDescriptor{ - rnnDesc.dataType, {batch_size, ht_size[3]}, {ht_stride[1], ht_stride[3]}}; - } - - // 2 dims batch, vec - inline miopen::TensorDescriptor BuildHxCxDesc2D(size_t batch_size) const - { - const std::vector hx_size{batch_size, hiddenHxCxInfo.getHiddenSize()}; - const std::vector hx_stride{hiddenHxCxInfo.getStrides()[1], - hiddenHxCxInfo.getStrides()[2]}; - - return miopen::TensorDescriptor{rnnDesc.dataType, hx_size, hx_stride}; - } - - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildHxCxDesc3D(size_t layer_size, size_t batch_size) const - { - const std::vector hx_accum_size{ - layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; - - return miopen::TensorDescriptor{ - rnnDesc.dataType, hx_accum_size, hiddenHxCxInfo.getStrides()}; - } - - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildTempDhtDesc3D(size_t layer_size, size_t batch_size) const - { - const std::vector dy_dhy_accum_size{ - layer_size, batch_size, hiddenHxCxInfo.getHiddenSize()}; - - const auto ws_dy_stride = [](const auto& ws_4dim_strides) -> std::vector { - // convert 4dim stride to 3 dim without direction - // TODO change hiddenBufferDesc - return std::vector{ws_4dim_strides[0], ws_4dim_strides[1], ws_4dim_strides[3]}; - }(workspaceInfo.getHiddenStateStride()); - - return miopen::TensorDescriptor{rnnDesc.dataType, dy_dhy_accum_size, ws_dy_stride}; - } - - // 3 dims layer, batch, vec - inline miopen::TensorDescriptor BuildWeiBiasDesc2D() const - { - const std::vector bias_size = [](const auto& wei_4dim_size) -> std::vector { - // wei_4dim_size{layer, dir, gate, vec} - return {1, wei_4dim_size[1] * wei_4dim_size[2] * wei_4dim_size[3]}; - }(weightsLayout.getBiasSize()); - - const auto bias_stride = [](const auto& wei_4dim_strides) -> std::vector { - // convert 4dim stride to 2 dim without direction - return std::vector{wei_4dim_strides[0], wei_4dim_strides[3]}; - }(weightsLayout.getBiasStride()); - - return miopen::TensorDescriptor{rnnDesc.dataType, bias_size, bias_stride}; - } - - inline size_t getVirtualLayer(const size_t layer_id, SequenceDirection direction) const - { - return layer_id * (isBidirectSeq ? 2 : 1) + - (direction == SequenceDirection::Forward ? 0 : 1); - } - - inline size_t getHxBatchSizeReadAtTime(const SequenceIterator& seq, - SequenceDirection direction) const - { - if(seq.isLast()) - return batchController.getBatchSize(seq.getPhisVal()); - - if(direction == SequenceDirection::Reverse) - { - return batchController.getBatchSize(seq.getPhisVal()) - - batchController.getBatchSize(seq.getPrev().getPhisVal()); - } - return 0; - } - - const GeneralLstmRedBuffer reservLayout; - // const WorkspaceBufferDescriptor workspaceInfo; - const GeneralLstmTempBuffer workspaceInfo; - - const WeightsBufferDescriptor weightsLayout; - const HiddenBuffersDescriptor hiddenHxCxInfo; - - const IOBufferDescriptor xInfo; - const IOBufferDescriptor yInfo; + size_t /*reserveSpaceSize*/) const; + const rnn_base::RNNBackwardWeightsModularAlgo rnnAlgoModules; const RNNDescriptor& rnnDesc; - - const ActivationDescriptor tanhDesc = {miopenActivationTANH, 1, 1, 1}; - const ActivationDescriptor sigDesc = {miopenActivationLOGISTIC, 1, 0, 1}; - const ActivationDescriptor reluDesc = {miopenActivationRELU, 1, 0, 1}; - - const BatchController batchController; - - const bool isBidirectSeq = false; + const size_t max_seq_len; }; class RNNModularSingleStreamBWWeights @@ -1378,7 +339,7 @@ class RNNModularMultiStreamBWWeights const SeqTensorDescriptor& xDesc, const SeqTensorDescriptor& yDesc, const TensorDescriptor& hDesc) - : rnnAlgoModules(RNNBackwardWeightsModularAlgo::create(rnn, xDesc, yDesc, hDesc)), + : rnnAlgoModules(RNNModuleAlgoBase::create(rnn, xDesc, yDesc, hDesc, miopenRNNTraining)), rnnDesc(rnn), max_seq_len(xDesc.GetMaxSequenceLength()) { @@ -1396,15 +357,7 @@ class RNNModularMultiStreamBWWeights // TODO static size_t GetWsSize() { return 0; }; - struct runtimeArgsBww - { - const Handle* handle; - ConstData_t x; - ConstData_t hx; - Data_t dw; - Data_t workSpace; - ConstData_t reserveSpace; - }; + void Compute(const Handle& handle, ConstData_t x, @@ -1416,7 +369,7 @@ class RNNModularMultiStreamBWWeights size_t /*reserveSpaceSize*/) const; private: - void PrologueDispatch(const runtimeArgsBww& args) const; + void PrologueDispatch(const runtimeArgsBWWeights& args) const; const rnn_base::RNNBackwardWeightsModularAlgo rnnAlgoModules; const RNNDescriptor& rnnDesc; diff --git a/src/rnn/Solutions/Base/bw_data_modular.cpp b/src/rnn/Solutions/Base/bw_data_modular.cpp index a8984482a2..b9042ad13c 100644 --- a/src/rnn/Solutions/Base/bw_data_modular.cpp +++ b/src/rnn/Solutions/Base/bw_data_modular.cpp @@ -145,7 +145,7 @@ void RNNBackwardDataModularAlgo::PropHiddenDht(const Handle& handle, const miopen::TensorDescriptor& filter_src_dsc = BuildLstmFilterHidDesc2D(); - const miopen::TensorDescriptor& ht_dest_dsc = BuildWsHtDesc2D(gemm_batch_size); + const miopen::TensorDescriptor& ht_dest_dsc = BuildTmpHtDesc2D(workspaceInfo ,gemm_batch_size); RnnBaseFunctions::BWD_GEMM_Hidden_Prop( handle, @@ -444,7 +444,7 @@ void RNNBackwardDataModularAlgo::PropHiddenDy(const Handle& handle, const auto filter_src_dsc = BuildLstmFilterXDesc2D(layer); - const auto ht_x_desc = BuildWsHtDesc2D(gemm_batch_size); + const auto ht_x_desc = BuildTmpHtDesc2D(workspaceInfo, gemm_batch_size); RnnBaseFunctions::BWD_GEMM_Hidden_Prop(handle, workSpace, diff --git a/src/rnn/Solutions/Base/fw_data_modular.cpp b/src/rnn/Solutions/Base/fw_data_modular.cpp index 458092c999..235c291d3d 100644 --- a/src/rnn/Solutions/Base/fw_data_modular.cpp +++ b/src/rnn/Solutions/Base/fw_data_modular.cpp @@ -290,7 +290,7 @@ void RNNForwardDataModularAlgo::PropHiddenHt(const Handle& handle, const auto filter_offset = weightsLayout.getMatrixHidOff(layer, static_cast(direction)); - const miopen::TensorDescriptor& ht_dest_dsc = BuildWsHtDesc2D(gemm_batch_size); + const miopen::TensorDescriptor& ht_dest_dsc = BuildTmpHtDesc2D(workspaceInfo, gemm_batch_size); const miopen::TensorDescriptor tmp_block_src_dsc = BuildLstmTmpBlockDesc2D(workspaceInfo, gemm_batch_size); @@ -437,7 +437,7 @@ void RNNForwardDataModularAlgo::PropHiddenY(const Handle& handle, const miopen::TensorDescriptor tmp_block_src_dsc = BuildLstmTmpBlockDesc2D(reservLayout, gemm_batch_size); - const auto tmp_ht_desc = BuildWsHtDesc2D(gemm_batch_size); + const auto tmp_ht_desc = BuildTmpHtDesc2D(reservLayout, gemm_batch_size); if(rnnDesc.rnnMode == miopenLSTM) { diff --git a/src/rnn/Solutions/bww_multi_stream.cpp b/src/rnn/Solutions/bww_multi_stream.cpp index 1f480afdea..503cda068d 100644 --- a/src/rnn/Solutions/bww_multi_stream.cpp +++ b/src/rnn/Solutions/bww_multi_stream.cpp @@ -33,7 +33,7 @@ namespace miopen { namespace rnn_base { -void RNNModularMultiStreamBWWeights::PrologueDispatch(const runtimeArgsBww& args) const +void RNNModularMultiStreamBWWeights::PrologueDispatch(const runtimeArgsBWWeights& args) const { rnnAlgoModules.PrepareWriteBuffers(*args.handle, args.dw); } @@ -51,7 +51,7 @@ void RNNModularMultiStreamBWWeights::Compute(const Handle& handle, if(rnnDesc.nLayers == 0 || max_seq_len == 0) return; - const runtimeArgsBww args{&handle, x, hx, dw, workSpace, reserveSpace}; + const runtimeArgsBWWeights args{&handle, x, hx, dw, workSpace, reserveSpace}; MultiStreamController ms_controller{handle, env::value_or(MIOPEN_RNN_MS_STREAM_CNT, 4)}; From 633d0a4f2c35db31a109b0276da439f334355bcf Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Wed, 13 Nov 2024 16:51:00 +0100 Subject: [PATCH 08/21] test fix --- test/rnn_seq_api.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rnn_seq_api.hpp b/test/rnn_seq_api.hpp index 177fa4d6ff..91eb3e772b 100644 --- a/test/rnn_seq_api.hpp +++ b/test/rnn_seq_api.hpp @@ -1233,7 +1233,7 @@ struct verify_train_rnn : verify_rnn_api_base const auto bwd_dhx = readTFromGPUOrEmpty(handle, dhx_dev, xHiddenState, nodhx); const auto bwd_dcx = readTFromGPUOrEmpty(handle, dcx_dev, xCellState, nodcx); - if(skip_backward_data) + if(skip_backward_weights) return result_tuple(fwd_y, fwd_hy, fwd_cy, bwd_din, bwd_dhx, bwd_dcx, {}); std::vector workSpace_bwd_out(workSpace_TCnt); From 757df2b8133a0d827526c6daa05465ceba4ca161 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Wed, 13 Nov 2024 17:14:09 +0100 Subject: [PATCH 09/21] back weights --- .../rnn/algorithms/dynamic_algo_utils.hpp | 137 ++++++++++++++++++ src/include/miopen/rnn/solvers.hpp | 35 +++-- src/rnn/Solutions/Base/bw_weights_modular.cpp | 41 +++++- src/rnn/Solutions/bww_s_steam.cpp | 46 ++++++ src/rnn/selector.cpp | 16 +- 5 files changed, 257 insertions(+), 18 deletions(-) diff --git a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp index 8bb99d4133..4bf96c8e97 100644 --- a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp @@ -308,5 +308,142 @@ class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo }; +class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo +{ + using BaseBWDModuleT = rnn_base::RNNBackwardWeightsModularAlgo; + + static SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) + { + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), def_layout, desc.GetLengths(), false}; + } + + static SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) + { + std::vector def_layout{1, 0, 2}; + return {desc.GetType(), + def_layout, + desc.GetLengths(), + desc.GetSequenceLengthsVector(), + std::vector{}, + true, + true}; + } + +public: + RNNBackwardWeiModuleAlgoDynamic(const RNNDescriptor& rnnD, + const SeqTensorDescriptor& xTDesc, + const SeqTensorDescriptor& yTDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + : BaseBWDModuleT(RNNModuleAlgoBase::create( + rnnD, buildDynamicVirtual(xTDesc), buildDynamicVirtual(yTDesc), hDesc, mode)), + realBatchController(BatchController::Create(xTDesc)), + realXDesc(xTDesc), + tmpMapXDesc(buildRealToDynamicMapTmp(xTDesc)) + + { + } + + struct runtimeArgsBwWeiDynamicExt + { + const ConstData_t realX; + const Data_t tempX; + const ConstData_t hx; + const Data_t dw; + const Data_t workSpace; + const ConstData_t reserveSpace; + }; + + runtimeArgsBwWeiDynamicExt createRuntimeArgsExt(const runtimeArgsBWWeights& runtimeArgs) const + { + const Data_t temp_x = + moveDataPtr(runtimeArgs.workSpace, workspaceInfo.getBufferSizeImpl(), rnnDesc.dataType); + + return { + runtimeArgs.x, + temp_x, + runtimeArgs.hx, + runtimeArgs.dw, + runtimeArgs.workSpace, + runtimeArgs.reserveSpace, + }; + } + + auto getTempBuffersSize() const + { + auto [ws_size, reserve_size] = BaseBWDModuleT::getTempBuffersSize(); + + return std::make_tuple(ws_size + tmpMapXDesc.GetTensorMaxByteSpace() + + reserve_size); + } + + static auto getTempBuffersSize(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) + { + auto y_desc = [](const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) { + std::vector y_lenghts{xDesc.GetLengths()}; + y_lenghts[2] = rnnD.hsize * (rnnD.dirMode == miopenRNNbidirection ? 2 : 1); + return SeqTensorDescriptor{xDesc.GetType(), y_lenghts}; + }(rnnD, xDesc); + + auto temp_x_desc = buildDynamicVirtual(xDesc); + auto temp_y_desc = buildDynamicVirtual(y_desc); + + auto [ws_size, reserve_size] = + RNNForwardDataModularAlgo::getTempBuffersSize(rnnD, temp_x_desc); + + return std::make_tuple(ws_size + temp_x_desc.GetTensorMaxByteSpace() + + temp_y_desc.GetTensorMaxByteSpace(), + reserve_size); + } + + void PhisHStateWeights(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t hx, + const SequenceIterator& seq, + size_t layer, + SequenceDirection direction) const; + + void PhisHStateWeights(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t hx, + size_t layer, + size_t max_seq_len, + SequenceDirection direction) const + { + if(hx == nullptr) + return; + + for(auto i = max_seq_len; i > 0; i--) + { + const auto seq = SequenceIterator(i - 1, direction, max_seq_len, false); + + PhisHStateWeights(handle, dw, workSpace, hx, seq, layer, direction); + } + } + + void realXProp(const Handle& handle, + const runtimeArgsBwWeiDynamicExt& runtimeArgsExt) const + { + + RNNTensorBaseLayoutConverter::ConvertInputTensorGPUData(handle, + realXDesc, + runtimeArgsExt.realX, + tmpMapXDesc, + runtimeArgsExt.tempX, + nullptr, + false); + } + +private: + BatchController realBatchController; + + SeqTensorDescriptor realXDesc; + SeqTensorDescriptor tmpMapXDesc; +}; + + } // namespace rnn_base } // namespace miopen diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index f575bdbfaf..aee295e0ed 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -258,9 +258,9 @@ class RNNModularSingleStreamBWWeights { public: RNNModularSingleStreamBWWeights(const RNNDescriptor& rnn, - const SeqTensorDescriptor& xDesc, - const SeqTensorDescriptor& yDesc, - const TensorDescriptor& hDesc) + const SeqTensorDescriptor& xDesc, + const SeqTensorDescriptor& yDesc, + const TensorDescriptor& hDesc) : rnnAlgoModules(RNNModuleAlgoBase::create(rnn, xDesc, yDesc, hDesc, miopenRNNTraining)), rnnDesc(rnn), max_seq_len(xDesc.GetMaxSequenceLength()) @@ -282,10 +282,10 @@ class RNNModularSingleStreamBWWeights void Compute(const Handle& handle, ConstData_t x, ConstData_t hx, - Data_t dw, + Data_t dw, Data_t workSpace, size_t /*workSpaceSize*/, - ConstData_t reserveSpace, + ConstData_t reserveSpace, size_t /*reserveSpaceSize*/) const; const rnn_base::RNNBackwardWeightsModularAlgo rnnAlgoModules; @@ -293,14 +293,16 @@ class RNNModularSingleStreamBWWeights const size_t max_seq_len; }; -class RNNModularSingleStreamBWWeights +class RNNDynamicModularSingleStreamBWWeights { +private: public: - RNNModularSingleStreamBWWeights(const RNNDescriptor& rnn, - const SeqTensorDescriptor& xDesc, - const SeqTensorDescriptor& yDesc, - const TensorDescriptor& hDesc) - : rnnAlgoModules(RNNBackwardWeightsModularAlgo::create(rnn, xDesc, yDesc, hDesc)), + RNNDynamicModularSingleStreamBWWeights(const RNNDescriptor& rnn, + const SeqTensorDescriptor& xDesc, + const SeqTensorDescriptor& yDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) + : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), rnnDesc(rnn), max_seq_len(xDesc.GetMaxSequenceLength()) { @@ -315,8 +317,13 @@ class RNNModularSingleStreamBWWeights #endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP } - // TODO - static size_t GetWsSize() { return 0; }; + auto getTempBuffersSize() const { return rnnAlgoModules.getTempBuffersSize(); } + + static auto getTempBuffersSize(const RNNDescriptor& rnn, const SeqTensorDescriptor& xDesc) + { + return decltype(rnnAlgoModules)::getTempBuffersSize(rnn, xDesc); + } + void Compute(const Handle& handle, ConstData_t x, @@ -327,7 +334,7 @@ class RNNModularSingleStreamBWWeights ConstData_t reserveSpace, size_t /*reserveSpaceSize*/) const; - const rnn_base::RNNBackwardWeightsModularAlgo rnnAlgoModules; + const rnn_base::RNNBackwardWeiModuleAlgoDynamic rnnAlgoModules; const RNNDescriptor& rnnDesc; const size_t max_seq_len; }; diff --git a/src/rnn/Solutions/Base/bw_weights_modular.cpp b/src/rnn/Solutions/Base/bw_weights_modular.cpp index 598d002cb0..87c44de985 100644 --- a/src/rnn/Solutions/Base/bw_weights_modular.cpp +++ b/src/rnn/Solutions/Base/bw_weights_modular.cpp @@ -323,7 +323,46 @@ void RNNBackwardWeightsModularAlgo::PhisHStateWeights(const Handle& handle, size_t layer, SequenceDirection direction) const { - const size_t gemm_batch_size = getHxBatchSizeReadAtTime(seq, direction); + const size_t gemm_batch_size = getHxBatchSizeReadAtTime(seq, batchController, direction); + + if(gemm_batch_size == 0 || hx == nullptr) + return; + + const size_t batch_shift = batchController.getBatchSum(seq.getPhisVal()) + + (batchController.getBatchSize(seq.getPhisVal()) - gemm_batch_size); + + const auto virt_layer = getVirtualLayer(layer, direction); + + const size_t block_offset = workspaceInfo.getGateBlockOffset(layer, batch_shift, direction); + const size_t hx_offset = hiddenHxCxInfo.getOffset(virt_layer, batch_shift); + const size_t filter_offset = weightsLayout.getMatrixHidOff(layer, static_cast(direction)); + + const TensorDescriptor block_dsc = BuildLstmTmpBlockDesc2D(workspaceInfo, gemm_batch_size); + const TensorDescriptor hx_desc = BuildHxCxDesc2D(gemm_batch_size); + const TensorDescriptor filter_dsc = BuildLstmFilterHidDesc2D(); + + RnnBaseFunctions::BWWei_GEMM(handle, + workSpace, + block_offset, + block_dsc, + hx, + hx_offset, + hx_desc, + dw, + filter_offset, + filter_dsc, + true); +} + +void RNNBackwardWeiModuleAlgoDynamic::PhisHStateWeights(const Handle& handle, + Data_t dw, + ConstData_t workSpace, + ConstData_t hx, + const SequenceIterator& seq, + size_t layer, + SequenceDirection direction) const +{ + const size_t gemm_batch_size = getHxBatchSizeReadAtTime(seq, realBatchController, direction); if(gemm_batch_size == 0 || hx == nullptr) return; diff --git a/src/rnn/Solutions/bww_s_steam.cpp b/src/rnn/Solutions/bww_s_steam.cpp index 736d8cfde3..450a1e6873 100644 --- a/src/rnn/Solutions/bww_s_steam.cpp +++ b/src/rnn/Solutions/bww_s_steam.cpp @@ -70,5 +70,51 @@ void RNNModularSingleStreamBWWeights::Compute(const Handle& handle, } } + +void RNNDynamicModularSingleStreamBWWeights::Compute(const Handle& handle, + ConstData_t x, + ConstData_t hx, + Data_t dw, + Data_t workSpace, + size_t workSpaceSize, + ConstData_t reserveSpace, + size_t /*reserveSpaceSize*/) const +{ + const auto args_ext = rnnAlgoModules.createRuntimeArgsExt( + runtimeArgsBWWeights{&handle, x, hx, dw, workSpace, reserveSpace}); + + if(rnnDesc.nLayers == 0 || max_seq_len == 0) + return; + + auto sequence_directions = + rnnDesc.dirMode == miopenRNNDirectionMode_t::miopenRNNbidirection ? 2 : 1; + + rnnAlgoModules.PrepareWriteBuffers(handle, dw); + + rnnAlgoModules.realXProp(handle, args_ext); + + for(int layer_i = 0; layer_i < rnnDesc.nLayers; layer_i++) + { + if(layer_i == 0) + rnnAlgoModules.PhisXInputWeights(handle, dw, workSpace, args_ext.tempX); + else + rnnAlgoModules.HiddenXInputWeights(handle, dw, workSpace, reserveSpace, layer_i); + + rnnAlgoModules.BiasUpdate(handle, dw, workSpace, layer_i, workSpaceSize); + + for(int dir = 0; dir < sequence_directions; dir++) + { + const auto seq_dir = dir == 0 ? rnn_base::SequenceDirection::Forward + : rnn_base::SequenceDirection::Reverse; + + rnnAlgoModules.PhisHStateWeights( + handle, dw, workSpace, hx, layer_i, max_seq_len, seq_dir); + + rnnAlgoModules.HiddenHStateWeights( + handle, dw, workSpace, reserveSpace, layer_i, max_seq_len, seq_dir); + } + } +} + } // namespace rnn_base } // namespace miopen diff --git a/src/rnn/selector.cpp b/src/rnn/selector.cpp index aeab61ce93..34f45712b2 100644 --- a/src/rnn/selector.cpp +++ b/src/rnn/selector.cpp @@ -186,9 +186,19 @@ void RNNDescriptor::ModularBackwardWeights(Handle& handle, } else { - rnn_base::RNNModularSingleStreamBWWeights single_stream{*this, xDesc, yDesc, hDesc}; - single_stream.Compute( - handle, x, hx, dw, workSpace, workSpaceSize, reserveSpace, reserveSpaceSize); + if(CheckDynamicAlgoSelection(handle, xDesc, miopenRNNFWDMode_t::miopenRNNTraining)) + { + rnn_base::RNNDynamicModularSingleStreamBWWeights single_stream{ + *this, xDesc, yDesc, hDesc, miopenRNNFWDMode_t::miopenRNNTraining}; + single_stream.Compute( + handle, x, hx, dw, workSpace, workSpaceSize, reserveSpace, reserveSpaceSize); + } + else + { + rnn_base::RNNModularSingleStreamBWWeights single_stream{*this, xDesc, yDesc, hDesc}; + single_stream.Compute( + handle, x, hx, dw, workSpace, workSpaceSize, reserveSpace, reserveSpaceSize); + } } } From cceef900435afd563b527ab4d171aa9ccb6146b7 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Fri, 15 Nov 2024 18:36:33 +0100 Subject: [PATCH 10/21] data management fix --- .../rnn/algorithms/default_algo_utils.hpp | 10 ++-- .../rnn/algorithms/dynamic_algo_utils.hpp | 50 +++++++++++++------ src/include/miopen/rnn/solvers.hpp | 21 +++++++- src/rnn/Solutions/Base/bw_weights_modular.cpp | 38 +++++++------- src/rnn/Solutions/bww_multi_stream.cpp | 20 +++++--- src/rnn/Solutions/bww_s_steam.cpp | 39 ++++++++++----- 6 files changed, 122 insertions(+), 56 deletions(-) diff --git a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp index 4b0839431d..0e86149d50 100644 --- a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp @@ -65,8 +65,10 @@ struct runtimeArgsBWWeights const ConstData_t x; const ConstData_t hx; const Data_t dw; - const Data_t workSpace; - const ConstData_t reserveSpace; + const ConstData_t backData; + const ConstData_t forwardData; + const Data_t freeWorkSpace; + const size_t freeWorkSpaceSize; }; @@ -546,7 +548,8 @@ class RNNBackwardWeightsModularAlgo : public RNNModuleAlgoBase public: void PrepareWriteBuffers(const Handle& handle, Data_t w) const; - void PhisXInputWeights(const Handle& handle, Data_t dw, Data_t workSpace, ConstData_t x) const; + void + PhisXInputWeights(const Handle& handle, Data_t dw, ConstData_t workSpace, ConstData_t x) const; void HiddenXInputWeights(const Handle& handle, Data_t dw, @@ -556,6 +559,7 @@ class RNNBackwardWeightsModularAlgo : public RNNModuleAlgoBase void BiasUpdate(const Handle& handle, Data_t dw, + ConstData_t backData, Data_t workSpace, size_t layer, size_t workSpaceSize) const; diff --git a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp index 4bf96c8e97..4d76eaf3f3 100644 --- a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp @@ -42,11 +42,17 @@ inline SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) inline SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) { std::vector def_layout{1, 0, 2}; + + auto zero_val_padding = [](miopenDataType_t type) { + std ::vector padding_fill(GetTypeSize(type),0); + return padding_fill; + }; + return {desc.GetType(), def_layout, desc.GetLengths(), desc.GetSequenceLengthsVector(), - std::vector{}, + zero_val_padding(desc.GetType()), true, true}; } @@ -351,23 +357,27 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo const Data_t tempX; const ConstData_t hx; const Data_t dw; - const Data_t workSpace; - const ConstData_t reserveSpace; + const ConstData_t backData; + const ConstData_t forwardData; + const Data_t freeWorkSpace; + const size_t freeWorkSpaceSize; }; runtimeArgsBwWeiDynamicExt createRuntimeArgsExt(const runtimeArgsBWWeights& runtimeArgs) const { - const Data_t temp_x = - moveDataPtr(runtimeArgs.workSpace, workspaceInfo.getBufferSizeImpl(), rnnDesc.dataType); - - return { - runtimeArgs.x, - temp_x, - runtimeArgs.hx, - runtimeArgs.dw, - runtimeArgs.workSpace, - runtimeArgs.reserveSpace, - }; + const Data_t temp_x = runtimeArgs.freeWorkSpace; + const auto temp_x_byte_size = tmpMapXDesc.GetTensorMaxByteSpace(); + + const Data_t free_ws = moveDataPtrByte(temp_x, temp_x_byte_size); + + return {runtimeArgs.x, + temp_x, + runtimeArgs.hx, + runtimeArgs.dw, + runtimeArgs.backData, + runtimeArgs.forwardData, + free_ws, + runtimeArgs.freeWorkSpaceSize - temp_x_byte_size}; } auto getTempBuffersSize() const @@ -424,9 +434,17 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo } } - void realXProp(const Handle& handle, - const runtimeArgsBwWeiDynamicExt& runtimeArgsExt) const + void realXProp(const Handle& handle, const runtimeArgsBwWeiDynamicExt& runtimeArgsExt) const { + const auto normalized_tensor_size = + tmpMapXDesc.GetTensorMaxByteSpace() / GetTypeSize(rnnDesc.dataType); + + const auto normalized_desc = miopen::TensorDescriptor( + rnnDesc.dataType, {1, normalized_tensor_size}, {normalized_tensor_size, 1}); + + const float beta = 0.; + + SetTensor(handle, normalized_desc, runtimeArgsExt.tempX, &beta); RNNTensorBaseLayoutConverter::ConvertInputTensorGPUData(handle, realXDesc, diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index aee295e0ed..d1ed335426 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -324,6 +324,25 @@ class RNNDynamicModularSingleStreamBWWeights return decltype(rnnAlgoModules)::getTempBuffersSize(rnn, xDesc); } + runtimeArgsBWWeights createRuntimeArgsBase(const Handle& handle, + ConstData_t x, + ConstData_t hx, + Data_t dw, + Data_t workSpace, + size_t workSpaceSize, + ConstData_t reserveSpace, + size_t reserveSpaceSize) const + { + const ConstData_t back_data_space = workSpace; + const auto back_data_byte_size = + rnnAlgoModules.workspaceInfo.getBufferSizeImpl() * GetTypeSize(rnnDesc.dataType); + + const Data_t free_ws = moveDataPtrByte(workSpace, back_data_byte_size); + const auto free_ws_size = workSpaceSize - back_data_byte_size; + + return runtimeArgsBWWeights{ + &handle, x, hx, dw, back_data_space, reserveSpace, free_ws, free_ws_size}; + } void Compute(const Handle& handle, ConstData_t x, @@ -334,7 +353,7 @@ class RNNDynamicModularSingleStreamBWWeights ConstData_t reserveSpace, size_t /*reserveSpaceSize*/) const; - const rnn_base::RNNBackwardWeiModuleAlgoDynamic rnnAlgoModules; + const RNNBackwardWeiModuleAlgoDynamic rnnAlgoModules; const RNNDescriptor& rnnDesc; const size_t max_seq_len; }; diff --git a/src/rnn/Solutions/Base/bw_weights_modular.cpp b/src/rnn/Solutions/Base/bw_weights_modular.cpp index 87c44de985..44f88e6350 100644 --- a/src/rnn/Solutions/Base/bw_weights_modular.cpp +++ b/src/rnn/Solutions/Base/bw_weights_modular.cpp @@ -34,11 +34,11 @@ namespace miopen { namespace rnn_base { miopenStatus_t ReducAddBias(const miopen::Handle& handle, Data_t dw, - const Data_t workSpace, + const ConstData_t backDataSpace, const miopen::TensorDescriptor& dw_desc, const miopen::TensorDescriptor& ws_desc, size_t dw_bias_offset, - size_t ws_bias_offset, + size_t back_bias_offset, Data_t red_workSpace, size_t red_workSpace_size_bytes) { @@ -61,12 +61,12 @@ miopenStatus_t ReducAddBias(const miopen::Handle& handle, dw, &alpha1, ws_desc, - workSpace, + backDataSpace, &beta_t, dw_desc, dw, dw_bias_offset, - ws_bias_offset, + back_bias_offset, dw_bias_offset, true); } @@ -82,8 +82,8 @@ miopenStatus_t ReducAddBias(const miopen::Handle& handle, miopenReduceTensorIndices_t::MIOPEN_REDUCE_TENSOR_NO_INDICES, miopenIndicesType_t::MIOPEN_32BIT_INDICES}; - Data_t srcA_with_offset = - static_cast(workSpace) + ws_bias_offset * GetTypeSize(dw_desc.GetType()); + ConstData_t srcA_with_offset = static_cast(backDataSpace) + + back_bias_offset * GetTypeSize(dw_desc.GetType()); Data_t dstC_with_offset = static_cast(dw) + dw_bias_offset * GetTypeSize(dw_desc.GetType()); @@ -120,7 +120,7 @@ miopenStatus_t ReducAddBias(const miopen::Handle& handle, { // nothing to reduce // just copy data from workspace to dw - CopyTensor(handle, ws_desc, workSpace, dw_desc, dw, ws_bias_offset, dw_bias_offset); + CopyTensor(handle, ws_desc, backDataSpace, dw_desc, dw, back_bias_offset, dw_bias_offset); } return miopenStatusSuccess; @@ -143,7 +143,7 @@ void RNNBackwardWeightsModularAlgo::PrepareWriteBuffers(const Handle& handle, Da void RNNBackwardWeightsModularAlgo::PhisXInputWeights(const Handle& handle, Data_t dw, - Data_t workSpace, + ConstData_t workSpace, ConstData_t x) const { const size_t gemm_batch_size = xInfo.getFullSeqMajorSize()[0]; @@ -242,8 +242,12 @@ void RNNBackwardWeightsModularAlgo::HiddenXInputWeights(const Handle& handle, true); } -void RNNBackwardWeightsModularAlgo::BiasUpdate( - const Handle& handle, Data_t dw, Data_t workSpace, size_t layer, size_t workSpaceSize) const +void RNNBackwardWeightsModularAlgo::BiasUpdate(const Handle& handle, + Data_t dw, + ConstData_t backData, + Data_t workSpace, + size_t layer, + size_t workSpaceSize) const { if(rnnDesc.biasMode != 0u) { @@ -253,23 +257,23 @@ void RNNBackwardWeightsModularAlgo::BiasUpdate( const miopen::TensorDescriptor dw_desc = BuildWeiBiasDesc2D(); - size_t main_ws_size = workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType); - - size_t reduction_ws_size = workSpaceSize - main_ws_size; + //size_t main_ws_size = workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType); + // + //size_t reduction_ws_size = workSpaceSize - main_ws_size; - Data_t reduction_workSpace = static_cast(workSpace) + main_ws_size; + //Data_t reduction_workSpace = static_cast(workSpace) + main_ws_size; size_t dw_bias_offset = weightsLayout.getBiasXinOff(layer, static_cast(SequenceDirection::Forward), 0); ReducAddBias(handle, dw, - workSpace, + backData, dw_desc, block_dsc, dw_bias_offset, workspaceInfo.getGateBlockOffset(layer, 0, SequenceDirection::Forward), - reduction_workSpace, - reduction_ws_size); + workSpace, + workSpaceSize); // second dw bias equal to the first, so just copy reduction result size_t dw_bias_2_offset = diff --git a/src/rnn/Solutions/bww_multi_stream.cpp b/src/rnn/Solutions/bww_multi_stream.cpp index 503cda068d..466c290c52 100644 --- a/src/rnn/Solutions/bww_multi_stream.cpp +++ b/src/rnn/Solutions/bww_multi_stream.cpp @@ -51,7 +51,15 @@ void RNNModularMultiStreamBWWeights::Compute(const Handle& handle, if(rnnDesc.nLayers == 0 || max_seq_len == 0) return; - const runtimeArgsBWWeights args{&handle, x, hx, dw, workSpace, reserveSpace}; + const ConstData_t back_data_space = workSpace; + const auto back_data_byte_size = + rnnAlgoModules.workspaceInfo.getBufferSizeImpl() * GetTypeSize(rnnDesc.dataType); + + const Data_t free_ws = moveDataPtrByte(workSpace, back_data_byte_size); + const auto free_ws_size = workSpaceSize - back_data_byte_size; + + const runtimeArgsBWWeights args{ + &handle, x, hx, dw, back_data_space, reserveSpace, free_ws, free_ws_size}; MultiStreamController ms_controller{handle, env::value_or(MIOPEN_RNN_MS_STREAM_CNT, 4)}; @@ -74,7 +82,7 @@ void RNNModularMultiStreamBWWeights::Compute(const Handle& handle, ms_controller.ChangeActiveStream(bias_stream); for(int layer_i = 0; layer_i < rnnDesc.nLayers; layer_i++) - rnnAlgoModules.BiasUpdate(handle, dw, workSpace, layer_i, workSpaceSize); + rnnAlgoModules.BiasUpdate(handle, dw, back_data_space, free_ws, layer_i, free_ws_size); auto sequence_directions = rnnDesc.dirMode == miopenRNNDirectionMode_t::miopenRNNbidirection ? 2 : 1; @@ -85,9 +93,9 @@ void RNNModularMultiStreamBWWeights::Compute(const Handle& handle, ms_controller.ChangeActiveStream(dispatch_stream_id); if(layer_i == 0) - rnnAlgoModules.PhisXInputWeights(handle, dw, workSpace, x); + rnnAlgoModules.PhisXInputWeights(handle, dw, back_data_space, x); else - rnnAlgoModules.HiddenXInputWeights(handle, dw, workSpace, reserveSpace, layer_i); + rnnAlgoModules.HiddenXInputWeights(handle, dw, back_data_space, reserveSpace, layer_i); for(int dir = 0; dir < sequence_directions; dir++) { @@ -95,10 +103,10 @@ void RNNModularMultiStreamBWWeights::Compute(const Handle& handle, : rnn_base::SequenceDirection::Reverse; rnnAlgoModules.PhisHStateWeights( - handle, dw, workSpace, hx, layer_i, max_seq_len, seq_dir); + handle, dw, back_data_space, hx, layer_i, max_seq_len, seq_dir); rnnAlgoModules.HiddenHStateWeights( - handle, dw, workSpace, reserveSpace, layer_i, max_seq_len, seq_dir); + handle, dw, back_data_space, reserveSpace, layer_i, max_seq_len, seq_dir); } } diff --git a/src/rnn/Solutions/bww_s_steam.cpp b/src/rnn/Solutions/bww_s_steam.cpp index 450a1e6873..43a9fc2466 100644 --- a/src/rnn/Solutions/bww_s_steam.cpp +++ b/src/rnn/Solutions/bww_s_steam.cpp @@ -45,16 +45,23 @@ void RNNModularSingleStreamBWWeights::Compute(const Handle& handle, auto sequence_directions = rnnDesc.dirMode == miopenRNNDirectionMode_t::miopenRNNbidirection ? 2 : 1; + const ConstData_t back_data_space = workSpace; + const auto back_data_byte_size = + rnnAlgoModules.workspaceInfo.getBufferSizeImpl() * GetTypeSize(rnnDesc.dataType); + + const Data_t free_ws = moveDataPtrByte(workSpace, back_data_byte_size); + const auto free_ws_size = workSpaceSize - back_data_byte_size; + rnnAlgoModules.PrepareWriteBuffers(handle, dw); for(int layer_i = 0; layer_i < rnnDesc.nLayers; layer_i++) { if(layer_i == 0) - rnnAlgoModules.PhisXInputWeights(handle, dw, workSpace, x); + rnnAlgoModules.PhisXInputWeights(handle, dw, back_data_space, x); else - rnnAlgoModules.HiddenXInputWeights(handle, dw, workSpace, reserveSpace, layer_i); + rnnAlgoModules.HiddenXInputWeights(handle, dw, back_data_space, reserveSpace, layer_i); - rnnAlgoModules.BiasUpdate(handle, dw, workSpace, layer_i, workSpaceSize); + rnnAlgoModules.BiasUpdate(handle, dw, back_data_space, free_ws, layer_i, free_ws_size); for(int dir = 0; dir < sequence_directions; dir++) { @@ -62,10 +69,10 @@ void RNNModularSingleStreamBWWeights::Compute(const Handle& handle, : rnn_base::SequenceDirection::Reverse; rnnAlgoModules.PhisHStateWeights( - handle, dw, workSpace, hx, layer_i, max_seq_len, seq_dir); + handle, dw, back_data_space, hx, layer_i, max_seq_len, seq_dir); rnnAlgoModules.HiddenHStateWeights( - handle, dw, workSpace, reserveSpace, layer_i, max_seq_len, seq_dir); + handle, dw, back_data_space, reserveSpace, layer_i, max_seq_len, seq_dir); } } } @@ -78,10 +85,8 @@ void RNNDynamicModularSingleStreamBWWeights::Compute(const Handle& handle, Data_t workSpace, size_t workSpaceSize, ConstData_t reserveSpace, - size_t /*reserveSpaceSize*/) const + size_t reserveSpaceSize) const { - const auto args_ext = rnnAlgoModules.createRuntimeArgsExt( - runtimeArgsBWWeights{&handle, x, hx, dw, workSpace, reserveSpace}); if(rnnDesc.nLayers == 0 || max_seq_len == 0) return; @@ -89,6 +94,13 @@ void RNNDynamicModularSingleStreamBWWeights::Compute(const Handle& handle, auto sequence_directions = rnnDesc.dirMode == miopenRNNDirectionMode_t::miopenRNNbidirection ? 2 : 1; + auto args_ext = rnnAlgoModules.createRuntimeArgsExt(createRuntimeArgsBase( + handle, x, hx, dw, workSpace, workSpaceSize, reserveSpace, reserveSpaceSize)); + + const auto back_data_space = args_ext.backData; + const auto free_work_space = args_ext.freeWorkSpace; + const auto free_work_space_size = args_ext.freeWorkSpaceSize; + rnnAlgoModules.PrepareWriteBuffers(handle, dw); rnnAlgoModules.realXProp(handle, args_ext); @@ -96,11 +108,12 @@ void RNNDynamicModularSingleStreamBWWeights::Compute(const Handle& handle, for(int layer_i = 0; layer_i < rnnDesc.nLayers; layer_i++) { if(layer_i == 0) - rnnAlgoModules.PhisXInputWeights(handle, dw, workSpace, args_ext.tempX); + rnnAlgoModules.PhisXInputWeights(handle, dw, back_data_space, args_ext.tempX); else - rnnAlgoModules.HiddenXInputWeights(handle, dw, workSpace, reserveSpace, layer_i); + rnnAlgoModules.HiddenXInputWeights(handle, dw, back_data_space, reserveSpace, layer_i); - rnnAlgoModules.BiasUpdate(handle, dw, workSpace, layer_i, workSpaceSize); + rnnAlgoModules.BiasUpdate( + handle, dw, back_data_space, free_work_space, layer_i, free_work_space_size); for(int dir = 0; dir < sequence_directions; dir++) { @@ -108,10 +121,10 @@ void RNNDynamicModularSingleStreamBWWeights::Compute(const Handle& handle, : rnn_base::SequenceDirection::Reverse; rnnAlgoModules.PhisHStateWeights( - handle, dw, workSpace, hx, layer_i, max_seq_len, seq_dir); + handle, dw, back_data_space, hx, layer_i, max_seq_len, seq_dir); rnnAlgoModules.HiddenHStateWeights( - handle, dw, workSpace, reserveSpace, layer_i, max_seq_len, seq_dir); + handle, dw, back_data_space, reserveSpace, layer_i, max_seq_len, seq_dir); } } } From 779b7aed5c9a47d1ac18d746f8a36c4014787c53 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Fri, 15 Nov 2024 18:39:09 +0100 Subject: [PATCH 11/21] clang format --- .../rnn/algorithms/default_algo_utils.hpp | 6 ++-- .../rnn/algorithms/dynamic_algo_utils.hpp | 33 +++++++++---------- src/include/miopen/rnn/solvers.hpp | 29 ++++++++-------- src/include/miopen/rnn/tmp_buffer_utils.hpp | 24 +++++++------- src/rnn/Solutions/Base/bw_data_modular.cpp | 2 +- src/rnn/Solutions/Base/bw_weights_modular.cpp | 18 +++++----- src/rnn/Solutions/bww_s_steam.cpp | 5 ++- 7 files changed, 54 insertions(+), 63 deletions(-) diff --git a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp index 0e86149d50..d1af4f7097 100644 --- a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp @@ -71,7 +71,6 @@ struct runtimeArgsBWWeights const size_t freeWorkSpaceSize; }; - class RNNModuleAlgoBase { protected: @@ -259,7 +258,8 @@ class RNNModuleAlgoBase } template - inline miopen::TensorDescriptor BuildTmpHtDesc2D(const BufType& tmpSpace, size_t batch_size) const + inline miopen::TensorDescriptor BuildTmpHtDesc2D(const BufType& tmpSpace, + size_t batch_size) const { auto& ht_stride = tmpSpace.getHiddenStateStride(); auto& ht_size = tmpSpace.hStateSizes; @@ -304,7 +304,7 @@ class RNNModuleAlgoBase return miopen::TensorDescriptor{rnnDesc.dataType, dy_dhy_accum_size, ws_dy_stride}; } - // 3 dims layer, batch, vec + // 3 dims layer, batch, vec inline miopen::TensorDescriptor BuildWeiBiasDesc2D() const { const std::vector bias_size = [](const auto& wei_4dim_size) -> std::vector { diff --git a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp index 4d76eaf3f3..ddae328b6e 100644 --- a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp @@ -44,7 +44,7 @@ inline SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& d std::vector def_layout{1, 0, 2}; auto zero_val_padding = [](miopenDataType_t type) { - std ::vector padding_fill(GetTypeSize(type),0); + std ::vector padding_fill(GetTypeSize(type), 0); return padding_fill; }; @@ -313,7 +313,6 @@ class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo SeqTensorDescriptor tmpMapDyDesc; }; - class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo { using BaseBWDModuleT = rnn_base::RNNBackwardWeightsModularAlgo; @@ -338,10 +337,10 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo public: RNNBackwardWeiModuleAlgoDynamic(const RNNDescriptor& rnnD, - const SeqTensorDescriptor& xTDesc, - const SeqTensorDescriptor& yTDesc, - const TensorDescriptor& hDesc, - miopenRNNFWDMode_t mode) + const SeqTensorDescriptor& xTDesc, + const SeqTensorDescriptor& yTDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) : BaseBWDModuleT(RNNModuleAlgoBase::create( rnnD, buildDynamicVirtual(xTDesc), buildDynamicVirtual(yTDesc), hDesc, mode)), realBatchController(BatchController::Create(xTDesc)), @@ -365,7 +364,7 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo runtimeArgsBwWeiDynamicExt createRuntimeArgsExt(const runtimeArgsBWWeights& runtimeArgs) const { - const Data_t temp_x = runtimeArgs.freeWorkSpace; + const Data_t temp_x = runtimeArgs.freeWorkSpace; const auto temp_x_byte_size = tmpMapXDesc.GetTensorMaxByteSpace(); const Data_t free_ws = moveDataPtrByte(temp_x, temp_x_byte_size); @@ -384,8 +383,7 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo { auto [ws_size, reserve_size] = BaseBWDModuleT::getTempBuffersSize(); - return std::make_tuple(ws_size + tmpMapXDesc.GetTensorMaxByteSpace() + - reserve_size); + return std::make_tuple(ws_size + tmpMapXDesc.GetTensorMaxByteSpace() + reserve_size); } static auto getTempBuffersSize(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xDesc) @@ -408,12 +406,12 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo } void PhisHStateWeights(const Handle& handle, - Data_t dw, - ConstData_t workSpace, - ConstData_t hx, - const SequenceIterator& seq, - size_t layer, - SequenceDirection direction) const; + Data_t dw, + ConstData_t workSpace, + ConstData_t hx, + const SequenceIterator& seq, + size_t layer, + SequenceDirection direction) const; void PhisHStateWeights(const Handle& handle, Data_t dw, @@ -425,11 +423,11 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo { if(hx == nullptr) return; - + for(auto i = max_seq_len; i > 0; i--) { const auto seq = SequenceIterator(i - 1, direction, max_seq_len, false); - + PhisHStateWeights(handle, dw, workSpace, hx, seq, layer, direction); } } @@ -462,6 +460,5 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo SeqTensorDescriptor tmpMapXDesc; }; - } // namespace rnn_base } // namespace miopen diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index d1ed335426..f8ae6be693 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -33,7 +33,6 @@ #include "miopen/rnn/algorithms/default_algo_utils.hpp" #include "miopen/rnn/algorithms/dynamic_algo_utils.hpp" - namespace miopen { namespace rnn_base { @@ -251,7 +250,7 @@ class RNNModularMultiStreamBWD }; // -//Backward Weights +// Backward Weights // class RNNModularSingleStreamBWWeights @@ -298,10 +297,10 @@ class RNNDynamicModularSingleStreamBWWeights private: public: RNNDynamicModularSingleStreamBWWeights(const RNNDescriptor& rnn, - const SeqTensorDescriptor& xDesc, - const SeqTensorDescriptor& yDesc, - const TensorDescriptor& hDesc, - miopenRNNFWDMode_t mode) + const SeqTensorDescriptor& xDesc, + const SeqTensorDescriptor& yDesc, + const TensorDescriptor& hDesc, + miopenRNNFWDMode_t mode) : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), rnnDesc(rnn), max_seq_len(xDesc.GetMaxSequenceLength()) @@ -324,14 +323,14 @@ class RNNDynamicModularSingleStreamBWWeights return decltype(rnnAlgoModules)::getTempBuffersSize(rnn, xDesc); } - runtimeArgsBWWeights createRuntimeArgsBase(const Handle& handle, - ConstData_t x, - ConstData_t hx, - Data_t dw, - Data_t workSpace, - size_t workSpaceSize, - ConstData_t reserveSpace, - size_t reserveSpaceSize) const + runtimeArgsBWWeights createRuntimeArgsBase(const Handle& handle, + ConstData_t x, + ConstData_t hx, + Data_t dw, + Data_t workSpace, + size_t workSpaceSize, + ConstData_t reserveSpace, + size_t reserveSpaceSize) const { const ConstData_t back_data_space = workSpace; const auto back_data_byte_size = @@ -383,8 +382,6 @@ class RNNModularMultiStreamBWWeights // TODO static size_t GetWsSize() { return 0; }; - - void Compute(const Handle& handle, ConstData_t x, ConstData_t hx, diff --git a/src/include/miopen/rnn/tmp_buffer_utils.hpp b/src/include/miopen/rnn/tmp_buffer_utils.hpp index 3c54ab8cb8..5bd6883170 100644 --- a/src/include/miopen/rnn/tmp_buffer_utils.hpp +++ b/src/include/miopen/rnn/tmp_buffer_utils.hpp @@ -193,8 +193,8 @@ class GeneralRNNTempBufferTemplate : public BaseRnnWsBufferPacked { protected: GeneralRNNTempBufferTemplate(const std::array& hstate_strides, - const std::array& hstate_sizes, - size_t total_element_cnt) + const std::array& hstate_sizes, + size_t total_element_cnt) : hStateStrides{hstate_strides}, hStateSizes{hstate_sizes}, totalElementCnt{total_element_cnt} @@ -261,20 +261,20 @@ class GeneralRNNTempBufferTemplate : public BaseRnnWsBufferPacked *} */ - template class GeneralLstmInternalBuffTemplate : public GeneralRNNTempBufferTemplate, public GeneralLstmWsExt, public LstmWsGateBlockExt { using RNNBufferTemplate = GeneralRNNTempBufferTemplate; + protected: GeneralLstmInternalBuffTemplate(const std::array& h_state_strides, - const std::array& h_state_sizes, - const std::array& lstm_gate_sizes, - const std::array& lstm_gate_strides, - const std::array& lstm_gates_block_sizes, - size_t total_element_cnt) + const std::array& h_state_sizes, + const std::array& lstm_gate_sizes, + const std::array& lstm_gate_strides, + const std::array& lstm_gates_block_sizes, + size_t total_element_cnt) : RNNBufferTemplate{h_state_strides, h_state_sizes, total_element_cnt}, gateSizes{lstm_gate_sizes}, gateStride{lstm_gate_strides}, @@ -376,8 +376,7 @@ class GeneralLstmInternalBuffTemplate : public GeneralRNNTempBufferTemplate pos{layer_id, vector_id, static_cast(direction)}; - return start_ident + - std::inner_product(pos.cbegin(), + return start_ident + std::inner_product(pos.cbegin(), pos.cend(), RNNBufferTemplate::hStateStrides.cbegin(), static_cast(0)); @@ -415,12 +414,11 @@ class GeneralLstmInternalBuffTemplate : public GeneralRNNTempBufferTemplate { public: - GeneralLstmTempBuffer(const GeneralLstmInternalBuffTemplate base ) + GeneralLstmTempBuffer(const GeneralLstmInternalBuffTemplate base) : GeneralLstmInternalBuffTemplate{base} { } diff --git a/src/rnn/Solutions/Base/bw_data_modular.cpp b/src/rnn/Solutions/Base/bw_data_modular.cpp index b9042ad13c..99cc213e50 100644 --- a/src/rnn/Solutions/Base/bw_data_modular.cpp +++ b/src/rnn/Solutions/Base/bw_data_modular.cpp @@ -145,7 +145,7 @@ void RNNBackwardDataModularAlgo::PropHiddenDht(const Handle& handle, const miopen::TensorDescriptor& filter_src_dsc = BuildLstmFilterHidDesc2D(); - const miopen::TensorDescriptor& ht_dest_dsc = BuildTmpHtDesc2D(workspaceInfo ,gemm_batch_size); + const miopen::TensorDescriptor& ht_dest_dsc = BuildTmpHtDesc2D(workspaceInfo, gemm_batch_size); RnnBaseFunctions::BWD_GEMM_Hidden_Prop( handle, diff --git a/src/rnn/Solutions/Base/bw_weights_modular.cpp b/src/rnn/Solutions/Base/bw_weights_modular.cpp index 44f88e6350..32ed4b6923 100644 --- a/src/rnn/Solutions/Base/bw_weights_modular.cpp +++ b/src/rnn/Solutions/Base/bw_weights_modular.cpp @@ -257,11 +257,11 @@ void RNNBackwardWeightsModularAlgo::BiasUpdate(const Handle& handle, const miopen::TensorDescriptor dw_desc = BuildWeiBiasDesc2D(); - //size_t main_ws_size = workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType); + // size_t main_ws_size = workspaceInfo.getBufferSize() * GetTypeSize(rnnDesc.dataType); // - //size_t reduction_ws_size = workSpaceSize - main_ws_size; + // size_t reduction_ws_size = workSpaceSize - main_ws_size; - //Data_t reduction_workSpace = static_cast(workSpace) + main_ws_size; + // Data_t reduction_workSpace = static_cast(workSpace) + main_ws_size; size_t dw_bias_offset = weightsLayout.getBiasXinOff(layer, static_cast(SequenceDirection::Forward), 0); @@ -359,12 +359,12 @@ void RNNBackwardWeightsModularAlgo::PhisHStateWeights(const Handle& handle, } void RNNBackwardWeiModuleAlgoDynamic::PhisHStateWeights(const Handle& handle, - Data_t dw, - ConstData_t workSpace, - ConstData_t hx, - const SequenceIterator& seq, - size_t layer, - SequenceDirection direction) const + Data_t dw, + ConstData_t workSpace, + ConstData_t hx, + const SequenceIterator& seq, + size_t layer, + SequenceDirection direction) const { const size_t gemm_batch_size = getHxBatchSizeReadAtTime(seq, realBatchController, direction); diff --git a/src/rnn/Solutions/bww_s_steam.cpp b/src/rnn/Solutions/bww_s_steam.cpp index 43a9fc2466..d513b150b0 100644 --- a/src/rnn/Solutions/bww_s_steam.cpp +++ b/src/rnn/Solutions/bww_s_steam.cpp @@ -77,7 +77,6 @@ void RNNModularSingleStreamBWWeights::Compute(const Handle& handle, } } - void RNNDynamicModularSingleStreamBWWeights::Compute(const Handle& handle, ConstData_t x, ConstData_t hx, @@ -97,8 +96,8 @@ void RNNDynamicModularSingleStreamBWWeights::Compute(const Handle& handle, auto args_ext = rnnAlgoModules.createRuntimeArgsExt(createRuntimeArgsBase( handle, x, hx, dw, workSpace, workSpaceSize, reserveSpace, reserveSpaceSize)); - const auto back_data_space = args_ext.backData; - const auto free_work_space = args_ext.freeWorkSpace; + const auto back_data_space = args_ext.backData; + const auto free_work_space = args_ext.freeWorkSpace; const auto free_work_space_size = args_ext.freeWorkSpaceSize; rnnAlgoModules.PrepareWriteBuffers(handle, dw); From f488519bed255f9798a032ea994a025787e1963f Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Tue, 26 Nov 2024 15:44:47 +0100 Subject: [PATCH 12/21] tidy --- .../rnn/algorithms/default_algo_utils.hpp | 40 ++++++++++--------- src/include/miopen/rnn/solvers.hpp | 2 +- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp index d1af4f7097..f5574fa1a1 100644 --- a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp @@ -139,7 +139,7 @@ class RNNModuleAlgoBase IOBufferDescriptor x_info{IOBufferDescriptor::build(xDesc)}; IOBufferDescriptor y_info{IOBufferDescriptor::build(yDesc)}; - return {std::move(rb_layout), + return {rb_layout, workspace_info, weights_layout, hidden_hxcx_info, @@ -150,29 +150,30 @@ class RNNModuleAlgoBase mode}; } - RNNModuleAlgoBase(RNNModuleAlgoBase&&) = default; + RNNModuleAlgoBase(RNNModuleAlgoBase&&) = default; + RNNModuleAlgoBase(const RNNModuleAlgoBase&) = default; // RNNModuleAlgoBase(RNNModuleAlgoBase const&) = default; - RNNModuleAlgoBase(GeneralLstmRedBuffer rb_layout, - GeneralLstmTempBuffer workspace_info, - WeightsBufferDescriptor weights_layout, - HiddenBuffersDescriptor hidden_hxcx_info, - IOBufferDescriptor x_info, - IOBufferDescriptor y_info, + RNNModuleAlgoBase(const GeneralLstmRedBuffer& rb_layout, + const GeneralLstmTempBuffer& workspace_info, + const WeightsBufferDescriptor& weights_layout, + const HiddenBuffersDescriptor& hidden_hxcx_info, + const IOBufferDescriptor& x_info, + const IOBufferDescriptor& y_info, const RNNDescriptor& rnn_desc, - BatchController batch_controller, + const BatchController& batch_controller, miopenRNNFWDMode_t fwd_mode) - : reservLayout(std::move(rb_layout)), - workspaceInfo(std::move(workspace_info)), - weightsLayout(std::move(weights_layout)), - hiddenHxCxInfo(std::move(hidden_hxcx_info)), - xInfo(std::move(x_info)), - yInfo(std::move(y_info)), + : reservLayout(rb_layout), + workspaceInfo(workspace_info), + weightsLayout(weights_layout), + hiddenHxCxInfo(hidden_hxcx_info), + xInfo(x_info), + yInfo(y_info), rnnDesc(rnn_desc), tanhDesc{miopenActivationTANH, 1, 1, 1}, sigDesc{miopenActivationLOGISTIC, 1, 0, 1}, reluDesc{miopenActivationRELU, 1, 0, 1}, - batchController(std::move(batch_controller)), + batchController((batch_controller)), fwdMode(fwd_mode), isBidirectSeq(false) { @@ -426,7 +427,8 @@ class RNNForwardDataModularAlgo : protected RNNModuleAlgoBase reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); } - RNNForwardDataModularAlgo(RNNModuleAlgoBase base) : RNNModuleAlgoBase(std::move(base)) {} + RNNForwardDataModularAlgo(RNNModuleAlgoBase&& base) : RNNModuleAlgoBase(std::move(base)) {} + RNNForwardDataModularAlgo(const RNNModuleAlgoBase& base) : RNNModuleAlgoBase(base) {} private: }; @@ -541,6 +543,7 @@ class RNNBackwardDataModularAlgo : protected RNNModuleAlgoBase } RNNBackwardDataModularAlgo(RNNModuleAlgoBase&& base) : RNNModuleAlgoBase(std::move(base)) {} + RNNBackwardDataModularAlgo(const RNNModuleAlgoBase& base) : RNNModuleAlgoBase(base) {} }; class RNNBackwardWeightsModularAlgo : public RNNModuleAlgoBase @@ -674,7 +677,8 @@ class RNNBackwardWeightsModularAlgo : public RNNModuleAlgoBase reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); } - RNNBackwardWeightsModularAlgo(RNNModuleAlgoBase base) : RNNModuleAlgoBase(std::move(base)) {} + RNNBackwardWeightsModularAlgo(RNNModuleAlgoBase&& base) : RNNModuleAlgoBase(std::move(base)) {} + RNNBackwardWeightsModularAlgo(const RNNModuleAlgoBase& base) : RNNModuleAlgoBase(base) {} protected: void HiddenHStateWeights_Unchecked(const Handle& handle, diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index f8ae6be693..353760c8af 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -330,7 +330,7 @@ class RNNDynamicModularSingleStreamBWWeights Data_t workSpace, size_t workSpaceSize, ConstData_t reserveSpace, - size_t reserveSpaceSize) const + size_t /*reserveSpaceSize*/) const { const ConstData_t back_data_space = workSpace; const auto back_data_byte_size = From 3ccae40c69dc2fc9ed98b8e36ce05ea611b33d50 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Wed, 27 Nov 2024 22:57:22 +0100 Subject: [PATCH 13/21] add new type in miopenRNNAlgo_t --- include/miopen/miopen.h | 4 +- .../rnn/algorithms/dynamic_algo_utils.hpp | 35 +++++++------ src/ocl/rnnocl.cpp | 34 +++++++++++-- src/rnn.cpp | 2 +- src/rnn/Solutions/Base/bw_data_modular.cpp | 20 ++------ src/rnn/selector.cpp | 50 +++++++++++-------- test/rnn_seq_api.cpp | 15 +++++- 7 files changed, 96 insertions(+), 64 deletions(-) diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 67652ab832..31fb12a0d3 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -3751,7 +3751,9 @@ typedef enum miopenRNNdefault = 0, /*!< Use dedicated gate-operation kernel for LSTM and fundamental algorithm for vanilla RNN & GRU */ miopenRNNfundamental = - 1, /*!< Function by basic tesnsor operations, supported for vanilla RNN, LSTM, GRU */ + 1, /*!< Deprecated, low performance. Function by basic tesnsor operations, supported for vanilla RNN, LSTM, GRU */ + miopenRNNroundedDynamic = 2, /*!< The algorithm rounds some RNN parametrs upwards + to utilize the most optimal GEMM kernel in the computation.*/ } miopenRNNAlgo_t; /*! @enum miopenRNNDirectionMode_t diff --git a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp index ddae328b6e..b57877a2a0 100644 --- a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp @@ -33,10 +33,26 @@ namespace miopen { namespace rnn_base { +inline std::vector roundedDynamicLengths(const SeqTensorDescriptor& desc) +{ + auto src_lens = desc.GetLengths(); + src_lens[1] = [](size_t v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; + }(src_lens[1]); + return src_lens; +} + inline SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) { std::vector def_layout{1, 0, 2}; - return {desc.GetType(), def_layout, desc.GetLengths(), false}; + return {desc.GetType(), def_layout, roundedDynamicLengths(desc), false}; } inline SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) @@ -167,23 +183,6 @@ class RNNModuleAlgoDynamic : public RNNForwardDataModularAlgo class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo { using BaseBWDModuleT = rnn_base::RNNBackwardDataModularAlgo; - static SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) - { - std::vector def_layout{1, 0, 2}; - return {desc.GetType(), def_layout, desc.GetLengths(), false}; - } - - static SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) - { - std::vector def_layout{1, 0, 2}; - return {desc.GetType(), - def_layout, - desc.GetLengths(), - desc.GetSequenceLengthsVector(), - std::vector{}, - true, - true}; - } public: RNNBackwardModuleAlgoDynamic(const RNNDescriptor& rnnD, diff --git a/src/ocl/rnnocl.cpp b/src/ocl/rnnocl.cpp index cee195819b..22c05811c3 100644 --- a/src/ocl/rnnocl.cpp +++ b/src/ocl/rnnocl.cpp @@ -2789,7 +2789,10 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( // input check end bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); - if(RNNForwardMSIsSupported(*this, false) && RNNForwardMSIsFast(seqLen)) + + // high priority for DynamicAlgo + if(!CheckDynamicAlgoSelection(handle, {}, miopenRNNTraining) && + RNNForwardMSIsSupported(*this, false) && RNNForwardMSIsFast(seqLen)) { return RNNForwardMS(handle, in_n, @@ -2809,7 +2812,7 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( miopenRNNFWDMode_t::miopenRNNTraining); } else if(dirMode == 0 && inputMode == miopenRNNlinear && rnnMode == miopenLSTM && !use_dropout && - algoMode == miopenRNNdefault) + (algoMode == miopenRNNdefault || algoMode == miopenRNNroundedDynamic)) { SeqTensorDescriptor x_seq = makeSeqTensorDescriptor(xDesc, seqLen, miopenRNNDataSeqMajorNotPadded); @@ -2836,6 +2839,13 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( reserveSpaceSize); } + if(algoMode == miopenRNNroundedDynamic) + { + MIOPEN_THROW(miopenStatusBadParm, + "This configuration is not supported with algoMode=miopenRNNroundedDynamic, " + "use miopenRNNdefault "); + } + int in_stride = xDesc[0].GetLengths()[1]; int hy_stride = hy_h * bi * static_cast(workspaceScale); int out_stride = out_h; @@ -4301,7 +4311,7 @@ void RNNDescriptor::RNNBackwardDataPackedTensors( bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); if(dirMode == 0 && inputMode == miopenRNNlinear && rnnMode == miopenLSTM && !use_dropout && - algoMode == miopenRNNdefault) + (algoMode == miopenRNNdefault || algoMode == miopenRNNroundedDynamic)) { SeqTensorDescriptor dx_seq = makeSeqTensorDescriptor(dxDesc, seqLen, miopenRNNDataSeqMajorNotPadded); @@ -4329,6 +4339,13 @@ void RNNDescriptor::RNNBackwardDataPackedTensors( reserveSpaceSize); } + if(algoMode == miopenRNNroundedDynamic) + { + MIOPEN_THROW(miopenStatusBadParm, + "This configuration is not supported with algoMode=miopenRNNroundedDynamic, " + "use miopenRNNdefault "); + } + int in_stride = in_h; int hy_stride = hy_h * bi * static_cast(workspaceScale); int out_stride = out_h; @@ -5992,8 +6009,8 @@ void RNNDescriptor::RNNBackwardWeightsPackedTensors( in_h = 0; } - if(dirMode == 0 && rnnMode == miopenLSTM && !use_dropout && algoMode == miopenRNNdefault && - !env::disabled(MIOPEN_RNNWRW_EXP)) + if(dirMode == 0 && rnnMode == miopenLSTM && !use_dropout && inputMode == miopenRNNlinear && + (algoMode == miopenRNNdefault || algoMode == miopenRNNroundedDynamic)) { SeqTensorDescriptor x_seq = makeSeqTensorDescriptor(xDesc, seqLen, miopenRNNDataSeqMajorNotPadded); @@ -6014,6 +6031,13 @@ void RNNDescriptor::RNNBackwardWeightsPackedTensors( reserveSpaceSize); } + if(algoMode == miopenRNNroundedDynamic) + { + MIOPEN_THROW(miopenStatusBadParm, + "This configuration is not supported with algoMode=miopenRNNroundedDynamic, " + "use miopenRNNdefault "); + } + size_t wei_shift_bias = (in_h + hy_h + (bi * hy_h + hy_h) * (nLayers - 1)) * wei_stride; float alpha0, alpha1, beta_t = 0; diff --git a/src/rnn.cpp b/src/rnn.cpp index 1d3f95eeac..8096fc97d5 100644 --- a/src/rnn.cpp +++ b/src/rnn.cpp @@ -572,7 +572,7 @@ size_t RNNDescriptor::GetWorkspaceSize(Handle& handle, size_t RNNDescriptor::GetReserveSize(size_t batchLenSum) const { auto x = 2 * workspaceScale * nLayers * batchLenSum * hsize * typeSize; - if(algoMode == miopenRNNdefault && rnnMode == miopenLSTM) + if( (algoMode == miopenRNNdefault || algoMode == miopenRNNroundedDynamic) && rnnMode == miopenLSTM) { x /= 2; x += nLayers * batchLenSum * hsize * typeSize; diff --git a/src/rnn/Solutions/Base/bw_data_modular.cpp b/src/rnn/Solutions/Base/bw_data_modular.cpp index 99cc213e50..ecbf60d333 100644 --- a/src/rnn/Solutions/Base/bw_data_modular.cpp +++ b/src/rnn/Solutions/Base/bw_data_modular.cpp @@ -275,15 +275,10 @@ void RNNBackwardDataModularAlgo::PropDhxDcx(const Handle& handle, const float alpha1 = 1; const float beta_t = 1; - const auto bOffset = rnnDesc.algoMode == miopenRNNdefault - ? reservLayout.getGasOffset(layer, + const auto bOffset = reservLayout.getGasOffset(layer, acc_batch_offset, direction, - LstmGateAndState::F) - : reservLayout.getActiveCellOffset( // TODO double check - layer, - acc_batch_offset, - direction); + LstmGateAndState::F); const auto a_offset = workspaceInfo.getGasOffset( layer, acc_batch_offset, direction, LstmGateAndState::St); @@ -601,13 +596,11 @@ void RNNBackwardDataModularAlgo::UpdateHStatePerTimeSeq(const Handle& handle, const size_t hidden_vec = rnnDesc.hsize; auto rnn_data_type = rnnDesc.dataType; auto rnn_mode = rnnDesc.rnnMode; - auto rnn_algo_mode = rnnDesc.algoMode; if(rnn_mode == miopenRNNRELU || rnn_mode == miopenRNNTANH) {} else if(rnn_mode == miopenLSTM) { - if(rnn_algo_mode == miopenRNNdefault) - { + size_t cur_comb_dim = batchController.getBatchSum(seq.getPhisVal()); size_t prev_comb_dim = !seq.isFirst() ? batchController.getBatchSum(seq.getPrev().getPhisVal()) @@ -655,13 +648,6 @@ void RNNBackwardDataModularAlgo::UpdateHStatePerTimeSeq(const Handle& handle, workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::Ht), workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::F)); } - else - { - MIOPEN_THROW(miopenStatusInternalError, - "TODO implementation algoMode != miopenRNNdefault"); - // TODO implementation - } - } else if(rnn_mode == miopenGRU) { MIOPEN_THROW(miopenStatusInternalError, "TODO implementation miopenGRU"); diff --git a/src/rnn/selector.cpp b/src/rnn/selector.cpp index 34f45712b2..625730d59c 100644 --- a/src/rnn/selector.cpp +++ b/src/rnn/selector.cpp @@ -43,7 +43,7 @@ MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNNBWDMS_EXP) MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNNBWMS_EXP) -MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNN_DYNAMIC_EXP) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNN_DYNAMIC_FORCE) namespace miopen { @@ -67,6 +67,7 @@ bool RNNBwWeightMSIsFast(const int seqLen) return false; } + std::tuple RNNDescriptor::GetTmpSpaceSizeDynamicAlgo( Handle& /*handle*/, const SeqTensorDescriptor& xDesc, miopenRNNFWDMode_t /*fwdMode*/) const { @@ -75,12 +76,19 @@ std::tuple RNNDescriptor::GetTmpSpaceSizeDynamicAlgo( bool RNNDescriptor::CheckDynamicAlgoSelection(Handle& /*handle*/, const SeqTensorDescriptor& /*xDesc*/, - miopenRNNFWDMode_t /*fwdMode*/) const + miopenRNNFWDMode_t fwdMode) const { + if(fwdMode == miopenRNNInference) + return false; + + bool algo_mode_match = algoMode == miopenRNNroundedDynamic || + (algoMode == miopenRNNdefault && env::enabled(MIOPEN_RNN_DYNAMIC_FORCE)); + bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); bool rnn_config_match = (dirMode == 0 && inputMode == miopenRNNlinear && - rnnMode == miopenLSTM && !use_dropout && algoMode == miopenRNNdefault); - if(rnn_config_match && env::enabled(MIOPEN_RNN_DYNAMIC_EXP)) + rnnMode == miopenLSTM && !use_dropout); + + if(rnn_config_match && algo_mode_match) { return true; } @@ -139,22 +147,24 @@ void RNNDescriptor::ModularBackward(Handle& handle, Data_t reserveSpace, size_t /*reserveSpaceSize*/) const { - if(RNNBwdMSIsFast(xDesc.GetMaxSequenceLength())) + + if(CheckDynamicAlgoSelection(handle, xDesc, miopenRNNFWDMode_t::miopenRNNTraining)) { - rnn_base::RNNModularMultiStreamBWD multi_stream{ + rnn_base::RNNDynamicModularSingleStreamBWD single_stream{ *this, xDesc, yDesc, hDesc, miopenRNNFWDMode_t::miopenRNNTraining}; - multi_stream.ComputeBWD(handle, dy, dhy, dhx, cx, dcy, dcx, dx, w, workSpace, reserveSpace); + single_stream.ComputeBWD( + handle, + rnn_base::runtimeArgsBwd{ + &handle, dy, dhy, dhx, cx, dcy, dcx, dx, w, workSpace, reserveSpace}); } else { - if(CheckDynamicAlgoSelection(handle, xDesc, miopenRNNFWDMode_t::miopenRNNTraining)) + if(RNNBwdMSIsFast(xDesc.GetMaxSequenceLength())) { - rnn_base::RNNDynamicModularSingleStreamBWD single_stream{ + rnn_base::RNNModularMultiStreamBWD multi_stream{ *this, xDesc, yDesc, hDesc, miopenRNNFWDMode_t::miopenRNNTraining}; - single_stream.ComputeBWD( - handle, - rnn_base::runtimeArgsBwd{ - &handle, dy, dhy, dhx, cx, dcy, dcx, dx, w, workSpace, reserveSpace}); + multi_stream.ComputeBWD( + handle, dy, dhy, dhx, cx, dcy, dcx, dx, w, workSpace, reserveSpace); } else { @@ -178,19 +188,19 @@ void RNNDescriptor::ModularBackwardWeights(Handle& handle, ConstData_t reserveSpace, size_t reserveSpaceSize) const { - if(RNNBwWeightMSIsFast(xDesc.GetMaxSequenceLength())) + if(CheckDynamicAlgoSelection(handle, xDesc, miopenRNNFWDMode_t::miopenRNNTraining)) { - rnn_base::RNNModularMultiStreamBWWeights multi_stream{*this, xDesc, yDesc, hDesc}; - multi_stream.Compute( + rnn_base::RNNDynamicModularSingleStreamBWWeights single_stream{ + *this, xDesc, yDesc, hDesc, miopenRNNFWDMode_t::miopenRNNTraining}; + single_stream.Compute( handle, x, hx, dw, workSpace, workSpaceSize, reserveSpace, reserveSpaceSize); } else { - if(CheckDynamicAlgoSelection(handle, xDesc, miopenRNNFWDMode_t::miopenRNNTraining)) + if(RNNBwWeightMSIsFast(xDesc.GetMaxSequenceLength())) { - rnn_base::RNNDynamicModularSingleStreamBWWeights single_stream{ - *this, xDesc, yDesc, hDesc, miopenRNNFWDMode_t::miopenRNNTraining}; - single_stream.Compute( + rnn_base::RNNModularMultiStreamBWWeights multi_stream{*this, xDesc, yDesc, hDesc}; + multi_stream.Compute( handle, x, hx, dw, workSpace, workSpaceSize, reserveSpace, reserveSpaceSize); } else diff --git a/test/rnn_seq_api.cpp b/test/rnn_seq_api.cpp index 7fec633f48..91c3e17ae2 100644 --- a/test/rnn_seq_api.cpp +++ b/test/rnn_seq_api.cpp @@ -41,7 +41,7 @@ struct rnn_seq_driver : rnn_seq_api_test_driver this->add(this->biasMode, "bias-mode", this->generate_data({1})); this->add(this->dirMode, "dir-mode", this->generate_data(modes)); this->add(this->rnnMode, "rnn-mode", this->generate_data({2, 1, 3}, 2)); - this->add(this->algoMode, "algo-mode", this->generate_data({0})); + this->add(this->algoMode, "algo-mode", this->generate_data({0, 2})); this->add(this->numLayers, "num-layers", this->generate_data({1, 3}, 3)); this->add(this->io_layout, "io_layout", this->generate_data({2, 1, 3}, 3)); this->add(this->batchSize, "batch-size", this->generate_data({1, 4, 6}, 6)); @@ -119,10 +119,21 @@ struct rnn_seq_driver : rnn_seq_api_test_driver return true; } + bool is_dynamic_algo_skip_case() + { + if(this->algoMode != 2) + return false; + + if(this->dirMode == 0 && this->rnnMode == miopenLSTM && this->useDropout == 0 && + this->inputMode == miopenRNNlinear) + return false; + return true; + } + void run() { - if(!this->full_set || (is_correct_params() && !is_skip_comb())) + if(!this->full_set || (is_correct_params() && !is_skip_comb() && !is_dynamic_algo_skip_case())) rnn_seq_api_test_driver::run(); else { From 1758984f29581b0ed86374eafb9fb631a580eeaa Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Wed, 27 Nov 2024 23:00:50 +0100 Subject: [PATCH 14/21] bug fix --- .../rnn/algorithms/default_algo_utils.hpp | 9 + .../rnn/algorithms/dynamic_algo_utils.hpp | 41 ++--- src/include/miopen/rnn/solvers.hpp | 9 +- src/include/miopen/rnn/tmp_buffer_utils.hpp | 7 +- src/rnn/Solutions/Base/bw_data_modular.cpp | 170 ++++++------------ src/rnn/Solutions/Base/fw_data_modular.cpp | 10 +- src/rnn/Solutions/bwd_s_stream.cpp | 37 ++-- src/rnn/Solutions/fwd_s_stream.cpp | 18 +- 8 files changed, 132 insertions(+), 169 deletions(-) diff --git a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp index f5574fa1a1..7fddc0a327 100644 --- a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp @@ -224,6 +224,11 @@ class RNNModuleAlgoBase (direction == SequenceDirection::Forward ? 0 : 1); } + inline size_t getTimeSeqSize() const + { + return batchController.size(); + } + template inline miopen::TensorDescriptor BuildLstmTmpBlockDesc2D(const BufType& buf_info, const size_t batch_size) const @@ -427,6 +432,8 @@ class RNNForwardDataModularAlgo : protected RNNModuleAlgoBase reservInfo.getBufferSize() * GetTypeSize(rnnD.dataType)); } + inline size_t getTimeSeqSize() const { return RNNModuleAlgoBase::getTimeSeqSize(); } + RNNForwardDataModularAlgo(RNNModuleAlgoBase&& base) : RNNModuleAlgoBase(std::move(base)) {} RNNForwardDataModularAlgo(const RNNModuleAlgoBase& base) : RNNModuleAlgoBase(base) {} @@ -542,6 +549,8 @@ class RNNBackwardDataModularAlgo : protected RNNModuleAlgoBase #endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP } + inline size_t getTimeSeqSize() const { return RNNModuleAlgoBase::getTimeSeqSize(); } + RNNBackwardDataModularAlgo(RNNModuleAlgoBase&& base) : RNNModuleAlgoBase(std::move(base)) {} RNNBackwardDataModularAlgo(const RNNModuleAlgoBase& base) : RNNModuleAlgoBase(base) {} }; diff --git a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp index b57877a2a0..f1d3388f26 100644 --- a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp @@ -166,24 +166,26 @@ class RNNModuleAlgoDynamic : public RNNForwardDataModularAlgo const runtimeArgsFwd& runtimeArgs) const; void PropHyCy(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, + const runtimeArgsFwdDynamicExt& runtimeArgs, size_t layer, const SequenceIterator& currentSeq, SequenceDirection direction) const; + inline size_t getRealTimeSeqSize() const { return realBatchController.size(); } + private: - BatchController realBatchController; + const BatchController realBatchController; - SeqTensorDescriptor realXDesc; - SeqTensorDescriptor realYDesc; - SeqTensorDescriptor tmpMapXDesc; - SeqTensorDescriptor tmpMapYDesc; + const SeqTensorDescriptor realXDesc; + const SeqTensorDescriptor realYDesc; + const SeqTensorDescriptor tmpMapXDesc; + const SeqTensorDescriptor tmpMapYDesc; }; class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo { using BaseBWDModuleT = rnn_base::RNNBackwardDataModularAlgo; - + public: RNNBackwardModuleAlgoDynamic(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xTDesc, @@ -297,11 +299,8 @@ class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo // const runtimeArgsBwdDynamicExt& runtimeArgsExt, // const runtimeArgsFwd& runtimeArgs) const; - void PropHyCy(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const; + inline size_t getRealTimeSeqSize() const { return realBatchController.size(); } + private: BatchController realBatchController; @@ -316,24 +315,6 @@ class RNNBackwardWeiModuleAlgoDynamic : public RNNBackwardWeightsModularAlgo { using BaseBWDModuleT = rnn_base::RNNBackwardWeightsModularAlgo; - static SeqTensorDescriptor buildDynamicVirtual(const SeqTensorDescriptor& desc) - { - std::vector def_layout{1, 0, 2}; - return {desc.GetType(), def_layout, desc.GetLengths(), false}; - } - - static SeqTensorDescriptor buildRealToDynamicMapTmp(const SeqTensorDescriptor& desc) - { - std::vector def_layout{1, 0, 2}; - return {desc.GetType(), - def_layout, - desc.GetLengths(), - desc.GetSequenceLengthsVector(), - std::vector{}, - true, - true}; - } - public: RNNBackwardWeiModuleAlgoDynamic(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xTDesc, diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index 353760c8af..d5a3b854ba 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -78,6 +78,7 @@ class RNNModularSingleStreamFWD class RNNDynamicModularSingleStreamFWD { private: + public: RNNDynamicModularSingleStreamFWD(const RNNDescriptor& rnn, const SeqTensorDescriptor& xDesc, @@ -85,8 +86,7 @@ class RNNDynamicModularSingleStreamFWD const TensorDescriptor& hDesc, miopenRNNFWDMode_t mode) : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), - rnnDesc(rnn), - max_seq_len(xDesc.GetMaxSequenceLength()) + rnnDesc(rnn) { } @@ -110,7 +110,6 @@ class RNNDynamicModularSingleStreamFWD const rnn_base::RNNModuleAlgoDynamic rnnAlgoModules; const RNNDescriptor& rnnDesc; - const size_t max_seq_len; }; // @@ -170,8 +169,7 @@ class RNNDynamicModularSingleStreamBWD const TensorDescriptor& hDesc, miopenRNNFWDMode_t mode) : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), - rnnDesc(rnn), - max_seq_len(xDesc.GetMaxSequenceLength()) + rnnDesc(rnn) { } @@ -195,7 +193,6 @@ class RNNDynamicModularSingleStreamBWD const rnn_base::RNNBackwardModuleAlgoDynamic rnnAlgoModules; const RNNDescriptor& rnnDesc; - const size_t max_seq_len; }; class RNNModularMultiStreamBWD diff --git a/src/include/miopen/rnn/tmp_buffer_utils.hpp b/src/include/miopen/rnn/tmp_buffer_utils.hpp index 5bd6883170..cd9c6c40a9 100644 --- a/src/include/miopen/rnn/tmp_buffer_utils.hpp +++ b/src/include/miopen/rnn/tmp_buffer_utils.hpp @@ -561,6 +561,11 @@ class BatchController return batchPrefSumAtTime[time_id]; } + size_t size() const + { + return batchAtTime.size(); + } + private: template >::value, bool> = true> explicit BatchController(T&& batch_at_time, T&& batch_prefix_sums) @@ -940,8 +945,6 @@ class IOBufferDescriptor return {lengths[1] * lengths[2], lengths[2], 1}; } - inline size_t getBufferSizeImpl() const { return packedLens[0]; } - // private: // local caching diff --git a/src/rnn/Solutions/Base/bw_data_modular.cpp b/src/rnn/Solutions/Base/bw_data_modular.cpp index ecbf60d333..464f4993ef 100644 --- a/src/rnn/Solutions/Base/bw_data_modular.cpp +++ b/src/rnn/Solutions/Base/bw_data_modular.cpp @@ -601,53 +601,53 @@ void RNNBackwardDataModularAlgo::UpdateHStatePerTimeSeq(const Handle& handle, else if(rnn_mode == miopenLSTM) { - size_t cur_comb_dim = batchController.getBatchSum(seq.getPhisVal()); - size_t prev_comb_dim = !seq.isFirst() - ? batchController.getBatchSum(seq.getPrev().getPhisVal()) - : batchController.getBatchSum(seq.getPhisVal()); - size_t next_comb_dim = !seq.isLast() - ? batchController.getBatchSum(seq.getNext().getPhisVal()) - : batchController.getBatchSum(seq.getPhisVal()); - - LSTMBackwardHiddenStateUpdate( - handle, - rnn_data_type, - seq.isLast(), // ti == 0, - seq.isFirst(), // ti == seqLen - 1, - static_cast(direction), - batchController.getBatchSize(0), - batchSizeUpdate, - useDcyIfGtBatch, - useCxIfGTBatch, - hidden_vec, - reservLayout.gateStride[1], - -666, // unused - -666, // unused - cx, - hiddenHxCxInfo.getOffset(getVirtualLayer(layer, direction), 0), - reserveSpace, - reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::I), - reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::F), - reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::O), - reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::G), - reservLayout.getActiveCellOffset(layer, cur_comb_dim, direction), - reservLayout.getGasOffset( // TODO - layer, - next_comb_dim, - direction, - LstmGateAndState::St), - dcy, - hiddenHxCxInfo.getOffset(getVirtualLayer(layer, direction), 0), - workSpace, - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::I), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::F), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::O), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::G), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::St), - workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::St), - workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::Ht), - workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::F)); - } + size_t cur_comb_dim = batchController.getBatchSum(seq.getPhisVal()); + size_t prev_comb_dim = !seq.isFirst() + ? batchController.getBatchSum(seq.getPrev().getPhisVal()) + : batchController.getBatchSum(seq.getPhisVal()); + size_t next_comb_dim = !seq.isLast() + ? batchController.getBatchSum(seq.getNext().getPhisVal()) + : batchController.getBatchSum(seq.getPhisVal()); + + LSTMBackwardHiddenStateUpdate( + handle, + rnn_data_type, + seq.isLast(), // ti == 0, + seq.isFirst(), // ti == seqLen - 1, + static_cast(direction), + batchController.getBatchSize(0), + batchSizeUpdate, + useDcyIfGtBatch, + useCxIfGTBatch, + hidden_vec, + reservLayout.gateStride[1], + -666, // unused + -666, // unused + cx, + hiddenHxCxInfo.getOffset(getVirtualLayer(layer, direction), 0), + reserveSpace, + reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::I), + reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::F), + reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::O), + reservLayout.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::G), + reservLayout.getActiveCellOffset(layer, cur_comb_dim, direction), + reservLayout.getGasOffset( // TODO + layer, + next_comb_dim, + direction, + LstmGateAndState::St), + dcy, + hiddenHxCxInfo.getOffset(getVirtualLayer(layer, direction), 0), + workSpace, + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::I), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::F), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::O), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::G), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::St), + workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::St), + workspaceInfo.getGasOffset(layer, cur_comb_dim, direction, LstmGateAndState::Ht), + workspaceInfo.getGasOffset(layer, prev_comb_dim, direction, LstmGateAndState::F)); + } else if(rnn_mode == miopenGRU) { MIOPEN_THROW(miopenStatusInternalError, "TODO implementation miopenGRU"); @@ -792,77 +792,17 @@ void RNNBackwardModuleAlgoDynamic::PrepareWriteBuffers( RNNBackwardDataModularAlgo::PrepareWriteBuffers( handle, runtimeArgsExt.dhx, runtimeArgsExt.dcx, runtimeArgsExt.workSpace); - // realDxProp(handle, runtimeArgsExt); -} - -void RNNBackwardModuleAlgoDynamic::PropHyCy(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, - size_t layer, - const SequenceIterator& currentSeq, - SequenceDirection direction) const -{ - if(runtimeArgs.hy != nullptr || (runtimeArgs.cy != nullptr)) { - const auto gap_batch_size = [&]() { - if(currentSeq.isLast()) - { - return realBatchController.getBatchSize(currentSeq.getPhisVal()); - } - else - { - if(direction == SequenceDirection::Forward) - { - return realBatchController.getBatchSize(currentSeq.getPhisVal()) - - realBatchController.getBatchSize(currentSeq.getNext().getPhisVal()); - } - else - return static_cast(0); - } - }(); - - const auto gap_batch_offset = [&]() { - if(currentSeq.isLast()) - return static_cast(0); - else - return realBatchController.getBatchSize(currentSeq.getPhisVal()) - gap_batch_size; - }(); - - if(gap_batch_size > 0) - { - - auto src_desc = BuildTempDhtDesc3D(1, gap_batch_size); - - auto dst_desc = BuildHxCxDesc3D(1, gap_batch_size); - - size_t tmp_batch_offset = - batchController.getBatchSum(currentSeq.getPhisVal()) + gap_batch_offset; - - if(runtimeArgs.hy != nullptr) - { - CopyTensor(handle, - src_desc, - runtimeArgs.reserveSpace, - dst_desc, - runtimeArgs.hy, - reservLayout.getGasOffset( - layer, tmp_batch_offset, direction, LstmGateAndState::Ht), - hiddenHxCxInfo.getOffset(layer, gap_batch_offset)); - } - - if(runtimeArgs.cy != nullptr) - { - CopyTensor(handle, - src_desc, - runtimeArgs.reserveSpace, - dst_desc, - runtimeArgs.cy, - reservLayout.getGasOffset( - layer, tmp_batch_offset, direction, LstmGateAndState::St), - hiddenHxCxInfo.getOffset(layer, gap_batch_offset)); - } - } + float beta = 0; + auto temp_dy_size = buildDynamicVirtual(realDyDesc).GetElementCount(); + miopen::TensorDescriptor temp_dy_desk{ + rnnDesc.dataType, {1, temp_dy_size}, {temp_dy_size, 1}}; + SetTensor(handle, temp_dy_desk, runtimeArgsExt.tempDy, &beta); } + + // realDxProp(handle, runtimeArgsExt); } + } // namespace rnn_base } // namespace miopen diff --git a/src/rnn/Solutions/Base/fw_data_modular.cpp b/src/rnn/Solutions/Base/fw_data_modular.cpp index 235c291d3d..ee86a4a5a9 100644 --- a/src/rnn/Solutions/Base/fw_data_modular.cpp +++ b/src/rnn/Solutions/Base/fw_data_modular.cpp @@ -547,11 +547,19 @@ void RNNModuleAlgoDynamic::PrepareWriteBuffers(const Handle& handle, const runtimeArgsFwd& runtimeArgs) const { RNNForwardDataModularAlgo::PrepareWriteBuffers(handle, runtimeArgs); + + { + float beta = 0; + auto temp_x_size = buildDynamicVirtual(realXDesc).GetElementCount(); + miopen::TensorDescriptor temp_x_desk{rnnDesc.dataType, {1, temp_x_size}, {temp_x_size, 1}}; + SetTensor(handle, temp_x_desk, runtimeArgsExt.tempX, &beta); + } + realXProp(handle, runtimeArgsExt); } void RNNModuleAlgoDynamic::PropHyCy(const Handle& handle, - const runtimeArgsFwd& runtimeArgs, + const runtimeArgsFwdDynamicExt& runtimeArgs, size_t layer, const SequenceIterator& currentSeq, SequenceDirection direction) const diff --git a/src/rnn/Solutions/bwd_s_stream.cpp b/src/rnn/Solutions/bwd_s_stream.cpp index d84caa6e83..8289879f74 100644 --- a/src/rnn/Solutions/bwd_s_stream.cpp +++ b/src/rnn/Solutions/bwd_s_stream.cpp @@ -139,7 +139,10 @@ void RNNDynamicModularSingleStreamBWD::ComputeBWD(Handle& handle, { auto layer_i = rnnDesc.nLayers; - if(layer_i == 0 || max_seq_len == 0) + auto seq_iterations = rnnAlgoModules.getTimeSeqSize(); + auto real_seq_iterations = rnnAlgoModules.getRealTimeSeqSize(); + + if(layer_i == 0 || seq_iterations == 0 || real_seq_iterations == 0) return; auto sequence_directions = @@ -174,17 +177,31 @@ void RNNDynamicModularSingleStreamBWD::ComputeBWD(Handle& handle, const auto seq_dir = dir == 0 ? rnn_base::SequenceDirection::Forward : rnn_base::SequenceDirection::Reverse; - auto ti = max_seq_len; + auto ti = seq_iterations; do { - const rnn_base::SequenceIterator cur_seq(--ti, seq_dir, max_seq_len, false); - - rnnAlgoModules.realPropDhy(handle, dhy, workSpace, layer_i, cur_seq, seq_dir); - - // rnnAlgoModules.HtHiddenDataZeroing(); - - rnnAlgoModules.realUpdateHStatePerTimeSeq( - handle, dcy, cx, dcx, workSpace, reserveSpace, layer_i, cur_seq, seq_dir); + const rnn_base::SequenceIterator cur_seq(--ti, seq_dir, seq_iterations, false); + + if(ti < real_seq_iterations) + { + const rnn_base::SequenceIterator real_cur_seq( + ti, seq_dir, real_seq_iterations, false); + + rnnAlgoModules.realPropDhy( + handle, dhy, workSpace, layer_i, real_cur_seq, seq_dir); + + // rnnAlgoModules.HtHiddenDataZeroing(); + + rnnAlgoModules.realUpdateHStatePerTimeSeq(handle, + dcy, + cx, + dcx, + workSpace, + reserveSpace, + layer_i, + real_cur_seq, + seq_dir); + } // GEMM if(ti != 0) diff --git a/src/rnn/Solutions/fwd_s_stream.cpp b/src/rnn/Solutions/fwd_s_stream.cpp index 74a2c0f632..6147f313a0 100644 --- a/src/rnn/Solutions/fwd_s_stream.cpp +++ b/src/rnn/Solutions/fwd_s_stream.cpp @@ -80,8 +80,10 @@ void RNNModularSingleStreamFWD::ComputeFWD(Handle& handle, const runtimeArgsFwd& void RNNDynamicModularSingleStreamFWD::ComputeFWD(Handle& handle, const runtimeArgsFwd& realRuntimeArgs) const { + auto seq_iterations = rnnAlgoModules.getTimeSeqSize(); + auto real_seq_iterations = rnnAlgoModules.getRealTimeSeqSize(); - if(rnnDesc.nLayers == 0 || max_seq_len == 0) + if(rnnDesc.nLayers == 0 || seq_iterations == 0) return; auto sequence_directions = @@ -105,7 +107,7 @@ void RNNDynamicModularSingleStreamFWD::ComputeFWD(Handle& handle, rnnAlgoModules.PropX(handle, runtimeArgs); rnnAlgoModules.AddBias(handle, runtimeArgs); - + for(auto layer_i = 0; layer_i < rnnDesc.nLayers; ++layer_i) { @@ -117,9 +119,9 @@ void RNNDynamicModularSingleStreamFWD::ComputeFWD(Handle& handle, if(layer_i != 0) rnnAlgoModules.PropHiddenY(handle, runtimeArgs, layer_i, seq_dir); - for(int ti = 0; ti < max_seq_len; ti++) + for(int ti = 0; ti < seq_iterations; ti++) { - const rnn_base::SequenceIterator cur_seq(ti, seq_dir, max_seq_len, true); + const rnn_base::SequenceIterator cur_seq(ti, seq_dir, seq_iterations, true); if(ti == 0) rnnAlgoModules.PropHxCx(handle, runtimeArgs, layer_i, cur_seq, seq_dir); @@ -129,7 +131,13 @@ void RNNDynamicModularSingleStreamFWD::ComputeFWD(Handle& handle, rnnAlgoModules.UpdateHStatePerTimeSeq( handle, runtimeArgs, layer_i, cur_seq, seq_dir); - rnnAlgoModules.PropHyCy(handle, runtimeArgs, layer_i, cur_seq, seq_dir); + if(ti < real_seq_iterations) + { + const rnn_base::SequenceIterator real_cur_seq( + ti, seq_dir, real_seq_iterations, true); + + rnnAlgoModules.PropHyCy(handle, runtimeArgsExt, layer_i, real_cur_seq, seq_dir); + } } } } From 3ff3954e0a9e4b14b12018118906e78cf9679e08 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Wed, 27 Nov 2024 23:01:54 +0100 Subject: [PATCH 15/21] clang-format --- include/miopen/miopen.h | 8 ++++---- src/include/miopen/rnn/algorithms/default_algo_utils.hpp | 5 +---- src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp | 3 +-- src/include/miopen/rnn/solvers.hpp | 7 ++----- src/include/miopen/rnn/tmp_buffer_utils.hpp | 5 +---- src/ocl/rnnocl.cpp | 1 - src/rnn/Solutions/Base/bw_data_modular.cpp | 7 ++----- src/rnn/Solutions/fwd_s_stream.cpp | 4 ++-- src/rnn/selector.cpp | 7 +++---- 9 files changed, 16 insertions(+), 31 deletions(-) diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 31fb12a0d3..2b9b125512 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -3748,10 +3748,10 @@ typedef enum */ typedef enum { - miopenRNNdefault = 0, /*!< Use dedicated gate-operation kernel for LSTM and fundamental - algorithm for vanilla RNN & GRU */ - miopenRNNfundamental = - 1, /*!< Deprecated, low performance. Function by basic tesnsor operations, supported for vanilla RNN, LSTM, GRU */ + miopenRNNdefault = 0, /*!< Use dedicated gate-operation kernel for LSTM and fundamental + algorithm for vanilla RNN & GRU */ + miopenRNNfundamental = 1, /*!< Deprecated, low performance. Function by basic tesnsor + operations, supported for vanilla RNN, LSTM, GRU */ miopenRNNroundedDynamic = 2, /*!< The algorithm rounds some RNN parametrs upwards to utilize the most optimal GEMM kernel in the computation.*/ } miopenRNNAlgo_t; diff --git a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp index 7fddc0a327..f0503d6654 100644 --- a/src/include/miopen/rnn/algorithms/default_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/default_algo_utils.hpp @@ -224,10 +224,7 @@ class RNNModuleAlgoBase (direction == SequenceDirection::Forward ? 0 : 1); } - inline size_t getTimeSeqSize() const - { - return batchController.size(); - } + inline size_t getTimeSeqSize() const { return batchController.size(); } template inline miopen::TensorDescriptor BuildLstmTmpBlockDesc2D(const BufType& buf_info, diff --git a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp index f1d3388f26..2771724294 100644 --- a/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp +++ b/src/include/miopen/rnn/algorithms/dynamic_algo_utils.hpp @@ -185,7 +185,7 @@ class RNNModuleAlgoDynamic : public RNNForwardDataModularAlgo class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo { using BaseBWDModuleT = rnn_base::RNNBackwardDataModularAlgo; - + public: RNNBackwardModuleAlgoDynamic(const RNNDescriptor& rnnD, const SeqTensorDescriptor& xTDesc, @@ -301,7 +301,6 @@ class RNNBackwardModuleAlgoDynamic : public RNNBackwardDataModularAlgo inline size_t getRealTimeSeqSize() const { return realBatchController.size(); } - private: BatchController realBatchController; diff --git a/src/include/miopen/rnn/solvers.hpp b/src/include/miopen/rnn/solvers.hpp index d5a3b854ba..96c4c40f5f 100644 --- a/src/include/miopen/rnn/solvers.hpp +++ b/src/include/miopen/rnn/solvers.hpp @@ -78,15 +78,13 @@ class RNNModularSingleStreamFWD class RNNDynamicModularSingleStreamFWD { private: - public: RNNDynamicModularSingleStreamFWD(const RNNDescriptor& rnn, const SeqTensorDescriptor& xDesc, const SeqTensorDescriptor& yDesc, const TensorDescriptor& hDesc, miopenRNNFWDMode_t mode) - : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), - rnnDesc(rnn) + : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), rnnDesc(rnn) { } @@ -168,8 +166,7 @@ class RNNDynamicModularSingleStreamBWD const SeqTensorDescriptor& yDesc, const TensorDescriptor& hDesc, miopenRNNFWDMode_t mode) - : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), - rnnDesc(rnn) + : rnnAlgoModules(rnn, xDesc, yDesc, hDesc, mode), rnnDesc(rnn) { } diff --git a/src/include/miopen/rnn/tmp_buffer_utils.hpp b/src/include/miopen/rnn/tmp_buffer_utils.hpp index cd9c6c40a9..b595ef27e0 100644 --- a/src/include/miopen/rnn/tmp_buffer_utils.hpp +++ b/src/include/miopen/rnn/tmp_buffer_utils.hpp @@ -561,10 +561,7 @@ class BatchController return batchPrefSumAtTime[time_id]; } - size_t size() const - { - return batchAtTime.size(); - } + size_t size() const { return batchAtTime.size(); } private: template >::value, bool> = true> diff --git a/src/ocl/rnnocl.cpp b/src/ocl/rnnocl.cpp index 22c05811c3..b1c35071f2 100644 --- a/src/ocl/rnnocl.cpp +++ b/src/ocl/rnnocl.cpp @@ -2789,7 +2789,6 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( // input check end bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); - // high priority for DynamicAlgo if(!CheckDynamicAlgoSelection(handle, {}, miopenRNNTraining) && RNNForwardMSIsSupported(*this, false) && RNNForwardMSIsFast(seqLen)) diff --git a/src/rnn/Solutions/Base/bw_data_modular.cpp b/src/rnn/Solutions/Base/bw_data_modular.cpp index 464f4993ef..95c1a2d239 100644 --- a/src/rnn/Solutions/Base/bw_data_modular.cpp +++ b/src/rnn/Solutions/Base/bw_data_modular.cpp @@ -275,10 +275,8 @@ void RNNBackwardDataModularAlgo::PropDhxDcx(const Handle& handle, const float alpha1 = 1; const float beta_t = 1; - const auto bOffset = reservLayout.getGasOffset(layer, - acc_batch_offset, - direction, - LstmGateAndState::F); + const auto bOffset = + reservLayout.getGasOffset(layer, acc_batch_offset, direction, LstmGateAndState::F); const auto a_offset = workspaceInfo.getGasOffset( layer, acc_batch_offset, direction, LstmGateAndState::St); @@ -803,6 +801,5 @@ void RNNBackwardModuleAlgoDynamic::PrepareWriteBuffers( // realDxProp(handle, runtimeArgsExt); } - } // namespace rnn_base } // namespace miopen diff --git a/src/rnn/Solutions/fwd_s_stream.cpp b/src/rnn/Solutions/fwd_s_stream.cpp index 6147f313a0..3a6de5e672 100644 --- a/src/rnn/Solutions/fwd_s_stream.cpp +++ b/src/rnn/Solutions/fwd_s_stream.cpp @@ -80,7 +80,7 @@ void RNNModularSingleStreamFWD::ComputeFWD(Handle& handle, const runtimeArgsFwd& void RNNDynamicModularSingleStreamFWD::ComputeFWD(Handle& handle, const runtimeArgsFwd& realRuntimeArgs) const { - auto seq_iterations = rnnAlgoModules.getTimeSeqSize(); + auto seq_iterations = rnnAlgoModules.getTimeSeqSize(); auto real_seq_iterations = rnnAlgoModules.getRealTimeSeqSize(); if(rnnDesc.nLayers == 0 || seq_iterations == 0) @@ -107,7 +107,7 @@ void RNNDynamicModularSingleStreamFWD::ComputeFWD(Handle& handle, rnnAlgoModules.PropX(handle, runtimeArgs); rnnAlgoModules.AddBias(handle, runtimeArgs); - + for(auto layer_i = 0; layer_i < rnnDesc.nLayers; ++layer_i) { diff --git a/src/rnn/selector.cpp b/src/rnn/selector.cpp index 625730d59c..1c91fc12d3 100644 --- a/src/rnn/selector.cpp +++ b/src/rnn/selector.cpp @@ -67,7 +67,6 @@ bool RNNBwWeightMSIsFast(const int seqLen) return false; } - std::tuple RNNDescriptor::GetTmpSpaceSizeDynamicAlgo( Handle& /*handle*/, const SeqTensorDescriptor& xDesc, miopenRNNFWDMode_t /*fwdMode*/) const { @@ -84,9 +83,9 @@ bool RNNDescriptor::CheckDynamicAlgoSelection(Handle& /*handle*/, bool algo_mode_match = algoMode == miopenRNNroundedDynamic || (algoMode == miopenRNNdefault && env::enabled(MIOPEN_RNN_DYNAMIC_FORCE)); - bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); - bool rnn_config_match = (dirMode == 0 && inputMode == miopenRNNlinear && - rnnMode == miopenLSTM && !use_dropout); + bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); + bool rnn_config_match = + (dirMode == 0 && inputMode == miopenRNNlinear && rnnMode == miopenLSTM && !use_dropout); if(rnn_config_match && algo_mode_match) { From 88f97c3a1c99e81b8513770e865954c343a32672 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Thu, 28 Nov 2024 00:51:44 +0100 Subject: [PATCH 16/21] format --- src/rnn.cpp | 3 ++- test/rnn_seq_api.cpp | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rnn.cpp b/src/rnn.cpp index 8096fc97d5..cab19c0f86 100644 --- a/src/rnn.cpp +++ b/src/rnn.cpp @@ -572,7 +572,8 @@ size_t RNNDescriptor::GetWorkspaceSize(Handle& handle, size_t RNNDescriptor::GetReserveSize(size_t batchLenSum) const { auto x = 2 * workspaceScale * nLayers * batchLenSum * hsize * typeSize; - if( (algoMode == miopenRNNdefault || algoMode == miopenRNNroundedDynamic) && rnnMode == miopenLSTM) + if((algoMode == miopenRNNdefault || algoMode == miopenRNNroundedDynamic) && + rnnMode == miopenLSTM) { x /= 2; x += nLayers * batchLenSum * hsize * typeSize; diff --git a/test/rnn_seq_api.cpp b/test/rnn_seq_api.cpp index 91c3e17ae2..c6300a5603 100644 --- a/test/rnn_seq_api.cpp +++ b/test/rnn_seq_api.cpp @@ -133,7 +133,8 @@ struct rnn_seq_driver : rnn_seq_api_test_driver void run() { - if(!this->full_set || (is_correct_params() && !is_skip_comb() && !is_dynamic_algo_skip_case())) + if(!this->full_set || + (is_correct_params() && !is_skip_comb() && !is_dynamic_algo_skip_case())) rnn_seq_api_test_driver::run(); else { From 69774498101d42929646e1590e42689553e7341a Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Thu, 28 Nov 2024 13:50:46 +0100 Subject: [PATCH 17/21] tidy --- test/rnn_seq_api.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/rnn_seq_api.hpp b/test/rnn_seq_api.hpp index 91eb3e772b..739d4c8b44 100644 --- a/test/rnn_seq_api.hpp +++ b/test/rnn_seq_api.hpp @@ -1088,8 +1088,10 @@ struct verify_train_rnn : verify_rnn_api_base nohy, nocy); if(skip_backward_data) + { return result_tuple( std::move(fwd_y), std::move(fwd_hy), std::move(fwd_cy), {}, {}, {}, {}); + } auto [bwd_din, bwd_dhx, bwd_dcx] = refMethod.bwd(input.desc, output.desc, @@ -1109,6 +1111,7 @@ struct verify_train_rnn : verify_rnn_api_base nocx); if(skip_backward_weights) + { return result_tuple(std::move(fwd_y), std::move(fwd_hy), std::move(fwd_cy), @@ -1116,6 +1119,7 @@ struct verify_train_rnn : verify_rnn_api_base std::move(bwd_dhx), std::move(bwd_dcx), {}); + } auto wrw_res = refMethod.wrw(input.desc, output.desc, From 6a8c777ccd0ee923a8f3934c138732c136d6e9a2 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Mon, 2 Dec 2024 18:44:10 +0100 Subject: [PATCH 18/21] cast --- src/include/miopen/rnn/tmp_buffer_utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/rnn/tmp_buffer_utils.hpp b/src/include/miopen/rnn/tmp_buffer_utils.hpp index b595ef27e0..04ab1720a6 100644 --- a/src/include/miopen/rnn/tmp_buffer_utils.hpp +++ b/src/include/miopen/rnn/tmp_buffer_utils.hpp @@ -66,7 +66,7 @@ OutputIt exclusive_scan_wa(InputIt first, InputIt last, OutputIt d_first, T init inline Data_t moveDataPtrByte(Data_t ptr, size_t byteOffset) { - return static_cast(reinterpret_cast(ptr) + byteOffset); + return static_cast(ptr) + byteOffset; } inline Data_t moveDataPtr(Data_t ptr, size_t elementOffset, miopenDataType_t elementType) From e1248134071991730310b51a58f20b514482455c Mon Sep 17 00:00:00 2001 From: Kamil Nasyrov Date: Wed, 4 Dec 2024 20:05:28 +0100 Subject: [PATCH 19/21] Update src/rnn/Solutions/bwd_s_stream.cpp Co-authored-by: Alex Eremin --- src/rnn/Solutions/bwd_s_stream.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rnn/Solutions/bwd_s_stream.cpp b/src/rnn/Solutions/bwd_s_stream.cpp index 8289879f74..fc931131dc 100644 --- a/src/rnn/Solutions/bwd_s_stream.cpp +++ b/src/rnn/Solutions/bwd_s_stream.cpp @@ -149,7 +149,7 @@ void RNNDynamicModularSingleStreamBWD::ComputeBWD(Handle& handle, rnnDesc.dirMode == miopenRNNDirectionMode_t::miopenRNNbidirection ? 2 : 1; const auto runtimeArgsExt = rnnAlgoModules.createRuntimeArgsExt(realRuntimeArgs); - const auto [real_dy, + const auto& [real_dy, temp_dy, dhy, dhx, From a29b9de3a3bcc3488c4ecd4f2cca297e4acbd961 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Thu, 5 Dec 2024 19:59:30 +0100 Subject: [PATCH 20/21] format --- src/rnn/Solutions/bwd_s_stream.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/rnn/Solutions/bwd_s_stream.cpp b/src/rnn/Solutions/bwd_s_stream.cpp index fc931131dc..b59b2a37d5 100644 --- a/src/rnn/Solutions/bwd_s_stream.cpp +++ b/src/rnn/Solutions/bwd_s_stream.cpp @@ -148,19 +148,19 @@ void RNNDynamicModularSingleStreamBWD::ComputeBWD(Handle& handle, auto sequence_directions = rnnDesc.dirMode == miopenRNNDirectionMode_t::miopenRNNbidirection ? 2 : 1; - const auto runtimeArgsExt = rnnAlgoModules.createRuntimeArgsExt(realRuntimeArgs); + const auto runtimeArgsExt = rnnAlgoModules.createRuntimeArgsExt(realRuntimeArgs); const auto& [real_dy, - temp_dy, - dhy, - dhx, - cx, - dcy, - dcx, - real_dx, - temp_dx, - w, - workSpace, - reserveSpace] = runtimeArgsExt; + temp_dy, + dhy, + dhx, + cx, + dcy, + dcx, + real_dx, + temp_dx, + w, + workSpace, + reserveSpace] = runtimeArgsExt; rnnAlgoModules.PrepareWriteBuffers(handle, runtimeArgsExt); From 52d11c13fa64cc7493a20917ccace35bf44ab4f8 Mon Sep 17 00:00:00 2001 From: Shurale-nkn Date: Tue, 10 Dec 2024 22:12:35 +0100 Subject: [PATCH 21/21] update --- src/rnn.cpp | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/src/rnn.cpp b/src/rnn.cpp index cab19c0f86..e8bec2be62 100644 --- a/src/rnn.cpp +++ b/src/rnn.cpp @@ -34,6 +34,7 @@ #include #include "miopen/env.hpp" +#include "miopen/rnn/solvers.hpp" // Disable specific warnings #define MIO_RNN_DEBUG 0 @@ -475,15 +476,7 @@ size_t RNNDescriptor::GetWorkspaceSize(Handle& handle, transformer_tmp_space = RNNTransformerWorkspaceSize(xDesc, fwdMode); } - const std::size_t total_sequence_len = xDesc.GetTotalSequenceLen(); - - size_t reduction_ws = ReductionWorkspaceSize(handle, - total_sequence_len, - nHiddenTensorsPerLayer, - workspaceScale, - hsize, - dirMode == miopenRNNbidirection, - dataType); + std::size_t total_sequence_len = 0; size_t solution_ws = 0; @@ -491,12 +484,23 @@ size_t RNNDescriptor::GetWorkspaceSize(Handle& handle, { auto [ws, rs] = GetTmpSpaceSizeDynamicAlgo(handle, xDesc, miopenRNNTraining); solution_ws = ws; + auto lens = rnn_base::roundedDynamicLengths(xDesc); + + total_sequence_len = lens[0] * lens[1]; } else { + total_sequence_len = xDesc.GetTotalSequenceLen(); solution_ws = GetMainSolWorkspaceSize(total_sequence_len, fwdMode, miopenRNNDataSeqMajorNotPadded); } + size_t reduction_ws = ReductionWorkspaceSize(handle, + total_sequence_len, + nHiddenTensorsPerLayer, + workspaceScale, + hsize, + dirMode == miopenRNNbidirection, + dataType); return transformer_tmp_space + reduction_ws + solution_ws; } @@ -539,16 +543,30 @@ size_t RNNDescriptor::GetWorkspaceSize(Handle& handle, SeqTensorDescriptor xSeqTDesc = makeSeqTensorDescriptor(xDesc, seqLength, miopenRNNDataSeqMajorNotPadded); + std::size_t total_sequence_len = 0; + if(CheckDynamicAlgoSelection(handle, xSeqTDesc, miopenRNNTraining)) { auto [ws, rs] = GetTmpSpaceSizeDynamicAlgo(handle, xSeqTDesc, miopenRNNTraining); - return ws + padding_converter_tmp_space; + + auto lens = rnn_base::roundedDynamicLengths(xSeqTDesc); + + total_sequence_len = lens[0] * lens[1]; + + return ws + padding_converter_tmp_space + + ReductionWorkspaceSize(handle, + total_sequence_len, + nHiddenTensorsPerLayer, + workspaceScale, + hsize, + dirMode == miopenRNNbidirection, + dataType); + ; } else { - std::size_t total_sequence_len = 0; - total_sequence_len = std::accumulate( + total_sequence_len = std::accumulate( xDesc.data, xDesc.data + seqLength, 0ULL, [](size_t x, miopenTensorDescriptor_t y) { return x + deref(y).GetLengths()[0]; });