Skip to content

[NVPTX] Improved support for grid_constant #97112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

akshayrdeodhar
Copy link
Contributor

@akshayrdeodhar akshayrdeodhar commented Jun 28, 2024

  • Supports escaped grid_constant pointers less conservatively. Casts uses inside Calls, PtrToInts, Stores where the pointer is a value operand to generic address space, immediately before the escape, while keeping other uses in the param address space

  • Related to: [NVPTX] Basic support for "grid_constant" #96125

- Supports escaped grid_constant pointers less conservatively. Casts
  uses inside Calls, PtrToInts, Stores where the pointer is a _value
  operand_ to generic address space, immediately before the escape,
  while keeping other uses in the param address space
@llvmbot
Copy link
Member

llvmbot commented Jun 28, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Akshay Deodhar (akshayrdeodhar)

Changes
  • Supports escaped grid_constant pointers less conservatively. Casts uses inside Calls, PtrToInts, Stores where the pointer is a value operand to generic address space, immediately before the escape, while keeping other uses in the param address space

Patch is 28.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97112.diff

3 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+136-37)
  • (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.cpp (+2-1)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+227-17)
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index e63c7a61c6f26..d5dffb8998a04 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -12,8 +12,7 @@
 // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
 //
 // Kernel parameters are read-only and accessible only via ld.param
-// instruction, directly or via a pointer. Pointers to kernel
-// arguments can't be converted to generic address space.
+// instruction, directly or via a pointer.
 //
 // Device function parameters are directly accessible via
 // ld.param/st.param, but taking the address of one returns a pointer
@@ -54,8 +53,10 @@
 //      ...
 //    }
 //
-// 2. Convert pointers in a byval kernel parameter to pointers in the global
-//    address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
+// 2. Convert byval kernel parameters to pointers in the param address space
+//    (so that NVPTX emits ld/st.param).  Convert pointers *within* a byval
+//    kernel parameter to pointers in the global address space. This allows
+//    NVPTX to emit ld/st.global.
 //
 //    struct S {
 //      int *x;
@@ -68,22 +69,68 @@
 //
 //    "b" points to the global address space. In the IR level,
 //
-//    define void @foo({i32*, i32*}* byval %input) {
-//      %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
-//      %b = load i32*, i32** %b_ptr
+//    define void @foo(ptr byval %input) {
+//      %b_ptr = getelementptr {ptr, ptr}, ptr %input, i64 0, i32 1
+//      %b = load ptr, ptr %b_ptr
 //      ; use %b
 //    }
 //
 //    becomes
 //
 //    define void @foo({i32*, i32*}* byval %input) {
-//      %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
-//      %b = load i32*, i32** %b_ptr
-//      %b_global = addrspacecast i32* %b to i32 addrspace(1)*
-//      %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
+//      %b_param = addrspacecat ptr %input to ptr addrspace(101)
+//      %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
+//      %b = load ptr, ptr addrspace(101) %b_ptr
+//      %b_global = addrspacecast ptr %b to ptr addrspace(1)
 //      ; use %b_generic
 //    }
 //
+//    Create a local copy of kernel byval parameters used in a way that *might* mutate
+//    the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
+//    are undefined behaviour, and don't require local copies.
+//
+//    define void @foo(ptr byval(%struct.s) align 4 %input) {
+//       store i32 42, ptr %input
+//       ret void
+//    }
+//
+//    becomes
+//
+//    define void @foo(ptr byval(%struct.s) align 4 %input) #1 {
+//      %input1 = alloca %struct.s, align 4
+//      %input2 = addrspacecast ptr %input to ptr addrspace(101)
+//      %input3 = load %struct.s, ptr addrspace(101) %input2, align 4
+//      store %struct.s %input3, ptr %input1, align 4
+//      store i32 42, ptr %input1, align 4
+//      ret void
+//    }
+//
+//    If %input were passed to a device function, or written to memory,
+//    conservatively assume that %input gets mutated, and create a local copy.
+//
+//    Convert param pointers to grid_constant byval kernel parameters that are
+//    passed into calls (device functions, intrinsics, inline asm), or otherwise
+//    "escape" (into stores/ptrtoints) to the generic address space, using the
+//    `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param
+//    (available for sm70+)
+//
+//    define void @foo(ptr byval(%struct.s) %input) {
+//      ; %input is a grid_constant
+//      %call = call i32 @escape(ptr %input)
+//      ret void
+//    }
+//
+//    becomes
+//
+//    define void @foo(ptr byval(%struct.s) %input) {
+//      %input1 = addrspacecast ptr %input to ptr addrspace(101)
+//      ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
+//      ; to prevent generic -> param -> generic from getting cancelled out
+//      %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
+//      %call = call i32 @escape(ptr %input1.gen)
+//      ret void
+//    }
+//
 // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
 // cancel the addrspacecast pair this pass emits.
 //===----------------------------------------------------------------------===//
@@ -166,19 +213,22 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
 // ones in parameter AS, so we can access them using ld.param.
 // =============================================================================
 
-// Replaces the \p OldUser instruction with the same in parameter AS.
-// Only Load and GEP are supported.
-static void convertToParamAS(Value *OldUser, Value *Param) {
-  Instruction *I = dyn_cast<Instruction>(OldUser);
-  assert(I && "OldUser must be an instruction");
+// For Loads, replaces the \p OldUse of the pointer with a Use of the same
+// pointer in parameter AS.
+// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
+// generic using cvta.param.
+static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
+  Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
+  assert(I && "OldUse must be in an instruction");
   struct IP {
+    Use *OldUse;
     Instruction *OldInstruction;
     Value *NewParam;
   };
-  SmallVector<IP> ItemsToConvert = {{I, Param}};
+  SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
   SmallVector<Instruction *> InstructionsToDelete;
 
-  auto CloneInstInParamAS = [](const IP &I) -> Value * {
+  auto CloneInstInParamAS = [GridConstant](const IP &I) -> Value * {
     if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
       LI->setOperand(0, I.NewParam);
       return LI;
@@ -202,6 +252,43 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
       // Just pass through the argument, the old ASC is no longer needed.
       return I.NewParam;
     }
+
+    if (GridConstant) {
+      auto GetParamAddrCastToGeneric =
+          [](Value *Addr, Instruction *OriginalUser) -> Value * {
+        PointerType *ReturnTy =
+            PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
+        Function *CvtToGen = Intrinsic::getDeclaration(
+            OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
+            {ReturnTy, PointerType::get(OriginalUser->getContext(),
+                                        ADDRESS_SPACE_PARAM)});
+
+        // Cast param address to generic address space
+        Value *CvtToGenCall =
+            CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
+                             OriginalUser->getIterator());
+        return CvtToGenCall;
+      };
+
+      if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
+        I.OldUse->set(GetParamAddrCastToGeneric(I.NewParam, CI));
+        return CI;
+      }
+      if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
+        // byval address is being stored, cast it to generic
+        if (SI->getValueOperand() == I.OldUse->get())
+          SI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, SI));
+        return SI;
+      }
+      if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
+        if (PI->getPointerOperand() == I.OldUse->get())
+          PI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, PI));
+        return PI;
+      }
+      llvm_unreachable(
+          "Instruction unsupported even for grid_constant argument");
+    }
+
     llvm_unreachable("Unsupported instruction");
   };
 
@@ -213,8 +300,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
       // We've created a new instruction. Queue users of the old instruction to
       // be converted and the instruction itself to be deleted. We can't delete
       // the old instruction yet, because it's still in use by a load somewhere.
-      for (Value *V : I.OldInstruction->users())
-        ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
+      for (Use &U : I.OldInstruction->uses())
+        ItemsToConvert.push_back({&U, cast<Instruction>(U.getUser()), NewInst});
 
       InstructionsToDelete.push_back(I.OldInstruction);
     }
@@ -272,6 +359,7 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
   SmallVector<Load> Loads;
   std::queue<LoadContext> Worklist;
   Worklist.push({ArgInParamAS, 0});
+  bool IsGridConstant = isParamGridConstant(*Arg);
 
   while (!Worklist.empty()) {
     LoadContext Ctx = Worklist.front();
@@ -303,8 +391,14 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
         continue;
       }
 
+      // supported for grid_constant
+      if (IsGridConstant &&
+          (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
+           isa<PtrToIntInst>(CurUser)))
+        continue;
+
       llvm_unreachable("All users must be one of: load, "
-                       "bitcast, getelementptr.");
+                       "bitcast, getelementptr, call, store, ptrtoint");
     }
   }
 
@@ -317,14 +411,15 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
 
 void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
                                       Argument *Arg) {
+  bool IsGridConstant = isParamGridConstant(*Arg);
   Function *Func = Arg->getParent();
   BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
   Type *StructType = Arg->getParamByValType();
   assert(StructType && "Missing byval type");
 
-  auto IsALoadChain = [&](Value *Start) {
+  auto AreSupportedUsers = [&](Value *Start) {
     SmallVector<Value *, 16> ValuesToCheck = {Start};
-    auto IsALoadChainInstr = [](Value *V) -> bool {
+    auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
       if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
         return true;
       // ASC to param space are OK, too -- we'll just strip them.
@@ -332,34 +427,43 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
         if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
           return true;
       }
+      // Simple calls and stores are supported for grid_constants
+      // writes to these pointers are undefined behaviour
+      if (IsGridConstant &&
+          (isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
+        return true;
       return false;
     };
 
     while (!ValuesToCheck.empty()) {
       Value *V = ValuesToCheck.pop_back_val();
-      if (!IsALoadChainInstr(V)) {
+      if (!IsSupportedUse(V)) {
         LLVM_DEBUG(dbgs() << "Need a "
                           << (isParamGridConstant(*Arg) ? "cast " : "copy ")
                           << "of " << *Arg << " because of " << *V << "\n");
         (void)Arg;
         return false;
       }
-      if (!isa<LoadInst>(V))
+      if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
+          !isa<PtrToIntInst>(V))
         llvm::append_range(ValuesToCheck, V->users());
     }
     return true;
   };
 
-  if (llvm::all_of(Arg->users(), IsALoadChain)) {
+  if (llvm::all_of(Arg->users(), AreSupportedUsers)) {
     // Convert all loads and intermediate operations to use parameter AS and
     // skip creation of a local copy of the argument.
-    SmallVector<User *, 16> UsersToUpdate(Arg->users());
+    SmallVector<Use *, 16> UsesToUpdate;
+    for (Use &U : Arg->uses())
+      UsesToUpdate.push_back(&U);
+
     Value *ArgInParamAS = new AddrSpaceCastInst(
         Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
         FirstInst);
-    for (Value *V : UsersToUpdate)
-      convertToParamAS(V, ArgInParamAS);
-    LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
+    for (Use *U : UsesToUpdate)
+      convertToParamAS(U, ArgInParamAS, IsGridConstant);
+    LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
 
     const auto *TLI =
         cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
@@ -376,16 +480,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
     // temporary copy. When a pointer might have escaped, conservatively replace
     // all of its uses (which might include a device function call) with a cast
     // to the generic address space.
-    // TODO: only cast byval grid constant parameters at use points that need
-    // generic address (e.g., merging parameter pointers with other address
-    // space, or escaping to call-sites, inline-asm, memory), and use the
-    // parameter address space for normal loads.
     IRBuilder<> IRB(&Func->getEntryBlock().front());
 
     // Cast argument to param address space
-    auto *CastToParam =
-        cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
-            Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
+    auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
+        Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
 
     // Cast param address to generic address space. We do not use an
     // addrspacecast to generic here, because, LLVM considers `Arg` to be in the
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index e4b2ec868519c..80361744fd5b6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -210,7 +210,8 @@ bool isParamGridConstant(const Value &V) {
   if (const Argument *Arg = dyn_cast<Argument>(&V)) {
     // "grid_constant" counts argument indices starting from 1
     if (Arg->hasByValAttr() &&
-        argHasNVVMAnnotation(*Arg, "grid_constant", /*StartArgIndexAtOne*/true)) {
+        argHasNVVMAnnotation(*Arg, "grid_constant",
+                             /*StartArgIndexAtOne*/ true)) {
       assert(isKernelFunction(*Arg->getParent()) &&
              "only kernel arguments can be grid_constant");
       return true;
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 46f54e0e6f4d4..f6db9c429dba5 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes OPT
 ; RUN: llc < %s -mcpu=sm_70 --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes PTX
 
@@ -67,22 +67,22 @@ define void @multiple_grid_const_escape(ptr byval(%struct.s) align 4 %input, i32
 ; PTX:         mov.{{.*}} [[RD1:%.*]], multiple_grid_const_escape_param_0;
 ; PTX:         mov.{{.*}} [[RD2:%.*]], multiple_grid_const_escape_param_2;
 ; PTX:         mov.{{.*}} [[RD3:%.*]], [[RD2]];
-; PTX:         cvta.param.{{.*}} [[RD4:%.*]], [[RD3]];
-; PTX:         mov.u64 [[RD5:%.*]], [[RD1]];
-; PTX:         cvta.param.{{.*}} [[RD6:%.*]], [[RD5]];
+; PTX:         mov.{{.*}} [[RD4:%.*]], [[RD1]];
+; PTX:         cvta.param.{{.*}} [[RD5:%.*]], [[RD4]];
+; PTX:         cvta.param.{{.*}} [[RD6:%.*]], [[RD3]];
 ; PTX:         {
-; PTX:         st.param.b64 [param0+0], [[RD6]];
-; PTX:         st.param.b64 [param2+0], [[RD4]];
+; PTX:         st.param.b64 [param0+0], [[RD5]];
+; PTX:         st.param.b64 [param2+0], [[RD6]];
 ;
 ; OPT-LABEL: define void @multiple_grid_const_escape(
 ; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], i32 [[A:%.*]], ptr byval(i32) align 4 [[B:%.*]]) {
-; OPT-NOT:     alloca i32
 ; OPT:         [[B_PARAM:%.*]] = addrspacecast ptr [[B]] to ptr addrspace(101)
-; OPT:         [[B_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[B_PARAM]])
-; OPT-NOT:     alloca [[STRUCT_S]]
 ; OPT:         [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT-NOT:     alloca %struct.s
+; OPT:         [[A_ADDR:%.*]] = alloca i32, align 4
 ; OPT:         [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
-; OPT:         [[CALL:%.*]] = call i32 @escape3(ptr [[INPUT_PARAM_GEN]], ptr {{.*}}, ptr [[B_PARAM_GEN]])
+; OPT:         [[B_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[B_PARAM]])
+; OPT-NEXT:    [[CALL:%.*]] = call i32 @escape3(ptr [[INPUT_PARAM_GEN]], ptr [[A_ADDR]], ptr [[B_PARAM_GEN]])
 ;
   %a.addr = alloca i32, align 4
   store i32 %a, ptr %a.addr, align 4
@@ -111,17 +111,19 @@ define void @grid_const_memory_escape(ptr byval(%struct.s) align 4 %input, ptr %
 define void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4 %input, ptr %result) {
 ; PTX-LABEL: grid_const_inlineasm_escape(
 ; PTX-NOT      .local 
-; PTX:         cvta.param.u64 [[RD2:%.*]], {{.*}}
-; PTX:         add.{{.*}} [[RD3:%.*]], [[RD2]], 4;
-; PTX:         add.s64 [[RD1:%.*]], [[RD2]], [[RD3]];
+; PTX:         add.{{.*}} [[RD2:%.*]], [[RD1:%.*]], 4;
+; PTX:         cvta.param.u64 [[RD4:%.*]], [[RD2]]
+; PTX:         cvta.param.u64 [[RD3:%.*]], [[RD1]]
+; PTX:         add.s64 [[RD5:%.*]], [[RD3]], [[RD4]];
 ;
 ; OPT-LABEL: define void @grid_const_inlineasm_escape(
 ; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[RESULT:%.*]]) {
 ; OPT-NOT:     alloca [[STRUCT_S]]
 ; OPT:         [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
-; OPT:         [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
-; OPT:         [[TMP:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT_PARAM_GEN]], i32 0, i32 0
-; OPT:         [[TMP1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT_PARAM_GEN]], i32 0, i32 1
+; OPT:         [[TMPPTR13:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT_PARAM]], i32 0, i32 0
+; OPT:         [[TMPPTR22:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT_PARAM]], i32 0, i32 1
+; OPT:         [[TMPPTR22_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[TMPPTR22]])
+; OPT:         [[TMPPTR13_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[TMPPTR13]])
 ; OPT:         [[TMP2:%.*]] = call i64 asm "add.s64 $0, $1, $2
 ;
   %tmpptr1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
@@ -131,10 +133,200 @@ define void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4 %input, pt
   ret void
 }
 
+define void @grid_const_partial_escape(ptr byval(i32) %input, ptr %output) {
+; PTX-LABEL: grid_const_partial_escape(
+; PTX-NOT:     .local
+; PTX:         ld.param.{{.*}} [[R1:%.*]], [grid_const_partial_escape_param_0];
+; PTX:         add.{{.*}}
+; PTX:         cvta.param.u64 [[RD3:%.*]], {{%.*}}
+; PTX:         st.param.{{.*}} [param0+0], [[RD3]]
+; PTX:         call
+;
+; OPT-LABEL: define void @grid_const_partial_escape(
+; OPT-SAME: ptr byval(i32) align 4 [[INPUT:%.*]], ptr {{%.*}}) {
+; OPT-NOT:     alloca
+; OPT:         [[INPUT1:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[VAL:%.*]] = load i32, ptr addrspace(101) [[INPUT1]], align 4
+; OPT:         [[TWICE:%.*]] = add i32 [[VAL]], [[VAL]]
+; OPT:         store i32 [[TWICE]]
+; OPT:         [[INPUT1_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT1]])
+; OPT:         [[CALL:%.*]] = call i32 @escape(ptr [[INPUT1_GEN]])
+; OPT:         ret void
+;
+  %val = load i32, ptr %input
+  %twice = add i32 %val, %val
+  store i32 %twice, ptr %output
+  %call = call i32 @escape(ptr %input)
+  ret void
+}
+
+define i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input, ptr %output) {
+; PTX-LABEL: grid_const_partial_escapemem(
+; PTX:       {
+; PTX:         ld.param.{{.*}} [[R1:%.*]], [grid_const_partial_escapemem_param_0];
+; PTX:         ld.param.{{.*}} [[R2:%.*]], [grid_const_partial_escapemem_param_0+4];
+; PTX:         cvta.param.{{.*}} [[RD5:%.*]], {{%.*}};
+; PTX:         st.global.{{.*}} [{{.*}}], [[RD5]];
+; PTX:         add.s32 [[R3:%.*]], [[R1]], [[R2]] 
+; PTX:         st.param.{{.*}} [param0+0], [[RD5]]
+; PTX:         escape
+; OPT-LABEL: define i32 @grid_const_partial_escapemem(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr {{%.*}}) {
+; OPT-NO...
[truncated]

@akshayrdeodhar
Copy link
Contributor Author

CC: @apaszke

@@ -210,7 +210,8 @@ bool isParamGridConstant(const Value &V) {
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
// "grid_constant" counts argument indices starting from 1
if (Arg->hasByValAttr() &&
argHasNVVMAnnotation(*Arg, "grid_constant", /*StartArgIndexAtOne*/true)) {
argHasNVVMAnnotation(*Arg, "grid_constant",
/*StartArgIndexAtOne*/ true)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: undo this nop change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-format 😅 (was going above 80 chars in the previous MR)

@akshayrdeodhar akshayrdeodhar merged commit 9fa7c05 into llvm:main Jun 30, 2024
9 checks passed
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
- Supports escaped grid_constant pointers less conservatively. Casts
uses inside Calls, PtrToInts, Stores where the pointer is a _value
operand_ to generic address space, immediately before the escape, while
keeping other uses in the param address space

- Related to: llvm#96125
@Artem-B
Copy link
Member

Artem-B commented Oct 11, 2024

@akshayrdeodhar , @jholewinski -- is cvta.param intended to be used on any .param data? Can it be used in __device__ functions if we know we never write to the pointer? Or is it intended to be used on kernels only?

NVCC appears to use it for the kernels only, but I can't tell if that's an inherent restriction on the instruction use, or if it just happens that NVCC enables it for parameters with __grid_constant__ argument only, which can be used only by the kernels.

PTX specs do not seem to specify it explicitly. They do link to "kernel function parameters" when they mention .param, but other than that it reads as if the instruction is applicable to all .param space data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants