Skip to content

[NVPTX] Add idp2a, idp4a intrinsics #102763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,77 @@ The ``@llvm.nvvm.fence.proxy.tensormap_generic.*`` is a uni-directional fence us

The address operand ``addr`` and the operand ``size`` together specify the memory range ``[addr, addr+size)`` on which the ordering guarantees on the memory accesses across the proxies is to be provided. The only supported value for the ``size`` operand is ``128`` and must be an immediate. Generic Addressing is used unconditionally, and the address specified by the operand addr must fall within the ``.global`` state space. Otherwise, the behavior is undefined. For more information, see `PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`_.

Arithmetic Intrinsics
---------------------

'``llvm.nvvm.idp2a.[us].[us]``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm

declare i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
declare i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
declare i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
declare i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)


Overview:
"""""""""

The '``llvm.nvvm.idp2a.[us].[us]``' intrinsics performs a 2-element vector dot
product followed by addition. They corresponds directly to the ``dp2a`` PTX
instruction.

Semantics:
""""""""""

The 32-bit value in ``%a`` is broken into 2 16-bit values which are extended to
32 bits. For the '``llvm.nvvm.idp2a.u.[us]``' variants zero-extension is used,
while for the '``llvm.nvvm.idp2a.s.[us]``' sign-extension is used. Two bytes are
selected from ``%b``, if ``%is.hi`` is true, the most significant bytes are
selected, otherwise the least significant bytes are selected. These bytes are
then extended to 32-bits. For the '``llvm.nvvm.idp2a.[us].u``' variants
zero-extension is used, while for the '``llvm.nvvm.idp2a.[us].s``'
sign-extension is used. The dot product of these 2-element vectors is added to
``%c`` to produce the return.


'``llvm.nvvm.idp4a.[us].[us]``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm

declare i32 @llvm.nvvm.idp4a.s.s(i32 %a, i32 %b, i32 %c)
declare i32 @llvm.nvvm.idp4a.s.u(i32 %a, i32 %b, i32 %c)
declare i32 @llvm.nvvm.idp4a.u.s(i32 %a, i32 %b, i32 %c)
declare i32 @llvm.nvvm.idp4a.u.u(i32 %a, i32 %b, i32 %c)

Overview:
"""""""""

The '``llvm.nvvm.idp4a.[us].[us]``' intrinsics perform a 4-element vector dot
product followed by addition. They corresponds directly to the ``dp4a`` PTX
instruction.

Semantics:
""""""""""

Each of the 4 bytes in both ``%a`` and ``%b`` are extended to 32-bit integers
forming 2 ``<4 x i32>``. For ``%a``, zero-extension is used in the
'``llvm.nvvm.idp4a.u.[us]``' variants, while sign-extension is used with
'``llvm.nvvm.idp4a.s.[us]``' variants. Similarly, for ``%b``, zero-extension is
used in the '``llvm.nvvm.idp4a.[us].u``' variants, while sign-extension is used
with '``llvm.nvvm.idp4a.[us].s``' variants. The dot product of these 4-element
vectors is added to ``%c`` to produce the return.



Other Intrinsics
----------------

Expand Down
16 changes: 16 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,22 @@ let TargetPrefix = "nvvm" in {
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty],
[IntrNoMem, IntrSpeculatable, Commutative]>;

//
// Dot Product
//
foreach a_type = ["s", "u"] in {
foreach b_type = ["s", "u"] in {
def int_nvvm_idp4a_ # a_type # _ # b_type :
DefaultAttrsIntrinsic<[llvm_i32_ty],
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable]>;
def int_nvvm_idp2a_ # a_type # _ # b_type :
DefaultAttrsIntrinsic<[llvm_i32_ty],
[llvm_i32_ty, llvm_i32_ty, llvm_i1_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<2>>]>;
}
}

//
// Convert
//
Expand Down
28 changes: 28 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;

def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;

def True : Predicate<"true">;
def False : Predicate<"false">;
Expand Down Expand Up @@ -3920,6 +3921,33 @@ let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in
}


foreach a_type = ["s", "u"] in {
foreach b_type = ["s", "u"] in {

def DOT4_ # a_type # b_type :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
"dp4a." # a_type # "32." # b_type # "32 \t$dst, $a, $b, $c;",
[(set Int32Regs:$dst,
(!cast<Intrinsic>("int_nvvm_idp4a_" # a_type # "_" # b_type)
(i32 Int32Regs:$a), (i32 Int32Regs:$b), (i32 Int32Regs:$c)))]>,
Requires<[hasDotInstructions]>;

foreach is_hi = [0, -1] in {
defvar lohi_suffix = !if(is_hi, "hi", "lo");

def DOT2_ # lohi_suffix # _ # a_type # b_type :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
"dp2a." # lohi_suffix # "." # a_type # "32." # b_type # "32 \t$dst, $a, $b, $c;",
[(set Int32Regs:$dst,
(!cast<Intrinsic>("int_nvvm_idp2a_" # a_type # "_" # b_type)
(i32 Int32Regs:$a), (i32 Int32Regs:$b), is_hi, (i32 Int32Regs:$c)))]>,
Requires<[hasDotInstructions]>;
}
}
}

include "NVPTXIntrinsics.td"

//-----------------------------------
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
bool hasMemoryOrdering() const { return SmVersion >= 70 && PTXVersion >= 60; }
// Does SM & PTX support atomic relaxed MMIO operations ?
bool hasRelaxedMMIO() const { return SmVersion >= 70 && PTXVersion >= 82; }
bool hasDotInstructions() const {
return SmVersion >= 61 && PTXVersion >= 50;
}
unsigned int getFullSmVersion() const { return FullSmVersion; }
unsigned int getSmVersion() const { return getFullSmVersion() / 10; }
// GPUs with "a" suffix have include architecture-accelerated features that
Expand Down
222 changes: 222 additions & 0 deletions llvm/test/CodeGen/NVPTX/dot-product.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -march=nvptx -mcpu=sm_61 | FileCheck %s
; RUN: llc < %s -march=nvptx64 -mcpu=sm_61 | FileCheck %s

target triple = "nvptx-nvidia-cuda"

declare i32 @llvm.nvvm.idp4a.s.s(i32, i32, i32)
declare i32 @llvm.nvvm.idp4a.s.u(i32, i32, i32)
declare i32 @llvm.nvvm.idp4a.u.s(i32, i32, i32)
declare i32 @llvm.nvvm.idp4a.u.u(i32, i32, i32)

define i32 @test_dp4a_u32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp4a_u32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_u32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp4a_u32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp4a_u32_u32_param_2];
; CHECK-NEXT: dp4a.u32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.u.u(i32 %a, i32 %b, i32 %c)
ret i32 %call
}

define i32 @test_dp4a_u32imm_u32imm(i32 %c) {
; CHECK-LABEL: test_dp4a_u32imm_u32imm(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_u32imm_u32imm_param_0];
; CHECK-NEXT: mov.b32 %r2, 0;
; CHECK-NEXT: dp4a.u32.u32 %r3, %r2, %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.u.u(i32 0, i32 0, i32 %c)
ret i32 %call
}

define i32 @test_dp4a_u32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp4a_u32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_u32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp4a_u32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp4a_u32_s32_param_2];
; CHECK-NEXT: dp4a.u32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.u.s(i32 %a, i32 %b, i32 %c)
ret i32 %call
}

define i32 @test_dp4a_s32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp4a_s32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_s32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp4a_s32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp4a_s32_u32_param_2];
; CHECK-NEXT: dp4a.s32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.s.u(i32 %a, i32 %b, i32 %c)
ret i32 %call
}

define i32 @test_dp4a_s32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp4a_s32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_s32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp4a_s32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp4a_s32_s32_param_2];
; CHECK-NEXT: dp4a.s32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.s.s(i32 %a, i32 %b, i32 %c)
ret i32 %call
}

declare i32 @llvm.nvvm.idp2a.s.s(i32, i32, i1 immarg, i32)
declare i32 @llvm.nvvm.idp2a.s.u(i32, i32, i1 immarg, i32)
declare i32 @llvm.nvvm.idp2a.u.s(i32, i32, i1 immarg, i32)
declare i32 @llvm.nvvm.idp2a.u.u(i32, i32, i1 immarg, i32)

define i32 @test_dp2a_lo_u32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_lo_u32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_lo_u32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_lo_u32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_lo_u32_u32_param_2];
; CHECK-NEXT: dp2a.lo.u32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 0, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_lo_u32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_lo_u32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_lo_u32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_lo_u32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_lo_u32_s32_param_2];
; CHECK-NEXT: dp2a.lo.u32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 0, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_lo_s32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_lo_s32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_lo_s32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_lo_s32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_lo_s32_u32_param_2];
; CHECK-NEXT: dp2a.lo.s32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 0, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_lo_s32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_lo_s32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_lo_s32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_lo_s32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_lo_s32_s32_param_2];
; CHECK-NEXT: dp2a.lo.s32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 0, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_hi_u32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_hi_u32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_hi_u32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_hi_u32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_hi_u32_u32_param_2];
; CHECK-NEXT: dp2a.hi.u32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 1, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_hi_u32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_hi_u32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_hi_u32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_hi_u32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_hi_u32_s32_param_2];
; CHECK-NEXT: dp2a.hi.u32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 1, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_hi_s32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_hi_s32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_hi_s32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_hi_s32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_hi_s32_u32_param_2];
; CHECK-NEXT: dp2a.hi.s32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 1, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_hi_s32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_hi_s32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_hi_s32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_hi_s32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_hi_s32_s32_param_2];
; CHECK-NEXT: dp2a.hi.s32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 1, i32 %c)
ret i32 %call
}
Loading