Skip to content

Commit 0065343

Browse files
authored
[NVPTX] Improve device function byval parameter lowering (#129188)
PTX supports 2 methods of accessing device function parameters: - "simple" case: If a parameters is only loaded, and all loads can address the parameter via a constant offset, then the parameter may be loaded via the ".param" address space. This case is not possible if the parameters is stored to or has it's address taken. This method is preferable when possible. - "move param" case: For more complex cases the address of the param may be placed in a register via a "mov" instruction. This mov also implicitly moves the param to the ".local" address space and allows for it to be written to. This essentially defers the responsibilty of the byval copy to the PTX calling convention. The handling of these cases in the NVPTX backend for byval pointers has some major issues. We currently attempt to determine if a copy is necessary in NVPTXLowerArgs and either explicitly make an additional copy in the IR, or insert "addrspacecast" to move the param to the param address space. Unfortunately the criteria for determining which case is possible are not correct, leading to miscompilations (https://godbolt.org/z/Gq1fP7a3G). Further, the criteria for the "simple" case aren't enforceable in LLVM IR across other transformations and instruction selection, making deciding between the 2 cases in NVPTXLowerArgs brittle and buggy. This patch aims to fix these issues and improve address space related optimization. In NVPTXLowerArgs, we conservatively assume that all parameters will use the "move param" case and the local address space. Responsibility for switching to the "simple" case is given to a new MachineIR pass, NVPTXForwardParams, which runs once it has become clear whether or not this is possible. This ensures that the correct address space is known for the "move param" case allowing for optimization, while still using the "simple" case where ever possible.
1 parent d2cbd5f commit 0065343

File tree

14 files changed

+609
-198
lines changed

14 files changed

+609
-198
lines changed

llvm/lib/Target/NVPTX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ set(NVPTXCodeGen_sources
1616
NVPTXAtomicLower.cpp
1717
NVPTXAsmPrinter.cpp
1818
NVPTXAssignValidGlobalNames.cpp
19+
NVPTXForwardParams.cpp
1920
NVPTXFrameLowering.cpp
2021
NVPTXGenericToNVVM.cpp
2122
NVPTXISelDAGToDAG.cpp

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
5252
bool NoTrapAfterNoreturn);
5353
MachineFunctionPass *createNVPTXPeephole();
5454
MachineFunctionPass *createNVPTXProxyRegErasurePass();
55+
MachineFunctionPass *createNVPTXForwardParamsPass();
5556

5657
struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
5758
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// PTX supports 2 methods of accessing device function parameters:
10+
//
11+
// - "simple" case: If a parameters is only loaded, and all loads can address
12+
// the parameter via a constant offset, then the parameter may be loaded via
13+
// the ".param" address space. This case is not possible if the parameters
14+
// is stored to or has it's address taken. This method is preferable when
15+
// possible. Ex:
16+
//
17+
// ld.param.u32 %r1, [foo_param_1];
18+
// ld.param.u32 %r2, [foo_param_1+4];
19+
//
20+
// - "move param" case: For more complex cases the address of the param may be
21+
// placed in a register via a "mov" instruction. This "mov" also implicitly
22+
// moves the param to the ".local" address space and allows for it to be
23+
// written to. This essentially defers the responsibilty of the byval copy
24+
// to the PTX calling convention.
25+
//
26+
// mov.b64 %rd1, foo_param_0;
27+
// st.local.u32 [%rd1], 42;
28+
// add.u64 %rd3, %rd1, %rd2;
29+
// ld.local.u32 %r2, [%rd3];
30+
//
31+
// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
32+
// parameters will use the "move param" case and the local address space. This
33+
// pass is responsible for switching to the "simple" case when possible, as it
34+
// is more efficient.
35+
//
36+
// We do this by simply traversing uses of the param "mov" instructions an
37+
// trivially checking if they are all loads.
38+
//
39+
//===----------------------------------------------------------------------===//
40+
41+
#include "NVPTX.h"
42+
#include "llvm/ADT/SmallVector.h"
43+
#include "llvm/CodeGen/MachineFunctionPass.h"
44+
#include "llvm/CodeGen/MachineInstr.h"
45+
#include "llvm/CodeGen/MachineOperand.h"
46+
#include "llvm/CodeGen/MachineRegisterInfo.h"
47+
#include "llvm/CodeGen/TargetRegisterInfo.h"
48+
#include "llvm/Support/ErrorHandling.h"
49+
50+
using namespace llvm;
51+
52+
static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
53+
SmallVectorImpl<MachineInstr *> &RemoveList,
54+
SmallVectorImpl<MachineInstr *> &LoadInsts) {
55+
switch (U.getOpcode()) {
56+
case NVPTX::LD_f32:
57+
case NVPTX::LD_f64:
58+
case NVPTX::LD_i16:
59+
case NVPTX::LD_i32:
60+
case NVPTX::LD_i64:
61+
case NVPTX::LD_i8:
62+
case NVPTX::LDV_f32_v2:
63+
case NVPTX::LDV_f32_v4:
64+
case NVPTX::LDV_f64_v2:
65+
case NVPTX::LDV_f64_v4:
66+
case NVPTX::LDV_i16_v2:
67+
case NVPTX::LDV_i16_v4:
68+
case NVPTX::LDV_i32_v2:
69+
case NVPTX::LDV_i32_v4:
70+
case NVPTX::LDV_i64_v2:
71+
case NVPTX::LDV_i64_v4:
72+
case NVPTX::LDV_i8_v2:
73+
case NVPTX::LDV_i8_v4: {
74+
LoadInsts.push_back(&U);
75+
return true;
76+
}
77+
case NVPTX::cvta_local:
78+
case NVPTX::cvta_local_64:
79+
case NVPTX::cvta_to_local:
80+
case NVPTX::cvta_to_local_64: {
81+
for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
82+
if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
83+
return false;
84+
85+
RemoveList.push_back(&U);
86+
return true;
87+
}
88+
default:
89+
return false;
90+
}
91+
}
92+
93+
static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
94+
SmallVectorImpl<MachineInstr *> &RemoveList) {
95+
SmallVector<MachineInstr *, 16> MaybeRemoveList;
96+
SmallVector<MachineInstr *, 16> LoadInsts;
97+
98+
for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
99+
if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
100+
return false;
101+
102+
RemoveList.append(MaybeRemoveList);
103+
RemoveList.push_back(&Mov);
104+
105+
const MachineOperand *ParamSymbol = Mov.uses().begin();
106+
assert(ParamSymbol->isSymbol());
107+
108+
constexpr unsigned LDInstBasePtrOpIdx = 6;
109+
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
110+
for (auto *LI : LoadInsts) {
111+
(LI->uses().begin() + LDInstBasePtrOpIdx)
112+
->ChangeToES(ParamSymbol->getSymbolName());
113+
(LI->uses().begin() + LDInstAddrSpaceOpIdx)
114+
->ChangeToImmediate(NVPTX::AddressSpace::Param);
115+
}
116+
return true;
117+
}
118+
119+
static bool forwardDeviceParams(MachineFunction &MF) {
120+
const auto &MRI = MF.getRegInfo();
121+
122+
bool Changed = false;
123+
SmallVector<MachineInstr *, 16> RemoveList;
124+
for (auto &MI : make_early_inc_range(*MF.begin()))
125+
if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
126+
MI.getOpcode() == NVPTX::MOV64_PARAM)
127+
Changed |= eliminateMove(MI, MRI, RemoveList);
128+
129+
for (auto *MI : RemoveList)
130+
MI->eraseFromParent();
131+
132+
return Changed;
133+
}
134+
135+
/// ----------------------------------------------------------------------------
136+
/// Pass (Manager) Boilerplate
137+
/// ----------------------------------------------------------------------------
138+
139+
namespace llvm {
140+
void initializeNVPTXForwardParamsPassPass(PassRegistry &);
141+
} // namespace llvm
142+
143+
namespace {
144+
struct NVPTXForwardParamsPass : public MachineFunctionPass {
145+
static char ID;
146+
NVPTXForwardParamsPass() : MachineFunctionPass(ID) {
147+
initializeNVPTXForwardParamsPassPass(*PassRegistry::getPassRegistry());
148+
}
149+
150+
bool runOnMachineFunction(MachineFunction &MF) override;
151+
152+
void getAnalysisUsage(AnalysisUsage &AU) const override {
153+
MachineFunctionPass::getAnalysisUsage(AU);
154+
}
155+
};
156+
} // namespace
157+
158+
char NVPTXForwardParamsPass::ID = 0;
159+
160+
INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
161+
"NVPTX Forward Params", false, false)
162+
163+
bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
164+
return forwardDeviceParams(MF);
165+
}
166+
167+
MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
168+
return new NVPTXForwardParamsPass();
169+
}

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,11 +2197,11 @@ static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
21972197
if (N.getOpcode() == NVPTXISD::Wrapper)
21982198
return N.getOperand(0);
21992199

2200-
// addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol
2200+
// addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
22012201
if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
22022202
if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
22032203
CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
2204-
CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam)
2204+
CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
22052205
return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);
22062206

22072207
if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,10 +3376,18 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33763376
assert(ObjectVT == Ins[InsIdx].VT &&
33773377
"Ins type did not match function type");
33783378
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
3379-
SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3380-
if (p.getNode())
3381-
p.getNode()->setIROrder(i + 1);
3382-
InVals.push_back(p);
3379+
3380+
SDValue P;
3381+
if (isKernelFunction(*F)) {
3382+
P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
3383+
P.getNode()->setIROrder(i + 1);
3384+
} else {
3385+
P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
3386+
P.getNode()->setIROrder(i + 1);
3387+
P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
3388+
ADDRESS_SPACE_GENERIC);
3389+
}
3390+
InVals.push_back(P);
33833391
}
33843392

33853393
if (!OutChains.empty())

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,7 +2324,7 @@ def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
23242324
def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
23252325
def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
23262326
def SDTCallValProfile : SDTypeProfile<1, 0, []>;
2327-
def SDTMoveParamProfile : SDTypeProfile<1, 1, []>;
2327+
def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
23282328
def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
23292329
def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>;
23302330
def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>;
@@ -2688,29 +2688,14 @@ def DeclareScalarRegInst :
26882688
".reg .b$size param$a;",
26892689
[(DeclareScalarParam (i32 imm:$a), (i32 imm:$size), (i32 1))]>;
26902690

2691-
class MoveParamInst<ValueType T, NVPTXRegClass regclass, string asmstr> :
2692-
NVPTXInst<(outs regclass:$dst), (ins regclass:$src),
2693-
!strconcat("mov", asmstr, " \t$dst, $src;"),
2694-
[(set T:$dst, (MoveParam T:$src))]>;
2695-
26962691
class MoveParamSymbolInst<NVPTXRegClass regclass, Operand srcty, ValueType vt,
26972692
string asmstr> :
26982693
NVPTXInst<(outs regclass:$dst), (ins srcty:$src),
26992694
!strconcat("mov", asmstr, " \t$dst, $src;"),
27002695
[(set vt:$dst, (MoveParam texternalsym:$src))]>;
27012696

2702-
def MoveParamI64 : MoveParamInst<i64, Int64Regs, ".b64">;
2703-
def MoveParamI32 : MoveParamInst<i32, Int32Regs, ".b32">;
2704-
2705-
def MoveParamSymbolI64 : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
2706-
def MoveParamSymbolI32 : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
2707-
2708-
def MoveParamI16 :
2709-
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
2710-
"cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ?
2711-
[(set i16:$dst, (MoveParam i16:$src))]>;
2712-
def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
2713-
def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
2697+
def MOV64_PARAM : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
2698+
def MOV32_PARAM : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
27142699

27152700
class PseudoUseParamInst<NVPTXRegClass regclass, ValueType vt> :
27162701
NVPTXInst<(outs), (ins regclass:$src),

0 commit comments

Comments
 (0)