Skip to content

Commit a922a23

Browse files
committed
[NVPTX] Improve kernel byval parameter lowering
1 parent d1081f9 commit a922a23

File tree

13 files changed

+307
-330
lines changed

13 files changed

+307
-330
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,6 +1909,21 @@ def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
19091909
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
19101910
"llvm.nvvm.ptr.param.to.gen">;
19111911

1912+
// Represents an explicit hole in the LLVM IR type system. It may be inserted by
1913+
// the compiler in cases where a pointer is of the wrong type. In the backend
1914+
// this intrinsic will be folded away and not equate to any instruction. It
1915+
// should not be used by any frontend and should only be considered well defined
1916+
// when added in the following cases:
1917+
//
1918+
// - NVPTXLowerArgs: When wrapping a byval pointer argument to a kernel
1919+
// function to convert the address space from generic (0) to param (101).
1920+
// This accounts for the fact that the parameter symbols will occupy this
1921+
// space when lowered during ISel.
1922+
//
1923+
def int_nvvm_internal_noop_addrspacecast :
1924+
DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty],
1925+
[IntrNoMem, IntrSpeculatable, NoUndef<ArgIndex<0>>, NoUndef<RetIndex>]>;
1926+
19121927
// Move intrinsics, used in nvvm internally
19131928

19141929
def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,9 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
985985
case ADDRESS_SPACE_LOCAL:
986986
Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local;
987987
break;
988+
case ADDRESS_SPACE_PARAM:
989+
Opc = TM.is64Bit() ? NVPTX::cvta_param_64 : NVPTX::cvta_param;
990+
break;
988991
}
989992
ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src));
990993
return;
@@ -1008,7 +1011,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
10081011
Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local;
10091012
break;
10101013
case ADDRESS_SPACE_PARAM:
1011-
Opc = TM.is64Bit() ? NVPTX::IMOV64r : NVPTX::IMOV32r;
1014+
Opc = TM.is64Bit() ? NVPTX::cvta_to_param_64 : NVPTX::cvta_to_param;
10121015
break;
10131016
}
10141017

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10141014
{MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
10151015
MVT::v32i32, MVT::v64i32, MVT::v128i32},
10161016
Custom);
1017+
1018+
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
10171019
}
10181020

10191021
const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1426,6 +1428,17 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
14261428

14271429
return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
14281430
}
1431+
1432+
// Peel of an addrspacecast to generic and load directly from the specific
1433+
// address space.
1434+
if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1435+
const auto *ASC = cast<AddrSpaceCastSDNode>(Ptr);
1436+
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1437+
Ptr = ASC->getOperand(0);
1438+
return MachinePointerInfo(ASC->getSrcAddressSpace());
1439+
}
1440+
}
1441+
14291442
return MachinePointerInfo();
14301443
}
14311444

@@ -2746,6 +2759,15 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
27462759
return Op;
27472760
}
27482761

2762+
static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
2763+
switch (Op->getConstantOperandVal(0)) {
2764+
default:
2765+
return Op;
2766+
case Intrinsic::nvvm_internal_noop_addrspacecast:
2767+
return Op.getOperand(1);
2768+
}
2769+
}
2770+
27492771
// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
27502772
// Lower these into a node returning the correct type which is zero-extended
27512773
// back to the correct size.
@@ -2889,6 +2911,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28892911
return LowerGlobalAddress(Op, DAG);
28902912
case ISD::INTRINSIC_W_CHAIN:
28912913
return Op;
2914+
case ISD::INTRINSIC_WO_CHAIN:
2915+
return lowerIntrinsicWOChain(Op, DAG);
28922916
case ISD::INTRINSIC_VOID:
28932917
return LowerIntrinsicVoid(Op, DAG);
28942918
case ISD::BUILD_VECTOR:

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2395,18 +2395,10 @@ multiclass G_TO_NG<string Str> {
23952395
"cvta.to." # Str # ".u64 \t$result, $src;", []>;
23962396
}
23972397

2398-
defm cvta_local : NG_TO_G<"local">;
2399-
defm cvta_shared : NG_TO_G<"shared">;
2400-
defm cvta_global : NG_TO_G<"global">;
2401-
defm cvta_const : NG_TO_G<"const">;
2402-
2403-
defm cvta_to_local : G_TO_NG<"local">;
2404-
defm cvta_to_shared : G_TO_NG<"shared">;
2405-
defm cvta_to_global : G_TO_NG<"global">;
2406-
defm cvta_to_const : G_TO_NG<"const">;
2407-
2408-
// nvvm.ptr.param.to.gen
2409-
defm cvta_param : NG_TO_G<"param">;
2398+
foreach space = ["local", "shared", "global", "const", "param"] in {
2399+
defm cvta_#space : NG_TO_G<space>;
2400+
defm cvta_to_#space : G_TO_NG<space>;
2401+
}
24102402

24112403
def : Pat<(int_nvvm_ptr_param_to_gen i32:$src),
24122404
(cvta_param $src)>;

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -265,18 +265,9 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
265265
if (HasCvtaParam) {
266266
auto GetParamAddrCastToGeneric =
267267
[](Value *Addr, Instruction *OriginalUser) -> Value * {
268-
PointerType *ReturnTy =
269-
PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
270-
Function *CvtToGen = Intrinsic::getOrInsertDeclaration(
271-
OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
272-
{ReturnTy, PointerType::get(OriginalUser->getContext(),
273-
ADDRESS_SPACE_PARAM)});
274-
275-
// Cast param address to generic address space
276-
Value *CvtToGenCall =
277-
CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
278-
OriginalUser->getIterator());
279-
return CvtToGenCall;
268+
IRBuilder<> IRB(OriginalUser);
269+
Type *GenTy = IRB.getPtrTy(ADDRESS_SPACE_GENERIC);
270+
return IRB.CreateAddrSpaceCast(Addr, GenTy, Addr->getName() + ".gen");
280271
};
281272
auto *ParamInGenericAS =
282273
GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);
@@ -515,33 +506,34 @@ void copyByValParam(Function &F, Argument &Arg) {
515506
BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
516507
Type *StructType = Arg.getParamByValType();
517508
const DataLayout &DL = F.getDataLayout();
518-
AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
519-
Arg.getName(), FirstInst);
509+
IRBuilder<> IRB(&*FirstInst);
510+
AllocaInst *AllocA = IRB.CreateAlloca(StructType, nullptr, Arg.getName());
520511
// Set the alignment to alignment of the byval parameter. This is because,
521512
// later load/stores assume that alignment, and we are going to replace
522513
// the use of the byval parameter with this alloca instruction.
523-
AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
524-
.value_or(DL.getPrefTypeAlign(StructType)));
514+
AllocA->setAlignment(
515+
Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
525516
Arg.replaceAllUsesWith(AllocA);
526517

527-
Value *ArgInParam = new AddrSpaceCastInst(
528-
&Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
529-
Arg.getName(), FirstInst);
518+
Value *ArgInParam =
519+
IRB.CreateIntrinsic(Intrinsic::nvvm_internal_noop_addrspacecast,
520+
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg.getType()},
521+
&Arg, {}, Arg.getName());
522+
530523
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
531524
// addrspacecast preserves alignment. Since params are constant, this load
532525
// is definitely not volatile.
533526
const auto ArgSize = *AllocA->getAllocationSize(DL);
534-
IRBuilder<> IRB(&*FirstInst);
535527
IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
536528
ArgSize);
537529
}
538530
} // namespace
539531

540532
static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
541533
Function *Func = Arg->getParent();
542-
bool HasCvtaParam =
543-
TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func);
544-
bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
534+
assert(isKernelFunction(*Func));
535+
const bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
536+
const bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
545537
const DataLayout &DL = Func->getDataLayout();
546538
BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
547539
Type *StructType = Arg->getParamByValType();
@@ -558,9 +550,11 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
558550
for (Use &U : Arg->uses())
559551
UsesToUpdate.push_back(&U);
560552

561-
Value *ArgInParamAS = new AddrSpaceCastInst(
562-
Arg, PointerType::get(StructType->getContext(), ADDRESS_SPACE_PARAM),
563-
Arg->getName(), FirstInst);
553+
IRBuilder<> IRB(&*FirstInst);
554+
Value *ArgInParamAS = IRB.CreateIntrinsic(
555+
Intrinsic::nvvm_internal_noop_addrspacecast,
556+
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()}, {Arg});
557+
564558
for (Use *U : UsesToUpdate)
565559
convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
566560
LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
@@ -578,30 +572,31 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
578572
// However, we're still not allowed to write to it. If the user specified
579573
// `__grid_constant__` for the argument, we'll consider escaped pointer as
580574
// read-only.
581-
if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
575+
if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
582576
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
583577
// Replace all argument pointer uses (which might include a device function
584578
// call) with a cast to the generic address space using cvta.param
585579
// instruction, which avoids a local copy.
586580
IRBuilder<> IRB(&Func->getEntryBlock().front());
587581

588-
// Cast argument to param address space
589-
auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
590-
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
582+
// Cast argument to param address space. Because the backend will emit the
583+
// argument already in the param address space, we need to use the noop
584+
// intrinsic, this had the added benefit of preventing other optimizations
585+
// from folding away this pair of addrspacecasts.
586+
auto *ParamSpaceArg =
587+
IRB.CreateIntrinsic(Intrinsic::nvvm_internal_noop_addrspacecast,
588+
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()},
589+
Arg, {}, Arg->getName() + ".param");
591590

592-
// Cast param address to generic address space. We do not use an
593-
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
594-
// generic address space, and a `generic -> param` cast followed by a `param
595-
// -> generic` cast will be folded away. The `param -> generic` intrinsic
596-
// will be correctly lowered to `cvta.param`.
597-
Value *CvtToGenCall = IRB.CreateIntrinsic(
598-
IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
599-
CastToParam, nullptr, CastToParam->getName() + ".gen");
591+
// Cast param address to generic address space.
592+
Value *GenericArg = IRB.CreateAddrSpaceCast(
593+
ParamSpaceArg, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
594+
Arg->getName() + ".gen");
600595

601-
Arg->replaceAllUsesWith(CvtToGenCall);
596+
Arg->replaceAllUsesWith(GenericArg);
602597

603598
// Do not replace Arg in the cast to param space
604-
CastToParam->setOperand(0, Arg);
599+
ParamSpaceArg->setOperand(0, Arg);
605600
} else
606601
copyByValParam(*Func, *Arg);
607602
}
@@ -715,12 +710,14 @@ static bool copyFunctionByValArgs(Function &F) {
715710
LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
716711
<< "\n");
717712
bool Changed = false;
718-
for (Argument &Arg : F.args())
719-
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
720-
!(isParamGridConstant(Arg) && isKernelFunction(F))) {
721-
copyByValParam(F, Arg);
722-
Changed = true;
723-
}
713+
if (isKernelFunction(F)) {
714+
for (Argument &Arg : F.args())
715+
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
716+
!isParamGridConstant(Arg)) {
717+
copyByValParam(F, Arg);
718+
Changed = true;
719+
}
720+
}
724721
return Changed;
725722
}
726723

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
#include "llvm/ADT/ArrayRef.h"
1717
#include "llvm/ADT/SmallVector.h"
1818
#include "llvm/ADT/StringRef.h"
19+
#include "llvm/IR/Argument.h"
1920
#include "llvm/IR/Constants.h"
2021
#include "llvm/IR/Function.h"
2122
#include "llvm/IR/GlobalVariable.h"
2223
#include "llvm/IR/Module.h"
2324
#include "llvm/Support/Alignment.h"
25+
#include "llvm/Support/ModRef.h"
2426
#include "llvm/Support/Mutex.h"
2527
#include <cstdint>
2628
#include <cstring>
@@ -228,17 +230,30 @@ static std::optional<uint64_t> getVectorProduct(ArrayRef<unsigned> V) {
228230
return std::accumulate(V.begin(), V.end(), 1, std::multiplies<uint64_t>{});
229231
}
230232

231-
bool isParamGridConstant(const Value &V) {
232-
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
233-
// "grid_constant" counts argument indices starting from 1
234-
if (Arg->hasByValAttr() &&
235-
argHasNVVMAnnotation(*Arg, "grid_constant",
236-
/*StartArgIndexAtOne*/ true)) {
237-
assert(isKernelFunction(*Arg->getParent()) &&
238-
"only kernel arguments can be grid_constant");
233+
bool isParamGridConstant(const Argument &Arg) {
234+
assert(isKernelFunction(*Arg.getParent()) &&
235+
"only kernel arguments can be grid_constant");
236+
237+
if (!Arg.hasByValAttr())
238+
return false;
239+
240+
// Lowering an argument as a grid_constant violates the byval semantics (and
241+
// the C++ API) by reusing the same memory location for the argument across
242+
// multiple threads. If an argument doesn't read memory and its address is not
243+
// captured (its address is not compared with any value), then the tweak of
244+
// the C++ API and byval semantics is unobservable by the program and we can
245+
// lower the arg as a grid_constant.
246+
if (Arg.onlyReadsMemory()) {
247+
const auto CI = Arg.getAttributes().getCaptureInfo();
248+
if (!capturesAddress(CI) && !capturesFullProvenance(CI))
239249
return true;
240-
}
241250
}
251+
252+
// "grid_constant" counts argument indices starting from 1
253+
if (argHasNVVMAnnotation(Arg, "grid_constant",
254+
/*StartArgIndexAtOne*/ true))
255+
return true;
256+
242257
return false;
243258
}
244259

llvm/lib/Target/NVPTX/NVPTXUtilities.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ inline bool isKernelFunction(const Function &F) {
6363
return F.getCallingConv() == CallingConv::PTX_Kernel;
6464
}
6565

66-
bool isParamGridConstant(const Value &);
66+
bool isParamGridConstant(const Argument &);
6767

6868
inline MaybeAlign getAlign(const Function &F, unsigned Index) {
6969
return F.getAttributes().getAttributes(Index).getStackAlignment();

llvm/test/CodeGen/NVPTX/bug21465.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ define ptx_kernel void @_Z11TakesStruct1SPi(ptr byval(%struct.S) nocapture reado
1212
entry:
1313
; CHECK-LABEL: @_Z11TakesStruct1SPi
1414
; PTX-LABEL: .visible .entry _Z11TakesStruct1SPi(
15-
; CHECK: addrspacecast ptr %input to ptr addrspace(101)
15+
; CHECK: call ptr addrspace(101) @llvm.nvvm.internal.noop.addrspacecast.p101.p0(ptr %input)
1616
%b = getelementptr inbounds %struct.S, ptr %input, i64 0, i32 1
1717
%0 = load i32, ptr %b, align 4
1818
; PTX-NOT: ld.param.u32 {{%r[0-9]+}}, [{{%rd[0-9]+}}]

llvm/test/CodeGen/NVPTX/forward-ld-param.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ define void @test_ld_param_byval(ptr byval(i32) %a) {
6565
; CHECK-LABEL: test_ld_param_byval(
6666
; CHECK: {
6767
; CHECK-NEXT: .reg .b32 %r<2>;
68-
; CHECK-NEXT: .reg .b64 %rd<3>;
68+
; CHECK-NEXT: .reg .b64 %rd<2>;
6969
; CHECK-EMPTY:
7070
; CHECK-NEXT: // %bb.0:
7171
; CHECK-NEXT: ld.param.u32 %r1, [test_ld_param_byval_param_0];

0 commit comments

Comments
 (0)