Skip to content

[NVPTX] Fix v2i8 call lowering, use generic ld/st nodes for call params #146930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/test/CodeGenCUDA/bf16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ __device__ __bf16 external_func( __bf16 in);
// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
__device__ __bf16 test_call( __bf16 in) {
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
// CHECK: st.param.b16 [param0], %[[R]];
// CHECK: .param .align 2 .b8 retval0[2];
// CHECK: st.param.b16 [param0], %[[R]];
// CHECK: call.uni (retval0), _Z13external_funcDF16b, (param0);
// CHECK: ld.param.b16 %[[RET:rs[0-9]+]], [retval0];
return external_func(in);
Expand Down
273 changes: 0 additions & 273 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
if (tryStoreVector(N))
return;
break;
case NVPTXISD::LoadParam:
case NVPTXISD::LoadParamV2:
case NVPTXISD::LoadParamV4:
if (tryLoadParam(N))
return;
break;
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
if (tryStoreParam(N))
return;
break;
case ISD::INTRINSIC_W_CHAIN:
if (tryIntrinsicChain(N))
return;
Expand Down Expand Up @@ -1429,267 +1417,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
return true;
}

bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
SDValue Chain = Node->getOperand(0);
SDValue Offset = Node->getOperand(2);
SDValue Glue = Node->getOperand(3);
SDLoc DL(Node);
MemSDNode *Mem = cast<MemSDNode>(Node);

unsigned VecSize;
switch (Node->getOpcode()) {
default:
return false;
case NVPTXISD::LoadParam:
VecSize = 1;
break;
case NVPTXISD::LoadParamV2:
VecSize = 2;
break;
case NVPTXISD::LoadParamV4:
VecSize = 4;
break;
}

EVT EltVT = Node->getValueType(0);
EVT MemVT = Mem->getMemoryVT();

std::optional<unsigned> Opcode;

switch (VecSize) {
default:
return false;
case 1:
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64);
break;
case 2:
Opcode =
pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
NVPTX::LoadParamMemV2I64);
break;
case 4:
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
NVPTX::LoadParamMemV4I32, {/* no v4i64 */});
break;
}
if (!Opcode)
return false;

SDVTList VTs;
if (VecSize == 1) {
VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
} else if (VecSize == 2) {
VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
} else {
EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
VTs = CurDAG->getVTList(EVTs);
}

unsigned OffsetVal = Offset->getAsZExtVal();

SmallVector<SDValue, 2> Ops(
{CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});

ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
return true;
}

// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
#define getOpcV2H(ty, opKind0, opKind1) \
NVPTX::StoreParamV2##ty##_##opKind0##opKind1

#define getOpcV2H1(ty, opKind0, isImm1) \
(isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)

#define getOpcodeForVectorStParamV2(ty, isimm) \
(isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])

#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \
NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3

#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \
(isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
: getOpcV4H(ty, opKind0, opKind1, opKind2, r)

#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \
(isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
: getOpcV4H3(ty, opKind0, opKind1, r, isImm3)

#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \
(isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
: getOpcV4H2(ty, opKind0, r, isImm2, isImm3)

#define getOpcodeForVectorStParamV4(ty, isimm) \
(isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
: getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])

#define getOpcodeForVectorStParam(n, ty, isimm) \
(n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
: getOpcodeForVectorStParamV4(ty, isimm)

static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
unsigned NumElts,
MVT::SimpleValueType MemTy,
SelectionDAG *CurDAG, SDLoc DL) {
// Determine which inputs are registers and immediates make new operators
// with constant values
SmallVector<bool, 4> IsImm(NumElts, false);
for (unsigned i = 0; i < NumElts; i++) {
IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
if (IsImm[i]) {
SDValue Imm = Ops[i];
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
const ConstantFP *CF = ConstImm->getConstantFPValue();
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
} else {
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
const ConstantInt *CI = ConstImm->getConstantIntValue();
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
}
Ops[i] = Imm;
}
}

// Get opcode for MemTy, size, and register/immediate operand ordering
switch (MemTy) {
case MVT::i8:
return getOpcodeForVectorStParam(NumElts, I8, IsImm);
case MVT::i16:
return getOpcodeForVectorStParam(NumElts, I16, IsImm);
case MVT::i32:
return getOpcodeForVectorStParam(NumElts, I32, IsImm);
case MVT::i64:
assert(NumElts == 2 && "MVT too large for NumElts > 2");
return getOpcodeForVectorStParamV2(I64, IsImm);
case MVT::f32:
return getOpcodeForVectorStParam(NumElts, F32, IsImm);
case MVT::f64:
assert(NumElts == 2 && "MVT too large for NumElts > 2");
return getOpcodeForVectorStParamV2(F64, IsImm);

// These cases don't support immediates, just use the all register version
// and generate moves.
case MVT::i1:
return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
: NVPTX::StoreParamV4I8_rrrr;
case MVT::f16:
case MVT::bf16:
return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
: NVPTX::StoreParamV4I16_rrrr;
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
case MVT::v4i8:
return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
: NVPTX::StoreParamV4I32_rrrr;
default:
llvm_unreachable("Cannot select st.param for unknown MemTy");
}
}

bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
SDLoc DL(N);
SDValue Chain = N->getOperand(0);
SDValue Param = N->getOperand(1);
unsigned ParamVal = Param->getAsZExtVal();
SDValue Offset = N->getOperand(2);
unsigned OffsetVal = Offset->getAsZExtVal();
MemSDNode *Mem = cast<MemSDNode>(N);
SDValue Glue = N->getOperand(N->getNumOperands() - 1);

// How many elements do we have?
unsigned NumElts;
switch (N->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case NVPTXISD::StoreParam:
NumElts = 1;
break;
case NVPTXISD::StoreParamV2:
NumElts = 2;
break;
case NVPTXISD::StoreParamV4:
NumElts = 4;
break;
}

// Build vector of operands
SmallVector<SDValue, 8> Ops;
for (unsigned i = 0; i < NumElts; ++i)
Ops.push_back(N->getOperand(i + 3));
Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});

// Determine target opcode
// If we have an i1, use an 8-bit store. The lowering code in
// NVPTXISelLowering will have already emitted an upcast.
std::optional<unsigned> Opcode;
switch (NumElts) {
default:
llvm_unreachable("Unexpected NumElts");
case 1: {
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
SDValue Imm = Ops[0];
if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
(isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
// Convert immediate to target constant
if (MemTy == MVT::f32 || MemTy == MVT::f64) {
const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
const ConstantFP *CF = ConstImm->getConstantFPValue();
Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
} else {
const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
const ConstantInt *CI = ConstImm->getConstantIntValue();
Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
}
Ops[0] = Imm;
// Use immediate version of store param
Opcode =
pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i,
NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i);
} else
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r);
if (Opcode == NVPTX::StoreParamI8_r) {
// Fine tune the opcode depending on the size of the operand.
// This helps to avoid creating redundant COPY instructions in
// InstrEmitter::AddRegisterOperand().
switch (Ops[0].getSimpleValueType().SimpleTy) {
default:
break;
case MVT::i32:
Opcode = NVPTX::StoreParamI8TruncI32_r;
break;
case MVT::i64:
Opcode = NVPTX::StoreParamI8TruncI64_r;
break;
}
}
break;
}
case 2:
case 4: {
MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
break;
}
}

SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});

ReplaceNode(N, Ret);
return true;
}

/// SelectBFE - Look for instruction sequences that can be made more efficient
/// by using the 'bfe' (bit-field extract) PTX instruction
bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
Expand Down
2 changes: 0 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryLDG(MemSDNode *N);
bool tryStore(SDNode *N);
bool tryStoreVector(SDNode *N);
bool tryLoadParam(SDNode *N);
bool tryStoreParam(SDNode *N);
bool tryFence(SDNode *N);
void SelectAddrSpaceCast(SDNode *N);
bool tryBFE(SDNode *N);
Expand Down
Loading