Skip to content

[NVPTX] Add tcgen05.cp/shift intrinsics #127669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,93 @@ operations.
For more information, refer to the PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-fence>`_.

'``llvm.nvvm.tcgen05.shift``'
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm

declare void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %tmem_addr)
declare void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %tmem_addr)

Overview:
"""""""""

The '``@llvm.nvvm.tcgen05.shift.{cg1/cg2}``' intrinsics correspond to
the ``tcgen05.shift.{cg1/cg2}`` PTX instructions. The ``tcgen05.shift``
is an asynchronous instruction which initiates the shifting of 32-byte
elements downwards across all the rows, except the last, by one row.
The address operand ``%tmem_addr`` specifies the base address of the
matrix in the Tensor Memory whose rows must be down shifted.

For more information, refer to the PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift>`_.

'``llvm.nvvm.tcgen05.cp``'
^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm

declare void @llvm.nvvm.tcgen05.cp.4x256b.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.128x256b.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.128x128b.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)

declare void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)

declare void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)

Overview:
"""""""""

The '``@llvm.nvvm.tcgen05.cp.{shape}.{src_fmt}.{cg1/cg2}``' intrinsics
correspond to the ``tcgen05.cp.*`` family of PTX instructions.
The ``tcgen05.cp`` instruction initiates an asynchronous copy operation from
shared memory to the location specified by ``%tmem_addr`` in Tensor Memory.
The 64-bit register operand ``%sdesc`` is the matrix descriptor representing
the source matrix in shared memory that needs to be copied.

The valid shapes for the copy operation are:
{128x256b, 4x256b, 128x128b, 64x128b_warpx2_02_13, 64x128b_warpx2_01_23, 32x128b_warpx4}.

Shapes ``64x128b`` and ``32x128b`` require dedicated multicast qualifiers,
which are appended to the corresponding intrinsic names.

Optionally, the data can be decompressed from the source format in the shared memory
to the destination format in Tensor Memory during the copy operation. Currently,
only ``.b8x16`` is supported as destination format. The valid source formats are
``.b6x16_p32`` and ``.b4x16_p64``.

When the source format is ``.b6x16_p32``, a contiguous set of 16 elements of 6-bits
each followed by four bytes of padding (``_p32``) in shared memory is decompressed
into 16 elements of 8-bits (``.b8x16``) each in the Tensor Memory.

When the source format is ``.b4x16_p64``, a contiguous set of 16 elements of 4-bits
each followed by eight bytes of padding (``_p64``) in shared memory is decompressed
into 16 elements of 8-bits (``.b8x16``) each in the Tensor Memory.

For more information on the decompression schemes, refer to the PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#optional-decompression>`_.

For more information on the tcgen05.cp instruction, refer to the PTX ISA
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-cp>`_.

Other Intrinsics
----------------
Expand Down
32 changes: 32 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def llvm_tmem_ptr_ty : LLVMQualPointerType<6>; // (tensor memory)ptr
// MISC
//

// Helper class that concatenates list elements with
// a given separator 'sep' and returns the result.
// Handles empty strings.
class StrJoin<string sep, list<string> str_list> {
string ret = !foldl("", str_list, a, b,
!if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
}

// Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
// Geom: m<M>n<N>k<K>. E.g. m8n32k16
// Frag: [a|b|c|d] ([x1|x2|x4] for ldmatrix)
Expand Down Expand Up @@ -5140,6 +5148,11 @@ foreach cta_group = ["cg1", "cg2"] in {
[llvm_shared_ptr_ty, llvm_i16_ty], // mbar_ptr, cta_mask
[IntrConvergent, IntrInaccessibleMemOrArgMemOnly,
NoCapture<ArgIndex<0>>]>;

def int_nvvm_tcgen05_shift_down_ # cta_group : Intrinsic<[],
[llvm_tmem_ptr_ty], // tmem_addr
[IntrConvergent, IntrArgMemOnly,
NoCapture<ArgIndex<0>>]>;
}

// Tcgen05 wait_ld/st intrinsics
Expand All @@ -5154,4 +5167,23 @@ def int_nvvm_tcgen05_fence_before_thread_sync : Intrinsic<[], [],
def int_nvvm_tcgen05_fence_after_thread_sync : Intrinsic<[], [],
[IntrNoMem, IntrHasSideEffects]>;

// Tcgen05 cp intrinsics
foreach cta_group = ["cg1", "cg2"] in {
foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
foreach shape = ["128x256b", "4x256b", "128x128b",
"64x128b_warpx2_02_13",
"64x128b_warpx2_01_23",
"32x128b_warpx4"] in {
defvar intr_suffix = StrJoin<"_", [shape, src_fmt, cta_group]>.ret;
defvar name_suffix = StrJoin<".", [shape, src_fmt, cta_group]>.ret;

def int_nvvm_tcgen05_cp_ # intr_suffix : Intrinsic<[],
[llvm_tmem_ptr_ty, // tmem_addr
llvm_i64_ty], // smem descriptor
[IntrConvergent, IntrInaccessibleMemOrArgMemOnly, NoCapture<ArgIndex<0>>],
"llvm.nvvm.tcgen05.cp." # name_suffix>;
}
}
}

} // let TargetPrefix = "nvvm"
42 changes: 42 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -7704,6 +7704,48 @@ defm TCGEN05_COMMIT_S64_CG2 : TCGEN05_COMMIT_INTR<Int64Regs, "shared", "2">;
defm TCGEN05_COMMIT_S32_CG1 : TCGEN05_COMMIT_INTR<Int32Regs, "shared", "1">;
defm TCGEN05_COMMIT_S32_CG2 : TCGEN05_COMMIT_INTR<Int32Regs, "shared", "2">;

multiclass TCGEN05_SHIFT_INTR<string num, Intrinsic Intr> {
def NAME : NVPTXInst<(outs),
(ins Int32Regs:$tmem_addr),
!strconcat("tcgen05.shift.cta_group::", num, ".down [$tmem_addr];"),
[(Intr Int32Regs:$tmem_addr)]>,
Requires<[hasTcgen05Instructions]>;
}
defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>;
defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>;

multiclass TCGEN05_CP_INTR<string shape, string src_fmt, string mc = ""> {
defvar dst_fmt = !if(!eq(src_fmt, ""), "", ".b8x16");
defvar fmt_asm = StrJoin<".", [dst_fmt, src_fmt]>.ret;
defvar fmt_intr = StrJoin<"_", [src_fmt]>.ret;

defvar shape_mc_asm = StrJoin<".", [shape, mc]>.ret;
defvar shape_mc_intr = !subst("::", "_", !subst(".", "_", shape_mc_asm));

defvar intr_prefix = StrJoin<"_", ["int_nvvm_tcgen05_cp", shape_mc_intr, fmt_intr]>.ret;
defvar IntrCG1 = !cast<Intrinsic>(intr_prefix # "_cg1");
defvar IntrCG2 = !cast<Intrinsic>(intr_prefix # "_cg2");

def NAME # _cg1 : NVPTXInst<(outs),
(ins Int32Regs:$tmem_addr, Int64Regs:$sdesc),
"tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm # " [$tmem_addr], $sdesc;",
[(IntrCG1 Int32Regs:$tmem_addr, Int64Regs:$sdesc)]>,
Requires<[hasTcgen05Instructions]>;
def NAME # _cg2 : NVPTXInst<(outs),
(ins Int32Regs:$tmem_addr, Int64Regs:$sdesc),
"tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm # " [$tmem_addr], $sdesc;",
[(IntrCG2 Int32Regs:$tmem_addr, Int64Regs:$sdesc)]>,
Requires<[hasTcgen05Instructions]>;
}

foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
defm TCGEN05_CP_128x256b # src_fmt : TCGEN05_CP_INTR<"128x256b", src_fmt>;
defm TCGEN05_CP_4x256b # src_fmt : TCGEN05_CP_INTR<"4x256b", src_fmt>;
defm TCGEN05_CP_128x128b # src_fmt : TCGEN05_CP_INTR<"128x128b", src_fmt>;
defm TCGEN05_CP_64x128_1 # src_fmt : TCGEN05_CP_INTR<"64x128b", src_fmt, "warpx2::02_13">;
defm TCGEN05_CP_64x128_2 # src_fmt : TCGEN05_CP_INTR<"64x128b", src_fmt, "warpx2::01_23">;
defm TCGEN05_CP_32x128 # src_fmt : TCGEN05_CP_INTR<"32x128b", src_fmt, "warpx4">;
}
} // isConvergent

let hasSideEffects = 1 in {
Expand Down
Loading