Skip to content

Commit 9fa7c05

Browse files
[NVPTX] Improved support for grid_constant (#97112)
- Supports escaped grid_constant pointers less conservatively. Casts uses inside Calls, PtrToInts, Stores where the pointer is a _value operand_ to generic address space, immediately before the escape, while keeping other uses in the param address space - Related to: #96125
1 parent 56a636f commit 9fa7c05

File tree

3 files changed

+365
-55
lines changed

3 files changed

+365
-55
lines changed

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 136 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
1313
//
1414
// Kernel parameters are read-only and accessible only via ld.param
15-
// instruction, directly or via a pointer. Pointers to kernel
16-
// arguments can't be converted to generic address space.
15+
// instruction, directly or via a pointer.
1716
//
1817
// Device function parameters are directly accessible via
1918
// ld.param/st.param, but taking the address of one returns a pointer
@@ -54,8 +53,10 @@
5453
// ...
5554
// }
5655
//
57-
// 2. Convert pointers in a byval kernel parameter to pointers in the global
58-
// address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
56+
// 2. Convert byval kernel parameters to pointers in the param address space
57+
// (so that NVPTX emits ld/st.param). Convert pointers *within* a byval
58+
// kernel parameter to pointers in the global address space. This allows
59+
// NVPTX to emit ld/st.global.
5960
//
6061
// struct S {
6162
// int *x;
@@ -68,22 +69,68 @@
6869
//
6970
// "b" points to the global address space. In the IR level,
7071
//
71-
// define void @foo({i32*, i32*}* byval %input) {
72-
// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
73-
// %b = load i32*, i32** %b_ptr
72+
// define void @foo(ptr byval %input) {
73+
// %b_ptr = getelementptr {ptr, ptr}, ptr %input, i64 0, i32 1
74+
// %b = load ptr, ptr %b_ptr
7475
// ; use %b
7576
// }
7677
//
7778
// becomes
7879
//
7980
// define void @foo({i32*, i32*}* byval %input) {
80-
// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
81-
// %b = load i32*, i32** %b_ptr
82-
// %b_global = addrspacecast i32* %b to i32 addrspace(1)*
83-
// %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
81+
// %b_param = addrspacecat ptr %input to ptr addrspace(101)
82+
// %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
83+
// %b = load ptr, ptr addrspace(101) %b_ptr
84+
// %b_global = addrspacecast ptr %b to ptr addrspace(1)
8485
// ; use %b_generic
8586
// }
8687
//
88+
// Create a local copy of kernel byval parameters used in a way that *might* mutate
89+
// the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
90+
// are undefined behaviour, and don't require local copies.
91+
//
92+
// define void @foo(ptr byval(%struct.s) align 4 %input) {
93+
// store i32 42, ptr %input
94+
// ret void
95+
// }
96+
//
97+
// becomes
98+
//
99+
// define void @foo(ptr byval(%struct.s) align 4 %input) #1 {
100+
// %input1 = alloca %struct.s, align 4
101+
// %input2 = addrspacecast ptr %input to ptr addrspace(101)
102+
// %input3 = load %struct.s, ptr addrspace(101) %input2, align 4
103+
// store %struct.s %input3, ptr %input1, align 4
104+
// store i32 42, ptr %input1, align 4
105+
// ret void
106+
// }
107+
//
108+
// If %input were passed to a device function, or written to memory,
109+
// conservatively assume that %input gets mutated, and create a local copy.
110+
//
111+
// Convert param pointers to grid_constant byval kernel parameters that are
112+
// passed into calls (device functions, intrinsics, inline asm), or otherwise
113+
// "escape" (into stores/ptrtoints) to the generic address space, using the
114+
// `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param
115+
// (available for sm70+)
116+
//
117+
// define void @foo(ptr byval(%struct.s) %input) {
118+
// ; %input is a grid_constant
119+
// %call = call i32 @escape(ptr %input)
120+
// ret void
121+
// }
122+
//
123+
// becomes
124+
//
125+
// define void @foo(ptr byval(%struct.s) %input) {
126+
// %input1 = addrspacecast ptr %input to ptr addrspace(101)
127+
// ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
128+
// ; to prevent generic -> param -> generic from getting cancelled out
129+
// %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
130+
// %call = call i32 @escape(ptr %input1.gen)
131+
// ret void
132+
// }
133+
//
87134
// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
88135
// cancel the addrspacecast pair this pass emits.
89136
//===----------------------------------------------------------------------===//
@@ -166,19 +213,22 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
166213
// ones in parameter AS, so we can access them using ld.param.
167214
// =============================================================================
168215

169-
// Replaces the \p OldUser instruction with the same in parameter AS.
170-
// Only Load and GEP are supported.
171-
static void convertToParamAS(Value *OldUser, Value *Param) {
172-
Instruction *I = dyn_cast<Instruction>(OldUser);
173-
assert(I && "OldUser must be an instruction");
216+
// For Loads, replaces the \p OldUse of the pointer with a Use of the same
217+
// pointer in parameter AS.
218+
// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
219+
// generic using cvta.param.
220+
static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
221+
Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
222+
assert(I && "OldUse must be in an instruction");
174223
struct IP {
224+
Use *OldUse;
175225
Instruction *OldInstruction;
176226
Value *NewParam;
177227
};
178-
SmallVector<IP> ItemsToConvert = {{I, Param}};
228+
SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
179229
SmallVector<Instruction *> InstructionsToDelete;
180230

181-
auto CloneInstInParamAS = [](const IP &I) -> Value * {
231+
auto CloneInstInParamAS = [GridConstant](const IP &I) -> Value * {
182232
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
183233
LI->setOperand(0, I.NewParam);
184234
return LI;
@@ -202,6 +252,43 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
202252
// Just pass through the argument, the old ASC is no longer needed.
203253
return I.NewParam;
204254
}
255+
256+
if (GridConstant) {
257+
auto GetParamAddrCastToGeneric =
258+
[](Value *Addr, Instruction *OriginalUser) -> Value * {
259+
PointerType *ReturnTy =
260+
PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
261+
Function *CvtToGen = Intrinsic::getDeclaration(
262+
OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
263+
{ReturnTy, PointerType::get(OriginalUser->getContext(),
264+
ADDRESS_SPACE_PARAM)});
265+
266+
// Cast param address to generic address space
267+
Value *CvtToGenCall =
268+
CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
269+
OriginalUser->getIterator());
270+
return CvtToGenCall;
271+
};
272+
273+
if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
274+
I.OldUse->set(GetParamAddrCastToGeneric(I.NewParam, CI));
275+
return CI;
276+
}
277+
if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
278+
// byval address is being stored, cast it to generic
279+
if (SI->getValueOperand() == I.OldUse->get())
280+
SI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, SI));
281+
return SI;
282+
}
283+
if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
284+
if (PI->getPointerOperand() == I.OldUse->get())
285+
PI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, PI));
286+
return PI;
287+
}
288+
llvm_unreachable(
289+
"Instruction unsupported even for grid_constant argument");
290+
}
291+
205292
llvm_unreachable("Unsupported instruction");
206293
};
207294

@@ -213,8 +300,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
213300
// We've created a new instruction. Queue users of the old instruction to
214301
// be converted and the instruction itself to be deleted. We can't delete
215302
// the old instruction yet, because it's still in use by a load somewhere.
216-
for (Value *V : I.OldInstruction->users())
217-
ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
303+
for (Use &U : I.OldInstruction->uses())
304+
ItemsToConvert.push_back({&U, cast<Instruction>(U.getUser()), NewInst});
218305

219306
InstructionsToDelete.push_back(I.OldInstruction);
220307
}
@@ -272,6 +359,7 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
272359
SmallVector<Load> Loads;
273360
std::queue<LoadContext> Worklist;
274361
Worklist.push({ArgInParamAS, 0});
362+
bool IsGridConstant = isParamGridConstant(*Arg);
275363

276364
while (!Worklist.empty()) {
277365
LoadContext Ctx = Worklist.front();
@@ -303,8 +391,14 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
303391
continue;
304392
}
305393

394+
// supported for grid_constant
395+
if (IsGridConstant &&
396+
(isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
397+
isa<PtrToIntInst>(CurUser)))
398+
continue;
399+
306400
llvm_unreachable("All users must be one of: load, "
307-
"bitcast, getelementptr.");
401+
"bitcast, getelementptr, call, store, ptrtoint");
308402
}
309403
}
310404

@@ -317,49 +411,59 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
317411

318412
void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
319413
Argument *Arg) {
414+
bool IsGridConstant = isParamGridConstant(*Arg);
320415
Function *Func = Arg->getParent();
321416
BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
322417
Type *StructType = Arg->getParamByValType();
323418
assert(StructType && "Missing byval type");
324419

325-
auto IsALoadChain = [&](Value *Start) {
420+
auto AreSupportedUsers = [&](Value *Start) {
326421
SmallVector<Value *, 16> ValuesToCheck = {Start};
327-
auto IsALoadChainInstr = [](Value *V) -> bool {
422+
auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
328423
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
329424
return true;
330425
// ASC to param space are OK, too -- we'll just strip them.
331426
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
332427
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
333428
return true;
334429
}
430+
// Simple calls and stores are supported for grid_constants
431+
// writes to these pointers are undefined behaviour
432+
if (IsGridConstant &&
433+
(isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
434+
return true;
335435
return false;
336436
};
337437

338438
while (!ValuesToCheck.empty()) {
339439
Value *V = ValuesToCheck.pop_back_val();
340-
if (!IsALoadChainInstr(V)) {
440+
if (!IsSupportedUse(V)) {
341441
LLVM_DEBUG(dbgs() << "Need a "
342442
<< (isParamGridConstant(*Arg) ? "cast " : "copy ")
343443
<< "of " << *Arg << " because of " << *V << "\n");
344444
(void)Arg;
345445
return false;
346446
}
347-
if (!isa<LoadInst>(V))
447+
if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
448+
!isa<PtrToIntInst>(V))
348449
llvm::append_range(ValuesToCheck, V->users());
349450
}
350451
return true;
351452
};
352453

353-
if (llvm::all_of(Arg->users(), IsALoadChain)) {
454+
if (llvm::all_of(Arg->users(), AreSupportedUsers)) {
354455
// Convert all loads and intermediate operations to use parameter AS and
355456
// skip creation of a local copy of the argument.
356-
SmallVector<User *, 16> UsersToUpdate(Arg->users());
457+
SmallVector<Use *, 16> UsesToUpdate;
458+
for (Use &U : Arg->uses())
459+
UsesToUpdate.push_back(&U);
460+
357461
Value *ArgInParamAS = new AddrSpaceCastInst(
358462
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
359463
FirstInst);
360-
for (Value *V : UsersToUpdate)
361-
convertToParamAS(V, ArgInParamAS);
362-
LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
464+
for (Use *U : UsesToUpdate)
465+
convertToParamAS(U, ArgInParamAS, IsGridConstant);
466+
LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
363467

364468
const auto *TLI =
365469
cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
@@ -376,16 +480,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
376480
// temporary copy. When a pointer might have escaped, conservatively replace
377481
// all of its uses (which might include a device function call) with a cast
378482
// to the generic address space.
379-
// TODO: only cast byval grid constant parameters at use points that need
380-
// generic address (e.g., merging parameter pointers with other address
381-
// space, or escaping to call-sites, inline-asm, memory), and use the
382-
// parameter address space for normal loads.
383483
IRBuilder<> IRB(&Func->getEntryBlock().front());
384484

385485
// Cast argument to param address space
386-
auto *CastToParam =
387-
cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
388-
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
486+
auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
487+
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
389488

390489
// Cast param address to generic address space. We do not use an
391490
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ bool isParamGridConstant(const Value &V) {
210210
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
211211
// "grid_constant" counts argument indices starting from 1
212212
if (Arg->hasByValAttr() &&
213-
argHasNVVMAnnotation(*Arg, "grid_constant", /*StartArgIndexAtOne*/true)) {
213+
argHasNVVMAnnotation(*Arg, "grid_constant",
214+
/*StartArgIndexAtOne*/ true)) {
214215
assert(isKernelFunction(*Arg->getParent()) &&
215216
"only kernel arguments can be grid_constant");
216217
return true;

0 commit comments

Comments
 (0)