Skip to content

Commit 6a6a81d

Browse files
scottp101gfxbot
authored andcommitted
Add mask argument to wavePrefix and waveInverseBallot. This is support for coming commits
Change-Id: Ic524f8083e729796041b3f277ab5a75c4f03dd5f
1 parent fc83b6f commit 6a6a81d

File tree

8 files changed

+189
-34
lines changed

8 files changed

+189
-34
lines changed

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

+99-26
Original file line numberDiff line numberDiff line change
@@ -7435,14 +7435,17 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
74357435
case GenISAIntrinsic::GenISA_WaveBallot:
74367436
emitWaveBallot(inst);
74377437
break;
7438+
case GenISAIntrinsic::GenISA_WaveInverseBallot:
7439+
emitWaveInverseBallot(inst);
7440+
break;
74387441
case GenISAIntrinsic::GenISA_WaveShuffleIndex:
74397442
emitSimdShuffle(inst);
74407443
break;
74417444
case GenISAIntrinsic::GenISA_WavePrefix:
7442-
emitWavePrefix(inst);
7445+
emitWavePrefix(cast<WavePrefixIntrinsic>(inst));
74437446
break;
74447447
case GenISAIntrinsic::GenISA_QuadPrefix:
7445-
emitWavePrefix(inst, true);
7448+
emitQuadPrefix(cast<QuadPrefixIntrinsic>(inst));
74467449
break;
74477450
case GenISAIntrinsic::GenISA_WaveAll:
74487451
emitWaveAll(inst);
@@ -9981,16 +9984,21 @@ void EmitPass::emitReductionAll(
99819984
m_encoder->Push();
99829985
}
99839986

9984-
void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc, CVariable* pSrc, CVariable* pSrcsArr[2], bool isPrefix, bool isQuad)
9987+
void EmitPass::emitPreOrPostFixOp(
9988+
e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
9989+
CVariable* pSrc, CVariable* pSrcsArr[2], CVariable *Flag,
9990+
bool isPrefix, bool isQuad)
99859991
{
9986-
// This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
9987-
99889992
if (m_currShader->m_Platform->doScalar64bScan() && CEncoder::GetCISADataTypeSize(type) == 8 && !isQuad)
99899993
{
9990-
emitPreOrPostFixOpScalar(op, identityValue, type, negateSrc, pSrc, pSrcsArr, isPrefix);
9994+
emitPreOrPostFixOpScalar(
9995+
op, identityValue, type, negateSrc,
9996+
pSrc, pSrcsArr, Flag,
9997+
isPrefix);
99919998
return;
99929999
}
999310000

10001+
// This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
999410002
bool isSimd32 = (m_currShader->m_dispatchSize == SIMDMode::SIMD32);
999510003
int counter = 1;
999610004
if (isSimd32)
@@ -10004,18 +10012,20 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1000410012
IGC::EALIGN_GRF,
1000510013
false);
1000610014

10007-
// Set the GRF to 0 with no mask. This will set all the registers to 0
10015+
// Set the GRF to <identity> with no mask. This will set all the registers to <identity>
1000810016
CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
1000910017
m_encoder->SetNoMask();
1001010018
m_encoder->Copy(pSrcCopy, pIdentityValue);
1001110019
m_encoder->Push();
1001210020

10013-
// Now copy the src with a mask so the disabled lanes still keep their 0
10021+
// Now copy the src with a mask so the disabled lanes still keep their <identity>
1001410022
if (negateSrc)
1001510023
{
1001610024
m_encoder->SetSrcModifier(0, EMOD_NEG);
1001710025
}
1001810026
m_encoder->SetSecondHalf(i == 1);
10027+
if (Flag)
10028+
m_encoder->SetPredicate(Flag);
1001910029
m_encoder->Copy(pSrcCopy, pSrc);
1002010030
m_encoder->Push();
1002110031

@@ -10063,8 +10073,9 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1006310073
{
1006410074
/*
1006510075
Copy the adjacent elements.
10066-
for example: r10 be the register
10067-
____ ____ ____ ____
10076+
for example: let r10 be the register
10077+
Assume we are performing addition for this example
10078+
____ ____ ____ ____
1006810079
__|____|____|____|____|____|____|____|_
1006910080
| 7 | 6 | 5 | 4 | 9 | 5 | 3 | 2 |
1007010081
---------------------------------------
@@ -10095,10 +10106,10 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1009510106
}
1009610107

1009710108
/*
10098-
____ ____
10099-
_______|____|________________|____|______ ___________________________________________
10109+
____ ____
10110+
_______|____|________________|____|______ ___________________________________________
1010010111
| 13 | 6 | 9 | 4 | 14 | 5 | 5 | 2 | ==> | 13 | 15 | 9 | 4 | 14 | 10 | 5 | 2 |
10101-
----------------------------------------- -------------------------------------------
10112+
----------------------------------------- -------------------------------------------
1010210113
*/
1010310114
// Now we have a weird copy happening. This will be done by SIMD 2 instructions.
1010410115

@@ -10127,7 +10138,7 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1012710138
}
1012810139

1012910140
/*
10130-
___________ ___________
10141+
___________ ___________
1013110142
__|___________|_________|___________|______ ___________________________________________
1013210143
| 13 | 15 | 9 | 4 | 14 | 10 | 5 | 2 | ==> | 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 |
1013310144
------------------------------------------- -------------------------------------------
@@ -10164,21 +10175,21 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
1016410175
}
1016510176

1016610177
/*
10167-
____
10178+
____
1016810179
__________________|____|_________________ ____________________________________________
1016910180
| 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 15 | 9 | 23 | 19 | 10 | 5 | 2 |
1017010181
----------------------------------------- --------------------------------------------
10171-
_________
10182+
_________
1017210183
_____________|_________|_________________ _____________________________________________
1017310184
| 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 15 | 28 | 23 | 19 | 10 | 5 | 2 |
1017410185
----------------------------------------- ---------------------------------------------
1017510186

10176-
______________
10187+
______________
1017710188
________|______________|_________________ _____________________________________________
1017810189
| 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 22 | 34 | 28 | 23 | 19 | 10 | 5 | 2 |
1017910190
----------------------------------------- ---------------------------------------------
1018010191

10181-
____________________
10192+
____________________
1018210193
__|____________________|_________________ _____________________________________________
1018310194
| 22 | 15 | 9 | 4 | 19 | 10 | 5 | 2 | ==> | 41 | 34 | 28 | 23 | 19 | 10 | 5 | 2 |
1018410195
----------------------------------------- ---------------------------------------------
@@ -10239,6 +10250,7 @@ void EmitPass::emitPreOrPostFixOpScalar(
1023910250
bool negateSrc,
1024010251
CVariable* src,
1024110252
CVariable* result[2],
10253+
CVariable* Flag,
1024210254
bool isPrefix)
1024310255
{
1024410256
// This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
@@ -10259,19 +10271,21 @@ void EmitPass::emitPreOrPostFixOpScalar(
1025910271
IGC::EALIGN_GRF,
1026010272
false);
1026110273

10262-
// Set the GRF to 0 with no mask. This will set all the registers to 0
10274+
// Set the GRF to <identity> with no mask. This will set all the registers to <identity>
1026310275
CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
1026410276
m_encoder->SetSecondHalf(i == 1);
1026510277
m_encoder->SetNoMask();
1026610278
m_encoder->Copy(pSrcCopy[i], pIdentityValue);
1026710279
m_encoder->Push();
1026810280

10269-
// Now copy the src with a mask so the disabled lanes still keep their 0
10281+
// Now copy the src with a mask so the disabled lanes still keep their <identity>
1027010282
if (negateSrc)
1027110283
{
1027210284
m_encoder->SetSrcModifier(0, EMOD_NEG);
1027310285
}
1027410286
m_encoder->SetSecondHalf(i == 1);
10287+
if (Flag)
10288+
m_encoder->SetPredicate(Flag);
1027510289
m_encoder->Copy(pSrcCopy[i], src);
1027610290
m_encoder->Push();
1027710291

@@ -14326,6 +14340,33 @@ void EmitPass::emitWaveBallot(llvm::GenIntrinsicInst* inst)
1432614340
}
1432714341
}
1432814342

14343+
void EmitPass::emitWaveInverseBallot(llvm::GenIntrinsicInst* inst)
14344+
{
14345+
CVariable *Mask = GetSymbol(inst->getOperand(0));
14346+
14347+
if (Mask->IsUniform())
14348+
{
14349+
if (m_encoder->IsSecondHalf())
14350+
return;
14351+
14352+
m_encoder->SetP(m_destination, Mask);
14353+
return;
14354+
}
14355+
14356+
// The uniform case should by far be the most common. Otherwise,
14357+
// fall back and compute:
14358+
//
14359+
// (val & (1 << id)) != 0
14360+
CVariable *Temp = m_currShader->GetNewVariable(
14361+
numLanes(m_currShader->m_SIMDSize), ISA_TYPE_UD, EALIGN_GRF);
14362+
14363+
m_currShader->GetSimdOffsetBase(Temp);
14364+
m_encoder->Shl(Temp, m_currShader->ImmToVariable(1, ISA_TYPE_UD), Temp);
14365+
m_encoder->And(Temp, Mask, Temp);
14366+
m_encoder->Cmp(EPREDICATE_NE,
14367+
m_destination, Temp, m_currShader->ImmToVariable(0, ISA_TYPE_UD));
14368+
}
14369+
1432914370
static void GetReductionOp(WaveOps op, Type* opndTy, uint64_t& identity, e_opcode& opcode, VISA_Type& type)
1433014371
{
1433114372
auto getISAType = [](Type* ty, bool isSigned = true)
@@ -14468,17 +14509,49 @@ static void GetReductionOp(WaveOps op, Type* opndTy, uint64_t& identity, e_opcod
1446814509
}
1446914510
}
1447014511

14471-
void EmitPass::emitWavePrefix(llvm::GenIntrinsicInst* inst, bool isQuad)
14512+
void EmitPass::emitWavePrefix(WavePrefixIntrinsic* I)
14513+
{
14514+
Value *Mask = I->getMask();
14515+
if (auto *CI = dyn_cast<ConstantInt>(Mask))
14516+
{
14517+
// If the mask is all set, then we just pass a null
14518+
// mask to emitScan() indicating we don't want to
14519+
// emit any predication.
14520+
if (CI->isAllOnesValue())
14521+
Mask = nullptr;
14522+
}
14523+
emitScan(
14524+
I->getSrc(), I->getOpKind(), I->isInclusiveScan(), Mask, false);
14525+
}
14526+
14527+
void EmitPass::emitQuadPrefix(QuadPrefixIntrinsic* I)
14528+
{
14529+
emitScan(
14530+
I->getSrc(), I->getOpKind(), I->isInclusiveScan(), nullptr, true);
14531+
}
14532+
14533+
void EmitPass::emitScan(
14534+
Value *Src, IGC::WaveOps Op,
14535+
bool isInclusiveScan, Value *Mask, bool isQuad)
1447214536
{
14473-
WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
14474-
bool isInclusiveScan = cast<llvm::ConstantInt>(inst->getOperand(2))->getZExtValue() != 0;
1447514537
VISA_Type type;
1447614538
e_opcode opCode;
1447714539
uint64_t identity = 0;
14478-
GetReductionOp(op, inst->getOperand(0)->getType(), identity, opCode, type);
14479-
CVariable* src = GetSymbol(inst->getOperand(0));
14540+
GetReductionOp(Op, Src->getType(), identity, opCode, type);
14541+
CVariable* src = GetSymbol(Src);
1448014542
CVariable *dst[2] = { nullptr, nullptr };
14481-
emitPreOrPostFixOp(opCode, identity, type, false, src, dst, !isInclusiveScan, isQuad);
14543+
CVariable *Flag = Mask ? GetSymbol(Mask) : nullptr;
14544+
14545+
emitPreOrPostFixOp(
14546+
opCode, identity, type,
14547+
false, src, dst, Flag,
14548+
!isInclusiveScan, isQuad);
14549+
14550+
// Now that we've computed the result in temporary registers,
14551+
// make sure we only write the results to lanes participating in the
14552+
// scan as specified by 'mask'.
14553+
if (Flag)
14554+
m_encoder->SetPredicate(Flag);
1448214555
m_encoder->Copy(m_destination, dst[0]);
1448314556
if (m_currShader->m_dispatchSize == SIMDMode::SIMD32)
1448414557
{

IGC/Compiler/CISACodeGen/EmitVISAPass.hpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ class EmitPass : public llvm::FunctionPass
259259
bool negateSrc,
260260
CVariable* src,
261261
CVariable* result[2],
262+
CVariable* Flag = nullptr,
262263
bool isPrefix = false,
263264
bool isQuad = false);
264265

@@ -269,7 +270,8 @@ class EmitPass : public llvm::FunctionPass
269270
bool negateSrc,
270271
CVariable* src,
271272
CVariable* result[2],
272-
bool isPrefix = false);
273+
CVariable* Flag,
274+
bool isPrefix);
273275

274276
bool IsUniformAtomic(llvm::Instruction* pInst);
275277
void emitAtomicRaw(llvm::GenIntrinsicInst* pInst);
@@ -360,8 +362,10 @@ class EmitPass : public llvm::FunctionPass
360362

361363
// CrossLane Instructions
362364
void emitWaveBallot(llvm::GenIntrinsicInst* inst);
365+
void emitWaveInverseBallot(llvm::GenIntrinsicInst* inst);
363366
void emitWaveShuffleIndex(llvm::GenIntrinsicInst* inst);
364-
void emitWavePrefix(llvm::GenIntrinsicInst* inst, bool isQuad = false);
367+
void emitWavePrefix(llvm::WavePrefixIntrinsic* I);
368+
void emitQuadPrefix(llvm::QuadPrefixIntrinsic* I);
365369
void emitWaveAll(llvm::GenIntrinsicInst* inst);
366370

367371
// Those three "vector" version shall be combined with
@@ -501,6 +505,9 @@ class EmitPass : public llvm::FunctionPass
501505
void emitSetMessagePhaseType(llvm::GenIntrinsicInst* inst, VISA_Type type);
502506
void emitSetMessagePhaseType_legacy(llvm::GenIntrinsicInst* inst, VISA_Type type);
503507

508+
void emitScan(llvm::Value *Src, IGC::WaveOps Op,
509+
bool isInclusiveScan, llvm::Value *Mask, bool isQuad);
510+
504511
// Cached per lane offset variables. This is a per basic block data
505512
// structure. For each entry, the first item is the scalar type size in
506513
// bytes, the second item is the corresponding symbol.

IGC/Compiler/Optimizer/OpenCLPasses/SubGroupFuncs/SubGroupFuncsResolution.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ void SubGroupFuncsResolution::subGroupScan(WaveOps op, CallInst &CI)
463463
IRBuilder<> IRB(&CI);
464464
Value* arg = CI.getArgOperand(0);
465465
Value* opVal = IRB.getInt8((uint8_t)op);
466-
Value* args[3] = { arg, opVal, IRB.getInt1(false) };
466+
Value* args[] = { arg, opVal, IRB.getInt1(false), IRB.getInt1(true) };
467467
Function* waveScan = GenISAIntrinsic::getDeclaration(CI.getCalledFunction()->getParent(),
468468
GenISAIntrinsic::GenISA_WavePrefix,
469469
arg->getType());

IGC/GenISAIntrinsics/GenIntrinsicInst.h

+46-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class GenIntrinsicInst : public CallInst {
9696
return isa<CallInst>(V) && classof(cast<CallInst>(V));
9797
}
9898

99-
uint64_t getImm64Operand(unsigned idx) {
99+
uint64_t getImm64Operand(unsigned idx) const {
100100
assert(isa<ConstantInt>(getOperand(idx)));
101101
return valueToImm64(getOperand(idx));
102102
}
@@ -531,6 +531,51 @@ class SGVIntrinsic : public GenIntrinsicInst {
531531
}
532532
};
533533

534+
class WavePrefixIntrinsic : public GenIntrinsicInst
535+
{
536+
public:
537+
Value *getSrc() const { return getOperand(0); }
538+
IGC::WaveOps getOpKind() const
539+
{
540+
return static_cast<IGC::WaveOps>(getImm64Operand(1));
541+
}
542+
bool isInclusiveScan() const
543+
{
544+
return getImm64Operand(2) != 0;
545+
}
546+
Value *getMask() const { return getOperand(3); }
547+
548+
// Methods for support type inquiry through isa, cast, and dyn_cast:
549+
static inline bool classof(const GenIntrinsicInst *I) {
550+
return I->getIntrinsicID() == GenISAIntrinsic::GenISA_WavePrefix;
551+
}
552+
static inline bool classof(const Value *V) {
553+
return isa<GenIntrinsicInst>(V) && classof(cast<GenIntrinsicInst>(V));
554+
}
555+
};
556+
557+
class QuadPrefixIntrinsic : public GenIntrinsicInst
558+
{
559+
public:
560+
Value *getSrc() const { return getOperand(0); }
561+
IGC::WaveOps getOpKind() const
562+
{
563+
return static_cast<IGC::WaveOps>(getImm64Operand(1));
564+
}
565+
bool isInclusiveScan() const
566+
{
567+
return getImm64Operand(2) != 0;
568+
}
569+
570+
// Methods for support type inquiry through isa, cast, and dyn_cast:
571+
static inline bool classof(const GenIntrinsicInst *I) {
572+
return I->getIntrinsicID() == GenISAIntrinsic::GenISA_QuadPrefix;
573+
}
574+
static inline bool classof(const Value *V) {
575+
return isa<GenIntrinsicInst>(V) && classof(cast<GenIntrinsicInst>(V));
576+
}
577+
};
578+
534579
template <class X, class Y>
535580
inline bool isa(const Y &Val, GenISAIntrinsic::ID id)
536581
{

IGC/GenISAIntrinsics/Intrinsic_definitions.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,18 @@
265265
"GenISA_pair_to_ptr": ["anyptr",["int","int"],"NoMem"],
266266
"GenISA_ptr_to_pair": [["int","int"],["anyptr"],"NoMem"],
267267
"GenISA_WaveBallot": ["int",["bool"],"Convergent,InaccessibleMemOnly"],
268+
# Arg 0 - Mask value
269+
# Return - assigns each lane the value of its corresponding bit.
270+
"GenISA_WaveInverseBallot": ["bool",["int"],"Convergent,InaccessibleMemOnly"],
268271
"GenISA_WaveShuffleIndex": ["anyint",[0,"int"],"Convergent,NoMem"],
269272
"GenISA_WaveAll": ["anyint",[0,"char"],"Convergent,InaccessibleMemOnly"],
270-
"GenISA_WavePrefix": ["anyint",[0,"char","bool"],"Convergent,InaccessibleMemOnly"],
273+
# Arg 0 - Src value
274+
# Arg 1 - Operation type
275+
# Arg 2 - Is the operation inclusive (1) or exclusive (0)?
276+
# Arg 3 - a mask that specifies a subset of lanes to participate
277+
# in the computation.
278+
# Return - The computed prefix/postfix result
279+
"GenISA_WavePrefix": ["anyint",[0,"char","bool","bool"],"Convergent,InaccessibleMemOnly"],
271280
"GenISA_QuadPrefix": ["anyint",[0,"char","bool"],"Convergent,InaccessibleMemOnly"],
272281
"GenISA_InitDiscardMask": ["bool",[],"None"],
273282
"GenISA_UpdateDiscardMask": ["bool",["bool","bool"],"None"],

IGC/GenISAIntrinsics/Intrinsics_ReadMe.csv

+1
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,7 @@ GenISA_WavePrefix,,"// Description: Accumulate and keep the intermediate results
12381238
,anyint,
12391239
,0,
12401240
,char,specify the type: sum / Prod / Min/ Max
1241+
,bool,a mask that specifies a subset of lanes to participate in the computation.
12411242
GenISA_setMessagePhaseV,,
12421243
,anyvector,new message phase result
12431244
,0,cur message phase

0 commit comments

Comments
 (0)