Skip to content

Commit

Permalink
Enabled the flexible split_k support in ck-based 3D convolution backw…
Browse files Browse the repository at this point in the history
…ard weight solver. (#3425)

* move ck op define to header

* add split_k value to 3d wrw solver

* deprecate f16f8f16 conv solvers

* change add split_k condition for 3d bilinear and scale kernels

* put check split_k condition to a common fucntion

* clang format fix

* add use ck kernel macro to the new function

* minor fix

* add ++id to the places where solver were removed

* minor fix

* mi300 db updates for 3d + ck

* Fix failing tests by removing invalid entries from existing DBs

---------

Co-authored-by: Christopher Erb <[email protected]>
Co-authored-by: Brian Harrison <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent a325bdb commit 75078dc
Show file tree
Hide file tree
Showing 20 changed files with 159 additions and 2,006 deletions.
3 changes: 0 additions & 3 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,6 @@ set( MIOpen_Source
solver/conv/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp
solver/conv/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp
solver/conv/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp
solver/conv/conv_hip_implicit_gemm_f16f8f16_fwd_xdlops.cpp
solver/conv/conv_hip_implicit_gemm_f16f8f16_bwd_xdlops.cpp
solver/conv/conv_hip_implicit_gemm_f16f8f16_wrw_xdlops.cpp
solver/conv/conv_hip_implicit_gemm_nonxdlops_common.cpp
solver/conv/conv_hip_implicit_gemm_wrw_v4r4.cpp
solver/conv/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp
Expand Down
220 changes: 1 addition & 219 deletions src/include/miopen/conv/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4576,6 +4576,7 @@ struct PerformanceConfigHipImplicitGemm3DGroupWrwXdlops
: PerfConfigBaseCK<PerformanceConfigHipImplicitGemm3DGroupWrwXdlops>
{
int index;
int split_k;
std::string kernel_id;
std::vector<std::string> valid_kernels;
PerformanceConfigHipImplicitGemm3DGroupWrwXdlops(int idx, std::string kernl_id)
Expand Down Expand Up @@ -4911,225 +4912,6 @@ struct ConvHipImplicitGemmGroupWrwXdlops final
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
};

struct PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops
: PerfConfigBaseCK<PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops>
{
int index = 0;
std::string kernel_id = "";
std::vector<std::string> valid_kernels;

PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id)
{
}

PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops() = default;

explicit PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops(bool)
: PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops(0, "")
{
}
MIOPEN_INTERNALS_EXPORT void HeuristicInit(const miopen::conv::ProblemDescription&);
MIOPEN_INTERNALS_EXPORT bool SetNextValue(const miopen::conv::ProblemDescription&);
MIOPEN_INTERNALS_EXPORT bool IsValidValue() const;
bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const
{
return IsValid(problem);
}
MIOPEN_INTERNALS_EXPORT bool IsValid(const miopen::conv::ProblemDescription&) const;
MIOPEN_INTERNALS_EXPORT bool
operator==(const PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops& other) const;

private:
template <typename DataType, typename ComputeType>
void Init(const miopen::conv::ProblemDescription&);
template <typename DataType, typename ComputeType>
bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const;
};

struct ConvHipImplicitGemmF16F8F16FwdXdlops final
: ConvTunableSolver<PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops>
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<ConvHipImplicitGemmF16F8F16FwdXdlops>();
}

MIOPEN_INTERNALS_EXPORT PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops
GetDefaultPerformanceConfig(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
MIOPEN_INTERNALS_EXPORT bool IsValidPerformanceConfig(
const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops&) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops
Search(const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const AnyInvokeParams& invoke_ctx) const override;
MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
bool IsDynamic() const override { return true; }
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const PerformanceConfigHipImplicitGemmF16F8F16FwdXdlops&) const override;
/// \ref igemm_get_wti_magic_number
float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override
{
return 0.02f;
};

private:
template <typename DataType, typename ComputeType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
};

struct PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops
: PerfConfigBaseCK<PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops>
{
int index;
std::string kernel_id;
std::vector<std::string> valid_kernels;
PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id)
{
}
PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops()
: PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops(0, "")
{
}
PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops(bool)
: PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops(0, "")
{
}
MIOPEN_INTERNALS_EXPORT void HeuristicInit(const miopen::conv::ProblemDescription&);
MIOPEN_INTERNALS_EXPORT bool SetNextValue(const miopen::conv::ProblemDescription&);
MIOPEN_INTERNALS_EXPORT bool IsValidValue() const;
bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const
{
return IsValid(problem);
}
MIOPEN_INTERNALS_EXPORT bool IsValid(const miopen::conv::ProblemDescription&) const;
MIOPEN_INTERNALS_EXPORT bool
operator==(const PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops& other) const;

private:
template <typename DataType, typename OutComputeType, typename WeiComputeType>
void Init(const miopen::conv::ProblemDescription&);
template <typename DataType, typename OutComputeType, typename WeiComputeType>
bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const;
};

struct ConvHipImplicitGemmF16F8F16BwdXdlops final
: ConvTunableSolver<PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops>
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<ConvHipImplicitGemmF16F8F16BwdXdlops>();
}

MIOPEN_INTERNALS_EXPORT PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops
GetDefaultPerformanceConfig(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
MIOPEN_INTERNALS_EXPORT bool IsValidPerformanceConfig(
const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops&) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops
Search(const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const AnyInvokeParams& invoke_ctx) const override;
MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
bool IsDynamic() const override { return true; }
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const PerformanceConfigHipImplicitGemmF16F8F16BwdXdlops&) const override;
/// \ref igemm_get_wti_magic_number
float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override
{
return 0.02f;
};

private:
template <typename DataType, typename OutComputeType, typename WeiComputeType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
};

struct PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops
: PerfConfigBaseCK<PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops>
{
int index;
std::string kernel_id;
std::vector<std::string> valid_kernels;
PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id)
{
}
PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops()
: PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops(0, "")
{
}
PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops(bool)
: PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops(0, "")
{
}
MIOPEN_INTERNALS_EXPORT void HeuristicInit(const miopen::conv::ProblemDescription&);
MIOPEN_INTERNALS_EXPORT bool SetNextValue(const miopen::conv::ProblemDescription&);
MIOPEN_INTERNALS_EXPORT bool IsValidValue() const;
bool IsValid(const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const
{
return IsValid(problem);
}
MIOPEN_INTERNALS_EXPORT bool IsValid(const miopen::conv::ProblemDescription&) const;
MIOPEN_INTERNALS_EXPORT bool
operator==(const PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops& other) const;

private:
template <typename DataType, typename OutComputeType, typename InComputeType>
void Init(const miopen::conv::ProblemDescription&);
template <typename DataType, typename OutComputeType, typename InComputeType>
bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const;
};

struct ConvHipImplicitGemmF16F8F16WrwXdlops final
: ConvTunableSolver<PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops>
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<ConvHipImplicitGemmF16F8F16WrwXdlops>();
}

MIOPEN_INTERNALS_EXPORT PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops
GetDefaultPerformanceConfig(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
MIOPEN_INTERNALS_EXPORT bool IsValidPerformanceConfig(
const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops&) const override;
MIOPEN_INTERNALS_EXPORT PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops
Search(const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const AnyInvokeParams& invoke_ctx) const override;
MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
bool IsDynamic() const override { return true; }
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const PerformanceConfigHipImplicitGemmF16F8F16WrwXdlops&) const override;
/// \ref igemm_get_wti_magic_number
float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override
{
return 0.02f;
};

private:
template <typename DataType, typename OutComputeType, typename InComputeType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
};

} // namespace conv
} // namespace solver
} // namespace miopen
107 changes: 94 additions & 13 deletions src/include/miopen/solver/implicitgemm_ck_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
#include <ck/utility/data_type.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp>
#endif // MIOPEN_USE_COMPOSABLEKERNEL

namespace miopen {
Expand All @@ -62,6 +64,72 @@ using DeviceOpGWrw = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<
template <typename DataType>
using DeviceOpGWrwPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceOpGWrw<DataType>>;

using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;

template <typename DataType>
using DeviceOpGBwdWeightDefault =
ck::tensor_operation::device::DeviceGroupedConvBwdWeight<3,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>;

template <typename DataType>
using DeviceOpGBwdWeightBilinear =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD<3,
InLayout,
WeiLayout,
OutLayout,
ck::Tuple<WeiLayout>,
DataType,
DataType,
DataType,
ck::Tuple<DataType>,
PassThrough,
Bilinear,
PassThrough>;

template <typename DataType>
using DeviceOpGBwdWeightScale =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD<3,
InLayout,
WeiLayout,
OutLayout,
ck::Tuple<>,
DataType,
DataType,
DataType,
ck::Tuple<>,
PassThrough,
Scale,
PassThrough>;

template <typename DataType>
using DeviceOpGBwdWeightDefaultPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOpGBwdWeightDefault<DataType>>;

template <typename DataType>
using DeviceOpGBwdWeightBilinearPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOpGBwdWeightBilinear<DataType>>;

template <typename DataType>
using DeviceOpGBwdWeightScalePtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOpGBwdWeightScale<DataType>>;

} // namespace conv
#endif

Expand Down Expand Up @@ -125,6 +193,29 @@ std::vector<std::string> FillValidKernelsIDs(const ProblemDescriptionType& probl
return valid_kernels;
}

#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
template <typename DeviceOpType>
inline constexpr bool IsSplitKNeeded()
{
return std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<ck::half_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<float>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<int8_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<ck::bhalf_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightDefaultPtrs<ck::half_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightDefaultPtrs<float>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightDefaultPtrs<int8_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightDefaultPtrs<ck::bhalf_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightBilinearPtrs<ck::half_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightBilinearPtrs<float>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightBilinearPtrs<int8_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightBilinearPtrs<ck::bhalf_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightScalePtrs<ck::half_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightScalePtrs<float>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightScalePtrs<int8_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGBwdWeightScalePtrs<ck::bhalf_t>>;
}
#endif

template <typename DeviceOpType,
typename CKArgsType,
typename ProblemDescriptionType = miopen::conv::ProblemDescription,
Expand All @@ -135,11 +226,7 @@ bool IsCKArgsSupported(const ProblemDescriptionType& problem, const std::string&
if(!kernel_id.empty())
{
auto conv_ptrs = DeviceOpType::GetInstances();
if constexpr(std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<ck::half_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<float>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<int8_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<ck::bhalf_t>> ||
CheckSplitK)
if constexpr(IsSplitKNeeded<DeviceOpType>() || CheckSplitK)
{
auto pos = kernel_id.find_last_of('+');
if(pos == std::string::npos)
Expand Down Expand Up @@ -789,10 +876,7 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx,

auto invoker_ptr = sh_conv_ptr->MakeInvokerPointer();
std::unique_ptr<ck::tensor_operation::device::BaseArgument> argument_ptr;
if constexpr(std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<ck::half_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<float>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<int8_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<ck::bhalf_t>>)
if constexpr(IsSplitKNeeded<DeviceOpType>())
{
if(split_k.has_value())
{
Expand Down Expand Up @@ -882,10 +966,7 @@ ConvSolution InitInvokerFactoryNHWC(const ExecutionContext&,
const Handle& handle, const AnyInvokeParams& primitive_parameters) {
const auto& data_ctx = primitive_parameters.CastTo<CastType>();
std::unique_ptr<ck::tensor_operation::device::BaseArgument> argument_ptr;
if constexpr(std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<ck::half_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<float>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<int8_t>> ||
std::is_same_v<DeviceOpType, conv::DeviceOpGWrwPtrs<ck::bhalf_t>>)
if constexpr(IsSplitKNeeded<DeviceOpType>())
{
argument_ptr = ck_args.MakeArgPtr(sh_conv_ptr,
data_ctx.tensors,
Expand Down
Binary file modified src/kernels/gfx90a68.HIP.fdb.txt.bz2
Binary file not shown.
Binary file removed src/kernels/gfx90a68.db.bz2
Binary file not shown.
Binary file added src/kernels/gfx90a68.db.txt.bz2
Binary file not shown.
4 changes: 2 additions & 2 deletions src/kernels/gfx942.kdb.bz2
Git LFS file not shown
Binary file modified src/kernels/gfx942130.HIP.fdb.txt.bz2
Binary file not shown.
Binary file modified src/kernels/gfx942130.db.txt.bz2
Binary file not shown.
Binary file modified src/kernels/gfx942e4.HIP.fdb.txt.bz2
Binary file not shown.
Binary file modified src/kernels/gfx942e4.db.txt.bz2
Binary file not shown.
Loading

0 comments on commit 75078dc

Please sign in to comment.