Skip to content

[NVVM] Upgrade nvvm.ptr.* intrinics to addrspace cast #109710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 0 additions & 63 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,69 +127,6 @@ Example: 64-bit PTX for CUDA Driver API: ``nvptx64-nvidia-cuda``
NVPTX Intrinsics
================

Address Space Conversion
------------------------

'``llvm.nvvm.ptr.*.to.gen``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

These are overloaded intrinsics. You can use these on any pointer types.

.. code-block:: llvm

declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1))
declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3))
declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4))
declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5))

Overview:
"""""""""

The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic
address space to a generic address space pointer.

Semantics:
""""""""""

These intrinsics modify the pointer value to be a valid generic address space
pointer.


'``llvm.nvvm.ptr.gen.to.*``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

These are overloaded intrinsics. You can use these on any pointer types.

.. code-block:: llvm

declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr)
declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr)
declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr)
declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr)

Overview:
"""""""""

The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic
address space to a pointer in the target address space. Note that these
intrinsics are only useful if the address space of the target address space of
the pointer is known. It is not legal to use address space conversion
intrinsics to convert a pointer from one non-generic address space to another
non-generic address space.

Semantics:
""""""""""

These intrinsics modify the pointer value to be a valid pointer in the target
non-generic address space.


Reading PTX Special Registers
-----------------------------

Expand Down
12 changes: 12 additions & 0 deletions llvm/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ Changes to the LLVM IR
* ``llvm.nvvm.rotate.right.b64``
* ``llvm.nvvm.rotate.b64``

* Remove the following intrinsics which can be replaced with an
``addrspacecast``:

* ``llvm.nvvm.ptr.gen.to.global``
* ``llvm.nvvm.ptr.gen.to.shared``
* ``llvm.nvvm.ptr.gen.to.constant``
* ``llvm.nvvm.ptr.gen.to.local``
* ``llvm.nvvm.ptr.global.to.gen``
* ``llvm.nvvm.ptr.shared.to.gen``
* ``llvm.nvvm.ptr.constant.to.gen``
* ``llvm.nvvm.ptr.local.to.gen``

Changes to LLVM infrastructure
------------------------------

Expand Down
50 changes: 12 additions & 38 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,18 @@
// * llvm.nvvm.max.ui --> select(x ule y, x, y)
// * llvm.nvvm.max.ull --> ibid.
// * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32
// * llvm.nvvm.bitcast.f2i --> bitcast
// * llvm.nvvm.bitcast.i2f --> ibid.
// * llvm.nvvm.bitcast.d2ll --> ibid.
// * llvm.nvvm.bitcast.ll2d --> ibid.
// * llvm.nvvm.bitcast.f2i --> bitcast
// * llvm.nvvm.bitcast.i2f --> ibid.
// * llvm.nvvm.bitcast.d2ll --> ibid.
// * llvm.nvvm.bitcast.ll2d --> ibid.
// * llvm.nvvm.ptr.gen.to.global --> addrspacecast
// * llvm.nvvm.ptr.gen.to.shared --> ibid.
// * llvm.nvvm.ptr.gen.to.constant --> ibid.
// * llvm.nvvm.ptr.gen.to.local --> ibid.
// * llvm.nvvm.ptr.global.to.gen --> ibid.
// * llvm.nvvm.ptr.shared.to.gen --> ibid.
// * llvm.nvvm.ptr.constant.to.gen --> ibid.
// * llvm.nvvm.ptr.local.to.gen --> ibid.

def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr
def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr
Expand Down Expand Up @@ -1602,40 +1610,6 @@ def int_nvvm_ldg_global_p : Intrinsic<[llvm_anyptr_ty],
[IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture<ArgIndex<0>>],
"llvm.nvvm.ldg.global.p">;

// Use for generic pointers
// - These intrinsics are used to convert address spaces.
// - The input pointer and output pointer must have the same type, except for
// the address-space. (This restriction is not enforced here as there is
// currently no way to describe it).
// - This complements the llvm bitcast, which can be used to cast one type
// of pointer to another type of pointer, while the address space remains
// the same.
def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable],
"llvm.nvvm.ptr.local.to.gen">;
def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable],
"llvm.nvvm.ptr.shared.to.gen">;
def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable],
"llvm.nvvm.ptr.global.to.gen">;
def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable],
"llvm.nvvm.ptr.constant.to.gen">;

def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable],
"llvm.nvvm.ptr.gen.to.global">;
def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable],
"llvm.nvvm.ptr.gen.to.shared">;
def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable],
"llvm.nvvm.ptr.gen.to.local">;
def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable],
"llvm.nvvm.ptr.gen.to.constant">;

// Used in nvvm internally to help address space opt and ptx code generation
// This is for params that are passed to kernel functions by pointer by-val.
def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/IR/AutoUpgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,16 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
else if (Name.consume_front("rotate."))
// nvvm.rotate.{b32,b64,right.b64}
Expand = Name == "b32" || Name == "b64" || Name == "right.b64";
else if (Name.consume_front("ptr.gen.to."))
// nvvm.ptr.gen.to.{local,shared,global,constant}
Expand = Name.starts_with("local") || Name.starts_with("shared") ||
Name.starts_with("global") || Name.starts_with("constant");
else if (Name.consume_front("ptr."))
// nvvm.ptr.{local,shared,global,constant}.to.gen
Expand =
(Name.consume_front("local") || Name.consume_front("shared") ||
Name.consume_front("global") || Name.consume_front("constant")) &&
Name.starts_with(".to.gen");
else
Expand = false;

Expand Down Expand Up @@ -2338,6 +2348,15 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty);
Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr,
{Arg, Arg, ZExtShiftAmt});
} else if ((Name.consume_front("ptr.gen.to.") &&
(Name.starts_with("local") || Name.starts_with("shared") ||
Name.starts_with("global") || Name.starts_with("constant"))) ||
(Name.consume_front("ptr.") &&
(Name.consume_front("local") || Name.consume_front("shared") ||
Name.consume_front("global") ||
Name.consume_front("constant")) &&
Name.starts_with(".to.gen"))) {
Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType());
} else {
Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name);
if (IID != Intrinsic::not_intrinsic &&
Expand Down
58 changes: 28 additions & 30 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,38 +1109,38 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
AddrSpaceCastSDNode *CastN = cast<AddrSpaceCastSDNode>(N);
unsigned SrcAddrSpace = CastN->getSrcAddressSpace();
unsigned DstAddrSpace = CastN->getDestAddressSpace();
SDLoc DL(N);
assert(SrcAddrSpace != DstAddrSpace &&
"addrspacecast must be between different address spaces");

if (DstAddrSpace == ADDRESS_SPACE_GENERIC) {
// Specific to generic

if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) {
SDValue CvtNone =
CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32);
SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64,
Src, CvtNone);
Src = SDValue(Cvt, 0);
}

unsigned Opc;
switch (SrcAddrSpace) {
default: report_fatal_error("Bad address space in addrspacecast");
case ADDRESS_SPACE_GLOBAL:
Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global;
break;
case ADDRESS_SPACE_SHARED:
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
? NVPTX::cvta_shared_6432
: NVPTX::cvta_shared_64)
: NVPTX::cvta_shared;
Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared;
break;
case ADDRESS_SPACE_CONST:
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
? NVPTX::cvta_const_6432
: NVPTX::cvta_const_64)
: NVPTX::cvta_const;
Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const;
break;
case ADDRESS_SPACE_LOCAL:
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
? NVPTX::cvta_local_6432
: NVPTX::cvta_local_64)
: NVPTX::cvta_local;
Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local;
break;
}
ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
Src));
ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src));
return;
} else {
// Generic to specific
Expand All @@ -1153,30 +1153,28 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global;
break;
case ADDRESS_SPACE_SHARED:
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
? NVPTX::cvta_to_shared_3264
: NVPTX::cvta_to_shared_64)
: NVPTX::cvta_to_shared;
Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared;
break;
case ADDRESS_SPACE_CONST:
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
? NVPTX::cvta_to_const_3264
: NVPTX::cvta_to_const_64)
: NVPTX::cvta_to_const;
Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const;
break;
case ADDRESS_SPACE_LOCAL:
Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
? NVPTX::cvta_to_local_3264
: NVPTX::cvta_to_local_64)
: NVPTX::cvta_to_local;
Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local;
break;
case ADDRESS_SPACE_PARAM:
Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64
: NVPTX::nvvm_ptr_gen_to_param;
Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr;
break;
}
ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
Src));

SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src);
if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. It's a cleaner approach than having _3264 conversion per AS.

SDValue CvtNone =
CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32);
CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32,
SDValue(CVTA, 0), CvtNone);
}

ReplaceNode(N, CVTA);
return;
}
}
Expand Down
4 changes: 0 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
"&& Subtarget->getPTXVersion() >= 64)">;

def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">;
def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">;
def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">;

def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">;

Expand Down
Loading
Loading