@@ -7435,14 +7435,17 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
7435
7435
case GenISAIntrinsic::GenISA_WaveBallot:
7436
7436
emitWaveBallot(inst);
7437
7437
break;
7438
+ case GenISAIntrinsic::GenISA_WaveInverseBallot:
7439
+ emitWaveInverseBallot(inst);
7440
+ break;
7438
7441
case GenISAIntrinsic::GenISA_WaveShuffleIndex:
7439
7442
emitSimdShuffle(inst);
7440
7443
break;
7441
7444
case GenISAIntrinsic::GenISA_WavePrefix:
7442
- emitWavePrefix(inst);
7445
+ emitWavePrefix(cast<WavePrefixIntrinsic>( inst) );
7443
7446
break;
7444
7447
case GenISAIntrinsic::GenISA_QuadPrefix:
7445
- emitWavePrefix( inst, true );
7448
+ emitQuadPrefix(cast<QuadPrefixIntrinsic>( inst) );
7446
7449
break;
7447
7450
case GenISAIntrinsic::GenISA_WaveAll:
7448
7451
emitWaveAll(inst);
@@ -9981,16 +9984,21 @@ void EmitPass::emitReductionAll(
9981
9984
m_encoder->Push();
9982
9985
}
9983
9986
9984
- void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc, CVariable* pSrc, CVariable* pSrcsArr[2], bool isPrefix, bool isQuad)
9987
+ void EmitPass::emitPreOrPostFixOp(
9988
+ e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
9989
+ CVariable* pSrc, CVariable* pSrcsArr[2], CVariable *Flag,
9990
+ bool isPrefix, bool isQuad)
9985
9991
{
9986
- // This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
9987
-
9988
9992
if (m_currShader->m_Platform->doScalar64bScan() && CEncoder::GetCISADataTypeSize(type) == 8 && !isQuad)
9989
9993
{
9990
- emitPreOrPostFixOpScalar(op, identityValue, type, negateSrc, pSrc, pSrcsArr, isPrefix);
9994
+ emitPreOrPostFixOpScalar(
9995
+ op, identityValue, type, negateSrc,
9996
+ pSrc, pSrcsArr, Flag,
9997
+ isPrefix);
9991
9998
return;
9992
9999
}
9993
10000
10001
+ // This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
9994
10002
bool isSimd32 = (m_currShader->m_dispatchSize == SIMDMode::SIMD32);
9995
10003
int counter = 1;
9996
10004
if (isSimd32)
@@ -10004,18 +10012,20 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
10004
10012
IGC::EALIGN_GRF,
10005
10013
false);
10006
10014
10007
- // Set the GRF to 0 with no mask. This will set all the registers to 0
10015
+ // Set the GRF to <identity> with no mask. This will set all the registers to <identity>
10008
10016
CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
10009
10017
m_encoder->SetNoMask();
10010
10018
m_encoder->Copy(pSrcCopy, pIdentityValue);
10011
10019
m_encoder->Push();
10012
10020
10013
- // Now copy the src with a mask so the disabled lanes still keep their 0
10021
+ // Now copy the src with a mask so the disabled lanes still keep their <identity>
10014
10022
if (negateSrc)
10015
10023
{
10016
10024
m_encoder->SetSrcModifier(0, EMOD_NEG);
10017
10025
}
10018
10026
m_encoder->SetSecondHalf(i == 1);
10027
+ if (Flag)
10028
+ m_encoder->SetPredicate(Flag);
10019
10029
m_encoder->Copy(pSrcCopy, pSrc);
10020
10030
m_encoder->Push();
10021
10031
@@ -10063,8 +10073,9 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
10063
10073
{
10064
10074
/*
10065
10075
Copy the adjacent elements.
10066
- for example: r10 be the register
10067
- ____ ____ ____ ____
10076
+ for example: let r10 be the register
10077
+ Assume we are performing addition for this example
10078
+ ____ ____ ____ ____
10068
10079
__|____|____|____|____|____|____|____|_
10069
10080
| 7 | 6 | 5 | 4 | 9 | 5 | 3 | 2 |
10070
10081
---------------------------------------
@@ -10095,10 +10106,10 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
10095
10106
}
10096
10107
10097
10108
/*
10098
- ____ ____
10099
- _______|____|________________|____|______ ___________________________________________
10109
+ ____ ____
10110
+ _______|____|________________|____|______ ___________________________________________
10100
10111
| 13 | 6 | 9 | 4 | 14 | 5 | 5 | 2 | ==> | 13 | 15 | 9 | 4 | 14 | 10 | 5 | 2 |
10101
- ----------------------------------------- -------------------------------------------
10112
+ ----------------------------------------- -------------------------------------------
10102
10113
*/
10103
10114
// Now we have a weird copy happening. This will be done by SIMD 2 instructions.
10104
10115
@@ -10127,7 +10138,7 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
10127
10138
}
10128
10139
10129
10140
/*
10130
- ___________ ___________
10141
+ ___________ ___________
10131
10142
__|___________|_________|___________|______ ___________________________________________
10132
10143
| 13 | 15 | 9 | 4 | 14 | 10 | 5 | 2 | ==> | 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 |
10133
10144
------------------------------------------- -------------------------------------------
@@ -10164,21 +10175,21 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
10164
10175
}
10165
10176
10166
10177
/*
10167
- ____
10178
+ ____
10168
10179
__________________|____|_________________ ____________________________________________
10169
10180
| 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 15 | 9 | 23 | 19 | 10 | 5 | 2 |
10170
10181
----------------------------------------- --------------------------------------------
10171
- _________
10182
+ _________
10172
10183
_____________|_________|_________________ _____________________________________________
10173
10184
| 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 15 | 28 | 23 | 19 | 10 | 5 | 2 |
10174
10185
----------------------------------------- ---------------------------------------------
10175
10186
10176
- ______________
10187
+ ______________
10177
10188
________|______________|_________________ _____________________________________________
10178
10189
| 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 34 | 28 | 23 | 19 | 10 | 5 | 2 |
10179
10190
----------------------------------------- ---------------------------------------------
10180
10191
10181
- ____________________
10192
+ ____________________
10182
10193
__|____________________|_________________ _____________________________________________
10183
10194
| 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 41 | 34 | 28 | 23 | 19 | 10 | 5 | 2 |
10184
10195
----------------------------------------- ---------------------------------------------
@@ -10239,6 +10250,7 @@ void EmitPass::emitPreOrPostFixOpScalar(
10239
10250
bool negateSrc,
10240
10251
CVariable* src,
10241
10252
CVariable* result[2],
10253
+ CVariable* Flag,
10242
10254
bool isPrefix)
10243
10255
{
10244
10256
// This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
@@ -10259,19 +10271,21 @@ void EmitPass::emitPreOrPostFixOpScalar(
10259
10271
IGC::EALIGN_GRF,
10260
10272
false);
10261
10273
10262
- // Set the GRF to 0 with no mask. This will set all the registers to 0
10274
+ // Set the GRF to <identity> with no mask. This will set all the registers to <identity>
10263
10275
CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
10264
10276
m_encoder->SetSecondHalf(i == 1);
10265
10277
m_encoder->SetNoMask();
10266
10278
m_encoder->Copy(pSrcCopy[i], pIdentityValue);
10267
10279
m_encoder->Push();
10268
10280
10269
- // Now copy the src with a mask so the disabled lanes still keep their 0
10281
+ // Now copy the src with a mask so the disabled lanes still keep their <identity>
10270
10282
if (negateSrc)
10271
10283
{
10272
10284
m_encoder->SetSrcModifier(0, EMOD_NEG);
10273
10285
}
10274
10286
m_encoder->SetSecondHalf(i == 1);
10287
+ if (Flag)
10288
+ m_encoder->SetPredicate(Flag);
10275
10289
m_encoder->Copy(pSrcCopy[i], src);
10276
10290
m_encoder->Push();
10277
10291
@@ -14326,6 +14340,33 @@ void EmitPass::emitWaveBallot(llvm::GenIntrinsicInst* inst)
14326
14340
}
14327
14341
}
14328
14342
14343
+ void EmitPass::emitWaveInverseBallot(llvm::GenIntrinsicInst* inst)
14344
+ {
14345
+ CVariable *Mask = GetSymbol(inst->getOperand(0));
14346
+
14347
+ if (Mask->IsUniform())
14348
+ {
14349
+ if (m_encoder->IsSecondHalf())
14350
+ return;
14351
+
14352
+ m_encoder->SetP(m_destination, Mask);
14353
+ return;
14354
+ }
14355
+
14356
+ // The uniform case should by far be the most common. Otherwise,
14357
+ // fall back and compute:
14358
+ //
14359
+ // (val & (1 << id)) != 0
14360
+ CVariable *Temp = m_currShader->GetNewVariable(
14361
+ numLanes(m_currShader->m_SIMDSize), ISA_TYPE_UD, EALIGN_GRF);
14362
+
14363
+ m_currShader->GetSimdOffsetBase(Temp);
14364
+ m_encoder->Shl(Temp, m_currShader->ImmToVariable(1, ISA_TYPE_UD), Temp);
14365
+ m_encoder->And(Temp, Mask, Temp);
14366
+ m_encoder->Cmp(EPREDICATE_NE,
14367
+ m_destination, Temp, m_currShader->ImmToVariable(0, ISA_TYPE_UD));
14368
+ }
14369
+
14329
14370
static void GetReductionOp(WaveOps op, Type* opndTy, uint64_t& identity, e_opcode& opcode, VISA_Type& type)
14330
14371
{
14331
14372
auto getISAType = [](Type* ty, bool isSigned = true)
@@ -14468,17 +14509,49 @@ static void GetReductionOp(WaveOps op, Type* opndTy, uint64_t& identity, e_opcod
14468
14509
}
14469
14510
}
14470
14511
14471
- void EmitPass::emitWavePrefix(llvm::GenIntrinsicInst* inst, bool isQuad)
14512
+ void EmitPass::emitWavePrefix(WavePrefixIntrinsic* I)
14513
+ {
14514
+ Value *Mask = I->getMask();
14515
+ if (auto *CI = dyn_cast<ConstantInt>(Mask))
14516
+ {
14517
+ // If the mask is all set, then we just pass a null
14518
+ // mask to emitScan() indicating we don't want to
14519
+ // emit any predication.
14520
+ if (CI->isAllOnesValue())
14521
+ Mask = nullptr;
14522
+ }
14523
+ emitScan(
14524
+ I->getSrc(), I->getOpKind(), I->isInclusiveScan(), Mask, false);
14525
+ }
14526
+
14527
+ void EmitPass::emitQuadPrefix(QuadPrefixIntrinsic* I)
14528
+ {
14529
+ emitScan(
14530
+ I->getSrc(), I->getOpKind(), I->isInclusiveScan(), nullptr, true);
14531
+ }
14532
+
14533
+ void EmitPass::emitScan(
14534
+ Value *Src, IGC::WaveOps Op,
14535
+ bool isInclusiveScan, Value *Mask, bool isQuad)
14472
14536
{
14473
- WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
14474
- bool isInclusiveScan = cast<llvm::ConstantInt>(inst->getOperand(2))->getZExtValue() != 0;
14475
14537
VISA_Type type;
14476
14538
e_opcode opCode;
14477
14539
uint64_t identity = 0;
14478
- GetReductionOp(op, inst->getOperand(0) ->getType(), identity, opCode, type);
14479
- CVariable* src = GetSymbol(inst->getOperand(0) );
14540
+ GetReductionOp(Op, Src ->getType(), identity, opCode, type);
14541
+ CVariable* src = GetSymbol(Src );
14480
14542
CVariable *dst[2] = { nullptr, nullptr };
14481
- emitPreOrPostFixOp(opCode, identity, type, false, src, dst, !isInclusiveScan, isQuad);
14543
+ CVariable *Flag = Mask ? GetSymbol(Mask) : nullptr;
14544
+
14545
+ emitPreOrPostFixOp(
14546
+ opCode, identity, type,
14547
+ false, src, dst, Flag,
14548
+ !isInclusiveScan, isQuad);
14549
+
14550
+ // Now that we've computed the result in temporary registers,
14551
+ // make sure we only write the results to lanes participating in the
14552
+ // scan as specified by 'mask'.
14553
+ if (Flag)
14554
+ m_encoder->SetPredicate(Flag);
14482
14555
m_encoder->Copy(m_destination, dst[0]);
14483
14556
if (m_currShader->m_dispatchSize == SIMDMode::SIMD32)
14484
14557
{
0 commit comments