Skip to content

[NVPTX] Fix v2i8 call lowering, use generic ld/st nodes for call params #146930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2025

@llvm/pr-subscribers-debuginfo
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 241.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146930.diff

37 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (-273)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+227-364)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2-8)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+3-130)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+3-3)
  • (modified) llvm/test/CodeGen/NVPTX/byval-const-global.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll (+5-5)
  • (modified) llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/combine-mad.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/compare-int.ll (+501-120)
  • (modified) llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll (+162-10)
  • (modified) llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+6-6)
  • (modified) llvm/test/CodeGen/NVPTX/fma.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/forward-ld-param.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/i128-param.ll (+10-10)
  • (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+6-6)
  • (modified) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+93-28)
  • (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+6-6)
  • (modified) llvm/test/CodeGen/NVPTX/idioms.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/indirect_byval.ll (+10-10)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+14-19)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+4-6)
  • (modified) llvm/test/CodeGen/NVPTX/lower-byval-args.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+7-8)
  • (modified) llvm/test/CodeGen/NVPTX/param-add.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+124-134)
  • (modified) llvm/test/CodeGen/NVPTX/param-overalign.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/param-vectorize-device.ll (+14-14)
  • (modified) llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+146-75)
  • (modified) llvm/test/CodeGen/NVPTX/store-undef.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/tex-read-cuda.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+242-352)
  • (modified) llvm/test/CodeGen/NVPTX/vaargs.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+15-17)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 5631342ecc13e..fdd2671a5289e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -145,18 +145,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     if (tryStoreVector(N))
       return;
     break;
-  case NVPTXISD::LoadParam:
-  case NVPTXISD::LoadParamV2:
-  case NVPTXISD::LoadParamV4:
-    if (tryLoadParam(N))
-      return;
-    break;
-  case NVPTXISD::StoreParam:
-  case NVPTXISD::StoreParamV2:
-  case NVPTXISD::StoreParamV4:
-    if (tryStoreParam(N))
-      return;
-    break;
   case ISD::INTRINSIC_W_CHAIN:
     if (tryIntrinsicChain(N))
       return;
@@ -1429,267 +1417,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   return true;
 }
 
-bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
-  SDValue Chain = Node->getOperand(0);
-  SDValue Offset = Node->getOperand(2);
-  SDValue Glue = Node->getOperand(3);
-  SDLoc DL(Node);
-  MemSDNode *Mem = cast<MemSDNode>(Node);
-
-  unsigned VecSize;
-  switch (Node->getOpcode()) {
-  default:
-    return false;
-  case NVPTXISD::LoadParam:
-    VecSize = 1;
-    break;
-  case NVPTXISD::LoadParamV2:
-    VecSize = 2;
-    break;
-  case NVPTXISD::LoadParamV4:
-    VecSize = 4;
-    break;
-  }
-
-  EVT EltVT = Node->getValueType(0);
-  EVT MemVT = Mem->getMemoryVT();
-
-  std::optional<unsigned> Opcode;
-
-  switch (VecSize) {
-  default:
-    return false;
-  case 1:
-    Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
-                             NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
-                             NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64);
-    break;
-  case 2:
-    Opcode =
-        pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
-                        NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
-                        NVPTX::LoadParamMemV2I64);
-    break;
-  case 4:
-    Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
-                             NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
-                             NVPTX::LoadParamMemV4I32, {/* no v4i64 */});
-    break;
-  }
-  if (!Opcode)
-    return false;
-
-  SDVTList VTs;
-  if (VecSize == 1) {
-    VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
-  } else if (VecSize == 2) {
-    VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
-  } else {
-    EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
-    VTs = CurDAG->getVTList(EVTs);
-  }
-
-  unsigned OffsetVal = Offset->getAsZExtVal();
-
-  SmallVector<SDValue, 2> Ops(
-      {CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
-
-  ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
-  return true;
-}
-
-// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
-#define getOpcV2H(ty, opKind0, opKind1)                                        \
-  NVPTX::StoreParamV2##ty##_##opKind0##opKind1
-
-#define getOpcV2H1(ty, opKind0, isImm1)                                        \
-  (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
-
-#define getOpcodeForVectorStParamV2(ty, isimm)                                 \
-  (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
-
-#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3)                      \
-  NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
-
-#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3)                      \
-  (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i)                       \
-           : getOpcV4H(ty, opKind0, opKind1, opKind2, r)
-
-#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3)                       \
-  (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3)                       \
-           : getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
-
-#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3)                        \
-  (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3)                        \
-           : getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
-
-#define getOpcodeForVectorStParamV4(ty, isimm)                                 \
-  (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3])                 \
-             : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
-
-#define getOpcodeForVectorStParam(n, ty, isimm)                                \
-  (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm)                            \
-           : getOpcodeForVectorStParamV4(ty, isimm)
-
-static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
-                                           unsigned NumElts,
-                                           MVT::SimpleValueType MemTy,
-                                           SelectionDAG *CurDAG, SDLoc DL) {
-  // Determine which inputs are registers and immediates make new operators
-  // with constant values
-  SmallVector<bool, 4> IsImm(NumElts, false);
-  for (unsigned i = 0; i < NumElts; i++) {
-    IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
-    if (IsImm[i]) {
-      SDValue Imm = Ops[i];
-      if (MemTy == MVT::f32 || MemTy == MVT::f64) {
-        const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
-        const ConstantFP *CF = ConstImm->getConstantFPValue();
-        Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
-      } else {
-        const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
-        const ConstantInt *CI = ConstImm->getConstantIntValue();
-        Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
-      }
-      Ops[i] = Imm;
-    }
-  }
-
-  // Get opcode for MemTy, size, and register/immediate operand ordering
-  switch (MemTy) {
-  case MVT::i8:
-    return getOpcodeForVectorStParam(NumElts, I8, IsImm);
-  case MVT::i16:
-    return getOpcodeForVectorStParam(NumElts, I16, IsImm);
-  case MVT::i32:
-    return getOpcodeForVectorStParam(NumElts, I32, IsImm);
-  case MVT::i64:
-    assert(NumElts == 2 && "MVT too large for NumElts > 2");
-    return getOpcodeForVectorStParamV2(I64, IsImm);
-  case MVT::f32:
-    return getOpcodeForVectorStParam(NumElts, F32, IsImm);
-  case MVT::f64:
-    assert(NumElts == 2 && "MVT too large for NumElts > 2");
-    return getOpcodeForVectorStParamV2(F64, IsImm);
-
-  // These cases don't support immediates, just use the all register version
-  // and generate moves.
-  case MVT::i1:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
-                          : NVPTX::StoreParamV4I8_rrrr;
-  case MVT::f16:
-  case MVT::bf16:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
-                          : NVPTX::StoreParamV4I16_rrrr;
-  case MVT::v2f16:
-  case MVT::v2bf16:
-  case MVT::v2i16:
-  case MVT::v4i8:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
-                          : NVPTX::StoreParamV4I32_rrrr;
-  default:
-    llvm_unreachable("Cannot select st.param for unknown MemTy");
-  }
-}
-
-bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
-  SDLoc DL(N);
-  SDValue Chain = N->getOperand(0);
-  SDValue Param = N->getOperand(1);
-  unsigned ParamVal = Param->getAsZExtVal();
-  SDValue Offset = N->getOperand(2);
-  unsigned OffsetVal = Offset->getAsZExtVal();
-  MemSDNode *Mem = cast<MemSDNode>(N);
-  SDValue Glue = N->getOperand(N->getNumOperands() - 1);
-
-  // How many elements do we have?
-  unsigned NumElts;
-  switch (N->getOpcode()) {
-  default:
-    llvm_unreachable("Unexpected opcode");
-  case NVPTXISD::StoreParam:
-    NumElts = 1;
-    break;
-  case NVPTXISD::StoreParamV2:
-    NumElts = 2;
-    break;
-  case NVPTXISD::StoreParamV4:
-    NumElts = 4;
-    break;
-  }
-
-  // Build vector of operands
-  SmallVector<SDValue, 8> Ops;
-  for (unsigned i = 0; i < NumElts; ++i)
-    Ops.push_back(N->getOperand(i + 3));
-  Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
-              CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
-
-  // Determine target opcode
-  // If we have an i1, use an 8-bit store. The lowering code in
-  // NVPTXISelLowering will have already emitted an upcast.
-  std::optional<unsigned> Opcode;
-  switch (NumElts) {
-  default:
-    llvm_unreachable("Unexpected NumElts");
-  case 1: {
-    MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
-    SDValue Imm = Ops[0];
-    if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
-        (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
-      // Convert immediate to target constant
-      if (MemTy == MVT::f32 || MemTy == MVT::f64) {
-        const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
-        const ConstantFP *CF = ConstImm->getConstantFPValue();
-        Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
-      } else {
-        const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
-        const ConstantInt *CI = ConstImm->getConstantIntValue();
-        Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
-      }
-      Ops[0] = Imm;
-      // Use immediate version of store param
-      Opcode =
-          pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i,
-                          NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i);
-    } else
-      Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
-                               NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
-                               NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r);
-    if (Opcode == NVPTX::StoreParamI8_r) {
-      // Fine tune the opcode depending on the size of the operand.
-      // This helps to avoid creating redundant COPY instructions in
-      // InstrEmitter::AddRegisterOperand().
-      switch (Ops[0].getSimpleValueType().SimpleTy) {
-      default:
-        break;
-      case MVT::i32:
-        Opcode = NVPTX::StoreParamI8TruncI32_r;
-        break;
-      case MVT::i64:
-        Opcode = NVPTX::StoreParamI8TruncI64_r;
-        break;
-      }
-    }
-    break;
-  }
-  case 2:
-  case 4: {
-    MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
-    Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
-    break;
-  }
-  }
-
-  SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
-  SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
-  MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
-  CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
-
-  ReplaceNode(N, Ret);
-  return true;
-}
-
 /// SelectBFE - Look for instruction sequences that can be made more efficient
 /// by using the 'bfe' (bit-field extract) PTX instruction
 bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 0e4dec1adca67..19b569d638b0c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -78,8 +78,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool tryLDG(MemSDNode *N);
   bool tryStore(SDNode *N);
   bool tryStoreVector(SDNode *N);
-  bool tryLoadParam(SDNode *N);
-  bool tryStoreParam(SDNode *N);
   bool tryFence(SDNode *N);
   void SelectAddrSpaceCast(SDNode *N);
   bool tryBFE(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index bb0aeb493ed48..3cda7df55ad58 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1049,12 +1049,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::DeclareArrayParam)
     MAKE_CASE(NVPTXISD::DeclareScalarParam)
     MAKE_CASE(NVPTXISD::CALL)
-    MAKE_CASE(NVPTXISD::LoadParam)
-    MAKE_CASE(NVPTXISD::LoadParamV2)
-    MAKE_CASE(NVPTXISD::LoadParamV4)
-    MAKE_CASE(NVPTXISD::StoreParam)
-    MAKE_CASE(NVPTXISD::StoreParamV2)
-    MAKE_CASE(NVPTXISD::StoreParamV4)
     MAKE_CASE(NVPTXISD::MoveParam)
     MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
     MAKE_CASE(NVPTXISD::BUILD_VECTOR)
@@ -1293,105 +1287,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
   return DL.getABITypeAlign(Ty);
 }
 
-static bool adjustElementType(EVT &ElementType) {
-  switch (ElementType.getSimpleVT().SimpleTy) {
-  default:
-    return false;
-  case MVT::f16:
-  case MVT::bf16:
-    ElementType = MVT::i16;
-    return true;
-  case MVT::f32:
-  case MVT::v2f16:
-  case MVT::v2bf16:
-    ElementType = MVT::i32;
-    return true;
-  case MVT::f64:
-    ElementType = MVT::i64;
-    return true;
-  }
-}
-
-// Use byte-store when the param address of the argument value is unaligned.
-// This may happen when the return value is a field of a packed structure.
-//
-// This is called in LowerCall() when passing the param values.
-static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
-                                        uint64_t Offset, EVT ElementType,
-                                        SDValue StVal, SDValue &InGlue,
-                                        unsigned ArgID, const SDLoc &dl) {
-  // Bit logic only works on integer types
-  if (adjustElementType(ElementType))
-    StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
-
-  // Store each byte
-  SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-  for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
-    // Shift the byte to the last byte position
-    SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
-                                   DAG.getConstant(i * 8, dl, MVT::i32));
-    SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
-                               DAG.getConstant(Offset + i, dl, MVT::i32),
-                               ShiftVal, InGlue};
-    // Trunc store only the last byte by using
-    //     st.param.b8
-    // The register type can be larger than b8.
-    Chain = DAG.getMemIntrinsicNode(
-        NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
-        MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
-    InGlue = Chain.getValue(1);
-  }
-  return Chain;
-}
-
-// Use byte-load when the param adress of the returned value is unaligned.
-// This may happen when the returned value is a field of a packed structure.
-static SDValue
-LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
-                           EVT ElementType, SDValue &InGlue,
-                           SmallVectorImpl<SDValue> &TempProxyRegOps,
-                           const SDLoc &dl) {
-  // Bit logic only works on integer types
-  EVT MergedType = ElementType;
-  adjustElementType(MergedType);
-
-  // Load each byte and construct the whole value. Initial value to 0
-  SDValue RetVal = DAG.getConstant(0, dl, MergedType);
-  // LoadParamMemI8 loads into i16 register only
-  SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
-  for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
-    SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
-                              DAG.getConstant(Offset + i, dl, MVT::i32),
-                              InGlue};
-    // This will be selected to LoadParamMemI8
-    SDValue LdVal =
-        DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
-                                MVT::i8, MachinePointerInfo(), Align(1));
-    SDValue TmpLdVal = LdVal.getValue(0);
-    Chain = LdVal.getValue(1);
-    InGlue = LdVal.getValue(2);
-
-    TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
-                           TmpLdVal.getSimpleValueType(), TmpLdVal);
-    TempProxyRegOps.push_back(TmpLdVal);
-
-    SDValue CMask = DAG.getConstant(255, dl, MergedType);
-    SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
-    // Need to extend the i16 register to the whole width.
-    TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
-    // Mask off the high bits. Leave only the lower 8bits.
-    // Do this because we are using loadparam.b8.
-    TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
-    // Shift and merge
-    TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
-    RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
-  }
-  if (ElementType != MergedType)
-    RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
-
-  return RetVal;
-}
-
 static bool shouldConvertToIndirectCall(const CallBase *CB,
                                         const GlobalAddressSDNode *Func) {
   if (!Func)
@@ -1458,10 +1353,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   SelectionDAG &DAG = CLI.DAG;
   SDLoc dl = CLI.DL;
-  SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
-  SDValue Chain = CLI.Chain;
+  const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
   SDValue Callee = CLI.Callee;
-  bool &isTailCall = CLI.IsTailCall;
   ArgListTy &Args = CLI.getArgs();
   Type *RetTy = CLI.RetTy;
   const CallBase *CB = CLI.CB;
@@ -1492,9 +1385,34 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   unsigned VAOffset = 0;                  // current offset in the param array
 
   const unsigned UniqueCallSite = GlobalUniqueCallSite++;
-  SDValue TempChain = Chain;
-  Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
-  SDValue InGlue = Chain.getValue(1);
+  const SDValue CallChain = CLI.Chain;
+  const SDValue StartChain =
+      DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl);
+  SDValue DeclareGlue = StartChain.getValue(1);
+
+  SmallVector<SDValue, 16> CallPrereqs{StartChain};
+
+  const auto DeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
+    // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
+    // loaded/stored using i16, so it's handled here as well.
+    const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
+    SDValue Declare =
+        DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+                    {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
+    CallPrereqs.push_back(Declare);
+    DeclareGlue = Declare.getValue(1);
+    return Declare;
+  };
+
+  const auto DeclareArrayParam = [&](SDValue Symbol, Align Align,
+                                     unsigned Size) {
+    SDValue Declare = DAG.getNode(
+        NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
+        {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
+    CallPrereqs.push_back(Declare);
+    DeclareGlue = Declare.getValue(1);
+    return Declare;
+  };
 
   // Args.size() and Outs.size() need not match.
   // Outs.size() will be larger
@@ -1555,43 +1473,23 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
-    const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
+    const SDValue ArgDeclare = [&]() {
       if (IsVAArg) {
-        if (ArgI == FirstVAArg) {
-          VADeclareParam = DAG.getNode(
-              NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
-              {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
-               GetI32(0), InGlue});
-          return VADeclareParam;
-        }
-        return std::nullopt;
-      }
-      if (IsByVal || shouldPassAsArray(Arg.Ty)) {
-        // declare .param .align <align> .b8 .param<n>[<size>];
-        return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
-                           {MVT::Other, MVT::Glue},
-                           {Chain, ParamSymbol, GetI32(ArgAlign.value()),
-                            GetI32(TypeSize), InGlue});
+        if (ArgI == FirstVAArg)
+          VADeclareParam = DeclareArrayParam(
+              ParamSymbol, Align(STI.getMaxRequiredAlignment()), 0);
+        return VADeclareParam;
       }
+
+      if (IsByVal || shouldPassAsArray(Arg.Ty))
+        return DeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
-      // declare .param .b<size> .param<n>;
-
-      // PTX ABI requ...
[truncated]

Copy link

github-actions bot commented Jul 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch 3 times, most recently from bdc2f75 to a33744a Compare July 3, 2025 18:33
@llvmbot llvmbot added the clang Clang issues not falling into any other category label Jul 3, 2025
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch from a33744a to 43cfe71 Compare July 3, 2025 21:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX clang Clang issues not falling into any other category debuginfo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants