Skip to content

Commit db6b9af

Browse files
AlexMacleanJaddyen
authored andcommitted
[NVPTX] Add intrinsic support for specialized prmt variants (llvm#140951)
1 parent e04975f commit db6b9af

File tree

6 files changed

+369
-48
lines changed

6 files changed

+369
-48
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,126 @@ all bits set to 0 except for %b bits starting at bit position %a. For the
624624
'``clamp``' variants, the values of %a and %b are clamped to the range [0, 32],
625625
which in practice is equivalent to using them as is.
626626

627+
'``llvm.nvvm.prmt``' Intrinsic
628+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
629+
630+
Syntax:
631+
"""""""
632+
633+
.. code-block:: llvm
634+
635+
declare i32 @llvm.nvvm.prmt(i32 %lo, i32 %hi, i32 %selector)
636+
637+
Overview:
638+
"""""""""
639+
640+
The '``llvm.nvvm.prmt``' constructs a permutation of the bytes of the first two
641+
operands, selecting based on the third operand.
642+
643+
Semantics:
644+
""""""""""
645+
646+
The bytes in the first two source operands are numbered from 0 to 7:
647+
{%hi, %lo} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each byte in the target
648+
register, a 4-bit selection value is defined.
649+
650+
The 3 lsbs of the selection value specify which of the 8 source bytes should be
651+
moved into the target position. The msb defines if the byte value should be
652+
copied, or if the sign (msb of the byte) should be replicated over all 8 bits
653+
of the target position (sign extend of the byte value); msb=0 means copy the
654+
literal value; msb=1 means replicate the sign.
655+
656+
These 4-bit selection values are pulled from the lower 16-bits of the %selector
657+
operand, with the least significant selection value corresponding to the least
658+
significant byte of the destination.
659+
660+
661+
'``llvm.nvvm.prmt.*``' Intrinsics
662+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
663+
664+
Syntax:
665+
"""""""
666+
667+
.. code-block:: llvm
668+
669+
declare i32 @llvm.nvvm.prmt.f4e(i32 %lo, i32 %hi, i32 %selector)
670+
declare i32 @llvm.nvvm.prmt.b4e(i32 %lo, i32 %hi, i32 %selector)
671+
672+
declare i32 @llvm.nvvm.prmt.rc8(i32 %lo, i32 %selector)
673+
declare i32 @llvm.nvvm.prmt.ecl(i32 %lo, i32 %selector)
674+
declare i32 @llvm.nvvm.prmt.ecr(i32 %lo, i32 %selector)
675+
declare i32 @llvm.nvvm.prmt.rc16(i32 %lo, i32 %selector)
676+
677+
Overview:
678+
"""""""""
679+
680+
The '``llvm.nvvm.prmt.*``' family of intrinsics constructs a permutation of the
681+
bytes of the first one or two operands, selecting based on the 2 least
682+
significant bits of the final operand.
683+
684+
Semantics:
685+
""""""""""
686+
687+
As with the generic '``llvm.nvvm.prmt``' intrinsic, the bytes in the first one
688+
or two source operands are numbered. The first source operand (%lo) is numbered
689+
{b3, b2, b1, b0}, in the case of the '``f4e``' and '``b4e``' variants, the
690+
second source operand (%hi) is numbered {b7, b6, b5, b4}.
691+
692+
Depending on the 2 least significant bits of the %selector operand, the result
693+
of the permutation is defined as follows:
694+
695+
+------------+----------------+--------------+
696+
| Mode | %selector[1:0] | Output |
697+
+------------+----------------+--------------+
698+
| '``f4e``' | 0 | {3, 2, 1, 0} |
699+
| +----------------+--------------+
700+
| | 1 | {4, 3, 2, 1} |
701+
| +----------------+--------------+
702+
| | 2 | {5, 4, 3, 2} |
703+
| +----------------+--------------+
704+
| | 3 | {6, 5, 4, 3} |
705+
+------------+----------------+--------------+
706+
| '``b4e``' | 0 | {5, 6, 7, 0} |
707+
| +----------------+--------------+
708+
| | 1 | {6, 7, 0, 1} |
709+
| +----------------+--------------+
710+
| | 2 | {7, 0, 1, 2} |
711+
| +----------------+--------------+
712+
| | 3 | {0, 1, 2, 3} |
713+
+------------+----------------+--------------+
714+
| '``rc8``' | 0 | {0, 0, 0, 0} |
715+
| +----------------+--------------+
716+
| | 1 | {1, 1, 1, 1} |
717+
| +----------------+--------------+
718+
| | 2 | {2, 2, 2, 2} |
719+
| +----------------+--------------+
720+
| | 3 | {3, 3, 3, 3} |
721+
+------------+----------------+--------------+
722+
| '``ecl``' | 0 | {3, 2, 1, 0} |
723+
| +----------------+--------------+
724+
| | 1 | {3, 2, 1, 1} |
725+
| +----------------+--------------+
726+
| | 2 | {3, 2, 2, 2} |
727+
| +----------------+--------------+
728+
| | 3 | {3, 3, 3, 3} |
729+
+------------+----------------+--------------+
730+
| '``ecr``' | 0 | {0, 0, 0, 0} |
731+
| +----------------+--------------+
732+
| | 1 | {1, 1, 1, 0} |
733+
| +----------------+--------------+
734+
| | 2 | {2, 2, 1, 0} |
735+
| +----------------+--------------+
736+
| | 3 | {3, 2, 1, 0} |
737+
+------------+----------------+--------------+
738+
| '``rc16``' | 0 | {1, 0, 1, 0} |
739+
| +----------------+--------------+
740+
| | 1 | {3, 2, 3, 2} |
741+
| +----------------+--------------+
742+
| | 2 | {1, 0, 1, 0} |
743+
| +----------------+--------------+
744+
| | 3 | {3, 2, 3, 2} |
745+
+------------+----------------+--------------+
746+
627747
TMA family of Intrinsics
628748
------------------------
629749

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -739,9 +739,24 @@ class NVVMBuiltin :
739739
}
740740

741741
let TargetPrefix = "nvvm" in {
742-
def int_nvvm_prmt : NVVMBuiltin,
743-
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
744-
[IntrNoMem, IntrSpeculatable]>;
742+
743+
// PRMT - permute
744+
745+
let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
746+
def int_nvvm_prmt : NVVMBuiltin,
747+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty]>;
748+
749+
foreach mode = ["f4e", "b4e"] in
750+
def int_nvvm_prmt_ # mode :
751+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty]>;
752+
753+
// Note: these variants also have 2 source operands but only one will ever
754+
// be used so we eliminate the other operand in the IR (0 is used as the
755+
// placeholder in the backend).
756+
foreach mode = ["rc8", "ecl", "ecr", "rc16"] in
757+
def int_nvvm_prmt_ # mode :
758+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty]>;
759+
}
745760

746761
def int_nvvm_nanosleep : NVVMBuiltin,
747762
DefaultAttrsIntrinsic<[], [llvm_i32_ty],

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,48 @@ def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
238238
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
239239

240240

241+
// This class provides a basic wrapper around an NVPTXInst that abstracts the
242+
// specific syntax of most PTX instructions. It automatically handles the
243+
// construction of the asm string based on the provided dag arguments.
244+
// For example, the following asm-strings would be computed:
245+
//
246+
// * BasicFlagsNVPTXInst<(outs Int32Regs:$dst),
247+
// (ins Int32Regs:$a, Int32Regs:$b), (ins),
248+
// "add.s32">;
249+
// ---> "add.s32 \t$dst, $a, $b;"
250+
//
251+
// * BasicFlagsNVPTXInst<(outs Int32Regs:$d),
252+
// (ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
253+
// (ins PrmtMode:$mode),
254+
// "prmt.b32${mode}">;
255+
// ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
256+
//
257+
class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmstr,
258+
list<dag> pattern = []>
259+
: NVPTXInst<
260+
outs_dag,
261+
!con(ins_dag, flags_dag),
262+
!strconcat(
263+
asmstr,
264+
!if(!and(!empty(ins_dag), !empty(outs_dag)), "",
265+
!strconcat(
266+
" \t",
267+
!interleave(
268+
!foreach(i, !range(!size(outs_dag)),
269+
"$" # !getdagname(outs_dag, i)),
270+
"|"),
271+
!if(!or(!empty(ins_dag), !empty(outs_dag)), "", ", "),
272+
!interleave(
273+
!foreach(i, !range(!size(ins_dag)),
274+
"$" # !getdagname(ins_dag, i)),
275+
", "))),
276+
";"),
277+
pattern>;
278+
279+
class BasicNVPTXInst<dag outs, dag insv, string asmstr, list<dag> pattern = []>
280+
: BasicFlagsNVPTXInst<outs, insv, (ins), asmstr, pattern>;
281+
282+
241283
multiclass I3Inst<string op_str, SDPatternOperator op_node, RegTyInfo t,
242284
bit commutative, list<Predicate> requires = []> {
243285
defvar asmstr = op_str # " \t$dst, $a, $b;";
@@ -1581,24 +1623,6 @@ def Hexu32imm : Operand<i32> {
15811623
let PrintMethod = "printHexu32imm";
15821624
}
15831625

1584-
multiclass PRMT<ValueType T, RegisterClass RC> {
1585-
def rrr
1586-
: NVPTXInst<(outs RC:$d),
1587-
(ins RC:$a, Int32Regs:$b, Int32Regs:$c, PrmtMode:$mode),
1588-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1589-
[(set T:$d, (prmt T:$a, T:$b, i32:$c, imm:$mode))]>;
1590-
def rri
1591-
: NVPTXInst<(outs RC:$d),
1592-
(ins RC:$a, Int32Regs:$b, Hexu32imm:$c, PrmtMode:$mode),
1593-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1594-
[(set T:$d, (prmt T:$a, T:$b, imm:$c, imm:$mode))]>;
1595-
def rii
1596-
: NVPTXInst<(outs RC:$d),
1597-
(ins RC:$a, i32imm:$b, Hexu32imm:$c, PrmtMode:$mode),
1598-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1599-
[(set T:$d, (prmt T:$a, imm:$b, imm:$c, imm:$mode))]>;
1600-
}
1601-
16021626
let hasSideEffects = false in {
16031627
// order is somewhat important here. signed/unsigned variants match
16041628
// the same patterns, so the first one wins. Having unsigned byte extraction
@@ -1612,7 +1636,31 @@ let hasSideEffects = false in {
16121636
defm BFI_B32 : BFI<"bfi.b32", i32, Int32Regs, i32imm>;
16131637
defm BFI_B64 : BFI<"bfi.b64", i64, Int64Regs, i64imm>;
16141638

1615-
defm PRMT_B32 : PRMT<i32, Int32Regs>;
1639+
def PRMT_B32rrr
1640+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1641+
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
1642+
(ins PrmtMode:$mode),
1643+
"prmt.b32$mode",
1644+
[(set i32:$d, (prmt i32:$a, i32:$b, i32:$c, imm:$mode))]>;
1645+
def PRMT_B32rri
1646+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1647+
(ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
1648+
(ins PrmtMode:$mode),
1649+
"prmt.b32$mode",
1650+
[(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
1651+
def PRMT_B32rii
1652+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1653+
(ins Int32Regs:$a, i32imm:$b, Hexu32imm:$c),
1654+
(ins PrmtMode:$mode),
1655+
"prmt.b32$mode",
1656+
[(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
1657+
def PRMT_B32rir
1658+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1659+
(ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
1660+
(ins PrmtMode:$mode),
1661+
"prmt.b32$mode",
1662+
[(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
1663+
16161664
}
16171665

16181666

@@ -3265,25 +3313,25 @@ include "NVPTXIntrinsics.td"
32653313

32663314
def : Pat <
32673315
(i32 (bswap i32:$a)),
3268-
(INT_NVVM_PRMT $a, (i32 0), (i32 0x0123))>;
3316+
(PRMT_B32rii $a, (i32 0), (i32 0x0123), PrmtNONE)>;
32693317

32703318
def : Pat <
32713319
(v2i16 (bswap v2i16:$a)),
3272-
(INT_NVVM_PRMT $a, (i32 0), (i32 0x2301))>;
3320+
(PRMT_B32rii $a, (i32 0), (i32 0x2301), PrmtNONE)>;
32733321

32743322
def : Pat <
32753323
(i64 (bswap i64:$a)),
32763324
(V2I32toI64
3277-
(INT_NVVM_PRMT (I64toI32H_Sink $a), (i32 0), (i32 0x0123)),
3278-
(INT_NVVM_PRMT (I64toI32L_Sink $a), (i32 0), (i32 0x0123)))>,
3325+
(PRMT_B32rii (I64toI32H_Sink $a), (i32 0), (i32 0x0123), PrmtNONE),
3326+
(PRMT_B32rii (I64toI32L_Sink $a), (i32 0), (i32 0x0123), PrmtNONE))>,
32793327
Requires<[hasPTX<71>]>;
32803328

32813329
// Fall back to the old way if we don't have PTX 7.1.
32823330
def : Pat <
32833331
(i64 (bswap i64:$a)),
32843332
(V2I32toI64
3285-
(INT_NVVM_PRMT (I64toI32H $a), (i32 0), (i32 0x0123)),
3286-
(INT_NVVM_PRMT (I64toI32L $a), (i32 0), (i32 0x0123)))>;
3333+
(PRMT_B32rii (I64toI32H $a), (i32 0), (i32 0x0123), PrmtNONE),
3334+
(PRMT_B32rii (I64toI32L $a), (i32 0), (i32 0x0123), PrmtNONE))>;
32873335

32883336

32893337
////////////////////////////////////////////////////////////////////////////////

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,8 +1025,23 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
10251025
// MISC
10261026
//
10271027

1028-
def INT_NVVM_PRMT : F_MATH_3<"prmt.b32 \t$dst, $src0, $src1, $src2;", Int32Regs,
1029-
Int32Regs, Int32Regs, Int32Regs, int_nvvm_prmt>;
1028+
class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
1029+
: Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),
1030+
(PRMT_B32rrr $a, $b, $c, prmt_mode)>;
1031+
1032+
class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
1033+
: Pat<(prmt_intrinsic i32:$a, i32:$c),
1034+
(PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;
1035+
1036+
def : PRMT3Pat<int_nvvm_prmt, PrmtNONE>;
1037+
def : PRMT3Pat<int_nvvm_prmt_f4e, PrmtF4E>;
1038+
def : PRMT3Pat<int_nvvm_prmt_b4e, PrmtB4E>;
1039+
1040+
def : PRMT2Pat<int_nvvm_prmt_rc8, PrmtRC8>;
1041+
def : PRMT2Pat<int_nvvm_prmt_ecl, PrmtECL>;
1042+
def : PRMT2Pat<int_nvvm_prmt_ecr, PrmtECR>;
1043+
def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;
1044+
10301045

10311046
def INT_NVVM_NANOSLEEP_I : NVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32 \t$i;",
10321047
[(int_nvvm_nanosleep imm:$i)]>,

llvm/test/CodeGen/NVPTX/bswap.ll

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ define i32 @bswap32(i32 %a) {
3333
; CHECK-EMPTY:
3434
; CHECK-NEXT: // %bb.0:
3535
; CHECK-NEXT: ld.param.b32 %r1, [bswap32_param_0];
36-
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 291;
36+
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
3737
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
3838
; CHECK-NEXT: ret;
3939
%b = tail call i32 @llvm.bswap.i32(i32 %a)
@@ -48,33 +48,43 @@ define <2 x i16> @bswapv2i16(<2 x i16> %a) #0 {
4848
; CHECK-EMPTY:
4949
; CHECK-NEXT: // %bb.0:
5050
; CHECK-NEXT: ld.param.b32 %r1, [bswapv2i16_param_0];
51-
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 8961;
51+
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x2301U;
5252
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
5353
; CHECK-NEXT: ret;
5454
%b = tail call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %a)
5555
ret <2 x i16> %b
5656
}
5757

5858
define i64 @bswap64(i64 %a) {
59-
; CHECK-LABEL: bswap64(
60-
; CHECK: {
61-
; CHECK-NEXT: .reg .b32 %r<5>;
62-
; CHECK-NEXT: .reg .b64 %rd<3>;
63-
; CHECK-EMPTY:
64-
; CHECK-NEXT: // %bb.0:
65-
; CHECK-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
59+
; PTX70-LABEL: bswap64(
60+
; PTX70: {
61+
; PTX70-NEXT: .reg .b32 %r<5>;
62+
; PTX70-NEXT: .reg .b64 %rd<3>;
63+
; PTX70-EMPTY:
64+
; PTX70-NEXT: // %bb.0:
65+
; PTX70-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
6666
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {%r1, tmp}, %rd1; }
67-
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 291;
67+
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
6868
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd1; }
69-
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 291;
69+
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
7070
; PTX70-NEXT: mov.b64 %rd2, {%r4, %r2};
71-
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
72-
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 291;
73-
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
74-
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 291;
75-
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
76-
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
77-
; CHECK-NEXT: ret;
71+
; PTX70-NEXT: st.param.b64 [func_retval0], %rd2;
72+
; PTX70-NEXT: ret;
73+
;
74+
; PTX71-LABEL: bswap64(
75+
; PTX71: {
76+
; PTX71-NEXT: .reg .b32 %r<5>;
77+
; PTX71-NEXT: .reg .b64 %rd<3>;
78+
; PTX71-EMPTY:
79+
; PTX71-NEXT: // %bb.0:
80+
; PTX71-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
81+
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
82+
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
83+
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
84+
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
85+
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
86+
; PTX71-NEXT: st.param.b64 [func_retval0], %rd2;
87+
; PTX71-NEXT: ret;
7888
%b = tail call i64 @llvm.bswap.i64(i64 %a)
7989
ret i64 %b
8090
}

0 commit comments

Comments
 (0)