Skip to content

[NVPTX] Improve device function byval parameter lowering #129188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(NVPTXCodeGen_sources
NVPTXAtomicLower.cpp
NVPTXAsmPrinter.cpp
NVPTXAssignValidGlobalNames.cpp
NVPTXForwardParams.cpp
NVPTXFrameLowering.cpp
NVPTXGenericToNVVM.cpp
NVPTXISelDAGToDAG.cpp
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
bool NoTrapAfterNoreturn);
MachineFunctionPass *createNVPTXPeephole();
MachineFunctionPass *createNVPTXProxyRegErasurePass();
MachineFunctionPass *createNVPTXForwardParamsPass();

struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
Expand Down
169 changes: 169 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// PTX supports 2 methods of accessing device function parameters:
//
// - "simple" case: If a parameters is only loaded, and all loads can address
// the parameter via a constant offset, then the parameter may be loaded via
// the ".param" address space. This case is not possible if the parameters
// is stored to or has it's address taken. This method is preferable when
// possible. Ex:
//
// ld.param.u32 %r1, [foo_param_1];
// ld.param.u32 %r2, [foo_param_1+4];
//
// - "move param" case: For more complex cases the address of the param may be
// placed in a register via a "mov" instruction. This "mov" also implicitly
// moves the param to the ".local" address space and allows for it to be
// written to. This essentially defers the responsibilty of the byval copy
// to the PTX calling convention.
//
// mov.b64 %rd1, foo_param_0;
// st.local.u32 [%rd1], 42;
// add.u64 %rd3, %rd1, %rd2;
// ld.local.u32 %r2, [%rd3];
//
// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
// parameters will use the "move param" case and the local address space. This
// pass is responsible for switching to the "simple" case when possible, as it
// is more efficient.
//
// We do this by simply traversing uses of the param "mov" instructions an
// trivially checking if they are all loads.
//
//===----------------------------------------------------------------------===//

#include "NVPTX.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/Support/ErrorHandling.h"

using namespace llvm;

static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
SmallVectorImpl<MachineInstr *> &RemoveList,
SmallVectorImpl<MachineInstr *> &LoadInsts) {
switch (U.getOpcode()) {
case NVPTX::LD_f32:
case NVPTX::LD_f64:
case NVPTX::LD_i16:
case NVPTX::LD_i32:
case NVPTX::LD_i64:
case NVPTX::LD_i8:
case NVPTX::LDV_f32_v2:
case NVPTX::LDV_f32_v4:
case NVPTX::LDV_f64_v2:
case NVPTX::LDV_f64_v4:
case NVPTX::LDV_i16_v2:
case NVPTX::LDV_i16_v4:
case NVPTX::LDV_i32_v2:
case NVPTX::LDV_i32_v4:
case NVPTX::LDV_i64_v2:
case NVPTX::LDV_i64_v4:
case NVPTX::LDV_i8_v2:
case NVPTX::LDV_i8_v4: {
LoadInsts.push_back(&U);
return true;
}
case NVPTX::cvta_local:
case NVPTX::cvta_local_64:
case NVPTX::cvta_to_local:
case NVPTX::cvta_to_local_64: {
for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
return false;

RemoveList.push_back(&U);
return true;
}
default:
return false;
}
}

static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
SmallVectorImpl<MachineInstr *> &RemoveList) {
SmallVector<MachineInstr *, 16> MaybeRemoveList;
SmallVector<MachineInstr *, 16> LoadInsts;

for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
return false;

RemoveList.append(MaybeRemoveList);
RemoveList.push_back(&Mov);

const MachineOperand *ParamSymbol = Mov.uses().begin();
assert(ParamSymbol->isSymbol());

constexpr unsigned LDInstBasePtrOpIdx = 6;
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
for (auto *LI : LoadInsts) {
(LI->uses().begin() + LDInstBasePtrOpIdx)
->ChangeToES(ParamSymbol->getSymbolName());
(LI->uses().begin() + LDInstAddrSpaceOpIdx)
->ChangeToImmediate(NVPTX::AddressSpace::Param);
}
return true;
}

static bool forwardDeviceParams(MachineFunction &MF) {
const auto &MRI = MF.getRegInfo();

bool Changed = false;
SmallVector<MachineInstr *, 16> RemoveList;
for (auto &MI : make_early_inc_range(*MF.begin()))
if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
MI.getOpcode() == NVPTX::MOV64_PARAM)
Changed |= eliminateMove(MI, MRI, RemoveList);

for (auto *MI : RemoveList)
MI->eraseFromParent();

return Changed;
}

/// ----------------------------------------------------------------------------
/// Pass (Manager) Boilerplate
/// ----------------------------------------------------------------------------

namespace llvm {
void initializeNVPTXForwardParamsPassPass(PassRegistry &);
} // namespace llvm

namespace {
struct NVPTXForwardParamsPass : public MachineFunctionPass {
static char ID;
NVPTXForwardParamsPass() : MachineFunctionPass(ID) {
initializeNVPTXForwardParamsPassPass(*PassRegistry::getPassRegistry());
}

bool runOnMachineFunction(MachineFunction &MF) override;

void getAnalysisUsage(AnalysisUsage &AU) const override {
MachineFunctionPass::getAnalysisUsage(AU);
}
};
} // namespace

char NVPTXForwardParamsPass::ID = 0;

INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
"NVPTX Forward Params", false, false)

bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
return forwardDeviceParams(MF);
}

MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
return new NVPTXForwardParamsPass();
}
4 changes: 2 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2197,11 +2197,11 @@ static SDValue selectBaseADDR(SDValue N, SelectionDAG *DAG) {
if (N.getOpcode() == NVPTXISD::Wrapper)
return N.getOperand(0);

// addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol
// addrspacecast(Wrapper(arg_symbol) to addrspace(PARAM)) -> arg_symbol
if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N))
if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam)
CastN->getOperand(0).getOpcode() == NVPTXISD::Wrapper)
return selectBaseADDR(CastN->getOperand(0).getOperand(0), DAG);

if (auto *FIN = dyn_cast<FrameIndexSDNode>(N))
Expand Down
16 changes: 12 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3376,10 +3376,18 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(ObjectVT == Ins[InsIdx].VT &&
"Ins type did not match function type");
SDValue Arg = getParamSymbol(DAG, i, PtrVT);
SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
if (p.getNode())
p.getNode()->setIROrder(i + 1);
InVals.push_back(p);

SDValue P;
if (isKernelFunction(*F)) {
P = DAG.getNode(NVPTXISD::Wrapper, dl, ObjectVT, Arg);
P.getNode()->setIROrder(i + 1);
} else {
P = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
P.getNode()->setIROrder(i + 1);
P = DAG.getAddrSpaceCast(dl, ObjectVT, P, ADDRESS_SPACE_LOCAL,
ADDRESS_SPACE_GENERIC);
}
InVals.push_back(P);
}

if (!OutChains.empty())
Expand Down
21 changes: 3 additions & 18 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2324,7 +2324,7 @@ def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
def SDTCallValProfile : SDTypeProfile<1, 0, []>;
def SDTMoveParamProfile : SDTypeProfile<1, 1, []>;
def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTStoreRetvalProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>;
def SDTStoreRetvalV2Profile : SDTypeProfile<0, 3, [SDTCisInt<0>]>;
def SDTStoreRetvalV4Profile : SDTypeProfile<0, 5, [SDTCisInt<0>]>;
Expand Down Expand Up @@ -2688,29 +2688,14 @@ def DeclareScalarRegInst :
".reg .b$size param$a;",
[(DeclareScalarParam (i32 imm:$a), (i32 imm:$size), (i32 1))]>;

class MoveParamInst<ValueType T, NVPTXRegClass regclass, string asmstr> :
NVPTXInst<(outs regclass:$dst), (ins regclass:$src),
!strconcat("mov", asmstr, " \t$dst, $src;"),
[(set T:$dst, (MoveParam T:$src))]>;

class MoveParamSymbolInst<NVPTXRegClass regclass, Operand srcty, ValueType vt,
string asmstr> :
NVPTXInst<(outs regclass:$dst), (ins srcty:$src),
!strconcat("mov", asmstr, " \t$dst, $src;"),
[(set vt:$dst, (MoveParam texternalsym:$src))]>;

def MoveParamI64 : MoveParamInst<i64, Int64Regs, ".b64">;
def MoveParamI32 : MoveParamInst<i32, Int32Regs, ".b32">;

def MoveParamSymbolI64 : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
def MoveParamSymbolI32 : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;

def MoveParamI16 :
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
"cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ?
[(set i16:$dst, (MoveParam i16:$src))]>;
def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
def MOV64_PARAM : MoveParamSymbolInst<Int64Regs, i64imm, i64, ".b64">;
def MOV32_PARAM : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;

class PseudoUseParamInst<NVPTXRegClass regclass, ValueType vt> :
NVPTXInst<(outs), (ins regclass:$src),
Expand Down
Loading