Skip to content

[NVPTX] Further cleanup call isel #146411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

AlexMaclean
Copy link
Member

@AlexMaclean AlexMaclean commented Jun 30, 2025

This change continues rewriting and cleanup around DAG ISel for formal-arguments, return values, and function calls. This causes some incidental changes, mostly to instruction ordering and register naming but also a couple improvements caused by using scalar types earlier in the lowering.

@llvmbot
Copy link
Member

llvmbot commented Jun 30, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 315.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146411.diff

15 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+150-153)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+10-6)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+23-35)
  • (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm60.ll (+540-540)
  • (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm70.ll (+540-540)
  • (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm90.ll (+540-540)
  • (modified) llvm/test/CodeGen/NVPTX/cmpxchg.ll (+120-120)
  • (modified) llvm/test/CodeGen/NVPTX/convert-int-sm20.ll (+3-3)
  • (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+9-12)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+6-6)
  • (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+30-30)
  • (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+96-96)
  • (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+1-1)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d9192fbfceff1..a41b094faa8d6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/Register.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetCallingConv.h"
@@ -390,35 +391,27 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
 /// and promote them to a larger size if they're not.
 ///
 /// The promoted type is placed in \p PromoteVT if the function returns true.
-static std::optional<MVT> PromoteScalarIntegerPTX(const EVT &VT) {
+static EVT promoteScalarIntegerPTX(const EVT VT) {
   if (VT.isScalarInteger()) {
-    MVT PromotedVT;
     switch (PowerOf2Ceil(VT.getFixedSizeInBits())) {
     default:
       llvm_unreachable(
           "Promotion is not suitable for scalars of size larger than 64-bits");
     case 1:
-      PromotedVT = MVT::i1;
-      break;
+      return MVT::i1;
     case 2:
     case 4:
     case 8:
-      PromotedVT = MVT::i8;
-      break;
+      return MVT::i8;
     case 16:
-      PromotedVT = MVT::i16;
-      break;
+      return MVT::i16;
     case 32:
-      PromotedVT = MVT::i32;
-      break;
+      return MVT::i32;
     case 64:
-      PromotedVT = MVT::i64;
-      break;
+      return MVT::i64;
     }
-    if (VT != PromotedVT)
-      return PromotedVT;
   }
-  return std::nullopt;
+  return VT;
 }
 
 // Check whether we can merge loads/stores of some of the pieces of a
@@ -1053,10 +1046,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     break;
 
     MAKE_CASE(NVPTXISD::RET_GLUE)
-    MAKE_CASE(NVPTXISD::DeclareParam)
+    MAKE_CASE(NVPTXISD::DeclareArrayParam)
     MAKE_CASE(NVPTXISD::DeclareScalarParam)
-    MAKE_CASE(NVPTXISD::DeclareRet)
-    MAKE_CASE(NVPTXISD::DeclareRetParam)
     MAKE_CASE(NVPTXISD::CALL)
     MAKE_CASE(NVPTXISD::LoadParam)
     MAKE_CASE(NVPTXISD::LoadParamV2)
@@ -1162,8 +1153,8 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
 }
 
 std::string NVPTXTargetLowering::getPrototype(
-    const DataLayout &DL, Type *retTy, const ArgListTy &Args,
-    const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign RetAlign,
+    const DataLayout &DL, Type *RetTy, const ArgListTy &Args,
+    const SmallVectorImpl<ISD::OutputArg> &Outs,
     std::optional<unsigned> FirstVAArg, const CallBase &CB,
     unsigned UniqueCallSite) const {
   auto PtrVT = getPointerTy(DL);
@@ -1172,22 +1163,22 @@ std::string NVPTXTargetLowering::getPrototype(
   raw_string_ostream O(Prototype);
   O << "prototype_" << UniqueCallSite << " : .callprototype ";
 
-  if (retTy->isVoidTy()) {
+  if (RetTy->isVoidTy()) {
     O << "()";
   } else {
     O << "(";
-    if (shouldPassAsArray(retTy)) {
-      assert(RetAlign && "RetAlign must be set for non-void return types");
-      O << ".param .align " << RetAlign->value() << " .b8 _["
-        << DL.getTypeAllocSize(retTy) << "]";
-    } else if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
+    if (shouldPassAsArray(RetTy)) {
+      const Align RetAlign = getArgumentAlignment(&CB, RetTy, 0, DL);
+      O << ".param .align " << RetAlign.value() << " .b8 _["
+        << DL.getTypeAllocSize(RetTy) << "]";
+    } else if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy()) {
       unsigned size = 0;
-      if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
+      if (auto *ITy = dyn_cast<IntegerType>(RetTy)) {
         size = ITy->getBitWidth();
       } else {
-        assert(retTy->isFloatingPointTy() &&
+        assert(RetTy->isFloatingPointTy() &&
                "Floating point type expected here");
-        size = retTy->getPrimitiveSizeInBits();
+        size = RetTy->getPrimitiveSizeInBits();
       }
       // PTX ABI requires all scalar return values to be at least 32
       // bits in size.  fp16 normally uses .b16 as its storage type in
@@ -1195,7 +1186,7 @@ std::string NVPTXTargetLowering::getPrototype(
       size = promoteScalarArgumentSize(size);
 
       O << ".param .b" << size << " _";
-    } else if (isa<PointerType>(retTy)) {
+    } else if (isa<PointerType>(RetTy)) {
       O << ".param .b" << PtrVT.getSizeInBits() << " _";
     } else {
       llvm_unreachable("Unknown return type");
@@ -1256,7 +1247,7 @@ std::string NVPTXTargetLowering::getPrototype(
 
   if (FirstVAArg)
     O << (first ? "" : ",") << " .param .align "
-      << STI.getMaxRequiredAlignment() << " .b8 _[]\n";
+      << STI.getMaxRequiredAlignment() << " .b8 _[]";
   O << ")";
   if (shouldEmitPTXNoReturn(&CB, *nvTM))
     O << " .noreturn";
@@ -1442,6 +1433,21 @@ static ISD::NodeType getExtOpcode(const ISD::ArgFlagsTy &Flags) {
   return ISD::ANY_EXTEND;
 }
 
+static SDValue correctParamType(SDValue V, EVT ExpectedVT,
+                                ISD::ArgFlagsTy Flags, SelectionDAG &DAG,
+                                SDLoc dl) {
+  const EVT ActualVT = V.getValueType();
+  assert((ActualVT == ExpectedVT ||
+          (ExpectedVT.isInteger() && ActualVT.isInteger())) &&
+         "Non-integer argument type size mismatch");
+  if (ExpectedVT.bitsGT(ActualVT))
+    return DAG.getNode(getExtOpcode(Flags), dl, ExpectedVT, V);
+  if (ExpectedVT.bitsLT(ActualVT))
+    return DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, V);
+
+  return V;
+}
+
 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
                                        SmallVectorImpl<SDValue> &InVals) const {
 
@@ -1505,9 +1511,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
          "Outs and OutVals must be the same size");
   // Declare the .params or .reg need to pass values
   // to the function
-  for (const auto [ArgI, Arg] : llvm::enumerate(Args)) {
-    const auto ArgOuts = AllOuts.take_while(
-        [ArgI = ArgI](auto O) { return O.OrigArgIndex == ArgI; });
+  for (const auto E : llvm::enumerate(Args)) {
+    const auto ArgI = E.index();
+    const auto Arg = E.value();
+    const auto ArgOuts =
+        AllOuts.take_while([&](auto O) { return O.OrigArgIndex == ArgI; });
     const auto ArgOutVals = AllOutVals.take_front(ArgOuts.size());
     AllOuts = AllOuts.drop_front(ArgOuts.size());
     AllOutVals = AllOutVals.drop_front(ArgOuts.size());
@@ -1515,6 +1523,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool IsVAArg = (ArgI >= FirstVAArg);
     const bool IsByVal = Arg.IsByVal;
 
+    const SDValue ParamSymbol =
+        getCallParamSymbol(DAG, IsVAArg ? FirstVAArg : ArgI, MVT::i32);
+
     SmallVector<EVT, 16> VTs;
     SmallVector<uint64_t, 16> Offsets;
 
@@ -1525,38 +1536,43 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     assert(VTs.size() == Offsets.size() && "Size mismatch");
     assert((IsByVal || VTs.size() == ArgOuts.size()) && "Size mismatch");
 
-    Align ArgAlign;
-    if (IsByVal) {
-      // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
-      // so we don't need to worry whether it's naturally aligned or not.
-      // See TargetLowering::LowerCallTo().
-      Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
-      ArgAlign = getFunctionByValParamAlign(CB->getCalledFunction(), ETy,
-                                            InitialAlign, DL);
-      if (IsVAArg)
-        VAOffset = alignTo(VAOffset, ArgAlign);
-    } else {
-      ArgAlign = getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
-    }
+    const Align ArgAlign = [&]() {
+      if (IsByVal) {
+        // The ByValAlign in the Outs[OIdx].Flags is always set at this point,
+        // so we don't need to worry whether it's naturally aligned or not.
+        // See TargetLowering::LowerCallTo().
+        const Align InitialAlign = ArgOuts[0].Flags.getNonZeroByValAlign();
+        const Align ByValAlign = getFunctionByValParamAlign(
+            CB->getCalledFunction(), ETy, InitialAlign, DL);
+        if (IsVAArg)
+          VAOffset = alignTo(VAOffset, ByValAlign);
+        return ByValAlign;
+      }
+      return getArgumentAlignment(CB, Arg.Ty, ArgI + 1, DL);
+    }();
 
     const unsigned TypeSize = DL.getTypeAllocSize(ETy);
     assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
-    const bool PassAsArray = IsByVal || shouldPassAsArray(Arg.Ty);
-    if (IsVAArg) {
-      if (ArgI == FirstVAArg) {
-        VADeclareParam = Chain =
-            DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
-                        {Chain, GetI32(STI.getMaxRequiredAlignment()),
-                         GetI32(ArgI), GetI32(1), InGlue});
+    const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
+      if (IsVAArg) {
+        if (ArgI == FirstVAArg) {
+          VADeclareParam = DAG.getNode(
+              NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
+              {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
+               GetI32(0), InGlue});
+          return VADeclareParam;
+        }
+        return std::nullopt;
+      }
+      if (IsByVal || shouldPassAsArray(Arg.Ty)) {
+        // declare .param .align <align> .b8 .param<n>[<size>];
+        return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
+                           {MVT::Other, MVT::Glue},
+                           {Chain, ParamSymbol, GetI32(ArgAlign.value()),
+                            GetI32(TypeSize), InGlue});
       }
-    } else if (PassAsArray) {
-      // declare .param .align <align> .b8 .param<n>[<size>];
-      Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, {MVT::Other, MVT::Glue},
-                          {Chain, GetI32(ArgAlign.value()), GetI32(ArgI),
-                           GetI32(TypeSize), InGlue});
-    } else {
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
       // declare .param .b<size> .param<n>;
 
@@ -1568,11 +1584,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
               ? promoteScalarArgumentSize(TypeSize * 8)
               : TypeSize * 8;
 
-      Chain =
-          DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
-                      {Chain, GetI32(ArgI), GetI32(PromotedSize), InGlue});
+      return DAG.getNode(NVPTXISD::DeclareScalarParam, dl,
+                         {MVT::Other, MVT::Glue},
+                         {Chain, ParamSymbol, GetI32(PromotedSize), InGlue});
+    }();
+    if (ArgDeclare) {
+      Chain = ArgDeclare->getValue(0);
+      InGlue = ArgDeclare->getValue(1);
     }
-    InGlue = Chain.getValue(1);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
     // than 32-bits are sign extended or zero extended, depending on
@@ -1594,8 +1613,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       } else {
         StVal = ArgOutVals[I];
 
-        if (auto PromotedVT = PromoteScalarIntegerPTX(StVal.getValueType())) {
-          StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, *PromotedVT,
+        auto PromotedVT = promoteScalarIntegerPTX(StVal.getValueType());
+        if (PromotedVT != StVal.getValueType()) {
+          StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, PromotedVT,
                               StVal);
         }
       }
@@ -1619,12 +1639,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     unsigned J = 0;
     for (const unsigned NumElts : VectorInfo) {
       const int CurOffset = Offsets[J];
-      EVT EltVT = VTs[J];
+      EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
       const Align PartAlign = commonAlignment(ArgAlign, CurOffset);
 
-      if (auto PromotedVT = PromoteScalarIntegerPTX(EltVT))
-        EltVT = *PromotedVT;
-
       // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
       // scalar store. In such cases, fall back to byte stores.
       if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) {
@@ -1695,27 +1712,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   }
 
   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
-  MaybeAlign RetAlign = std::nullopt;
 
   // Handle Result
   if (!Ins.empty()) {
-    RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
-
-    // Declare
-    //  .param .align N .b8 retval0[<size-in-bytes>], or
-    //  .param .b<size-in-bits> retval0
-    const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
-    if (!shouldPassAsArray(RetTy)) {
-      const unsigned PromotedResultSize = promoteScalarArgumentSize(ResultSize);
-      Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, {MVT::Other, MVT::Glue},
-                          {Chain, GetI32(PromotedResultSize), InGlue});
-      InGlue = Chain.getValue(1);
-    } else {
-      Chain = DAG.getNode(
-          NVPTXISD::DeclareRetParam, dl, {MVT::Other, MVT::Glue},
-          {Chain, GetI32(RetAlign->value()), GetI32(ResultSize / 8), InGlue});
-      InGlue = Chain.getValue(1);
-    }
+    const SDValue RetDeclare = [&]() {
+      const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+      const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
+      if (shouldPassAsArray(RetTy)) {
+        const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+        return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
+                           {MVT::Other, MVT::Glue},
+                           {Chain, RetSymbol, GetI32(RetAlign.value()),
+                            GetI32(ResultSize / 8), InGlue});
+      }
+      const auto PromotedResultSize = promoteScalarArgumentSize(ResultSize);
+      return DAG.getNode(
+          NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+          {Chain, RetSymbol, GetI32(PromotedResultSize), InGlue});
+    }();
+    Chain = RetDeclare.getValue(0);
+    InGlue = RetDeclare.getValue(1);
   }
 
   const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
@@ -1760,7 +1776,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // The prototype is embedded in a string and put as the operand for a
     // CallPrototype SDNode which will print out to the value of the string.
     std::string Proto =
-        getPrototype(DL, RetTy, Args, CLI.Outs, RetAlign,
+        getPrototype(DL, RetTy, Args, CLI.Outs,
                      HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
                      UniqueCallSite);
     const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
@@ -1773,11 +1789,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   if (ConvertToIndirectCall) {
     // Copy the function ptr to a ptx register and use the register to call the
     // function.
-    EVT DestVT = Callee.getValueType();
-    MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+    const MVT DestVT = Callee.getValueType().getSimpleVT();
+    MachineRegisterInfo &MRI = DAG.getMachineFunction().getRegInfo();
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-    unsigned DestReg =
-        RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
+    Register DestReg = MRI.createVirtualRegister(TLI.getRegClassFor(DestVT));
     auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
     Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
   }
@@ -1810,7 +1825,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
     assert(VTs.size() == Ins.size() && "Bad value decomposition");
 
-    assert(RetAlign && "RetAlign is guaranteed to be set");
+    const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
 
     // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
     // 32-bits are sign extended or zero extended, depending on whether
@@ -1818,17 +1833,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     const bool ExtendIntegerRetVal =
         RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
 
-    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, *RetAlign);
+    const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
     unsigned I = 0;
     for (const unsigned VectorizedSize : VectorInfo) {
-      EVT TheLoadType = VTs[I];
+      EVT TheLoadType = promoteScalarIntegerPTX(VTs[I]);
       EVT EltType = Ins[I].VT;
-      const Align EltAlign = commonAlignment(*RetAlign, Offsets[I]);
+      const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
 
-      if (auto PromotedVT = PromoteScalarIntegerPTX(TheLoadType)) {
-        TheLoadType = *PromotedVT;
-        EltType = *PromotedVT;
-      }
+      if (TheLoadType != VTs[I])
+        EltType = TheLoadType;
 
       if (ExtendIntegerRetVal) {
         TheLoadType = MVT::i32;
@@ -1898,13 +1911,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       continue;
     }
 
-    SDValue Ret = DAG.getNode(
-        NVPTXISD::ProxyReg, dl,
-        {ProxyRegOps[I].getSimpleValueType(), MVT::Other, MVT::Glue},
-        {Chain, ProxyRegOps[I], InGlue});
-
-    Chain = Ret.getValue(1);
-    InGlue = Ret.getValue(2);
+    SDValue Ret =
+        DAG.getNode(NVPTXISD::ProxyReg, dl, ProxyRegOps[I].getSimpleValueType(),
+                    {Chain, ProxyRegOps[I]});
 
     const EVT ExpectedVT = Ins[I].VT;
     if (!Ret.getValueType().bitsEq(ExpectedVT)) {
@@ -1914,14 +1923,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   }
 
   for (SDValue &T : TempProxyRegOps) {
-    SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl,
-                               {T.getSimpleValueType(), MVT::Other, MVT::Glue},
-                               {Chain, T.getOperand(0), InGlue});
+    SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl, T.getSimpleValueType(),
+                               {Chain, T.getOperand(0)});
     DAG.ReplaceAllUsesWith(T, Repl);
     DAG.RemoveDeadNode(T.getNode());
-
-    Chain = Repl.getValue(1);
-    InGlue = Repl.getValue(2);
   }
 
   // set isTailCall to false for now, until we figure out how to express
@@ -3292,11 +3297,17 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
 // Name of the symbol is composed from its index and the function name.
 // Negative index corresponds to special parameter (unsized array) used for
 // passing variable arguments.
-SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx,
-                                            EVT v) const {
+SDValue NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int I,
+                                            EVT T) const {
   StringRef SavedStr = nvTM->getStrPool().save(
-      getParamName(&DAG.getMachineFunction().getFunction(), idx));
-  return DAG.getExternalSymbol(SavedStr.data(), v);
+      getParamName(&DAG.getMachineFunction().getFunction(), I));
+  return DAG.getExternalSymbol(SavedStr.data(), T);
+}
+
+SDValue NVPTXTargetLowering::getCallParamSymbol(SelectionDAG &DAG, int I,
+                                                EVT T) const {
+  const StringRef SavedStr = nvTM->getStrPool().save("param" + Twine(I));
+  return DAG.getExternalSymbol(SavedStr.data(), T);
 }
 
 SDValue NVPTXTargetLowering::LowerFormalArguments(
@@ -3393,8 +3404,11 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         const unsigned PackingAmt =
             LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
 
-        const EVT VecVT = EVT::getVectorVT(
-            F->getContext(), LoadVT.getScalarType(), NumElts * PackingAmt);
+        const EVT VecVT =
+            NumElts == 1
+                ? LoadVT
+                : EVT::getVectorVT(F->getContext(), LoadVT.getScalarType(),
+                                   NumElts * PackingAmt);
 
         SDValue VecAddr = DAG.getObjectPtrOffset(
             dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
@@ -3408,22 +3422,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         if (P....
[truncated]

@AlexMaclean AlexMaclean requested a review from Prince781 June 30, 2025 19:38
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/cleanup-isel-call branch from b16ad3d to 7f46a18 Compare June 30, 2025 20:21
Copy link
Contributor

@Prince781 Prince781 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AlexMaclean AlexMaclean merged commit 475cd8d into llvm:main Jul 1, 2025
7 checks passed
yzhang93 added a commit to iree-org/llvm-project that referenced this pull request Jul 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants