Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Enable f4_e2m1 jit gemm #2442

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

kealan-barbieri
Copy link
Contributor

Description

  • Enable f4_e2m1 in jit::gemm.
  • Enable dumping copy_plan during creation when DNNL_DEV_MODE enabled.

Partially covers MFDNN-124711

Checklist

General

  • Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  • Have you formatted the code using clang-format?

@kealan-barbieri kealan-barbieri added the platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel label Jan 17, 2025
@kealan-barbieri kealan-barbieri requested review from a team as code owners January 17, 2025 18:23
@github-actions github-actions bot added the component:tests Codeowner: @oneapi-src/onednn-arch label Jan 17, 2025
// cmp (ge) t0:w, y:w, 31
// shr y:uw, 10
// csel (ge) y:fp16, 0x7bff, y:fp16, t0:fp16
// csel (ze) y:fp16, NaN:fp16, y:fp16, t1:fp16
Copy link
Contributor

@petercad petercad Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: there's a much faster sequence, though this is OK for now:

shl   t0:ud  x:ub   24
add   t0:ud  t0:ud  1
mov   y:hf   t0:f

@kealan-barbieri kealan-barbieri force-pushed the kealanba/f4_e2m1_gemm branch 3 times, most recently from 4a4aabf to b823c25 Compare January 18, 2025 00:20
@kealan-barbieri
Copy link
Contributor Author

make test
disable test_device_cpu
disable build_cpu_runtime_omp
disable build_cpu_runtime_sycl
disable build_cpu_runtime_tbb
disable benchdnn_all
enable benchdnn_matmul
enable benchdnn_ip

@echeresh
Copy link
Contributor

@kealan-barbieri Do we have f4_e2m1 coverage in benchdnn input files? If missing, can you please add some?

In the long term #2434 should help with that.

@kealan-barbieri
Copy link
Contributor Author

@echeresh there is existing coverage: https://github.com/oneapi-src/oneDNN/blob/main/tests/benchdnn/inputs/matmul/test_matmul_fp4

@kealan-barbieri kealan-barbieri force-pushed the kealanba/f4_e2m1_gemm branch 2 times, most recently from 91cfc9b to 8a4fc5e Compare January 30, 2025 22:04
@github-actions github-actions bot removed the component:tests Codeowner: @oneapi-src/onednn-arch label Jan 30, 2025
@kealan-barbieri kealan-barbieri force-pushed the kealanba/f4_e2m1_gemm branch 2 times, most recently from 6cf6814 to 895871f Compare January 31, 2025 23:03
@kealan-barbieri
Copy link
Contributor Author

make test
disable test_device_cpu
disable build_cpu_runtime_omp
disable build_cpu_runtime_sycl
disable build_cpu_runtime_tbb
disable benchdnn_all
enable benchdnn_matmul
enable benchdnn_ip

@@ -168,7 +168,7 @@ bool BLASKernelGenerator<hw>::gemmAccessC(COperation op, const GEMMProblem &prob
block2DCRemainder |= block2DCFull;
block2DCRemainder &= !strategy.C.atomic;
block2DCFull &= !strategy.C.atomic;
bool altCRemainder = strategy.altCRemainder && !strategy.C.padded && (remainderM || remainderN || problem.gemmt());
bool altCRemainder = !problem.Tc_ext.is4Bit() && strategy.altCRemainder && !strategy.C.padded && (remainderM || remainderN || problem.gemmt());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check should be moved to GEMMStrategy::preflight, something like

altCRemainder &= (problem.Tc_ext.bits() >= 8);

case MatrixLayout::T: x = j0; y = i0; break;
}
emad(1, offsetC, offsetC, x, xstride, strategy, state);
if(xstride == 0){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a slightly more future-proof check:

Suggested change
if(xstride == 0){
if (Tc_ext.is4Bit()) {

@@ -1439,6 +1444,7 @@ bool BLASKernelGenerator<hw>::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStr
// Get register layouts for A/B/C.
if (!getRegLayout(Ta_load, state.A_layout, unrollM, strategy.ka_load, remM_A, remK_A, false, AvoidFragment, 0, 0, problem.A, strategy.A)) return false;
if (!getRegLayout(Tb_load, state.B_layout, strategy.kb_load, unrollN, remK_B, remN_B, false, AvoidFragment, 0, 0, problem.B, strategy.B)) return false;
if (cColMajor && strategy.systolic && (!state.A_layout[0].colMajor || !state.B_layout[0].colMajor)) stub();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition should be checked already inside outerProductSystolic -- is it being missed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not being missed inside outerProductSystolic, I found it useful to have an assert nearby the source of the error. If you prefer I'll drop it.

{
if (i.src0.neg || i.sat || i.hasCMod()) stub("Unsupported modifier");
int simd = i.simd;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to provide a comment here with assembly pseudo-code of the emulation sequence (similar to the other complicated conversion sequences).

@@ -1030,6 +1030,14 @@ auto _CATALOG_ = kcatalog::toFlatCatalog({
{{'F', "gemm", {"[FO]", "O", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "#I"}, "aB32 aS32 aB wg 1x4x8 kr cab4 ks32 af hi pt bk0 grf256 sys l4 dm sr br", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {524288, 131072, 16777216}, {524288, 131072, 16777216}, {32, 8, 32}, {1, 4, 8}, 1, (WGType) 1, 261, 8192, 8192, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.0714e+06, 814336, -806.186, 79098.3, 0, 0, 1.57352, 2.83206, 2.69271, 7.17928, 0.0520116, 0.0520116, 0, 0.665122, 1.00198, 1.00049, 9.78834e-15}}},
{{'F', "gemm", {"[FO]", "O", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "IAB"}, "at32x2+m64@96 am64+m32@128 aB wg 4x2x4 kr xaf st hi pt sr br sb128 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 64}, {16, 16, 64}, {4, 2, 4}, 1, (WGType) 1, 445, 0, 8192, {4, 8, 4}, {true, true, true}}, {'E', 17, {1.12606e+06, -137008, -24369.3, 228077, 2.14139e+06, 1.80224e+06, 0.232306, 0.352856, 0.366483, 1.01866, 0.0096663, 0.00801051, 0.00318484, 0.842293, 1.37613, 0.921541, 4.97764e-12}}},
{{'F', "gemm", {"[FO]", "O", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "ABIps"}, "at32+m128@96 am64+m64@112 aB wg 4x8 xaf st hi pt sr br sb128 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {1048576, 655360, 16777216}, {1048576, 655360, 64}, {16, 40, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 0, 0, {4, 8, 4}, {true, true, true}}, {'E', 17, {884649, 705009, 0, 0, 6.08092e+06, 1.03629e+07, 0.496687, 0.364134, 0.840897, 1.28383, 0.00200956, 0.00200956, 0, 1, 2.31371, 1.15477, 1.64496e-12}}},
{{'F', "gemm", {"[EH]", "[EH]", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 16, -1}, {1, 1, 1}, "#I"}, "aB16+m32@32 aB32 aB wg 2x8 af vav li nmk pt sr br ca3 bk0 sys kv dm afb l4", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 255}, {1048576, 32768, 16777216}, {1048576, 32768, 32}, {64, 2, 32}, {2, 8, 1}, 1, (WGType) 1, 441, 24576, 0, {4, 2, 4}, {true, true, true}}, {'E', 17, {1.32162e+06, 161954, 0, 0, 2.32817e+06, 0, 0.71806, 4.15517, 0.786689, 1.40778, 0.0341164, 0.0131941, 0.0256486, 0.947188, 1.39057, 0.987284, 5.0128e-12}}},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these just be EES instead of [EH][EH]S? I think you have the automatic upconversion already set up in gemmAutoTypeConversions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants