Skip to content

[NVPTX][NFCI] Use DataLayout to determine short shared/local/const pointers #89404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void NVPTXFrameLowering::emitPrologue(MachineFunction &MF,
bool Is64Bit =
static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit();
unsigned CvtaLocalOpcode =
(Is64Bit ? NVPTX::cvta_local_yes_64 : NVPTX::cvta_local_yes);
(Is64Bit ? NVPTX::cvta_local_64 : NVPTX::cvta_local);
unsigned MovDepotOpcode =
(Is64Bit ? NVPTX::MOV_DEPOT_ADDR_64 : NVPTX::MOV_DEPOT_ADDR);
if (!MR.use_empty(NRI->getFrameRegister(MF))) {
Expand Down
51 changes: 26 additions & 25 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const {
return TL->allowUnsafeFPMath(*MF);
}

bool NVPTXDAGToDAGISel::useShortPointers() const {
return TM.useShortPointers();
}

/// Select - Select instructions not customized! Used for
/// expanded, promoted and normal instructions.
void NVPTXDAGToDAGISel::Select(SDNode *N) {
Expand Down Expand Up @@ -768,22 +764,25 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
switch (SrcAddrSpace) {
default: report_fatal_error("Bad address space in addrspacecast");
case ADDRESS_SPACE_GLOBAL:
Opc = TM.is64Bit() ? NVPTX::cvta_global_yes_64 : NVPTX::cvta_global_yes;
Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global;
break;
case ADDRESS_SPACE_SHARED:
Opc = TM.is64Bit() ? (useShortPointers() ? NVPTX::cvta_shared_yes_6432
: NVPTX::cvta_shared_yes_64)
: NVPTX::cvta_shared_yes;
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
? NVPTX::cvta_shared_6432
: NVPTX::cvta_shared_64)
: NVPTX::cvta_shared;
Comment on lines +770 to +773
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a common pattern. We could extract it into a helper function

break;
case ADDRESS_SPACE_CONST:
Opc = TM.is64Bit() ? (useShortPointers() ? NVPTX::cvta_const_yes_6432
: NVPTX::cvta_const_yes_64)
: NVPTX::cvta_const_yes;
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
? NVPTX::cvta_const_6432
: NVPTX::cvta_const_64)
: NVPTX::cvta_const;
break;
case ADDRESS_SPACE_LOCAL:
Opc = TM.is64Bit() ? (useShortPointers() ? NVPTX::cvta_local_yes_6432
: NVPTX::cvta_local_yes_64)
: NVPTX::cvta_local_yes;
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
? NVPTX::cvta_local_6432
: NVPTX::cvta_local_64)
: NVPTX::cvta_local;
break;
}
ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
Expand All @@ -797,23 +796,25 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
switch (DstAddrSpace) {
default: report_fatal_error("Bad address space in addrspacecast");
case ADDRESS_SPACE_GLOBAL:
Opc = TM.is64Bit() ? NVPTX::cvta_to_global_yes_64
: NVPTX::cvta_to_global_yes;
Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global;
break;
case ADDRESS_SPACE_SHARED:
Opc = TM.is64Bit() ? (useShortPointers() ? NVPTX::cvta_to_shared_yes_3264
: NVPTX::cvta_to_shared_yes_64)
: NVPTX::cvta_to_shared_yes;
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
? NVPTX::cvta_to_shared_3264
: NVPTX::cvta_to_shared_64)
: NVPTX::cvta_to_shared;
break;
case ADDRESS_SPACE_CONST:
Opc = TM.is64Bit() ? (useShortPointers() ? NVPTX::cvta_to_const_yes_3264
: NVPTX::cvta_to_const_yes_64)
: NVPTX::cvta_to_const_yes;
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
? NVPTX::cvta_to_const_3264
: NVPTX::cvta_to_const_64)
: NVPTX::cvta_to_const;
break;
case ADDRESS_SPACE_LOCAL:
Opc = TM.is64Bit() ? (useShortPointers() ? NVPTX::cvta_to_local_yes_3264
: NVPTX::cvta_to_local_yes_64)
: NVPTX::cvta_to_local_yes;
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
? NVPTX::cvta_to_local_3264
: NVPTX::cvta_to_local_64)
: NVPTX::cvta_to_local;
break;
case ADDRESS_SPACE_PARAM:
Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool useF32FTZ() const;
bool allowFMA() const;
bool allowUnsafeFPMath() const;
bool useShortPointers() const;

public:
static char ID;
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;

def True : Predicate<"true">;
def False : Predicate<"false">;

class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
Expand All @@ -171,7 +172,10 @@ def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
"&& Subtarget->getPTXVersion() >= 64)">;

def useShortPtr : Predicate<"useShortPointers()">;
def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">;
def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">;
def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">;

def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">;

Expand Down
37 changes: 18 additions & 19 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -2407,46 +2407,45 @@ defm INT_PTX_LDG_G_v4f32_ELE
: VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>;


multiclass NG_TO_G<string Str, Intrinsic Intrin> {
def _yes : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src),
multiclass NG_TO_G<string Str, Intrinsic Intrin, Predicate ShortPtr> {
def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src),
!strconcat("cvta.", Str, ".u32 \t$result, $src;"),
[(set Int32Regs:$result, (Intrin Int32Regs:$src))]>;
def _yes_64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src),
def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src),
!strconcat("cvta.", Str, ".u64 \t$result, $src;"),
[(set Int64Regs:$result, (Intrin Int64Regs:$src))]>;
def _yes_6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src),
def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src),
"{{ .reg .b64 %tmp;\n\t"
#" cvt.u64.u32 \t%tmp, $src;\n\t"
#" cvta." # Str # ".u64 \t$result, %tmp; }}",
[(set Int64Regs:$result, (Intrin Int32Regs:$src))]>,
Requires<[useShortPtr]>;
Requires<[ShortPtr]>;
}

multiclass G_TO_NG<string Str, Intrinsic Intrin> {
def _yes : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src),
multiclass G_TO_NG<string Str, Intrinsic Intrin, Predicate ShortPtr> {
def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src),
!strconcat("cvta.to.", Str, ".u32 \t$result, $src;"),
[(set Int32Regs:$result, (Intrin Int32Regs:$src))]>;
def _yes_64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src),
def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src),
!strconcat("cvta.to.", Str, ".u64 \t$result, $src;"),
[(set Int64Regs:$result, (Intrin Int64Regs:$src))]>;
def _yes_3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src),
def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src),
"{{ .reg .b64 %tmp;\n\t"
#" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t"
#" cvt.u32.u64 \t$result, %tmp; }}",
[(set Int32Regs:$result, (Intrin Int64Regs:$src))]>,
Requires<[useShortPtr]>;
Requires<[ShortPtr]>;
}

defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen>;
defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen>;
defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen>;
defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen>;

defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local>;
defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared>;
defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global>;
defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant>;
defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>;
defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;

defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;
defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>;
defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>;

// nvvm.ptr.gen.to.param
def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result),
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXPeephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//
// It will transform the following pattern
// %0 = LEA_ADDRi64 %VRFrame64, 4
// %1 = cvta_to_local_yes_64 %0
// %1 = cvta_to_local_64 %0
//
// into
// %1 = LEA_ADDRi64 %VRFrameLocal64, 4
Expand Down Expand Up @@ -76,8 +76,8 @@ static bool isCVTAToLocalCombinationCandidate(MachineInstr &Root) {
auto &MBB = *Root.getParent();
auto &MF = *MBB.getParent();
// Check current instruction is cvta.to.local
if (Root.getOpcode() != NVPTX::cvta_to_local_yes_64 &&
Root.getOpcode() != NVPTX::cvta_to_local_yes)
if (Root.getOpcode() != NVPTX::cvta_to_local_64 &&
Root.getOpcode() != NVPTX::cvta_to_local)
return false;

auto &Op = Root.getOperand(1);
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ NVPTXTargetMachine::NVPTXTargetMachine(const Target &T, const Triple &TT,
: LLVMTargetMachine(T, computeDataLayout(is64bit, UseShortPointersOpt), TT,
CPU, FS, Options, Reloc::PIC_,
getEffectiveCodeModel(CM, CodeModel::Small), OL),
is64bit(is64bit), UseShortPointers(UseShortPointersOpt),
TLOF(std::make_unique<NVPTXTargetObjectFile>()),
is64bit(is64bit), TLOF(std::make_unique<NVPTXTargetObjectFile>()),
Subtarget(TT, std::string(CPU), std::string(FS), *this),
StrPool(StrAlloc) {
if (TT.getOS() == Triple::NVCL)
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ namespace llvm {
///
class NVPTXTargetMachine : public LLVMTargetMachine {
bool is64bit;
// Use 32-bit pointers for accessing const/local/short AS.
bool UseShortPointers;
std::unique_ptr<TargetLoweringObjectFile> TLOF;
NVPTX::DrvInterface drvInterface;
NVPTXSubtarget Subtarget;
Expand All @@ -46,7 +44,6 @@ class NVPTXTargetMachine : public LLVMTargetMachine {
}
const NVPTXSubtarget *getSubtargetImpl() const { return &Subtarget; }
bool is64Bit() const { return is64bit; }
bool useShortPointers() const { return UseShortPointers; }
NVPTX::DrvInterface getDrvInterface() const { return drvInterface; }
UniqueStringSaver &getStrPool() const {
return const_cast<UniqueStringSaver &>(StrPool);
Expand Down