Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Enable f4_e2m1 jit gemm #2442

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ inline size_t data_type_size(data_type_t data_type) {
inline size_t elements_to_bytes(data_type_t data_type, size_t count) {
using namespace data_type;
switch ((int)data_type) {
case f4_e2m1:
case s4:
case u4: return (count + 1) >> 1;
default: return data_type_size(data_type) * count;
Expand All @@ -127,6 +128,7 @@ inline size_t elements_to_bytes(data_type_t data_type, size_t count) {
inline size_t bytes_to_elements(data_type_t data_type, size_t bytes) {
using namespace data_type;
switch ((int)data_type) {
case f4_e2m1:
case s4:
case u4: return bytes * 2;
default: return utils::div_up(bytes, data_type_size(data_type));
Expand Down
2 changes: 2 additions & 0 deletions src/gpu/intel/compute/kernel_arg_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ enum class kernel_arg_kind_t {
enum class scalar_type_t {
undef,
_char,
_f4_e2m1,
_hfloat8,
_bfloat8,
_bfloat16,
Expand Down Expand Up @@ -75,6 +76,7 @@ inline std::string to_string(scalar_type_t type) {
switch (type) {
CASE(undef);
CASE(_char);
CASE(_f4_e2m1);
CASE(_hfloat8);
CASE(_bfloat8);
CASE(_bfloat16);
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/gen_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ struct gen_gemm_t : public gpu_gemm_t {
VERBOSE_INCONSISTENT_DT, "a", "acc");
} else if (!wei_decomp_) {
VDISPATCH_GEMM(utils::one_of(d->a_type(), f64, f32, f16,
f8_e5m2, f8_e4m3),
f8_e5m2, f8_e4m3, f4_e2m1),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_GEMM(d->b_type() == d->a_type()
|| (utils::one_of(d->a_type(), f8_e5m2, f8_e4m3)
Expand Down
7 changes: 6 additions & 1 deletion src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ compute::scalar_type_t gen_gemm_kernel_desc_t::scalar_type() const {
case Type::u32: return compute::scalar_type_t::_uint;
case Type::s64: return compute::scalar_type_t::_long;
case Type::u64: return compute::scalar_type_t::_ulong;
case Type::f4_e2m1: return compute::scalar_type_t::_f4_e2m1;
case Type::bf8: return compute::scalar_type_t::_bfloat8;
case Type::hf8: return compute::scalar_type_t::_hfloat8;
case Type::bf16: return compute::scalar_type_t::_bfloat16;
Expand Down Expand Up @@ -79,7 +80,7 @@ status_t gen_gemm_kernel_desc_t::finalize(const char *tags) {
entry_->restrictions.alignment[2]));
}

problem_.CO.setAlignment(problem_.Tco.size());
problem_.CO.setAlignment(problem_.Tco.paddedSize());

// Parse strategy string.
strategy_ = GEMMStrategy(hw_, stepping_);
Expand Down Expand Up @@ -587,6 +588,10 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
= match_params.back().selector.precisions[1];
}
}
add_mode_matches(true, [](Type dt) -> const char * {
if (dt.isFP4()) return "[EH]";
return nullptr;
});

EvaluateParams eval_params;

Expand Down
1 change: 1 addition & 0 deletions src/gpu/intel/jit/gemm/gen_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ static inline Type convert_dnnl_to_kernel_type(data_type_t type) {
case data_type::bf16: return Type::bf16;
case data_type::f8_e5m2: return Type::bf8;
case data_type::f8_e4m3: return Type::hf8;
case data_type::f4_e2m1: return Type::f4_e2m1;
case data_type::s32: return Type::s32;
case data_type::u8: return Type::u8;
case data_type::s8: return Type::s8;
Expand Down
6 changes: 3 additions & 3 deletions src/gpu/intel/jit/gemm/generator/pieces/c_update.cxx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -168,7 +168,7 @@ bool BLASKernelGenerator<hw>::gemmAccessC(COperation op, const GEMMProblem &prob
block2DCRemainder |= block2DCFull;
block2DCRemainder &= !strategy.C.atomic;
block2DCFull &= !strategy.C.atomic;
bool altCRemainder = strategy.altCRemainder && !strategy.C.padded && (remainderM || remainderN || problem.gemmt());
bool altCRemainder = !problem.Tc_ext.is4Bit() && strategy.altCRemainder && !strategy.C.padded && (remainderM || remainderN || problem.gemmt());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check should be moved to GEMMStrategy::preflight, something like

altCRemainder &= (problem.Tc_ext.bits() >= 8);

bool stdCRemainder = !(altCRemainder && (strategy.remHandling[LoopM] == RemainderHandling::KnownRemainder)
&& (strategy.remHandling[LoopN] == RemainderHandling::KnownRemainder));

Expand Down Expand Up @@ -429,7 +429,7 @@ bool BLASKernelGenerator<hw>::gemmAccessC(COperation op, const GEMMProblem &prob

for (int q = 0; q < state.C_count; q++) {
bool checkAlign = (problem.C.alignment % align) != 0;
bool checkWidth = (q == 0 && Tc_ext.size() < 4 && op != COperation::Load);
bool checkWidth = (q == 0 && Tc_ext.paddedSize() < 4 && op != COperation::Load);
auto &labelNonBlock2DRem = altCRemainder ? labelAltCRemainder : labelStdCRemainder;

if (checkAlign) {
Expand Down
Loading
Loading