Skip to content

[NVPTX] Improved support for grid_constant #97112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 136 additions & 37 deletions llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
//
// Kernel parameters are read-only and accessible only via ld.param
// instruction, directly or via a pointer. Pointers to kernel
// arguments can't be converted to generic address space.
// instruction, directly or via a pointer.
//
// Device function parameters are directly accessible via
// ld.param/st.param, but taking the address of one returns a pointer
Expand Down Expand Up @@ -54,8 +53,10 @@
// ...
// }
//
// 2. Convert pointers in a byval kernel parameter to pointers in the global
// address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
// 2. Convert byval kernel parameters to pointers in the param address space
// (so that NVPTX emits ld/st.param). Convert pointers *within* a byval
// kernel parameter to pointers in the global address space. This allows
// NVPTX to emit ld/st.global.
//
// struct S {
// int *x;
Expand All @@ -68,22 +69,68 @@
//
// "b" points to the global address space. In the IR level,
//
// define void @foo({i32*, i32*}* byval %input) {
// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
// %b = load i32*, i32** %b_ptr
// define void @foo(ptr byval %input) {
// %b_ptr = getelementptr {ptr, ptr}, ptr %input, i64 0, i32 1
// %b = load ptr, ptr %b_ptr
// ; use %b
// }
//
// becomes
//
// define void @foo({i32*, i32*}* byval %input) {
// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
// %b = load i32*, i32** %b_ptr
// %b_global = addrspacecast i32* %b to i32 addrspace(1)*
// %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
// %b_param = addrspacecat ptr %input to ptr addrspace(101)
// %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
// %b = load ptr, ptr addrspace(101) %b_ptr
// %b_global = addrspacecast ptr %b to ptr addrspace(1)
// ; use %b_generic
// }
//
// Create a local copy of kernel byval parameters used in a way that *might* mutate
// the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
// are undefined behaviour, and don't require local copies.
//
// define void @foo(ptr byval(%struct.s) align 4 %input) {
// store i32 42, ptr %input
// ret void
// }
//
// becomes
//
// define void @foo(ptr byval(%struct.s) align 4 %input) #1 {
// %input1 = alloca %struct.s, align 4
// %input2 = addrspacecast ptr %input to ptr addrspace(101)
// %input3 = load %struct.s, ptr addrspace(101) %input2, align 4
// store %struct.s %input3, ptr %input1, align 4
// store i32 42, ptr %input1, align 4
// ret void
// }
//
// If %input were passed to a device function, or written to memory,
// conservatively assume that %input gets mutated, and create a local copy.
//
// Convert param pointers to grid_constant byval kernel parameters that are
// passed into calls (device functions, intrinsics, inline asm), or otherwise
// "escape" (into stores/ptrtoints) to the generic address space, using the
// `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param
// (available for sm70+)
//
// define void @foo(ptr byval(%struct.s) %input) {
// ; %input is a grid_constant
// %call = call i32 @escape(ptr %input)
// ret void
// }
//
// becomes
//
// define void @foo(ptr byval(%struct.s) %input) {
// %input1 = addrspacecast ptr %input to ptr addrspace(101)
// ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
// ; to prevent generic -> param -> generic from getting cancelled out
// %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
// %call = call i32 @escape(ptr %input1.gen)
// ret void
// }
//
// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
// cancel the addrspacecast pair this pass emits.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -166,19 +213,22 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
// ones in parameter AS, so we can access them using ld.param.
// =============================================================================

// Replaces the \p OldUser instruction with the same in parameter AS.
// Only Load and GEP are supported.
static void convertToParamAS(Value *OldUser, Value *Param) {
Instruction *I = dyn_cast<Instruction>(OldUser);
assert(I && "OldUser must be an instruction");
// For Loads, replaces the \p OldUse of the pointer with a Use of the same
// pointer in parameter AS.
// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
// generic using cvta.param.
static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
assert(I && "OldUse must be in an instruction");
struct IP {
Use *OldUse;
Instruction *OldInstruction;
Value *NewParam;
};
SmallVector<IP> ItemsToConvert = {{I, Param}};
SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
SmallVector<Instruction *> InstructionsToDelete;

auto CloneInstInParamAS = [](const IP &I) -> Value * {
auto CloneInstInParamAS = [GridConstant](const IP &I) -> Value * {
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
LI->setOperand(0, I.NewParam);
return LI;
Expand All @@ -202,6 +252,43 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
// Just pass through the argument, the old ASC is no longer needed.
return I.NewParam;
}

if (GridConstant) {
auto GetParamAddrCastToGeneric =
[](Value *Addr, Instruction *OriginalUser) -> Value * {
PointerType *ReturnTy =
PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
Function *CvtToGen = Intrinsic::getDeclaration(
OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
{ReturnTy, PointerType::get(OriginalUser->getContext(),
ADDRESS_SPACE_PARAM)});

// Cast param address to generic address space
Value *CvtToGenCall =
CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
OriginalUser->getIterator());
return CvtToGenCall;
};

if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
I.OldUse->set(GetParamAddrCastToGeneric(I.NewParam, CI));
return CI;
}
if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
// byval address is being stored, cast it to generic
if (SI->getValueOperand() == I.OldUse->get())
SI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, SI));
return SI;
}
if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
if (PI->getPointerOperand() == I.OldUse->get())
PI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, PI));
return PI;
}
llvm_unreachable(
"Instruction unsupported even for grid_constant argument");
}

llvm_unreachable("Unsupported instruction");
};

Expand All @@ -213,8 +300,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
// We've created a new instruction. Queue users of the old instruction to
// be converted and the instruction itself to be deleted. We can't delete
// the old instruction yet, because it's still in use by a load somewhere.
for (Value *V : I.OldInstruction->users())
ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
for (Use &U : I.OldInstruction->uses())
ItemsToConvert.push_back({&U, cast<Instruction>(U.getUser()), NewInst});

InstructionsToDelete.push_back(I.OldInstruction);
}
Expand Down Expand Up @@ -272,6 +359,7 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
SmallVector<Load> Loads;
std::queue<LoadContext> Worklist;
Worklist.push({ArgInParamAS, 0});
bool IsGridConstant = isParamGridConstant(*Arg);

while (!Worklist.empty()) {
LoadContext Ctx = Worklist.front();
Expand Down Expand Up @@ -303,8 +391,14 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
continue;
}

// supported for grid_constant
if (IsGridConstant &&
(isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
isa<PtrToIntInst>(CurUser)))
continue;

llvm_unreachable("All users must be one of: load, "
"bitcast, getelementptr.");
"bitcast, getelementptr, call, store, ptrtoint");
}
}

Expand All @@ -317,49 +411,59 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,

void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
Argument *Arg) {
bool IsGridConstant = isParamGridConstant(*Arg);
Function *Func = Arg->getParent();
BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
Type *StructType = Arg->getParamByValType();
assert(StructType && "Missing byval type");

auto IsALoadChain = [&](Value *Start) {
auto AreSupportedUsers = [&](Value *Start) {
SmallVector<Value *, 16> ValuesToCheck = {Start};
auto IsALoadChainInstr = [](Value *V) -> bool {
auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
return true;
// ASC to param space are OK, too -- we'll just strip them.
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
return true;
}
// Simple calls and stores are supported for grid_constants
// writes to these pointers are undefined behaviour
if (IsGridConstant &&
(isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
return true;
return false;
};

while (!ValuesToCheck.empty()) {
Value *V = ValuesToCheck.pop_back_val();
if (!IsALoadChainInstr(V)) {
if (!IsSupportedUse(V)) {
LLVM_DEBUG(dbgs() << "Need a "
<< (isParamGridConstant(*Arg) ? "cast " : "copy ")
<< "of " << *Arg << " because of " << *V << "\n");
(void)Arg;
return false;
}
if (!isa<LoadInst>(V))
if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
!isa<PtrToIntInst>(V))
llvm::append_range(ValuesToCheck, V->users());
}
return true;
};

if (llvm::all_of(Arg->users(), IsALoadChain)) {
if (llvm::all_of(Arg->users(), AreSupportedUsers)) {
// Convert all loads and intermediate operations to use parameter AS and
// skip creation of a local copy of the argument.
SmallVector<User *, 16> UsersToUpdate(Arg->users());
SmallVector<Use *, 16> UsesToUpdate;
for (Use &U : Arg->uses())
UsesToUpdate.push_back(&U);

Value *ArgInParamAS = new AddrSpaceCastInst(
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
FirstInst);
for (Value *V : UsersToUpdate)
convertToParamAS(V, ArgInParamAS);
LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
for (Use *U : UsesToUpdate)
convertToParamAS(U, ArgInParamAS, IsGridConstant);
LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");

const auto *TLI =
cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
Expand All @@ -376,16 +480,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
// temporary copy. When a pointer might have escaped, conservatively replace
// all of its uses (which might include a device function call) with a cast
// to the generic address space.
// TODO: only cast byval grid constant parameters at use points that need
// generic address (e.g., merging parameter pointers with other address
// space, or escaping to call-sites, inline-asm, memory), and use the
// parameter address space for normal loads.
IRBuilder<> IRB(&Func->getEntryBlock().front());

// Cast argument to param address space
auto *CastToParam =
cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));

// Cast param address to generic address space. We do not use an
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ bool isParamGridConstant(const Value &V) {
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
// "grid_constant" counts argument indices starting from 1
if (Arg->hasByValAttr() &&
argHasNVVMAnnotation(*Arg, "grid_constant", /*StartArgIndexAtOne*/true)) {
argHasNVVMAnnotation(*Arg, "grid_constant",
/*StartArgIndexAtOne*/ true)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: undo this nop change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-format 😅 (was going above 80 chars in the previous MR)

assert(isKernelFunction(*Arg->getParent()) &&
"only kernel arguments can be grid_constant");
return true;
Expand Down
Loading
Loading