Skip to content

[NVPTX] Improve kernel byval parameter lowering #136008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1909,6 +1909,22 @@ def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
"llvm.nvvm.ptr.param.to.gen">;

// Represents an explicit hole in the LLVM IR type system. It may be inserted by
// the compiler in cases where a pointer is of the wrong type. In the backend
// this intrinsic will be folded away and not equate to any instruction. It
// should not be used by any frontend and should only be considered well defined
// when added in the following cases:
//
// - NVPTXLowerArgs: When wrapping a byval pointer argument to a kernel
// function to convert the address space from generic (0) to param (101).
// This accounts for the fact that the parameter symbols will occupy this
// space when lowered during ISel.
//
def int_nvvm_internal_addrspace_wrap :
DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty],
[IntrNoMem, IntrSpeculatable, NoUndef<ArgIndex<0>>,
NoUndef<RetIndex>]>;

// Move intrinsics, used in nvvm internally

def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,9 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
case ADDRESS_SPACE_LOCAL:
Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local;
break;
case ADDRESS_SPACE_PARAM:
Opc = TM.is64Bit() ? NVPTX::cvta_param_64 : NVPTX::cvta_param;
break;
}
ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src));
return;
Expand All @@ -1008,7 +1011,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local;
break;
case ADDRESS_SPACE_PARAM:
Opc = TM.is64Bit() ? NVPTX::IMOV64r : NVPTX::IMOV32r;
Opc = TM.is64Bit() ? NVPTX::cvta_to_param_64 : NVPTX::cvta_to_param;
break;
}

Expand Down
24 changes: 24 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
{MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
MVT::v32i32, MVT::v64i32, MVT::v128i32},
Custom);

setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
}

const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
Expand Down Expand Up @@ -1426,6 +1428,17 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,

return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
}

// Peel of an addrspacecast to generic and load directly from the specific
// address space.
if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
const auto *ASC = cast<AddrSpaceCastSDNode>(Ptr);
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
Ptr = ASC->getOperand(0);
return MachinePointerInfo(ASC->getSrcAddressSpace());
}
}

return MachinePointerInfo();
}

Expand Down Expand Up @@ -2746,6 +2759,15 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
return Op;
}

static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
switch (Op->getConstantOperandVal(0)) {
default:
return Op;
case Intrinsic::nvvm_internal_addrspace_wrap:
return Op.getOperand(1);
}
}

// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
// Lower these into a node returning the correct type which is zero-extended
// back to the correct size.
Expand Down Expand Up @@ -2889,6 +2911,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerGlobalAddress(Op, DAG);
case ISD::INTRINSIC_W_CHAIN:
return Op;
case ISD::INTRINSIC_WO_CHAIN:
return lowerIntrinsicWOChain(Op, DAG);
case ISD::INTRINSIC_VOID:
return LowerIntrinsicVoid(Op, DAG);
case ISD::BUILD_VECTOR:
Expand Down
16 changes: 4 additions & 12 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -2395,18 +2395,10 @@ multiclass G_TO_NG<string Str> {
"cvta.to." # Str # ".u64 \t$result, $src;", []>;
}

defm cvta_local : NG_TO_G<"local">;
defm cvta_shared : NG_TO_G<"shared">;
defm cvta_global : NG_TO_G<"global">;
defm cvta_const : NG_TO_G<"const">;

defm cvta_to_local : G_TO_NG<"local">;
defm cvta_to_shared : G_TO_NG<"shared">;
defm cvta_to_global : G_TO_NG<"global">;
defm cvta_to_const : G_TO_NG<"const">;

// nvvm.ptr.param.to.gen
defm cvta_param : NG_TO_G<"param">;
foreach space = ["local", "shared", "global", "const", "param"] in {
defm cvta_#space : NG_TO_G<space>;
defm cvta_to_#space : G_TO_NG<space>;
}

def : Pat<(int_nvvm_ptr_param_to_gen i32:$src),
(cvta_param $src)>;
Expand Down
89 changes: 43 additions & 46 deletions llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,18 +265,9 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
if (HasCvtaParam) {
auto GetParamAddrCastToGeneric =
[](Value *Addr, Instruction *OriginalUser) -> Value * {
PointerType *ReturnTy =
PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
Function *CvtToGen = Intrinsic::getOrInsertDeclaration(
OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
{ReturnTy, PointerType::get(OriginalUser->getContext(),
ADDRESS_SPACE_PARAM)});

// Cast param address to generic address space
Value *CvtToGenCall =
CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
OriginalUser->getIterator());
return CvtToGenCall;
IRBuilder<> IRB(OriginalUser);
Type *GenTy = IRB.getPtrTy(ADDRESS_SPACE_GENERIC);
return IRB.CreateAddrSpaceCast(Addr, GenTy, Addr->getName() + ".gen");
};
auto *ParamInGenericAS =
GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);
Expand Down Expand Up @@ -515,33 +506,34 @@ void copyByValParam(Function &F, Argument &Arg) {
BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
Type *StructType = Arg.getParamByValType();
const DataLayout &DL = F.getDataLayout();
AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
Arg.getName(), FirstInst);
IRBuilder<> IRB(&*FirstInst);
AllocaInst *AllocA = IRB.CreateAlloca(StructType, nullptr, Arg.getName());
// Set the alignment to alignment of the byval parameter. This is because,
// later load/stores assume that alignment, and we are going to replace
// the use of the byval parameter with this alloca instruction.
AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
.value_or(DL.getPrefTypeAlign(StructType)));
AllocA->setAlignment(
Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
Arg.replaceAllUsesWith(AllocA);

Value *ArgInParam = new AddrSpaceCastInst(
&Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
Arg.getName(), FirstInst);
Value *ArgInParam =
IRB.CreateIntrinsic(Intrinsic::nvvm_internal_addrspace_wrap,
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg.getType()},
&Arg, {}, Arg.getName());

// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
// addrspacecast preserves alignment. Since params are constant, this load
// is definitely not volatile.
const auto ArgSize = *AllocA->getAllocationSize(DL);
IRBuilder<> IRB(&*FirstInst);
IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
ArgSize);
}
} // namespace

static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
Function *Func = Arg->getParent();
bool HasCvtaParam =
TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func);
bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
assert(isKernelFunction(*Func));
const bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
const bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
const DataLayout &DL = Func->getDataLayout();
BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
Type *StructType = Arg->getParamByValType();
Expand All @@ -558,9 +550,11 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
for (Use &U : Arg->uses())
UsesToUpdate.push_back(&U);

Value *ArgInParamAS = new AddrSpaceCastInst(
Arg, PointerType::get(StructType->getContext(), ADDRESS_SPACE_PARAM),
Arg->getName(), FirstInst);
IRBuilder<> IRB(&*FirstInst);
Value *ArgInParamAS = IRB.CreateIntrinsic(
Intrinsic::nvvm_internal_addrspace_wrap,
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()}, {Arg});

for (Use *U : UsesToUpdate)
convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
Expand All @@ -578,30 +572,31 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
// However, we're still not allowed to write to it. If the user specified
// `__grid_constant__` for the argument, we'll consider escaped pointer as
// read-only.
if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
// Replace all argument pointer uses (which might include a device function
// call) with a cast to the generic address space using cvta.param
// instruction, which avoids a local copy.
IRBuilder<> IRB(&Func->getEntryBlock().front());

// Cast argument to param address space
auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
// Cast argument to param address space. Because the backend will emit the
// argument already in the param address space, we need to use the noop
// intrinsic, this had the added benefit of preventing other optimizations
// from folding away this pair of addrspacecasts.
auto *ParamSpaceArg =
IRB.CreateIntrinsic(Intrinsic::nvvm_internal_addrspace_wrap,
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()},
Arg, {}, Arg->getName() + ".param");

// Cast param address to generic address space. We do not use an
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
// generic address space, and a `generic -> param` cast followed by a `param
// -> generic` cast will be folded away. The `param -> generic` intrinsic
// will be correctly lowered to `cvta.param`.
Value *CvtToGenCall = IRB.CreateIntrinsic(
IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
CastToParam, nullptr, CastToParam->getName() + ".gen");
// Cast param address to generic address space.
Value *GenericArg = IRB.CreateAddrSpaceCast(
ParamSpaceArg, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
Arg->getName() + ".gen");

Arg->replaceAllUsesWith(CvtToGenCall);
Arg->replaceAllUsesWith(GenericArg);

// Do not replace Arg in the cast to param space
CastToParam->setOperand(0, Arg);
ParamSpaceArg->setOperand(0, Arg);
} else
copyByValParam(*Func, *Arg);
}
Expand Down Expand Up @@ -715,12 +710,14 @@ static bool copyFunctionByValArgs(Function &F) {
LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
<< "\n");
bool Changed = false;
for (Argument &Arg : F.args())
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
!(isParamGridConstant(Arg) && isKernelFunction(F))) {
copyByValParam(F, Arg);
Changed = true;
}
if (isKernelFunction(F)) {
for (Argument &Arg : F.args())
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
!isParamGridConstant(Arg)) {
copyByValParam(F, Arg);
Changed = true;
}
}
return Changed;
}

Expand Down
33 changes: 24 additions & 9 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/ModRef.h"
#include "llvm/Support/Mutex.h"
#include <cstdint>
#include <cstring>
Expand Down Expand Up @@ -228,17 +230,30 @@ static std::optional<uint64_t> getVectorProduct(ArrayRef<unsigned> V) {
return std::accumulate(V.begin(), V.end(), 1, std::multiplies<uint64_t>{});
}

bool isParamGridConstant(const Value &V) {
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
// "grid_constant" counts argument indices starting from 1
if (Arg->hasByValAttr() &&
argHasNVVMAnnotation(*Arg, "grid_constant",
/*StartArgIndexAtOne*/ true)) {
assert(isKernelFunction(*Arg->getParent()) &&
"only kernel arguments can be grid_constant");
bool isParamGridConstant(const Argument &Arg) {
assert(isKernelFunction(*Arg.getParent()) &&
"only kernel arguments can be grid_constant");

if (!Arg.hasByValAttr())
return false;

// Lowering an argument as a grid_constant violates the byval semantics (and
// the C++ API) by reusing the same memory location for the argument across
// multiple threads. If an argument doesn't read memory and its address is not
// captured (its address is not compared with any value), then the tweak of
// the C++ API and byval semantics is unobservable by the program and we can
// lower the arg as a grid_constant.
if (Arg.onlyReadsMemory()) {
const auto CI = Arg.getAttributes().getCaptureInfo();
if (!capturesAddress(CI) && !capturesFullProvenance(CI))
return true;
}
}

// "grid_constant" counts argument indices starting from 1
if (argHasNVVMAnnotation(Arg, "grid_constant",
/*StartArgIndexAtOne*/ true))
return true;

return false;
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ inline bool isKernelFunction(const Function &F) {
return F.getCallingConv() == CallingConv::PTX_Kernel;
}

bool isParamGridConstant(const Value &);
bool isParamGridConstant(const Argument &);

inline MaybeAlign getAlign(const Function &F, unsigned Index) {
return F.getAttributes().getAttributes(Index).getStackAlignment();
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/NVPTX/bug21465.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ define ptx_kernel void @_Z11TakesStruct1SPi(ptr byval(%struct.S) nocapture reado
entry:
; CHECK-LABEL: @_Z11TakesStruct1SPi
; PTX-LABEL: .visible .entry _Z11TakesStruct1SPi(
; CHECK: addrspacecast ptr %input to ptr addrspace(101)
; CHECK: call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr %input)
%b = getelementptr inbounds %struct.S, ptr %input, i64 0, i32 1
%0 = load i32, ptr %b, align 4
; PTX-NOT: ld.param.u32 {{%r[0-9]+}}, [{{%rd[0-9]+}}]
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/NVPTX/forward-ld-param.ll
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ define void @test_ld_param_byval(ptr byval(i32) %a) {
; CHECK-LABEL: test_ld_param_byval(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_ld_param_byval_param_0];
Expand Down
Loading
Loading