Skip to content

Commit e880f61

Browse files
committed
[NVPTX] Improve byval device parameter lowering
1 parent 9f28621 commit e880f61

13 files changed

+411
-117
lines changed

llvm/lib/Target/NVPTX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ set(NVPTXCodeGen_sources
1616
NVPTXAtomicLower.cpp
1717
NVPTXAsmPrinter.cpp
1818
NVPTXAssignValidGlobalNames.cpp
19+
NVPTXForwardParams.cpp
1920
NVPTXFrameLowering.cpp
2021
NVPTXGenericToNVVM.cpp
2122
NVPTXISelDAGToDAG.cpp

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
5252
bool NoTrapAfterNoreturn);
5353
MachineFunctionPass *createNVPTXPeephole();
5454
MachineFunctionPass *createNVPTXProxyRegErasurePass();
55+
MachineFunctionPass *createNVPTXForwardParamsPass();
5556

5657
struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
5758
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// PTX supports 2 methods of accessing device function parameters:
10+
//
11+
// - "simple" case: If a parameters is only loaded, and all loads can address
12+
// the parameter via a constant offset, then the parameter may be loaded via
13+
// the ".param" address space. This case is not possible if the parameters
14+
// is stored to or has it's address taken. This method is preferable when
15+
// possible. Ex:
16+
//
17+
// ld.param.u32 %r1, [foo_param_1];
18+
// ld.param.u32 %r2, [foo_param_1+4];
19+
//
20+
// - "move param" case: For more complex cases the address of the param may be
21+
// placed in a register via a "mov" instruction. This "mov" also implicitly
22+
// moves the param to the ".local" address space and allows for it to be
23+
// written to. This essentially defers the responsibilty of the byval copy
24+
// to the PTX calling convention.
25+
//
26+
// mov.b64 %rd1, foo_param_0;
27+
// st.local.u32 [%rd1], 42;
28+
// add.u64 %rd3, %rd1, %rd2;
29+
// ld.local.u32 %r2, [%rd3];
30+
//
31+
// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
32+
// parameters will use the "move param" case and the local address space. This
33+
// pass is responsible for switching to the "simple" case when possible, as it
34+
// is more efficient.
35+
//
36+
// We do this by simply traversing uses of the param "mov" instructions an
37+
// trivially checking if they are all loads.
38+
//
39+
//===----------------------------------------------------------------------===//
40+
41+
#include "NVPTX.h"
42+
#include "llvm/ADT/SmallVector.h"
43+
#include "llvm/CodeGen/MachineFunctionPass.h"
44+
#include "llvm/CodeGen/MachineInstr.h"
45+
#include "llvm/CodeGen/MachineOperand.h"
46+
#include "llvm/CodeGen/MachineRegisterInfo.h"
47+
#include "llvm/CodeGen/TargetRegisterInfo.h"
48+
#include "llvm/Support/ErrorHandling.h"
49+
50+
using namespace llvm;
51+
52+
static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
53+
SmallVectorImpl<MachineInstr *> &RemoveList,
54+
SmallVectorImpl<MachineInstr *> &LoadInsts) {
55+
switch (U.getOpcode()) {
56+
case NVPTX::LD_f32:
57+
case NVPTX::LD_f64:
58+
case NVPTX::LD_i16:
59+
case NVPTX::LD_i32:
60+
case NVPTX::LD_i64:
61+
case NVPTX::LD_i8:
62+
case NVPTX::LDV_f32_v2:
63+
case NVPTX::LDV_f32_v4:
64+
case NVPTX::LDV_f64_v2:
65+
case NVPTX::LDV_f64_v4:
66+
case NVPTX::LDV_i16_v2:
67+
case NVPTX::LDV_i16_v4:
68+
case NVPTX::LDV_i32_v2:
69+
case NVPTX::LDV_i32_v4:
70+
case NVPTX::LDV_i64_v2:
71+
case NVPTX::LDV_i64_v4:
72+
case NVPTX::LDV_i8_v2:
73+
case NVPTX::LDV_i8_v4: {
74+
LoadInsts.push_back(&U);
75+
return true;
76+
}
77+
case NVPTX::cvta_local:
78+
case NVPTX::cvta_local_64:
79+
case NVPTX::cvta_to_local:
80+
case NVPTX::cvta_to_local_64: {
81+
for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
82+
if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
83+
return false;
84+
85+
RemoveList.push_back(&U);
86+
return true;
87+
}
88+
default:
89+
return false;
90+
}
91+
}
92+
93+
static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
94+
SmallVectorImpl<MachineInstr *> &RemoveList) {
95+
SmallVector<MachineInstr *, 16> MaybeRemoveList;
96+
SmallVector<MachineInstr *, 16> LoadInsts;
97+
98+
for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
99+
if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
100+
return false;
101+
102+
RemoveList.append(MaybeRemoveList);
103+
RemoveList.push_back(&Mov);
104+
105+
const MachineOperand *ParamSymbol = Mov.uses().begin();
106+
assert(ParamSymbol->isSymbol());
107+
108+
constexpr unsigned LDInstBasePtrOpIdx = 6;
109+
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
110+
for (auto *LI : LoadInsts) {
111+
(LI->uses().begin() + LDInstBasePtrOpIdx)
112+
->ChangeToES(ParamSymbol->getSymbolName());
113+
(LI->uses().begin() + LDInstAddrSpaceOpIdx)
114+
->ChangeToImmediate(NVPTX::AddressSpace::Param);
115+
}
116+
return true;
117+
}
118+
119+
static bool forwardDeviceParams(MachineFunction &MF) {
120+
const auto &MRI = MF.getRegInfo();
121+
122+
bool Changed = false;
123+
SmallVector<MachineInstr *, 16> RemoveList;
124+
for (auto &MI : make_early_inc_range(*MF.begin()))
125+
if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
126+
MI.getOpcode() == NVPTX::MOV64_PARAM)
127+
Changed |= eliminateMove(MI, MRI, RemoveList);
128+
129+
for (auto *MI : RemoveList)
130+
MI->eraseFromParent();
131+
132+
return Changed;
133+
}
134+
135+
/// ----------------------------------------------------------------------------
136+
/// Pass (Manager) Boilerplate
137+
/// ----------------------------------------------------------------------------
138+
139+
namespace llvm {
140+
void initializeNVPTXForwardParamsPassPass(PassRegistry &);
141+
} // namespace llvm
142+
143+
namespace {
144+
struct NVPTXForwardParamsPass : public MachineFunctionPass {
145+
static char ID;
146+
NVPTXForwardParamsPass() : MachineFunctionPass(ID) {
147+
initializeNVPTXForwardParamsPassPass(*PassRegistry::getPassRegistry());
148+
}
149+
150+
bool runOnMachineFunction(MachineFunction &MF) override;
151+
152+
void getAnalysisUsage(AnalysisUsage &AU) const override {
153+
MachineFunctionPass::getAnalysisUsage(AU);
154+
}
155+
};
156+
} // namespace
157+
158+
char NVPTXForwardParamsPass::ID = 0;
159+
160+
INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
161+
"NVPTX Forward Params", false, false)
162+
163+
bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
164+
return forwardDeviceParams(MF);
165+
}
166+
167+
MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
168+
return new NVPTXForwardParamsPass();
169+
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,11 +2197,11 @@ static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
21972197
if (N.getOpcode() == NVPTXISD::Wrapper)
21982198
return N.getOperand(0);
21992199

2200-
// addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol
2200+
// addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
22012201
if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
22022202
if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
22032203
CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
2204-
CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam)
2204+
CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
22052205
return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);
22062206

22072207
if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,10 +3376,18 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33763376
assert(ObjectVT == Ins[InsIdx].VT &&
33773377
"Ins type did not match function type");
33783378
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3379-
SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3380-
if (p.getNode())
3381-
p.getNode()->setIROrder(i + 1);
3382-
InVals.push_back(p);
3379+
3380+
SDValue P;
3381+
if (isKernelFunction(*F)) {
3382+
P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
3383+
P.getNode()->setIROrder(i + 1);
3384+
} else {
3385+
P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3386+
P.getNode()->setIROrder(i + 1);
3387+
P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
3388+
ADDRESS_SPACE_GENERIC);
3389+
}
3390+
InVals.push_back(P);
33833391
}
33843392

33853393
if (!OutChains.empty())

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,7 +2324,7 @@ def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
23242324
def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
23252325
def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
23262326
def SDTCallValProfile : SDTypeProfile<1, 0, []>;
2327-
def SDTMoveParamProfile : SDTypeProfile<1, 1, []>;
2327+
def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
23282328
def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
23292329
def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>;
23302330
def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>;
@@ -2688,29 +2688,14 @@ def DeclareScalarRegInst :
26882688
".reg .b$size param$a;",
26892689
[(DeclareScalarParam (i32 imm:$a), (i32 imm:$size), (i32 1))]>;
26902690

2691-
class MoveParamInst<ValueType T, NVPTXRegClass regclass, string asmstr> :
2692-
NVPTXInst<(outs regclass:$dst), (ins regclass:$src),
2693-
!strconcat("mov", asmstr, " \t$dst, $src;"),
2694-
[(set T:$dst, (MoveParam T:$src))]>;
2695-
26962691
class MoveParamSymbolInst<NVPTXRegClass regclass, Operand srcty, ValueType vt,
26972692
string asmstr> :
26982693
NVPTXInst<(outs regclass:$dst), (ins srcty:$src),
26992694
!strconcat("mov", asmstr, " \t$dst, $src;"),
27002695
[(set vt:$dst, (MoveParam texternalsym:$src))]>;
27012696

2702-
def MoveParamI64 : MoveParamInst<i64, Int64Regs, ".b64">;
2703-
def MoveParamI32 : MoveParamInst<i32, Int32Regs, ".b32">;
2704-
2705-
def MoveParamSymbolI64 : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
2706-
def MoveParamSymbolI32 : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
2707-
2708-
def MoveParamI16 :
2709-
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
2710-
"cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ?
2711-
[(set i16:$dst, (MoveParam i16:$src))]>;
2712-
def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
2713-
def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
2697+
def MOV64_PARAM : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
2698+
def MOV32_PARAM : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
27142699

27152700
class PseudoUseParamInst<NVPTXRegClass regclass, ValueType vt> :
27162701
NVPTXInst<(outs), (ins regclass:$src),

0 commit comments

Comments
 (0)