Skip to content

[NVPTX] Add intrinsic support for specialized prmt variants #140951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,126 @@ all bits set to 0 except for %b bits starting at bit position %a. For the
'``clamp``' variants, the values of %a and %b are clamped to the range [0, 32],
which in practice is equivalent to using them as is.

'``llvm.nvvm.prmt``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm

declare i32 @llvm.nvvm.prmt(i32 %lo, i32 %hi, i32 %selector)

Overview:
"""""""""

The '``llvm.nvvm.prmt``' constructs a permutation of the bytes of the first two
operands, selecting based on the third operand.

Semantics:
""""""""""

The bytes in the first two source operands are numbered from 0 to 7:
{%hi, %lo} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each byte in the target
register, a 4-bit selection value is defined.

The 3 lsbs of the selection value specify which of the 8 source bytes should be
moved into the target position. The msb defines if the byte value should be
copied, or if the sign (msb of the byte) should be replicated over all 8 bits
of the target position (sign extend of the byte value); msb=0 means copy the
literal value; msb=1 means replicate the sign.

These 4-bit selection values are pulled from the lower 16-bits of the %selector
operand, with the least significant selection value corresponding to the least
significant byte of the destination.


'``llvm.nvvm.prmt.*``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm

declare i32 @llvm.nvvm.prmt.f4e(i32 %lo, i32 %hi, i32 %selector)
declare i32 @llvm.nvvm.prmt.b4e(i32 %lo, i32 %hi, i32 %selector)

declare i32 @llvm.nvvm.prmt.rc8(i32 %lo, i32 %selector)
declare i32 @llvm.nvvm.prmt.ecl(i32 %lo, i32 %selector)
declare i32 @llvm.nvvm.prmt.ecr(i32 %lo, i32 %selector)
declare i32 @llvm.nvvm.prmt.rc16(i32 %lo, i32 %selector)

Overview:
"""""""""

The '``llvm.nvvm.prmt.*``' family of intrinsics constructs a permutation of the
bytes of the first one or two operands, selecting based on the 2 least
significant bits of the final operand.

Semantics:
""""""""""

As with the generic '``llvm.nvvm.prmt``' intrinsic, the bytes in the first one
or two source operands are numbered. The first source operand (%lo) is numbered
{b3, b2, b1, b0}, in the case of the '``f4e``' and '``b4e``' variants, the
second source operand (%hi) is numbered {b7, b6, b5, b4}.

Depending on the 2 least significant bits of the %selector operand, the result
of the permutation is defined as follows:

+------------+----------------+--------------+
| Mode | %selector[1:0] | Output |
+------------+----------------+--------------+
| '``f4e``' | 0 | {3, 2, 1, 0} |
| +----------------+--------------+
| | 1 | {4, 3, 2, 1} |
| +----------------+--------------+
| | 2 | {5, 4, 3, 2} |
| +----------------+--------------+
| | 3 | {6, 5, 4, 3} |
+------------+----------------+--------------+
| '``b4e``' | 0 | {5, 6, 7, 0} |
| +----------------+--------------+
| | 1 | {6, 7, 0, 1} |
| +----------------+--------------+
| | 2 | {7, 0, 1, 2} |
| +----------------+--------------+
| | 3 | {0, 1, 2, 3} |
+------------+----------------+--------------+
| '``rc8``' | 0 | {0, 0, 0, 0} |
| +----------------+--------------+
| | 1 | {1, 1, 1, 1} |
| +----------------+--------------+
| | 2 | {2, 2, 2, 2} |
| +----------------+--------------+
| | 3 | {3, 3, 3, 3} |
+------------+----------------+--------------+
| '``ecl``' | 0 | {3, 2, 1, 0} |
| +----------------+--------------+
| | 1 | {3, 2, 1, 1} |
| +----------------+--------------+
| | 2 | {3, 2, 2, 2} |
| +----------------+--------------+
| | 3 | {3, 3, 3, 3} |
+------------+----------------+--------------+
| '``ecr``' | 0 | {0, 0, 0, 0} |
| +----------------+--------------+
| | 1 | {1, 1, 1, 0} |
| +----------------+--------------+
| | 2 | {2, 2, 1, 0} |
| +----------------+--------------+
| | 3 | {3, 2, 1, 0} |
+------------+----------------+--------------+
| '``rc16``' | 0 | {1, 0, 1, 0} |
| +----------------+--------------+
| | 1 | {3, 2, 3, 2} |
| +----------------+--------------+
| | 2 | {1, 0, 1, 0} |
| +----------------+--------------+
| | 3 | {3, 2, 3, 2} |
+------------+----------------+--------------+

TMA family of Intrinsics
------------------------

Expand Down
21 changes: 18 additions & 3 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -739,9 +739,24 @@ class NVVMBuiltin :
}

let TargetPrefix = "nvvm" in {
def int_nvvm_prmt : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable]>;

// PRMT - permute

let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
def int_nvvm_prmt : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty]>;

foreach mode = ["f4e", "b4e"] in
def int_nvvm_prmt_ # mode :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty]>;

// Note: these variants also have 2 source operands but only one will ever
// be used so we eliminate the other operand in the IR (0 is used as the
// placeholder in the backend).
foreach mode = ["rc8", "ecl", "ecr", "rc16"] in
def int_nvvm_prmt_ # mode :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty]>;
}

def int_nvvm_nanosleep : NVVMBuiltin,
DefaultAttrsIntrinsic<[], [llvm_i32_ty],
Expand Down
98 changes: 73 additions & 25 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,48 @@ def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;


// This class provides a basic wrapper around an NVPTXInst that abstracts the
// specific syntax of most PTX instructions. It automatically handles the
// construction of the asm string based on the provided dag arguments.
// For example, the following asm-strings would be computed:
//
// * BasicFlagsNVPTXInst<(outs Int32Regs:$dst),
// (ins Int32Regs:$a, Int32Regs:$b), (ins),
// "add.s32">;
// ---> "add.s32 \t$dst, $a, $b;"
//
// * BasicFlagsNVPTXInst<(outs Int32Regs:$d),
// (ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
// (ins PrmtMode:$mode),
// "prmt.b32${mode}">;
// ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
//
class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmstr,
list<dag> pattern = []>
: NVPTXInst<
outs_dag,
!con(ins_dag, flags_dag),
!strconcat(
asmstr,
!if(!and(!empty(ins_dag), !empty(outs_dag)), "",
!strconcat(
" \t",
!interleave(
!foreach(i, !range(!size(outs_dag)),
"$" # !getdagname(outs_dag, i)),
"|"),
!if(!or(!empty(ins_dag), !empty(outs_dag)), "", ", "),
!interleave(
!foreach(i, !range(!size(ins_dag)),
"$" # !getdagname(ins_dag, i)),
", "))),
";"),
pattern>;

class BasicNVPTXInst<dag outs, dag insv, string asmstr, list<dag> pattern = []>
: BasicFlagsNVPTXInst<outs, insv, (ins), asmstr, pattern>;


multiclass I3Inst<string op_str, SDPatternOperator op_node, RegTyInfo t,
bit commutative, list<Predicate> requires = []> {
defvar asmstr = op_str # " \t$dst, $a, $b;";
Expand Down Expand Up @@ -1581,24 +1623,6 @@ def Hexu32imm : Operand<i32> {
let PrintMethod = "printHexu32imm";
}

multiclass PRMT<ValueType T, RegisterClass RC> {
def rrr
: NVPTXInst<(outs RC:$d),
(ins RC:$a, Int32Regs:$b, Int32Regs:$c, PrmtMode:$mode),
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
[(set T:$d, (prmt T:$a, T:$b, i32:$c, imm:$mode))]>;
def rri
: NVPTXInst<(outs RC:$d),
(ins RC:$a, Int32Regs:$b, Hexu32imm:$c, PrmtMode:$mode),
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
[(set T:$d, (prmt T:$a, T:$b, imm:$c, imm:$mode))]>;
def rii
: NVPTXInst<(outs RC:$d),
(ins RC:$a, i32imm:$b, Hexu32imm:$c, PrmtMode:$mode),
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
[(set T:$d, (prmt T:$a, imm:$b, imm:$c, imm:$mode))]>;
}

let hasSideEffects = false in {
// order is somewhat important here. signed/unsigned variants match
// the same patterns, so the first one wins. Having unsigned byte extraction
Expand All @@ -1612,7 +1636,31 @@ let hasSideEffects = false in {
defm BFI_B32 : BFI<"bfi.b32", i32, Int32Regs, i32imm>;
defm BFI_B64 : BFI<"bfi.b64", i64, Int64Regs, i64imm>;

defm PRMT_B32 : PRMT<i32, Int32Regs>;
def PRMT_B32rrr
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, i32:$b, i32:$c, imm:$mode))]>;
def PRMT_B32rri
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
(ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
def PRMT_B32rii
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
(ins Int32Regs:$a, i32imm:$b, Hexu32imm:$c),
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
def PRMT_B32rir
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
(ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;

}


Expand Down Expand Up @@ -3265,25 +3313,25 @@ include "NVPTXIntrinsics.td"

def : Pat <
(i32 (bswap i32:$a)),
(INT_NVVM_PRMT $a, (i32 0), (i32 0x0123))>;
(PRMT_B32rii $a, (i32 0), (i32 0x0123), PrmtNONE)>;

def : Pat <
(v2i16 (bswap v2i16:$a)),
(INT_NVVM_PRMT $a, (i32 0), (i32 0x2301))>;
(PRMT_B32rii $a, (i32 0), (i32 0x2301), PrmtNONE)>;

def : Pat <
(i64 (bswap i64:$a)),
(V2I32toI64
(INT_NVVM_PRMT (I64toI32H_Sink $a), (i32 0), (i32 0x0123)),
(INT_NVVM_PRMT (I64toI32L_Sink $a), (i32 0), (i32 0x0123)))>,
(PRMT_B32rii (I64toI32H_Sink $a), (i32 0), (i32 0x0123), PrmtNONE),
(PRMT_B32rii (I64toI32L_Sink $a), (i32 0), (i32 0x0123), PrmtNONE))>,
Requires<[hasPTX<71>]>;

// Fall back to the old way if we don't have PTX 7.1.
def : Pat <
(i64 (bswap i64:$a)),
(V2I32toI64
(INT_NVVM_PRMT (I64toI32H $a), (i32 0), (i32 0x0123)),
(INT_NVVM_PRMT (I64toI32L $a), (i32 0), (i32 0x0123)))>;
(PRMT_B32rii (I64toI32H $a), (i32 0), (i32 0x0123), PrmtNONE),
(PRMT_B32rii (I64toI32L $a), (i32 0), (i32 0x0123), PrmtNONE))>;


////////////////////////////////////////////////////////////////////////////////
Expand Down
19 changes: 17 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,23 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
// MISC
//

def INT_NVVM_PRMT : F_MATH_3<"prmt.b32 \t$dst, $src0, $src1, $src2;", Int32Regs,
Int32Regs, Int32Regs, Int32Regs, int_nvvm_prmt>;
class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
: Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),
(PRMT_B32rrr $a, $b, $c, prmt_mode)>;

class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
: Pat<(prmt_intrinsic i32:$a, i32:$c),
(PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;

def : PRMT3Pat<int_nvvm_prmt, PrmtNONE>;
def : PRMT3Pat<int_nvvm_prmt_f4e, PrmtF4E>;
def : PRMT3Pat<int_nvvm_prmt_b4e, PrmtB4E>;

def : PRMT2Pat<int_nvvm_prmt_rc8, PrmtRC8>;
def : PRMT2Pat<int_nvvm_prmt_ecl, PrmtECL>;
def : PRMT2Pat<int_nvvm_prmt_ecr, PrmtECR>;
def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;


def INT_NVVM_NANOSLEEP_I : NVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32 \t$i;",
[(int_nvvm_nanosleep imm:$i)]>,
Expand Down
46 changes: 28 additions & 18 deletions llvm/test/CodeGen/NVPTX/bswap.ll
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ define i32 @bswap32(i32 %a) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [bswap32_param_0];
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 291;
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%b = tail call i32 @llvm.bswap.i32(i32 %a)
Expand All @@ -48,33 +48,43 @@ define <2 x i16> @bswapv2i16(<2 x i16> %a) #0 {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [bswapv2i16_param_0];
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 8961;
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x2301U;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%b = tail call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %a)
ret <2 x i16> %b
}

define i64 @bswap64(i64 %a) {
; CHECK-LABEL: bswap64(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
; PTX70-LABEL: bswap64(
; PTX70: {
; PTX70-NEXT: .reg .b32 %r<5>;
; PTX70-NEXT: .reg .b64 %rd<3>;
; PTX70-EMPTY:
; PTX70-NEXT: // %bb.0:
; PTX70-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {%r1, tmp}, %rd1; }
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 291;
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd1; }
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 291;
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
; PTX70-NEXT: mov.b64 %rd2, {%r4, %r2};
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 291;
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 291;
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-NEXT: ret;
; PTX70-NEXT: st.param.b64 [func_retval0], %rd2;
; PTX70-NEXT: ret;
;
; PTX71-LABEL: bswap64(
; PTX71: {
; PTX71-NEXT: .reg .b32 %r<5>;
; PTX71-NEXT: .reg .b64 %rd<3>;
; PTX71-EMPTY:
; PTX71-NEXT: // %bb.0:
; PTX71-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
; PTX71-NEXT: st.param.b64 [func_retval0], %rd2;
; PTX71-NEXT: ret;
%b = tail call i64 @llvm.bswap.i64(i64 %a)
ret i64 %b
}
Expand Down
Loading