-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] Enable f4_e2m1 jit gemm #2442
base: main
Are you sure you want to change the base?
Conversation
ebac6cb
to
44b218e
Compare
44b218e
to
021d757
Compare
// cmp (ge) t0:w, y:w, 31 | ||
// shr y:uw, 10 | ||
// csel (ge) y:fp16, 0x7bff, y:fp16, t0:fp16 | ||
// csel (ze) y:fp16, NaN:fp16, y:fp16, t1:fp16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side note: there's a much faster sequence, though this is OK for now:
shl t0:ud x:ub 24
add t0:ud t0:ud 1
mov y:hf t0:f
4a4aabf
to
b823c25
Compare
make test |
@kealan-barbieri Do we have f4_e2m1 coverage in benchdnn input files? If missing, can you please add some? In the long term #2434 should help with that. |
@echeresh there is existing coverage: https://github.com/oneapi-src/oneDNN/blob/main/tests/benchdnn/inputs/matmul/test_matmul_fp4 |
91cfc9b
to
8a4fc5e
Compare
6cf6814
to
895871f
Compare
make test |
@@ -168,7 +168,7 @@ bool BLASKernelGenerator<hw>::gemmAccessC(COperation op, const GEMMProblem &prob | |||
block2DCRemainder |= block2DCFull; | |||
block2DCRemainder &= !strategy.C.atomic; | |||
block2DCFull &= !strategy.C.atomic; | |||
bool altCRemainder = strategy.altCRemainder && !strategy.C.padded && (remainderM || remainderN || problem.gemmt()); | |||
bool altCRemainder = !problem.Tc_ext.is4Bit() && strategy.altCRemainder && !strategy.C.padded && (remainderM || remainderN || problem.gemmt()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check should be moved to GEMMStrategy::preflight
, something like
altCRemainder &= (problem.Tc_ext.bits() >= 8);
case MatrixLayout::T: x = j0; y = i0; break; | ||
} | ||
emad(1, offsetC, offsetC, x, xstride, strategy, state); | ||
if(xstride == 0){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a slightly more future-proof check:
if(xstride == 0){ | |
if (Tc_ext.is4Bit()) { |
@@ -1439,6 +1444,7 @@ bool BLASKernelGenerator<hw>::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStr | |||
// Get register layouts for A/B/C. | |||
if (!getRegLayout(Ta_load, state.A_layout, unrollM, strategy.ka_load, remM_A, remK_A, false, AvoidFragment, 0, 0, problem.A, strategy.A)) return false; | |||
if (!getRegLayout(Tb_load, state.B_layout, strategy.kb_load, unrollN, remK_B, remN_B, false, AvoidFragment, 0, 0, problem.B, strategy.B)) return false; | |||
if (cColMajor && strategy.systolic && (!state.A_layout[0].colMajor || !state.B_layout[0].colMajor)) stub(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition should be checked already inside outerProductSystolic
-- is it being missed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not being missed inside outerProductSystolic
, I found it useful to have an assert nearby the source of the error. If you prefer I'll drop it.
{ | ||
if (i.src0.neg || i.sat || i.hasCMod()) stub("Unsupported modifier"); | ||
int simd = i.simd; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to provide a comment here with assembly pseudo-code of the emulation sequence (similar to the other complicated conversion sequences).
@@ -1030,6 +1030,14 @@ auto _CATALOG_ = kcatalog::toFlatCatalog({ | |||
{{'F', "gemm", {"[FO]", "O", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "#I"}, "aB32 aS32 aB wg 1x4x8 kr cab4 ks32 af hi pt bk0 grf256 sys l4 dm sr br", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {524288, 131072, 16777216}, {524288, 131072, 16777216}, {32, 8, 32}, {1, 4, 8}, 1, (WGType) 1, 261, 8192, 8192, {1, 1, 4}, {true, true, true}}, {'E', 17, {1.0714e+06, 814336, -806.186, 79098.3, 0, 0, 1.57352, 2.83206, 2.69271, 7.17928, 0.0520116, 0.0520116, 0, 0.665122, 1.00198, 1.00049, 9.78834e-15}}}, | |||
{{'F', "gemm", {"[FO]", "O", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "IAB"}, "at32x2+m64@96 am64+m32@128 aB wg 4x2x4 kr xaf st hi pt sr br sb128 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 64}, {16, 16, 64}, {4, 2, 4}, 1, (WGType) 1, 445, 0, 8192, {4, 8, 4}, {true, true, true}}, {'E', 17, {1.12606e+06, -137008, -24369.3, 228077, 2.14139e+06, 1.80224e+06, 0.232306, 0.352856, 0.366483, 1.01866, 0.0096663, 0.00801051, 0.00318484, 0.842293, 1.37613, 0.921541, 4.97764e-12}}}, | |||
{{'F', "gemm", {"[FO]", "O", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 8, 1}, "ABIps"}, "at32+m128@96 am64+m64@112 aB wg 4x8 xaf st hi pt sr br sb128 bk0 sm sn grf256 sys kv afb", {16, (LoopType) 255, 256, {(LoopType) 208, (LoopType) 255, (LoopType) 255}, {1048576, 655360, 16777216}, {1048576, 655360, 64}, {16, 40, 32}, {4, 8, 1}, 1, (WGType) 1, 441, 0, 0, {4, 8, 4}, {true, true, true}}, {'E', 17, {884649, 705009, 0, 0, 6.08092e+06, 1.03629e+07, 0.496687, 0.364134, 0.840897, 1.28383, 0.00200956, 0.00200956, 0, 1, 2.31371, 1.15477, 1.64496e-12}}}, | |||
{{'F', "gemm", {"[EH]", "[EH]", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 16, -1}, {1, 1, 1}, "#I"}, "aB16+m32@32 aB32 aB wg 2x8 af vav li nmk pt sr br ca3 bk0 sys kv dm afb l4", {16, (LoopType) 255, 128, {(LoopType) 225, (LoopType) 255, (LoopType) 255}, {1048576, 32768, 16777216}, {1048576, 32768, 32}, {64, 2, 32}, {2, 8, 1}, 1, (WGType) 1, 441, 24576, 0, {4, 2, 4}, {true, true, true}}, {'E', 17, {1.32162e+06, 161954, 0, 0, 2.32817e+06, 0, 0.71806, 4.15517, 0.786689, 1.40778, 0.0341164, 0.0131941, 0.0256486, 0.947188, 1.39057, 0.987284, 5.0128e-12}}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can these just be EES
instead of [EH][EH]S
? I think you have the automatic upconversion already set up in gemmAutoTypeConversions.
895871f
to
cbde4ec
Compare
Description
copy_plan
during creation whenDNNL_DEV_MODE
enabled.Partially covers MFDNN-124711
Checklist
General
make test
andmake test_benchdnn_*
) pass locally for each commit?