Skip to content

Commit 0c9f01c

Browse files
committed
[NVPTX] Add support for specialized prmt variants
1 parent a0b6cfd commit 0c9f01c

File tree

6 files changed

+326
-48
lines changed

6 files changed

+326
-48
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,126 @@ all bits set to 0 except for %b bits starting at bit position %a. For the
661661
'``clamp``' variants, the values of %a and %b are clamped to the range [0, 32],
662662
which in practice is equivalent to using them as is.
663663

664+
'``llvm.nvvm.prmt``' Intrinsic
665+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
666+
667+
Syntax:
668+
"""""""
669+
670+
.. code-block:: llvm
671+
672+
declare i32 @llvm.nvvm.prmt(i32 %a, i32 %b, i32 %c)
673+
674+
Overview:
675+
"""""""""
676+
677+
The '``llvm.nvvm.prmt``' constructs a permutation of the bytes of the first two
678+
operands, selecting based on the third operand.
679+
680+
Semantics:
681+
""""""""""
682+
683+
The bytes in the first two source operands are numbered from 0 to 7:
684+
{%b, %a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each byte in the target
685+
register, a 4-bit selection value is defined.
686+
687+
The 3 lsbs of the selection value specify which of the 8 source bytes should be
688+
moved into the target position. The msb defines if the byte value should be
689+
copied, or if the sign (msb of the byte) should be replicated over all 8 bits
690+
of the target position (sign extend of the byte value); msb=0 means copy the
691+
literal value; msb=1 means replicate the sign.
692+
693+
These 4-bit selection values are pulled from the lower 16-bits of the third
694+
operand, with the least significant selection value corresponding to the least
695+
significant byte of the destination.
696+
697+
698+
'``llvm.nvvm.prmt.*``' Intrinsics
699+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
700+
701+
Syntax:
702+
"""""""
703+
704+
.. code-block:: llvm
705+
706+
declare i32 @llvm.nvvm.prmt.f4e(i32 %a, i32 %b, i32 %c)
707+
declare i32 @llvm.nvvm.prmt.b4e(i32 %a, i32 %b, i32 %c)
708+
709+
declare i32 @llvm.nvvm.prmt.rc8(i32 %a, i32 %c)
710+
declare i32 @llvm.nvvm.prmt.ecl(i32 %a, i32 %c)
711+
declare i32 @llvm.nvvm.prmt.ecr(i32 %a, i32 %c)
712+
declare i32 @llvm.nvvm.prmt.rc16(i32 %a, i32 %c)
713+
714+
Overview:
715+
"""""""""
716+
717+
The '``llvm.nvvm.prmt.*``' family of intrinsics constructs a permutation of the
718+
bytes of the first one or two operands, selecting based on the 2 least
719+
significant bits of the final operand.
720+
721+
Semantics:
722+
""""""""""
723+
724+
As with the generic '``llvm.nvvm.prmt``' intrinsic, the bytes in the first one
725+
or two source operands are numbered. The first source operand (%a) is numbered
726+
{b3, b2, b1, b0}, in the case of the '``f4e``' and '``b4e``' variants, the
727+
second source operand (%b) is numbered {b7, b6, b5, b4}.
728+
729+
Depending on the 2 least significant bits of the final operand, the result of
730+
the permutation is defined as follows:
731+
732+
+------------+---------+--------------+
733+
| Mode | %c[1:0] | Output |
734+
+------------+---------+--------------+
735+
| '``f4e``' | 0 | {3, 2, 1, 0} |
736+
| +---------+--------------+
737+
| | 1 | {4, 3, 2, 1} |
738+
| +---------+--------------+
739+
| | 2 | {5, 4, 3, 2} |
740+
| +---------+--------------+
741+
| | 3 | {6, 5, 4, 3} |
742+
+------------+---------+--------------+
743+
| '``b4e``' | 0 | {5, 6, 7, 0} |
744+
| +---------+--------------+
745+
| | 1 | {6, 7, 0, 1} |
746+
| +---------+--------------+
747+
| | 2 | {7, 0, 1, 2} |
748+
| +---------+--------------+
749+
| | 3 | {0, 1, 2, 3} |
750+
+------------+---------+--------------+
751+
| '``rc8``' | 0 | {0, 0, 0, 0} |
752+
| +---------+--------------+
753+
| | 1 | {1, 1, 1, 1} |
754+
| +---------+--------------+
755+
| | 2 | {2, 2, 2, 2} |
756+
| +---------+--------------+
757+
| | 3 | {3, 3, 3, 3} |
758+
+------------+---------+--------------+
759+
| '``ecl``' | 0 | {3, 2, 1, 0} |
760+
| +---------+--------------+
761+
| | 1 | {3, 2, 1, 1} |
762+
| +---------+--------------+
763+
| | 2 | {3, 2, 2, 2} |
764+
| +---------+--------------+
765+
| | 3 | {3, 3, 3, 3} |
766+
+------------+---------+--------------+
767+
| '``ecr``' | 0 | {0, 0, 0, 0} |
768+
| +---------+--------------+
769+
| | 1 | {1, 1, 1, 0} |
770+
| +---------+--------------+
771+
| | 2 | {2, 2, 1, 0} |
772+
| +---------+--------------+
773+
| | 3 | {3, 2, 1, 0} |
774+
+------------+---------+--------------+
775+
| '``rc16``' | 0 | {1, 0, 1, 0} |
776+
| +---------+--------------+
777+
| | 1 | {3, 2, 3, 2} |
778+
| +---------+--------------+
779+
| | 2 | {1, 0, 1, 0} |
780+
| +---------+--------------+
781+
| | 3 | {3, 2, 3, 2} |
782+
+------------+---------+--------------+
783+
664784
TMA family of Intrinsics
665785
------------------------
666786

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,23 @@ class NVVMBuiltin :
745745
}
746746

747747
let TargetPrefix = "nvvm" in {
748-
def int_nvvm_prmt : NVVMBuiltin,
749-
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
750-
[IntrNoMem, IntrSpeculatable]>;
748+
749+
// PRMT - permute
750+
751+
let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
752+
def int_nvvm_prmt : NVVMBuiltin,
753+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty]>;
754+
755+
foreach mode = ["f4e", "b4e"] in
756+
def int_nvvm_prmt_ # mode :
757+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty]>;
758+
759+
// Note: these variants also have 2 source operands but only one will ever
760+
// be used so we eliminate the other operand in the IR.
761+
foreach mode = ["rc8", "ecl", "ecr", "rc16"] in
762+
def int_nvvm_prmt_ # mode :
763+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty]>;
764+
}
751765

752766
def int_nvvm_nanosleep : NVVMBuiltin,
753767
DefaultAttrsIntrinsic<[], [llvm_i32_ty],

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,24 +1622,6 @@ def Hexu32imm : Operand<i32> {
16221622
let PrintMethod = "printHexu32imm";
16231623
}
16241624

1625-
multiclass PRMT<ValueType T, RegisterClass RC> {
1626-
def rrr
1627-
: NVPTXInst<(outs RC:$d),
1628-
(ins RC:$a, Int32Regs:$b, Int32Regs:$c, PrmtMode:$mode),
1629-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1630-
[(set T:$d, (prmt T:$a, T:$b, i32:$c, imm:$mode))]>;
1631-
def rri
1632-
: NVPTXInst<(outs RC:$d),
1633-
(ins RC:$a, Int32Regs:$b, Hexu32imm:$c, PrmtMode:$mode),
1634-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1635-
[(set T:$d, (prmt T:$a, T:$b, imm:$c, imm:$mode))]>;
1636-
def rii
1637-
: NVPTXInst<(outs RC:$d),
1638-
(ins RC:$a, i32imm:$b, Hexu32imm:$c, PrmtMode:$mode),
1639-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1640-
[(set T:$d, (prmt T:$a, imm:$b, imm:$c, imm:$mode))]>;
1641-
}
1642-
16431625
let hasSideEffects = false in {
16441626
// order is somewhat important here. signed/unsigned variants match
16451627
// the same patterns, so the first one wins. Having unsigned byte extraction
@@ -1653,7 +1635,31 @@ let hasSideEffects = false in {
16531635
defm BFI_B32 : BFI<"bfi.b32", i32, Int32Regs, i32imm>;
16541636
defm BFI_B64 : BFI<"bfi.b64", i64, Int64Regs, i64imm>;
16551637

1656-
defm PRMT_B32 : PRMT<i32, Int32Regs>;
1638+
def PRMT_B32rrr
1639+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1640+
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
1641+
(ins PrmtMode:$mode),
1642+
"prmt.b32$mode",
1643+
[(set i32:$d, (prmt i32:$a, i32:$b, i32:$c, imm:$mode))]>;
1644+
def PRMT_B32rri
1645+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1646+
(ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
1647+
(ins PrmtMode:$mode),
1648+
"prmt.b32$mode",
1649+
[(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
1650+
def PRMT_B32rii
1651+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1652+
(ins Int32Regs:$a, i32imm:$b, Hexu32imm:$c),
1653+
(ins PrmtMode:$mode),
1654+
"prmt.b32$mode",
1655+
[(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
1656+
def PRMT_B32rir
1657+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1658+
(ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
1659+
(ins PrmtMode:$mode),
1660+
"prmt.b32$mode",
1661+
[(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
1662+
16571663
}
16581664

16591665

@@ -3306,25 +3312,25 @@ include "NVPTXIntrinsics.td"
33063312

33073313
def : Pat <
33083314
(i32 (bswap i32:$a)),
3309-
(INT_NVVM_PRMT $a, (i32 0), (i32 0x0123))>;
3315+
(PRMT_B32rii $a, (i32 0), (i32 0x0123), PrmtNONE)>;
33103316

33113317
def : Pat <
33123318
(v2i16 (bswap v2i16:$a)),
3313-
(INT_NVVM_PRMT $a, (i32 0), (i32 0x2301))>;
3319+
(PRMT_B32rii $a, (i32 0), (i32 0x2301), PrmtNONE)>;
33143320

33153321
def : Pat <
33163322
(i64 (bswap i64:$a)),
33173323
(V2I32toI64
3318-
(INT_NVVM_PRMT (I64toI32H_Sink $a), (i32 0), (i32 0x0123)),
3319-
(INT_NVVM_PRMT (I64toI32L_Sink $a), (i32 0), (i32 0x0123)))>,
3324+
(PRMT_B32rii (I64toI32H_Sink $a), (i32 0), (i32 0x0123), PrmtNONE),
3325+
(PRMT_B32rii (I64toI32L_Sink $a), (i32 0), (i32 0x0123), PrmtNONE))>,
33203326
Requires<[hasPTX<71>]>;
33213327

33223328
// Fall back to the old way if we don't have PTX 7.1.
33233329
def : Pat <
33243330
(i64 (bswap i64:$a)),
33253331
(V2I32toI64
3326-
(INT_NVVM_PRMT (I64toI32H $a), (i32 0), (i32 0x0123)),
3327-
(INT_NVVM_PRMT (I64toI32L $a), (i32 0), (i32 0x0123)))>;
3332+
(PRMT_B32rii (I64toI32H $a), (i32 0), (i32 0x0123), PrmtNONE),
3333+
(PRMT_B32rii (I64toI32L $a), (i32 0), (i32 0x0123), PrmtNONE))>;
33283334

33293335

33303336
////////////////////////////////////////////////////////////////////////////////

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,8 +1028,23 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
10281028
// MISC
10291029
//
10301030

1031-
def INT_NVVM_PRMT : F_MATH_3<"prmt.b32 \t$dst, $src0, $src1, $src2;", Int32Regs,
1032-
Int32Regs, Int32Regs, Int32Regs, int_nvvm_prmt>;
1031+
class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
1032+
: Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),
1033+
(PRMT_B32rrr $a, $b, $c, prmt_mode)>;
1034+
1035+
class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
1036+
: Pat<(prmt_intrinsic i32:$a, i32:$c),
1037+
(PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;
1038+
1039+
def : PRMT3Pat<int_nvvm_prmt, PrmtNONE>;
1040+
def : PRMT3Pat<int_nvvm_prmt_f4e, PrmtF4E>;
1041+
def : PRMT3Pat<int_nvvm_prmt_b4e, PrmtB4E>;
1042+
1043+
def : PRMT2Pat<int_nvvm_prmt_rc8, PrmtRC8>;
1044+
def : PRMT2Pat<int_nvvm_prmt_ecl, PrmtECL>;
1045+
def : PRMT2Pat<int_nvvm_prmt_ecr, PrmtECR>;
1046+
def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;
1047+
10331048

10341049
def INT_NVVM_NANOSLEEP_I : NVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32 \t$i;",
10351050
[(int_nvvm_nanosleep imm:$i)]>,

llvm/test/CodeGen/NVPTX/bswap.ll

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ define i32 @bswap32(i32 %a) {
3333
; CHECK-EMPTY:
3434
; CHECK-NEXT: // %bb.0:
3535
; CHECK-NEXT: ld.param.b32 %r1, [bswap32_param_0];
36-
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 291;
36+
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
3737
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
3838
; CHECK-NEXT: ret;
3939
%b = tail call i32 @llvm.bswap.i32(i32 %a)
@@ -48,33 +48,43 @@ define <2 x i16> @bswapv2i16(<2 x i16> %a) #0 {
4848
; CHECK-EMPTY:
4949
; CHECK-NEXT: // %bb.0:
5050
; CHECK-NEXT: ld.param.b32 %r1, [bswapv2i16_param_0];
51-
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 8961;
51+
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x2301U;
5252
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
5353
; CHECK-NEXT: ret;
5454
%b = tail call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %a)
5555
ret <2 x i16> %b
5656
}
5757

5858
define i64 @bswap64(i64 %a) {
59-
; CHECK-LABEL: bswap64(
60-
; CHECK: {
61-
; CHECK-NEXT: .reg .b32 %r<5>;
62-
; CHECK-NEXT: .reg .b64 %rd<3>;
63-
; CHECK-EMPTY:
64-
; CHECK-NEXT: // %bb.0:
65-
; CHECK-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
59+
; PTX70-LABEL: bswap64(
60+
; PTX70: {
61+
; PTX70-NEXT: .reg .b32 %r<5>;
62+
; PTX70-NEXT: .reg .b64 %rd<3>;
63+
; PTX70-EMPTY:
64+
; PTX70-NEXT: // %bb.0:
65+
; PTX70-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
6666
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {%r1, tmp}, %rd1; }
67-
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 291;
67+
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
6868
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd1; }
69-
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 291;
69+
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
7070
; PTX70-NEXT: mov.b64 %rd2, {%r4, %r2};
71-
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
72-
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 291;
73-
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
74-
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 291;
75-
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
76-
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
77-
; CHECK-NEXT: ret;
71+
; PTX70-NEXT: st.param.b64 [func_retval0], %rd2;
72+
; PTX70-NEXT: ret;
73+
;
74+
; PTX71-LABEL: bswap64(
75+
; PTX71: {
76+
; PTX71-NEXT: .reg .b32 %r<5>;
77+
; PTX71-NEXT: .reg .b64 %rd<3>;
78+
; PTX71-EMPTY:
79+
; PTX71-NEXT: // %bb.0:
80+
; PTX71-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
81+
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
82+
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
83+
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
84+
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
85+
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
86+
; PTX71-NEXT: st.param.b64 [func_retval0], %rd2;
87+
; PTX71-NEXT: ret;
7888
%b = tail call i64 @llvm.bswap.i64(i64 %a)
7989
ret i64 %b
8090
}

0 commit comments

Comments
 (0)