diff --git a/src/batch_norm.cpp b/src/batch_norm.cpp index 2c5486f307..55dbefd15a 100644 --- a/src/batch_norm.cpp +++ b/src/batch_norm.cpp @@ -44,7 +44,7 @@ void DeriveBNTensorDescriptor(TensorDescriptor& derivedBnDesc, { auto lengths = xDesc.GetLengths(); - std::vector newlens(lengths.size()); + std::vector newlens(lengths.size()); newlens[1] = lengths[1]; if(bn_mode == miopenBNSpatial) { diff --git a/src/include/miopen/pooling.hpp b/src/include/miopen/pooling.hpp index 0ab5ffa1c7..d51da9b5ce 100644 --- a/src/include/miopen/pooling.hpp +++ b/src/include/miopen/pooling.hpp @@ -131,7 +131,7 @@ struct MIOPEN_EXPORT PoolingDescriptor : miopenPoolingDescriptor std::tuple GetForwardOutputDim(const TensorDescriptor& xDesc) const; - void GetForwardOutputDimNd(const TensorDescriptor& xDesc, int dims, int* tensorDimArr) const; + void GetForwardOutputDimNd(const TensorDescriptor& xDesc, int dims, size_t* tensorDimArr) const; TensorDescriptor GetForwardOutputTensor(const TensorDescriptor& xDesc) const; diff --git a/src/include/miopen/rnn.hpp b/src/include/miopen/rnn.hpp index eba00e2a29..a5b8fac0b7 100644 --- a/src/include/miopen/rnn.hpp +++ b/src/include/miopen/rnn.hpp @@ -110,7 +110,7 @@ struct MIOPEN_INTERNALS_EXPORT RNNDescriptor : miopenRNNDescriptor size_t paramsOffsetCalculation(const TensorDescriptor& xDesc, int layer, int paramID) const; - std::vector + std::vector pTensorLengthsCalculation(const TensorDescriptor& xDesc, int layer, int paramID) const; static SeqTensorDescriptor makeSeqTensorDescriptor(miopenDataType_t t, @@ -538,7 +538,7 @@ struct MIOPEN_INTERNALS_EXPORT RNNDescriptor : miopenRNNDescriptor size_t reserveSpaceSize) const; void RNNForwardMS(Handle& handle, - std::vector& seq_array, + std::vector& seq_array, const TensorDescriptor& xDesc, ConstData_t x, const TensorDescriptor& hxDesc, diff --git a/src/include/miopen/rnn_util.hpp b/src/include/miopen/rnn_util.hpp index 92876b8a9a..1bb38dfb9b 100644 --- a/src/include/miopen/rnn_util.hpp +++ b/src/include/miopen/rnn_util.hpp @@ -171,7 +171,7 @@ struct RNNTensorPaddingConverter { static void ConvertTensorData(const Handle& handle, const TensorDescriptor& padded_tensor_desc, - std::vector& bsize_per_time, + std::vector& bsize_per_time, ConstData_t src, Data_t dst, bool is_src_padded); diff --git a/src/include/miopen/tensor.hpp b/src/include/miopen/tensor.hpp index 48a05a5a98..d4edc4047c 100644 --- a/src/include/miopen/tensor.hpp +++ b/src/include/miopen/tensor.hpp @@ -151,15 +151,10 @@ struct MIOPEN_INTERNALS_EXPORT TensorDescriptor : miopenTensorDescriptor // The delegation constructor should be placed above the target constructor in the // code for better dependency tracking - TensorDescriptor(miopenDataType_t t, const std::initializer_list& lens_in); - TensorDescriptor(miopenDataType_t t, const std::vector& lens_in); TensorDescriptor(miopenDataType_t t, const std::initializer_list& lens_in); TensorDescriptor(miopenDataType_t t, const std::vector& lens_in); TensorDescriptor(miopenDataType_t t, std::vector&& lens_in); - TensorDescriptor(miopenDataType_t t, - miopenTensorLayout_t layout_in, - const std::vector& lens_in); TensorDescriptor(miopenDataType_t t, miopenTensorLayout_t layout_in, const std::initializer_list& lens_in); @@ -170,9 +165,6 @@ struct MIOPEN_INTERNALS_EXPORT TensorDescriptor : miopenTensorDescriptor miopenTensorLayout_t layout_in, std::vector&& lens_in); - TensorDescriptor(miopenDataType_t t, - const std::vector& lens_in, - const std::vector& strides_in); TensorDescriptor(miopenDataType_t t, const std::initializer_list& lens_in, const std::initializer_list& strides_in); diff --git a/src/ocl/ctcocl.cpp b/src/ocl/ctcocl.cpp index 35001e2855..165b418d8c 100644 --- a/src/ocl/ctcocl.cpp +++ b/src/ocl/ctcocl.cpp @@ -193,7 +193,7 @@ void CTCLossDescriptor::CTCLoss(Handle& handle, float time = 0.; if(apply_softmax_layer) { - std::vector sfm_size(4, 1); + std::vector sfm_size(4, 1); sfm_size[0] = max_time_step * batch_size; sfm_size[1] = class_sz; auto sfm_desc = miopen::TensorDescriptor(probsDesc.GetType(), sfm_size); diff --git a/src/ocl/rnnocl.cpp b/src/ocl/rnnocl.cpp index 2700034f93..d9247caedc 100644 --- a/src/ocl/rnnocl.cpp +++ b/src/ocl/rnnocl.cpp @@ -174,7 +174,7 @@ miopenStatus_t ReducAddBias(miopen::Handle& handle, int lda = k, ldb = ws_desc.GetStrides()[1], ldc = n; const miopen::TensorDescriptor red_matrix{ - red_type, std::vector{1, 1, k}, std::vector{k, k, 1}}; + red_type, std::vector{1, 1, k}, std::vector{k, k, 1}}; SetTensor(handle, red_matrix, red_workSpace, &alpha1); @@ -254,7 +254,7 @@ miopenStatus_t ReducAddBias(miopen::Handle& handle, } // namespace void RNNDescriptor::RNNForwardMS(Handle& handle, - std::vector& seq_array, + std::vector& seq_array, const TensorDescriptor& xDesc, ConstData_t x, const TensorDescriptor& hxDesc, @@ -271,7 +271,7 @@ void RNNDescriptor::RNNForwardMS(Handle& handle, miopenRNNFWDMode_t fwd_mode) const { #if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP - std::vector in_n; + std::vector in_n; int in_vec = xDesc.GetLengths()[1]; // input vector size int out_vec = yDesc.GetLengths()[1]; // output vector size @@ -289,7 +289,7 @@ void RNNDescriptor::RNNForwardMS(Handle& handle, ms_controller.ChangeActiveStream(root_stream_id); int total_batch_size = 0; - std::vector bacc_per_time(seq_len + 1); + std::vector bacc_per_time(seq_len + 1); for(int i = 0; i < seq_len; i++) { @@ -583,13 +583,13 @@ void RNNDescriptor::RNNForwardMS(Handle& handle, const auto bias_desc = miopen::TensorDescriptor(wDesc.GetType(), - std::vector{1, 1, WeiBuf.bias_vector_mul_gate()}, - std::vector{bias_stride, bias_stride, 1}); + std::vector{1, 1, WeiBuf.bias_vector_mul_gate()}, + std::vector{bias_stride, bias_stride, 1}); const auto hidden_interim_desc = miopen::TensorDescriptor( wDesc.GetType(), - std::vector{1, RBuff.batches, WeiBuf.bias_vector_mul_gate()}, - std::vector{ + std::vector{1, RBuff.batches, WeiBuf.bias_vector_mul_gate()}, + std::vector{ RBuff.batches * RBuff.gemm_write_stride(), RBuff.gemm_write_stride(), 1}); const auto RB_layer_out_off = RBuff.layer_offset(layer); @@ -1064,7 +1064,7 @@ void RNNDescriptor::RNNForwardMS(Handle& handle, } else { - std::vector layer_stream_id(nLayers, 2); + std::vector layer_stream_id(nLayers, 2); layer_stream_id[0] = 1; auto dispatch_next_chunk = [&layer_upd_cur_time, @@ -1244,7 +1244,7 @@ void RNNDescriptor::RNNForwardInference(Handle& handle, // RNNTensorPaddingConverter::CreatePackedDescriptor() // for future developments: as long as we don't use strides from xDesc and yDesc // we ignoring conversion of this descriptors. - std::vector in_n(seqLen); + std::vector in_n(seqLen); for(int i = 0; i < seqLen; i++) { @@ -1327,7 +1327,7 @@ void RNNDescriptor::RNNForwardInferencePacked(Handle& handle, // reset kernel timer profileRNNkernels(handle, 0, ctime); - std::vector in_n; + std::vector in_n; int in_h = xDesc[0].GetLengths()[1]; // input vector size int hy_d = hyDesc.GetLengths()[0]; // biNumLayers int hy_n = hyDesc.GetLengths()[1]; // max batch size @@ -1421,7 +1421,7 @@ void RNNDescriptor::RNNForwardInferencePacked(Handle& handle, float alpha0, alpha1, beta_t; float alpha = 1, beta = 0; - std::vector sp_size(3, 1), sp_stride(3, 1), w_size(3, 1), w_stride(3, 1), x_size(3, 1), + std::vector sp_size(3, 1), sp_stride(3, 1), w_size(3, 1), w_stride(3, 1), x_size(3, 1), x_stride(3, 1), y_size(3, 1), y_stride(3, 1), hx_size(3, 1), hx_stride(3, 1); miopen::TensorDescriptor sp_desc, w_desc, x_desc, y_desc, hx_desc; @@ -2635,7 +2635,7 @@ void RNNDescriptor::RNNForwardTraining(Handle& handle, // RNNTensorPaddingConverter::CreatePackedDescriptor() // for future developments: as long as we don't use strides from xDesc and yDesc // we ignoring conversion of this descriptors. - std::vector in_n(seqLen); + std::vector in_n(seqLen); for(int i = 0; i < seqLen; i++) { @@ -2749,7 +2749,7 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( } int batch_n = 0; - std::vector in_n; + std::vector in_n; for(int i = 0; i < seqLen; i++) { int batchval, batchvalout; @@ -2842,7 +2842,7 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( float alpha0, alpha1, beta_t; float alpha = 1, beta = 0; - std::vector sp_size(3, 1), sp_stride(3, 1), w_size(3, 1), w_stride(3, 1), x_size(3, 1), + std::vector sp_size(3, 1), sp_stride(3, 1), w_size(3, 1), w_stride(3, 1), x_size(3, 1), x_stride(3, 1), y_size(3, 1), y_stride(3, 1), hx_size(3, 1), hx_stride(3, 1); miopen::TensorDescriptor sp_desc, w_desc, x_desc, y_desc, hx_desc; @@ -2990,7 +2990,7 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( if(use_dropout) { - std::vector drop_size(2), drop_in_str(2, 1), drop_out_str(2, 1); + std::vector drop_size(2), drop_in_str(2, 1), drop_out_str(2, 1); drop_size[0] = batch_n; drop_size[1] = hy_h * bi; drop_in_str[0] = hy_stride; @@ -4139,7 +4139,7 @@ void RNNDescriptor::RNNBackwardData(Handle& handle, (packedDYSize + packedDXSize)); auto shifted_workSpace_size = workSpaceSize - (packedDYSize + packedDXSize); - std::vector in_n(seqLen); + std::vector in_n(seqLen); for(int i = 0; i < seqLen; i++) { @@ -4244,7 +4244,7 @@ void RNNDescriptor::RNNBackwardDataPackedTensors( auto rnn_data_type = dhxDesc.GetType(); - std::vector in_n; + std::vector in_n; int in_h = dxDesc[0].GetLengths()[1]; int hy_d = dhxDesc.GetLengths()[0]; int hy_n = dhxDesc.GetLengths()[1]; @@ -4345,7 +4345,7 @@ void RNNDescriptor::RNNBackwardDataPackedTensors( float alpha0, alpha1, beta_t; float alpha = 1, beta = 0; - std::vector sp_size(3, 1), sp_stride(3, 1), x_size(3, 1), x_stride(3, 1), y_size(3, 1), + std::vector sp_size(3, 1), sp_stride(3, 1), x_size(3, 1), x_stride(3, 1), y_size(3, 1), y_stride(3, 1), hx_size(3, 1), hx_stride(3, 1); miopen::TensorDescriptor sp_desc, x_desc, y_desc, hx_desc; @@ -4497,7 +4497,7 @@ void RNNDescriptor::RNNBackwardDataPackedTensors( if(use_dropout) { - std::vector drop_size(2), drop_in_str(2, 1); + std::vector drop_size(2), drop_in_str(2, 1); drop_size[0] = batch_n; drop_size[1] = hy_h * bi; drop_in_str[0] = hy_stride; @@ -5685,7 +5685,7 @@ void RNNDescriptor::RNNBackwardDataPackedTensors( // dinput if(inputMode == miopenRNNskip) { - const std::vector dx_size{1, batch_n, hy_h}; + const std::vector dx_size{1, batch_n, hy_h}; x_desc = miopen::TensorDescriptor(rnn_data_type, dx_size, x_stride); sp_desc = miopen::TensorDescriptor(rnn_data_type, dx_size, sp_stride); @@ -5828,7 +5828,7 @@ void RNNDescriptor::RNNBackwardWeights(Handle& handle, (packedXSize + WA_workSpace_bug)); auto shifted_workSpace_size = workSpaceSize - (packedXSize + WA_workSpace_bug); - std::vector in_n(seqLen); + std::vector in_n(seqLen); for(int i = 0; i < seqLen; i++) { @@ -5917,7 +5917,7 @@ void RNNDescriptor::RNNBackwardWeightsPackedTensors( } std::string network_config; - std::vector in_n; + std::vector in_n; int in_h = xDesc[0].GetLengths()[1]; int hy_d = hxDesc.GetLengths()[0]; int hy_n = hxDesc.GetLengths()[1]; @@ -6012,7 +6012,7 @@ void RNNDescriptor::RNNBackwardWeightsPackedTensors( float alpha0, alpha1, beta_t = 0; - std::vector sp_size(3, 1), sp_stride(3, 1), w_size(3, 1), w_stride(3, 1); + std::vector sp_size(3, 1), sp_stride(3, 1), w_size(3, 1), w_stride(3, 1); miopen::TensorDescriptor sp_desc, w_desc; sp_stride[0] = batch_n * hy_stride; @@ -6233,7 +6233,7 @@ void RNNDescriptor::RNNBackwardWeightsPackedTensors( else { // second dw bias equal to the first, so just copy reduction result - const std::vector dw_bias_strides{wei_stride, wei_stride, 1}; + const std::vector dw_bias_strides{wei_stride, wei_stride, 1}; const miopen::TensorDescriptor dw_desc{ rnn_data_t, {1, 1, wei_stride}, dw_bias_strides}; diff --git a/src/pooling.cpp b/src/pooling.cpp index a65cb3c0ab..2f10086b15 100644 --- a/src/pooling.cpp +++ b/src/pooling.cpp @@ -157,11 +157,11 @@ PoolingDescriptor::GetForwardOutputDim(const TensorDescriptor& xDesc) const void PoolingDescriptor::GetForwardOutputDimNd(const TensorDescriptor& xDesc, int dims, - int* tensorDimArr) const + size_t* tensorDimArr) const { assert(xDesc.GetLengths().size() == dims && xDesc.GetLengths().size() <= 5 && xDesc.GetLengths().size() >= 4); // currently only support 2D/3D pooling - std::vector out_dim; + std::vector out_dim; auto input_dim = xDesc.GetLengths(); auto strs = GetStrides(); auto padd = GetPads(); @@ -175,8 +175,8 @@ void PoolingDescriptor::GetForwardOutputDimNd(const TensorDescriptor& xDesc, assert(std::all_of(padd.begin(), padd.end(), [](int s) { return s >= 0; })); auto in_itr = input_dim.begin(); - out_dim.push_back(int(*(in_itr++))); // n - out_dim.push_back(int(*(in_itr++))); // c + out_dim.push_back(*(in_itr++)); // n + out_dim.push_back(*(in_itr++)); // c auto str_itr = strs.begin(); auto pad_itr = padd.begin(); @@ -215,12 +215,12 @@ void PoolingDescriptor::GetForwardOutputDimNd(const TensorDescriptor& xDesc, TensorDescriptor PoolingDescriptor::GetForwardOutputTensor(const TensorDescriptor& xDesc) const { - std::vector out_dim(xDesc.GetNumDims()); + std::vector out_dim(xDesc.GetNumDims()); GetForwardOutputDimNd(xDesc, xDesc.GetNumDims(), out_dim.data()); const std::string default_layout = tensor_layout_get_default(xDesc.GetNumDims()); const std::string in_layout = xDesc.GetLayout(default_layout); - std::vector out_strides; + std::vector out_strides; tensor_layout_to_strides(out_dim, default_layout, in_layout, out_strides); return {xDesc.GetType(), out_dim, out_strides}; diff --git a/src/pooling_api.cpp b/src/pooling_api.cpp index 8c475b57d5..9fed494dfe 100644 --- a/src/pooling_api.cpp +++ b/src/pooling_api.cpp @@ -267,8 +267,9 @@ miopenGetPoolingNdForwardOutputDim(const miopenPoolingDescriptor_t poolDesc, MIOPEN_LOG_FUNCTION(poolDesc, tensorDesc, dims); return miopen::try_([&] { - miopen::deref(poolDesc).GetForwardOutputDimNd( - miopen::deref(tensorDesc), dims, tensorDimArr); + std::vector tmp(dims); + miopen::deref(poolDesc).GetForwardOutputDimNd(miopen::deref(tensorDesc), dims, tmp.data()); + std::copy_n(tmp.data(), dims, tensorDimArr); }); } diff --git a/src/rnn.cpp b/src/rnn.cpp index 5d23c79018..f46c65c269 100644 --- a/src/rnn.cpp +++ b/src/rnn.cpp @@ -213,7 +213,7 @@ size_t RNNDescriptor::paramsOffsetCalculation(const TensorDescriptor& xDesc, return layerJump; } -std::vector RNNDescriptor::pTensorLengthsCalculation(const TensorDescriptor& xDesc, +std::vector RNNDescriptor::pTensorLengthsCalculation(const TensorDescriptor& xDesc, const int layer, const int paramID) const { @@ -223,7 +223,7 @@ std::vector RNNDescriptor::pTensorLengthsCalculation(const TensorDescriptor inputVectorLen = 0; } - std::vector tdim(2, 0); + std::vector tdim(2, 0); if(dirMode != 0u) { @@ -686,7 +686,7 @@ void RNNDescriptor::GetParamsDescriptor(Handle& /* handle */, // Create weight super tensor descriptor int bi = (dirMode == miopenRNNbidirection) ? 2 : 1; - std::vector weight_lens(2, 0); + std::vector weight_lens(2, 0); weight_lens[0] = inputVectorLen + ((nLayers - 1) * (bi + 1) + 1) * hsize; weight_lens[1] = bi * hsize * nHiddenTensorsPerLayer; if(biasMode == miopenRNNwithBias) @@ -842,11 +842,11 @@ void RNNDescriptor::SetLayerParam(const Handle& handle, auto poffset = paramsOffsetCalculation(xDesc, layer, paramID); // 2. Calculate the strides for the matrix - std::vector pstride(2, 1); + std::vector pstride(2, 1); pstride[1] = paramDesc.GetLengths()[0]; - std::vector intLens(paramDesc.GetLengths().begin(), paramDesc.GetLengths().end()); + std::vector intLens(paramDesc.GetLengths().begin(), paramDesc.GetLengths().end()); // 3. Construct descriptor to access into w auto paramSrc = miopen::TensorDescriptor(dataType, intLens, pstride); @@ -895,9 +895,9 @@ void RNNDescriptor::SetLayerBias(const Handle& handle, auto boffset = biasOffsetCalculation(xDesc, layer, biasID) + poffset; // 2. Calculate the strides for the matrix - std::vector bstride(1, 1); + std::vector bstride(1, 1); - std::vector intLens(biasDesc.GetLengths().begin(), biasDesc.GetLengths().end()); + std::vector intLens(biasDesc.GetLengths().begin(), biasDesc.GetLengths().end()); // 3. Construct descriptor to access into w auto biasSrc = miopen::TensorDescriptor(dataType, intLens, bstride); @@ -1064,7 +1064,7 @@ SeqTensorDescriptor RNNDescriptor::makeSeqTensorDescriptor(miopenDataType_t t, const int* lensPerSeq, const void* padding_marker_ptr) { - const std::vector lens = {batchSize, maxSeqLength, vectorSize}; + const std::vector lens = {batchSize, maxSeqLength, vectorSize}; const auto [dim_order, padded_sequences] = convertRNNBaseLayout(layout); @@ -1080,7 +1080,7 @@ SeqTensorDescriptor RNNDescriptor::makeSeqTensorDescriptor(miopenDataType_t t, return {t, dim_order, lens, - std::vector(lensPerSeq, lensPerSeq + batchSize), + std::vector(lensPerSeq, lensPerSeq + batchSize), padding_marker_in, true, padded_sequences}; diff --git a/src/rnn/rnn_util.cpp b/src/rnn/rnn_util.cpp index 0c1b0f7cf9..09fce4e45e 100644 --- a/src/rnn/rnn_util.cpp +++ b/src/rnn/rnn_util.cpp @@ -37,7 +37,7 @@ int getReductionAlgo() { return env::value_or(MIOPEN_RNNWRW_REDUCTION, 1); } void RNNTensorPaddingConverter::ConvertTensorData(const Handle& handle, const TensorDescriptor& padded_tensor_desc, - std::vector& bsize_per_time, + std::vector& bsize_per_time, ConstData_t src, Data_t dst, bool is_src_padded) diff --git a/src/tensor.cpp b/src/tensor.cpp index 2d1b48faf9..e988087ace 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -88,28 +88,20 @@ std::optional GetDefaultLayout(unsigned num_dims) } template -bool CheckLengths(const std::vector& lens, T maxval = 0) +bool CheckLengthsValues(const std::vector& lens) { - if(lens.empty()) - return false; - if(!std::all_of(lens.cbegin(), lens.cend(), [](T x) { return x > 0; })) - return false; - if(maxval) - { - if(!std::all_of(lens.cbegin(), lens.cend(), [maxval](T x) { return x <= maxval; })) - return false; - } - return true; + return std::all_of(lens.cbegin(), lens.cend(), [](T x) { + return x > 0 && + static_cast(x) <= static_cast(std::numeric_limits::max()); + }); } -std::vector ConvertLengthsOrThrow(const std::vector& lens_in, - [[maybe_unused]] const std::string& err_msg) +std::vector ConvertLengthsOrThrow(const std::vector& lens) { - if(!CheckLengths(lens_in)) - MIOPEN_THROW(miopenStatusBadParm, err_msg); + if(!CheckLengthsValues(lens)) + MIOPEN_THROW(miopenStatusBadParm, std::string{"Length/Stride values must be > 0"}); - std::vector lens(lens_in.cbegin(), lens_in.cend()); - return lens; + return std::vector(lens.cbegin(), lens.cend()); } std::string GetStorageLayout4D5D(unsigned num_dims, bool is_CHWNc = false) @@ -253,20 +245,6 @@ TensorDescriptor::TensorDescriptor(miopenDataType_t t) : packed(true), type(t) { // The delegation constructor should be placed above the target constructor in the // code for better dependency tracking -TensorDescriptor::TensorDescriptor(miopenDataType_t t, const std::initializer_list& lens_in) - : TensorDescriptor(t, std::vector(lens_in)) -{ -} - -TensorDescriptor::TensorDescriptor(miopenDataType_t t, const std::vector& lens_in) - : TensorDescriptor(t, - GetDefaultLayout(lens_in.size()), - ConvertLengthsOrThrow(lens_in, "Lengths must be > 0"), - {}, - false) -{ -} - TensorDescriptor::TensorDescriptor(miopenDataType_t t, const std::initializer_list& lens_in) : TensorDescriptor(t, std::vector(lens_in)) @@ -283,13 +261,6 @@ TensorDescriptor::TensorDescriptor(miopenDataType_t t, std::vector& { } -TensorDescriptor::TensorDescriptor(miopenDataType_t t, - miopenTensorLayout_t layout_in, - const std::vector& lens_in) - : TensorDescriptor(t, layout_in, ConvertLengthsOrThrow(lens_in, "Lengths must be > 0")) -{ -} - TensorDescriptor::TensorDescriptor(miopenDataType_t t, miopenTensorLayout_t layout_in, const std::initializer_list& lens_in) @@ -311,15 +282,6 @@ TensorDescriptor::TensorDescriptor(miopenDataType_t t, { } -TensorDescriptor::TensorDescriptor(miopenDataType_t t, - const std::vector& lens_in, - const std::vector& strides_in) - : TensorDescriptor(t, - ConvertLengthsOrThrow(lens_in, "Lengths must be > 0"), - ConvertLengthsOrThrow(strides_in, "Strides must be > 0")) -{ -} - TensorDescriptor::TensorDescriptor(miopenDataType_t t, const std::initializer_list& lens_in, const std::initializer_list& strides_in) @@ -395,7 +357,7 @@ void TensorDescriptor::CheckArgsAndInit(bool use_strides) if(tensorLayout && !IsLayoutSupported(tensorLayout.value(), lens.size())) MIOPEN_THROW(miopenStatusBadParm, "Unsupported layout"); - if(!CheckLengths(lens, static_cast(std::numeric_limits::max()))) + if(!CheckLengthsValues(lens)) MIOPEN_THROW(miopenStatusBadParm, "Lengths must be > 0 and <= INT64_MAX"); vector_length = GetVectorLengthForLayout(tensorLayout); @@ -405,7 +367,7 @@ void TensorDescriptor::CheckArgsAndInit(bool use_strides) if(lens.size() != strides.size()) MIOPEN_THROW(miopenStatusBadParm, "Lengths and strides dimensions must be equal"); - if(!CheckLengths(strides, static_cast(std::numeric_limits::max()))) + if(!CheckLengthsValues(strides)) MIOPEN_THROW(miopenStatusBadParm, "Strides must be > 0 and <= INT64_MAX"); packed = (this->GetElementSize() == this->GetElementSpace()); @@ -438,7 +400,7 @@ TensorDescriptor TensorDescriptor::MakeDescriptor(miopenDataType_t t, const int* if(plens == nullptr || size <= 0) MIOPEN_THROW(miopenStatusInvalidValue); - return {t, std::vector(plens, plens + size)}; + return {t, ConvertLengthsOrThrow(std::vector(plens, plens + size))}; } TensorDescriptor @@ -458,7 +420,7 @@ TensorDescriptor TensorDescriptor::MakeDescriptor(miopenDataType_t t, if(plens == nullptr || size <= 0) MIOPEN_THROW(miopenStatusInvalidValue); - return {t, layout, std::vector(plens, plens + size)}; + return {t, layout, ConvertLengthsOrThrow(std::vector(plens, plens + size))}; } TensorDescriptor TensorDescriptor::MakeDescriptor(miopenDataType_t t, @@ -480,7 +442,9 @@ TensorDescriptor TensorDescriptor::MakeDescriptor(miopenDataType_t t, if(plens == nullptr || pstrides == nullptr || size <= 0) MIOPEN_THROW(miopenStatusInvalidValue); - return {t, std::vector(plens, plens + size), std::vector(pstrides, pstrides + size)}; + return {t, + ConvertLengthsOrThrow(std::vector(plens, plens + size)), + ConvertLengthsOrThrow(std::vector(pstrides, pstrides + size))}; } TensorDescriptor TensorDescriptor::MakeDescriptor(miopenDataType_t t, diff --git a/src/tensor_api.cpp b/src/tensor_api.cpp index 2f3edda564..ff1b014d9d 100644 --- a/src/tensor_api.cpp +++ b/src/tensor_api.cpp @@ -56,8 +56,9 @@ extern "C" miopenStatus_t miopenSet4dTensorDescriptor( MIOPEN_LOG_FUNCTION(tensorDesc, dataType, n, c, h, w); return miopen::try_([&] { - std::initializer_list lens = {n, c, h, w}; - miopen::deref(tensorDesc) = miopen::TensorDescriptor(dataType, lens); + const std::array lens = {n, c, h, w}; + miopen::deref(tensorDesc) = + miopen::TensorDescriptor::MakeDescriptor(dataType, lens.data(), lens.size()); }); } @@ -92,9 +93,11 @@ extern "C" miopenStatus_t miopenSet4dTensorDescriptorEx(miopenTensorDescriptor_t { MIOPEN_LOG_FUNCTION(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); return miopen::try_([&] { - std::initializer_list lens = {n, c, h, w}; - std::initializer_list strides = {nStride, cStride, hStride, wStride}; - miopen::deref(tensorDesc) = miopen::TensorDescriptor(dataType, lens, strides); + static constexpr int size = 4; + const std::array lens = {n, c, h, w}; + const std::array strides = {nStride, cStride, hStride, wStride}; + miopen::deref(tensorDesc) = + miopen::TensorDescriptor::MakeDescriptor(dataType, lens.data(), strides.data(), size); }); } @@ -342,15 +345,15 @@ static void LogCmdTensorOp(miopenTensorOp_t tensorOp, if(!is_set && !is_scale) { // clang-format off - ss << " -A " << std::to_string(*static_cast(alpha)) - << " -B " << std::to_string(*static_cast(alpha2)) + ss << " -A " << std::to_string(*static_cast(alpha)) + << " -B " << std::to_string(*static_cast(alpha2)) << " -G " << std::to_string(*static_cast(beta)); // clang-format on } // clang-format off - ss << " -n " << miopen::deref(aDesc).GetLengths()[0] + ss << " -n " << miopen::deref(aDesc).GetLengths()[0] << " -c " << miopen::deref(aDesc).GetLengths()[1] - << " -H " << miopen::deref(aDesc).GetLengths()[2] + << " -H " << miopen::deref(aDesc).GetLengths()[2] << " -W " << miopen::deref(aDesc).GetLengths()[3]; // clag-format on if(is_set)